Skip to content

Commit 01264ba

Browse files
committed
stream option
1 parent fbb6a7f commit 01264ba

5 files changed

Lines changed: 389 additions & 22 deletions

File tree

chat.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ func (c *Client) openaiChatCompletionStream(ctx context.Context, req *ChatReques
100100
r := req.clone()
101101
r.Stream = true
102102

103+
// Request usage data in the final stream chunk if not already set.
104+
if r.StreamOptions == nil {
105+
r.StreamOptions = &StreamOptions{IncludeUsage: true}
106+
}
107+
103108
c.applyDefaultModel(&r)
104109

105110
resp, err := c.doRequest(ctx, &r)

openai_chat_test.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,159 @@ func TestChatCompletionStreamAPIError(t *testing.T) {
291291
}
292292
}
293293

294+
func TestChatCompletionStreamAutoInjectsStreamOptions(t *testing.T) {
295+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
296+
var req ChatRequest
297+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
298+
t.Errorf("decode request: %v", err)
299+
}
300+
301+
if req.StreamOptions == nil {
302+
t.Error("StreamOptions = nil, want non-nil")
303+
} else if !req.StreamOptions.IncludeUsage {
304+
t.Error("StreamOptions.IncludeUsage = false, want true")
305+
}
306+
307+
w.Header().Set("Content-Type", "text/event-stream")
308+
flusher, _ := w.(http.Flusher)
309+
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
310+
flusher.Flush()
311+
}))
312+
defer srv.Close()
313+
314+
c, err := NewClient(WithAPIKey("sk-test"), WithBaseURL(srv.URL))
315+
if err != nil {
316+
t.Fatalf("NewClient: %v", err)
317+
}
318+
319+
// Caller does NOT set StreamOptions.
320+
stream, err := c.ChatCompletionStream(context.Background(), &ChatRequest{
321+
Model: ModelOpenaiGPT4o,
322+
Messages: []Message{{Role: RoleUser, Content: NewTextContent("Hi")}},
323+
})
324+
if err != nil {
325+
t.Fatalf("ChatCompletionStream: %v", err)
326+
}
327+
defer func() { _ = stream.Close() }()
328+
329+
// Drain the stream.
330+
for {
331+
_, err := stream.Recv()
332+
if errors.Is(err, io.EOF) {
333+
break
334+
}
335+
if err != nil {
336+
t.Fatalf("Recv: %v", err)
337+
}
338+
}
339+
}
340+
341+
func TestChatCompletionStreamPreservesExplicitStreamOptions(t *testing.T) {
342+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
343+
var req ChatRequest
344+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
345+
t.Errorf("decode request: %v", err)
346+
}
347+
348+
if req.StreamOptions == nil {
349+
t.Error("StreamOptions = nil, want non-nil")
350+
} else if req.StreamOptions.IncludeUsage {
351+
t.Error("StreamOptions.IncludeUsage = true, want false (caller explicitly set false)")
352+
}
353+
354+
w.Header().Set("Content-Type", "text/event-stream")
355+
flusher, _ := w.(http.Flusher)
356+
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
357+
flusher.Flush()
358+
}))
359+
defer srv.Close()
360+
361+
c, err := NewClient(WithAPIKey("sk-test"), WithBaseURL(srv.URL))
362+
if err != nil {
363+
t.Fatalf("NewClient: %v", err)
364+
}
365+
366+
// Caller explicitly sets IncludeUsage: false.
367+
stream, err := c.ChatCompletionStream(context.Background(), &ChatRequest{
368+
Model: ModelOpenaiGPT4o,
369+
Messages: []Message{{Role: RoleUser, Content: NewTextContent("Hi")}},
370+
StreamOptions: &StreamOptions{IncludeUsage: false},
371+
})
372+
if err != nil {
373+
t.Fatalf("ChatCompletionStream: %v", err)
374+
}
375+
defer func() { _ = stream.Close() }()
376+
377+
for {
378+
_, err := stream.Recv()
379+
if errors.Is(err, io.EOF) {
380+
break
381+
}
382+
if err != nil {
383+
t.Fatalf("Recv: %v", err)
384+
}
385+
}
386+
}
387+
388+
func TestChatCompletionStreamUsage(t *testing.T) {
389+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
390+
w.Header().Set("Content-Type", "text/event-stream")
391+
flusher, _ := w.(http.Flusher)
392+
393+
chunks := []string{
394+
`{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}`,
395+
`{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":8,"completion_tokens":16,"total_tokens":24}}`,
396+
}
397+
398+
for _, chunk := range chunks {
399+
_, _ = fmt.Fprintf(w, "data: %s\n\n", chunk)
400+
flusher.Flush()
401+
}
402+
403+
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
404+
flusher.Flush()
405+
}))
406+
defer srv.Close()
407+
408+
c, err := NewClient(WithAPIKey("sk-test"), WithBaseURL(srv.URL))
409+
if err != nil {
410+
t.Fatalf("NewClient: %v", err)
411+
}
412+
413+
stream, err := c.ChatCompletionStream(context.Background(), &ChatRequest{
414+
Model: ModelOpenaiGPT4o,
415+
Messages: []Message{{Role: RoleUser, Content: NewTextContent("Hi")}},
416+
})
417+
if err != nil {
418+
t.Fatalf("ChatCompletionStream: %v", err)
419+
}
420+
defer func() { _ = stream.Close() }()
421+
422+
for {
423+
_, err := stream.Recv()
424+
if errors.Is(err, io.EOF) {
425+
break
426+
}
427+
if err != nil {
428+
t.Fatalf("Recv: %v", err)
429+
}
430+
}
431+
432+
usage := stream.Usage()
433+
if usage == nil {
434+
t.Fatal("Usage() = nil, want non-nil")
435+
}
436+
if usage.PromptTokens != 8 {
437+
t.Errorf("prompt_tokens = %d, want 8", usage.PromptTokens)
438+
}
439+
if usage.CompletionTokens != 16 {
440+
t.Errorf("completion_tokens = %d, want 16", usage.CompletionTokens)
441+
}
442+
if usage.TotalTokens != 24 {
443+
t.Errorf("total_tokens = %d, want 24", usage.TotalTokens)
444+
}
445+
}
446+
294447
func TestChatCompletionWithTools(t *testing.T) {
295448
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
296449
var req ChatRequest

openai_stream_test.go

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,169 @@ func TestStreamAPIError(t *testing.T) {
174174
t.Errorf("code = %q", apiErr.Code)
175175
}
176176
}
177+
178+
func TestStreamUsageFromFinalChunk(t *testing.T) {
179+
body := ""
180+
body += "data: " + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}` + "\n\n"
181+
body += "data: " + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":20,"total_tokens":30}}` + "\n\n"
182+
body += "data: [DONE]\n\n"
183+
184+
s := newStream(io.NopCloser(strings.NewReader(body)))
185+
186+
for {
187+
_, err := s.Recv()
188+
if errors.Is(err, io.EOF) {
189+
break
190+
}
191+
if err != nil {
192+
t.Fatalf("Recv: %v", err)
193+
}
194+
}
195+
196+
usage := s.Usage()
197+
if usage == nil {
198+
t.Fatal("Usage() = nil, want non-nil")
199+
}
200+
if usage.PromptTokens != 10 {
201+
t.Errorf("prompt_tokens = %d, want 10", usage.PromptTokens)
202+
}
203+
if usage.CompletionTokens != 20 {
204+
t.Errorf("completion_tokens = %d, want 20", usage.CompletionTokens)
205+
}
206+
if usage.TotalTokens != 30 {
207+
t.Errorf("total_tokens = %d, want 30", usage.TotalTokens)
208+
}
209+
}
210+
211+
func TestStreamUsageNilWhenNoUsageChunk(t *testing.T) {
212+
body := ""
213+
body += "data: " + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}` + "\n\n"
214+
body += "data: " + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}` + "\n\n"
215+
body += "data: [DONE]\n\n"
216+
217+
s := newStream(io.NopCloser(strings.NewReader(body)))
218+
219+
for {
220+
_, err := s.Recv()
221+
if errors.Is(err, io.EOF) {
222+
break
223+
}
224+
if err != nil {
225+
t.Fatalf("Recv: %v", err)
226+
}
227+
}
228+
229+
if usage := s.Usage(); usage != nil {
230+
t.Errorf("Usage() = %+v, want nil", usage)
231+
}
232+
}
233+
234+
func TestStreamCloseOnCloseCallback(t *testing.T) {
235+
body := ""
236+
body += "data: " + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}` + "\n\n"
237+
body += "data: " + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":10,"total_tokens":15}}` + "\n\n"
238+
body += "data: [DONE]\n\n"
239+
240+
var callbackUsage *Usage
241+
var callbackCalled bool
242+
243+
s := newStream(io.NopCloser(strings.NewReader(body)))
244+
s = WrapStream(s, func(u *Usage) {
245+
callbackCalled = true
246+
callbackUsage = u
247+
})
248+
249+
for {
250+
_, err := s.Recv()
251+
if errors.Is(err, io.EOF) {
252+
break
253+
}
254+
if err != nil {
255+
t.Fatalf("Recv: %v", err)
256+
}
257+
}
258+
259+
if err := s.Close(); err != nil {
260+
t.Fatalf("Close: %v", err)
261+
}
262+
263+
if !callbackCalled {
264+
t.Fatal("onClose callback was not called")
265+
}
266+
if callbackUsage == nil {
267+
t.Fatal("callback received nil usage, want non-nil")
268+
}
269+
if callbackUsage.PromptTokens != 5 {
270+
t.Errorf("prompt_tokens = %d, want 5", callbackUsage.PromptTokens)
271+
}
272+
if callbackUsage.CompletionTokens != 10 {
273+
t.Errorf("completion_tokens = %d, want 10", callbackUsage.CompletionTokens)
274+
}
275+
if callbackUsage.TotalTokens != 15 {
276+
t.Errorf("total_tokens = %d, want 15", callbackUsage.TotalTokens)
277+
}
278+
}
279+
280+
func TestStreamCloseOnCloseCallbackWithoutUsage(t *testing.T) {
281+
body := ""
282+
body += "data: " + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hi"},"finish_reason":null}]}` + "\n\n"
283+
body += "data: [DONE]\n\n"
284+
285+
var callbackCalled bool
286+
var callbackUsage *Usage
287+
288+
s := newStream(io.NopCloser(strings.NewReader(body)))
289+
s = WrapStream(s, func(u *Usage) {
290+
callbackCalled = true
291+
callbackUsage = u
292+
})
293+
294+
for {
295+
_, err := s.Recv()
296+
if errors.Is(err, io.EOF) {
297+
break
298+
}
299+
if err != nil {
300+
t.Fatalf("Recv: %v", err)
301+
}
302+
}
303+
304+
if err := s.Close(); err != nil {
305+
t.Fatalf("Close: %v", err)
306+
}
307+
308+
if !callbackCalled {
309+
t.Fatal("onClose callback was not called")
310+
}
311+
if callbackUsage != nil {
312+
t.Errorf("callback usage = %+v, want nil", callbackUsage)
313+
}
314+
}
315+
316+
func TestWrapStreamNil(t *testing.T) {
317+
var callbackCalled bool
318+
var callbackUsage *Usage
319+
320+
result := WrapStream(nil, func(u *Usage) {
321+
callbackCalled = true
322+
callbackUsage = u
323+
})
324+
325+
if result != nil {
326+
t.Errorf("WrapStream(nil, cb) = %v, want nil", result)
327+
}
328+
if !callbackCalled {
329+
t.Fatal("onClose callback was not called immediately for nil stream")
330+
}
331+
if callbackUsage != nil {
332+
t.Errorf("callback usage = %+v, want nil", callbackUsage)
333+
}
334+
}
335+
336+
func TestWrapStreamNilCallbackNil(t *testing.T) {
337+
// Should not panic and should return nil.
338+
result := WrapStream(nil, nil)
339+
if result != nil {
340+
t.Errorf("WrapStream(nil, nil) = %v, want nil", result)
341+
}
342+
}

