Skip to content

Commit c380d50

Browse files
authored
Implement the fine-tunes API (#130)
- Add FineTune Structs and Requests - Add CRUD Methods
1 parent c46ebb2 commit c380d50

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

fine_tunes.go

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"net/http"
9+
)
10+
11+
type FineTuneRequest struct {
12+
TrainingFile string `json:"training_file"`
13+
ValidationFile string `json:"validation_file,omitempty"`
14+
Model string `json:"model,omitempty"`
15+
Epochs int `json:"n_epochs,omitempty"`
16+
BatchSize int `json:"batch_size,omitempty"`
17+
LearningRateMultiplier float32 `json:"learning_rate_multiplier,omitempty"`
18+
PromptLossRate float32 `json:"prompt_loss_rate,omitempty"`
19+
ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"`
20+
ClassificationClasses int `json:"classification_n_classes,omitempty"`
21+
ClassificationPositiveClass string `json:"classification_positive_class,omitempty"`
22+
ClassificationBetas []float32 `json:"classification_betas,omitempty"`
23+
Suffix string `json:"suffix,omitempty"`
24+
}
25+
26+
type FineTune struct {
27+
ID string `json:"id"`
28+
Object string `json:"object"`
29+
Model string `json:"model"`
30+
CreatedAt int `json:"created_at"`
31+
FineTunedModel string `json:"fine_tuned_model"`
32+
Hyperparams FineTuneHyperParams `json:"hyperparams"`
33+
OrganizationID string `json:"organization_id"`
34+
ResultFiles []File `json:"result_files"`
35+
Status string `json:"status"`
36+
ValidationFiles []File `json:"validation_files"`
37+
TrainingFiles []File `json:"training_files"`
38+
UpdatedAt int `json:"updated_at"`
39+
}
40+
41+
type FineTuneEvent struct {
42+
Object string `json:"object"`
43+
CreatedAt int `json:"created_at"`
44+
Level string `json:"level"`
45+
Message string `json:"message"`
46+
}
47+
48+
type FineTuneHyperParams struct {
49+
BatchSize int `json:"batch_size"`
50+
LearningRateMultiplier float64 `json:"learning_rate_multiplier"`
51+
Epochs int `json:"n_epochs"`
52+
PromptLossWeight float64 `json:"prompt_loss_weight"`
53+
}
54+
55+
type FineTuneList struct {
56+
Object string `json:"object"`
57+
Data []FineTune `json:"data"`
58+
}
59+
type FineTuneEventList struct {
60+
Object string `json:"object"`
61+
Data []FineTuneEvent `json:"data"`
62+
}
63+
64+
type FineTuneDeleteResponse struct {
65+
ID string `json:"id"`
66+
Object string `json:"object"`
67+
Deleted bool `json:"deleted"`
68+
}
69+
70+
func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) {
71+
var reqBytes []byte
72+
reqBytes, err = json.Marshal(request)
73+
if err != nil {
74+
return
75+
}
76+
77+
urlSuffix := "/fine-tunes"
78+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
79+
if err != nil {
80+
return
81+
}
82+
83+
err = c.sendRequest(req, &response)
84+
return
85+
}
86+
87+
// Cancel a fine-tune job.
88+
func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
89+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil)
90+
if err != nil {
91+
return
92+
}
93+
94+
err = c.sendRequest(req, &response)
95+
return
96+
}
97+
98+
func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) {
99+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil)
100+
if err != nil {
101+
return
102+
}
103+
104+
err = c.sendRequest(req, &response)
105+
return
106+
}
107+
108+
func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) {
109+
urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID)
110+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
111+
if err != nil {
112+
return
113+
}
114+
115+
err = c.sendRequest(req, &response)
116+
return
117+
}
118+
119+
func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) {
120+
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil)
121+
if err != nil {
122+
return
123+
}
124+
125+
err = c.sendRequest(req, &response)
126+
return
127+
}
128+
129+
func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) {
130+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil)
131+
if err != nil {
132+
return
133+
}
134+
135+
err = c.sendRequest(req, &response)
136+
return
137+
}

0 commit comments

Comments
 (0)