Skip to content

Commit 99cc170

Browse files
authored
feat: support batches api (#746)
* feat: support batches api * update batch_test.go * fix golangci-lint check * fix golangci-lint check * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix: create batch api * update batch_test.go * feat: add `CreateBatchWithUploadFile` * feat: add `UploadBatchFile` * optimize variable and type naming * expose `BatchLineItem` interface * update batches const
1 parent c69c3bb commit 99cc170

File tree

4 files changed

+655
-0
lines changed

4 files changed

+655
-0
lines changed

batch.go

+275
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"net/http"
10+
"net/url"
11+
)
12+
13+
const batchesSuffix = "/batches"
14+
15+
type BatchEndpoint string
16+
17+
const (
18+
BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions"
19+
BatchEndpointCompletions BatchEndpoint = "/v1/completions"
20+
BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings"
21+
)
22+
23+
type BatchLineItem interface {
24+
MarshalBatchLineItem() []byte
25+
}
26+
27+
type BatchChatCompletionRequest struct {
28+
CustomID string `json:"custom_id"`
29+
Body ChatCompletionRequest `json:"body"`
30+
Method string `json:"method"`
31+
URL BatchEndpoint `json:"url"`
32+
}
33+
34+
func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte {
35+
marshal, _ := json.Marshal(r)
36+
return marshal
37+
}
38+
39+
type BatchCompletionRequest struct {
40+
CustomID string `json:"custom_id"`
41+
Body CompletionRequest `json:"body"`
42+
Method string `json:"method"`
43+
URL BatchEndpoint `json:"url"`
44+
}
45+
46+
func (r BatchCompletionRequest) MarshalBatchLineItem() []byte {
47+
marshal, _ := json.Marshal(r)
48+
return marshal
49+
}
50+
51+
type BatchEmbeddingRequest struct {
52+
CustomID string `json:"custom_id"`
53+
Body EmbeddingRequest `json:"body"`
54+
Method string `json:"method"`
55+
URL BatchEndpoint `json:"url"`
56+
}
57+
58+
func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte {
59+
marshal, _ := json.Marshal(r)
60+
return marshal
61+
}
62+
63+
type Batch struct {
64+
ID string `json:"id"`
65+
Object string `json:"object"`
66+
Endpoint BatchEndpoint `json:"endpoint"`
67+
Errors *struct {
68+
Object string `json:"object,omitempty"`
69+
Data struct {
70+
Code string `json:"code,omitempty"`
71+
Message string `json:"message,omitempty"`
72+
Param *string `json:"param,omitempty"`
73+
Line *int `json:"line,omitempty"`
74+
} `json:"data"`
75+
} `json:"errors"`
76+
InputFileID string `json:"input_file_id"`
77+
CompletionWindow string `json:"completion_window"`
78+
Status string `json:"status"`
79+
OutputFileID *string `json:"output_file_id"`
80+
ErrorFileID *string `json:"error_file_id"`
81+
CreatedAt int `json:"created_at"`
82+
InProgressAt *int `json:"in_progress_at"`
83+
ExpiresAt *int `json:"expires_at"`
84+
FinalizingAt *int `json:"finalizing_at"`
85+
CompletedAt *int `json:"completed_at"`
86+
FailedAt *int `json:"failed_at"`
87+
ExpiredAt *int `json:"expired_at"`
88+
CancellingAt *int `json:"cancelling_at"`
89+
CancelledAt *int `json:"cancelled_at"`
90+
RequestCounts BatchRequestCounts `json:"request_counts"`
91+
Metadata map[string]any `json:"metadata"`
92+
}
93+
94+
type BatchRequestCounts struct {
95+
Total int `json:"total"`
96+
Completed int `json:"completed"`
97+
Failed int `json:"failed"`
98+
}
99+
100+
type CreateBatchRequest struct {
101+
InputFileID string `json:"input_file_id"`
102+
Endpoint BatchEndpoint `json:"endpoint"`
103+
CompletionWindow string `json:"completion_window"`
104+
Metadata map[string]any `json:"metadata"`
105+
}
106+
107+
type BatchResponse struct {
108+
httpHeader
109+
Batch
110+
}
111+
112+
var ErrUploadBatchFileFailed = errors.New("upload batch file failed")
113+
114+
// CreateBatch — API call to Create batch.
115+
func (c *Client) CreateBatch(
116+
ctx context.Context,
117+
request CreateBatchRequest,
118+
) (response BatchResponse, err error) {
119+
if request.CompletionWindow == "" {
120+
request.CompletionWindow = "24h"
121+
}
122+
123+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request))
124+
if err != nil {
125+
return
126+
}
127+
128+
err = c.sendRequest(req, &response)
129+
return
130+
}
131+
132+
type UploadBatchFileRequest struct {
133+
FileName string
134+
Lines []BatchLineItem
135+
}
136+
137+
func (r *UploadBatchFileRequest) MarshalJSONL() []byte {
138+
buff := bytes.Buffer{}
139+
for i, line := range r.Lines {
140+
if i != 0 {
141+
buff.Write([]byte("\n"))
142+
}
143+
buff.Write(line.MarshalBatchLineItem())
144+
}
145+
return buff.Bytes()
146+
}
147+
148+
func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) {
149+
r.Lines = append(r.Lines, BatchChatCompletionRequest{
150+
CustomID: customerID,
151+
Body: body,
152+
Method: "POST",
153+
URL: BatchEndpointChatCompletions,
154+
})
155+
}
156+
157+
func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) {
158+
r.Lines = append(r.Lines, BatchCompletionRequest{
159+
CustomID: customerID,
160+
Body: body,
161+
Method: "POST",
162+
URL: BatchEndpointCompletions,
163+
})
164+
}
165+
166+
func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) {
167+
r.Lines = append(r.Lines, BatchEmbeddingRequest{
168+
CustomID: customerID,
169+
Body: body,
170+
Method: "POST",
171+
URL: BatchEndpointEmbeddings,
172+
})
173+
}
174+
175+
// UploadBatchFile — upload batch file.
176+
func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) {
177+
if request.FileName == "" {
178+
request.FileName = "@batchinput.jsonl"
179+
}
180+
return c.CreateFileBytes(ctx, FileBytesRequest{
181+
Name: request.FileName,
182+
Bytes: request.MarshalJSONL(),
183+
Purpose: PurposeBatch,
184+
})
185+
}
186+
187+
type CreateBatchWithUploadFileRequest struct {
188+
Endpoint BatchEndpoint `json:"endpoint"`
189+
CompletionWindow string `json:"completion_window"`
190+
Metadata map[string]any `json:"metadata"`
191+
UploadBatchFileRequest
192+
}
193+
194+
// CreateBatchWithUploadFile — API call to Create batch with upload file.
195+
func (c *Client) CreateBatchWithUploadFile(
196+
ctx context.Context,
197+
request CreateBatchWithUploadFileRequest,
198+
) (response BatchResponse, err error) {
199+
var file File
200+
file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{
201+
FileName: request.FileName,
202+
Lines: request.Lines,
203+
})
204+
if err != nil {
205+
err = errors.Join(ErrUploadBatchFileFailed, err)
206+
return
207+
}
208+
return c.CreateBatch(ctx, CreateBatchRequest{
209+
InputFileID: file.ID,
210+
Endpoint: request.Endpoint,
211+
CompletionWindow: request.CompletionWindow,
212+
Metadata: request.Metadata,
213+
})
214+
}
215+
216+
// RetrieveBatch — API call to Retrieve batch.
217+
func (c *Client) RetrieveBatch(
218+
ctx context.Context,
219+
batchID string,
220+
) (response BatchResponse, err error) {
221+
urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID)
222+
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
223+
if err != nil {
224+
return
225+
}
226+
err = c.sendRequest(req, &response)
227+
return
228+
}
229+
230+
// CancelBatch — API call to Cancel batch.
231+
func (c *Client) CancelBatch(
232+
ctx context.Context,
233+
batchID string,
234+
) (response BatchResponse, err error) {
235+
urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID)
236+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix))
237+
if err != nil {
238+
return
239+
}
240+
err = c.sendRequest(req, &response)
241+
return
242+
}
243+
244+
type ListBatchResponse struct {
245+
httpHeader
246+
Object string `json:"object"`
247+
Data []Batch `json:"data"`
248+
FirstID string `json:"first_id"`
249+
LastID string `json:"last_id"`
250+
HasMore bool `json:"has_more"`
251+
}
252+
253+
// ListBatch API call to List batch.
254+
func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) {
255+
urlValues := url.Values{}
256+
if limit != nil {
257+
urlValues.Add("limit", fmt.Sprintf("%d", *limit))
258+
}
259+
if after != nil {
260+
urlValues.Add("after", *after)
261+
}
262+
encodedValues := ""
263+
if len(urlValues) > 0 {
264+
encodedValues = "?" + urlValues.Encode()
265+
}
266+
267+
urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues)
268+
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
269+
if err != nil {
270+
return
271+
}
272+
273+
err = c.sendRequest(req, &response)
274+
return
275+
}

0 commit comments

Comments
 (0)