schema.go

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,31 @@ type Thinking struct {
4949
BudgetTokens int `json:"budget_tokens,omitempty"`
5050
}
5151

52+
// StreamOptions configures streaming behavior.
53+
type StreamOptions struct {
54+
IncludeUsage bool `json:"include_usage"`
55+
}
56+
5257
// ChatRequest represents a request to the chat completions API.
5358
type ChatRequest struct {
54-
Model string `json:"model"`
55-
Messages []Message `json:"messages"`
56-
Temperature *float64 `json:"temperature,omitempty"`
57-
MaxTokens *int `json:"max_tokens,omitempty"`
58-
TopP *float64 `json:"top_p,omitempty"`
59-
N *int `json:"n,omitempty"`
60-
Stop []string `json:"stop,omitempty"`
61-
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
62-
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
63-
Seed *int `json:"seed,omitempty"`
64-
User string `json:"user,omitempty"`
65-
ResponseFormat any `json:"response_format,omitempty"`
66-
Stream bool `json:"stream,omitempty"`
67-
Tools []Tool `json:"tools,omitempty"`
68-
ToolChoice any `json:"tool_choice,omitempty"`
69-
Thinking *Thinking `json:"thinking,omitempty"`
70-
ReasoningEffort string `json:"reasoning_effort,omitempty"`
59+
Model string `json:"model"`
60+
Messages []Message `json:"messages"`
61+
Temperature *float64 `json:"temperature,omitempty"`
62+
MaxTokens *int `json:"max_tokens,omitempty"`
63+
TopP *float64 `json:"top_p,omitempty"`
64+
N *int `json:"n,omitempty"`
65+
Stop []string `json:"stop,omitempty"`
66+
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
67+
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
68+
Seed *int `json:"seed,omitempty"`
69+
User string `json:"user,omitempty"`
70+
ResponseFormat any `json:"response_format,omitempty"`
71+
Stream bool `json:"stream,omitempty"`
72+
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
73+
Tools []Tool `json:"tools,omitempty"`
74+
ToolChoice any `json:"tool_choice,omitempty"`
75+
Thinking *Thinking `json:"thinking,omitempty"`
76+
ReasoningEffort string `json:"reasoning_effort,omitempty"`
7177
}
7278

7379
// ChatResponse represents a response from the chat completions API.

0 commit comments

Comments
 (0)