@@ -388,6 +388,120 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
388
388
}
389
389
}
390
390
391
+ func TestCreateChatCompletionStreamStreamOptions (t * testing.T ) {
392
+ client , server , teardown := setupOpenAITestServer ()
393
+ defer teardown ()
394
+
395
+ server .RegisterHandler ("/v1/chat/completions" , func (w http.ResponseWriter , _ * http.Request ) {
396
+ w .Header ().Set ("Content-Type" , "text/event-stream" )
397
+
398
+ // Send test responses
399
+ var dataBytes []byte
400
+ //nolint:lll
401
+ data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}`
402
+ dataBytes = append (dataBytes , []byte ("data: " + data + "\n \n " )... )
403
+
404
+ //nolint:lll
405
+ data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}`
406
+ dataBytes = append (dataBytes , []byte ("data: " + data + "\n \n " )... )
407
+
408
+ //nolint:lll
409
+ data = `{"id":"3","object":"completion","created":1598069256,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`
410
+ dataBytes = append (dataBytes , []byte ("data: " + data + "\n \n " )... )
411
+
412
+ dataBytes = append (dataBytes , []byte ("data: [DONE]\n \n " )... )
413
+
414
+ _ , err := w .Write (dataBytes )
415
+ checks .NoError (t , err , "Write error" )
416
+ })
417
+
418
+ stream , err := client .CreateChatCompletionStream (context .Background (), openai.ChatCompletionRequest {
419
+ MaxTokens : 5 ,
420
+ Model : openai .GPT3Dot5Turbo ,
421
+ Messages : []openai.ChatCompletionMessage {
422
+ {
423
+ Role : openai .ChatMessageRoleUser ,
424
+ Content : "Hello!" ,
425
+ },
426
+ },
427
+ Stream : true ,
428
+ StreamOptions : & openai.StreamOptions {
429
+ IncludeUsage : true ,
430
+ },
431
+ })
432
+ checks .NoError (t , err , "CreateCompletionStream returned error" )
433
+ defer stream .Close ()
434
+
435
+ expectedResponses := []openai.ChatCompletionStreamResponse {
436
+ {
437
+ ID : "1" ,
438
+ Object : "completion" ,
439
+ Created : 1598069254 ,
440
+ Model : openai .GPT3Dot5Turbo ,
441
+ SystemFingerprint : "fp_d9767fc5b9" ,
442
+ Choices : []openai.ChatCompletionStreamChoice {
443
+ {
444
+ Delta : openai.ChatCompletionStreamChoiceDelta {
445
+ Content : "response1" ,
446
+ },
447
+ FinishReason : "max_tokens" ,
448
+ },
449
+ },
450
+ },
451
+ {
452
+ ID : "2" ,
453
+ Object : "completion" ,
454
+ Created : 1598069255 ,
455
+ Model : openai .GPT3Dot5Turbo ,
456
+ SystemFingerprint : "fp_d9767fc5b9" ,
457
+ Choices : []openai.ChatCompletionStreamChoice {
458
+ {
459
+ Delta : openai.ChatCompletionStreamChoiceDelta {
460
+ Content : "response2" ,
461
+ },
462
+ FinishReason : "max_tokens" ,
463
+ },
464
+ },
465
+ },
466
+ {
467
+ ID : "3" ,
468
+ Object : "completion" ,
469
+ Created : 1598069256 ,
470
+ Model : openai .GPT3Dot5Turbo ,
471
+ SystemFingerprint : "fp_d9767fc5b9" ,
472
+ Choices : []openai.ChatCompletionStreamChoice {},
473
+ Usage : & openai.Usage {
474
+ PromptTokens : 1 ,
475
+ CompletionTokens : 1 ,
476
+ TotalTokens : 2 ,
477
+ },
478
+ },
479
+ }
480
+
481
+ for ix , expectedResponse := range expectedResponses {
482
+ b , _ := json .Marshal (expectedResponse )
483
+ t .Logf ("%d: %s" , ix , string (b ))
484
+
485
+ receivedResponse , streamErr := stream .Recv ()
486
+ checks .NoError (t , streamErr , "stream.Recv() failed" )
487
+ if ! compareChatResponses (expectedResponse , receivedResponse ) {
488
+ t .Errorf ("Stream response %v is %v, expected %v" , ix , receivedResponse , expectedResponse )
489
+ }
490
+ }
491
+
492
+ _ , streamErr := stream .Recv ()
493
+ if ! errors .Is (streamErr , io .EOF ) {
494
+ t .Errorf ("stream.Recv() did not return EOF in the end: %v" , streamErr )
495
+ }
496
+
497
+ _ , streamErr = stream .Recv ()
498
+
499
+ checks .ErrorIs (t , streamErr , io .EOF , "stream.Recv() did not return EOF when the stream is finished" )
500
+ if ! errors .Is (streamErr , io .EOF ) {
501
+ t .Errorf ("stream.Recv() did not return EOF when the stream is finished: %v" , streamErr )
502
+ }
503
+ }
504
+
391
505
// Helper funcs.
392
506
func compareChatResponses (r1 , r2 openai.ChatCompletionStreamResponse ) bool {
393
507
if r1 .ID != r2 .ID || r1 .Object != r2 .Object || r1 .Created != r2 .Created || r1 .Model != r2 .Model {
@@ -401,6 +515,15 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
401
515
return false
402
516
}
403
517
}
518
+ if r1 .Usage != nil || r2 .Usage != nil {
519
+ if r1 .Usage == nil || r2 .Usage == nil {
520
+ return false
521
+ }
522
+ if r1 .Usage .PromptTokens != r2 .Usage .PromptTokens || r1 .Usage .CompletionTokens != r2 .Usage .CompletionTokens ||
523
+ r1 .Usage .TotalTokens != r2 .Usage .TotalTokens {
524
+ return false
525
+ }
526
+ }
404
527
return true
405
528
}
406
529
0 commit comments