Skip to content

Commit d10f1b8

Browse files
Yu0ugenglixia
and
genglixia
authored
add chatcompletion stream delta refusal and logprobs (#882)
* add chatcompletion stream refusal and logprobs * fix slice to struct * add integration test * fix lint * fix lint * fix: the object should be pointer --------- Co-authored-by: genglixia <[email protected]>
1 parent 6e08732 commit d10f1b8

File tree

2 files changed

+289
-4
lines changed

2 files changed

+289
-4
lines changed

chat_stream.go

+24-4
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,33 @@ type ChatCompletionStreamChoiceDelta struct {
1010
Role string `json:"role,omitempty"`
1111
FunctionCall *FunctionCall `json:"function_call,omitempty"`
1212
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
13+
Refusal string `json:"refusal,omitempty"`
14+
}
15+
16+
type ChatCompletionStreamChoiceLogprobs struct {
17+
Content []ChatCompletionTokenLogprob `json:"content,omitempty"`
18+
Refusal []ChatCompletionTokenLogprob `json:"refusal,omitempty"`
19+
}
20+
21+
type ChatCompletionTokenLogprob struct {
22+
Token string `json:"token"`
23+
Bytes []int64 `json:"bytes,omitempty"`
24+
Logprob float64 `json:"logprob,omitempty"`
25+
TopLogprobs []ChatCompletionTokenLogprobTopLogprob `json:"top_logprobs"`
26+
}
27+
28+
type ChatCompletionTokenLogprobTopLogprob struct {
29+
Token string `json:"token"`
30+
Bytes []int64 `json:"bytes"`
31+
Logprob float64 `json:"logprob"`
1332
}
1433

1534
type ChatCompletionStreamChoice struct {
16-
Index int `json:"index"`
17-
Delta ChatCompletionStreamChoiceDelta `json:"delta"`
18-
FinishReason FinishReason `json:"finish_reason"`
19-
ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"`
35+
Index int `json:"index"`
36+
Delta ChatCompletionStreamChoiceDelta `json:"delta"`
37+
Logprobs *ChatCompletionStreamChoiceLogprobs `json:"logprobs,omitempty"`
38+
FinishReason FinishReason `json:"finish_reason"`
39+
ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"`
2040
}
2141

2242
type PromptFilterResult struct {

chat_stream_test.go

+265
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,271 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
358358
t.Logf("%+v\n", apiErr)
359359
}
360360

361+
func TestCreateChatCompletionStreamWithRefusal(t *testing.T) {
362+
client, server, teardown := setupOpenAITestServer()
363+
defer teardown()
364+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
365+
w.Header().Set("Content-Type", "text/event-stream")
366+
367+
dataBytes := []byte{}
368+
369+
//nolint:lll
370+
dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`)...)
371+
dataBytes = append(dataBytes, []byte("\n\n")...)
372+
373+
//nolint:lll
374+
dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":"Hello"},"finish_reason":null}]}`)...)
375+
dataBytes = append(dataBytes, []byte("\n\n")...)
376+
377+
//nolint:lll
378+
dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":" World"},"finish_reason":null}]}`)...)
379+
dataBytes = append(dataBytes, []byte("\n\n")...)
380+
381+
//nolint:lll
382+
dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...)
383+
dataBytes = append(dataBytes, []byte("\n\n")...)
384+
385+
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
386+
387+
_, err := w.Write(dataBytes)
388+
checks.NoError(t, err, "Write error")
389+
})
390+
391+
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
392+
MaxTokens: 2000,
393+
Model: openai.GPT4oMini20240718,
394+
Messages: []openai.ChatCompletionMessage{
395+
{
396+
Role: openai.ChatMessageRoleUser,
397+
Content: "Hello!",
398+
},
399+
},
400+
Stream: true,
401+
})
402+
checks.NoError(t, err, "CreateCompletionStream returned error")
403+
defer stream.Close()
404+
405+
expectedResponses := []openai.ChatCompletionStreamResponse{
406+
{
407+
ID: "1",
408+
Object: "chat.completion.chunk",
409+
Created: 1729585728,
410+
Model: openai.GPT4oMini20240718,
411+
SystemFingerprint: "fp_d9767fc5b9",
412+
Choices: []openai.ChatCompletionStreamChoice{
413+
{
414+
Index: 0,
415+
Delta: openai.ChatCompletionStreamChoiceDelta{},
416+
},
417+
},
418+
},
419+
{
420+
ID: "2",
421+
Object: "chat.completion.chunk",
422+
Created: 1729585728,
423+
Model: openai.GPT4oMini20240718,
424+
SystemFingerprint: "fp_d9767fc5b9",
425+
Choices: []openai.ChatCompletionStreamChoice{
426+
{
427+
Index: 0,
428+
Delta: openai.ChatCompletionStreamChoiceDelta{
429+
Refusal: "Hello",
430+
},
431+
},
432+
},
433+
},
434+
{
435+
ID: "3",
436+
Object: "chat.completion.chunk",
437+
Created: 1729585728,
438+
Model: openai.GPT4oMini20240718,
439+
SystemFingerprint: "fp_d9767fc5b9",
440+
Choices: []openai.ChatCompletionStreamChoice{
441+
{
442+
Index: 0,
443+
Delta: openai.ChatCompletionStreamChoiceDelta{
444+
Refusal: " World",
445+
},
446+
},
447+
},
448+
},
449+
{
450+
ID: "4",
451+
Object: "chat.completion.chunk",
452+
Created: 1729585728,
453+
Model: openai.GPT4oMini20240718,
454+
SystemFingerprint: "fp_d9767fc5b9",
455+
Choices: []openai.ChatCompletionStreamChoice{
456+
{
457+
Index: 0,
458+
FinishReason: "stop",
459+
},
460+
},
461+
},
462+
}
463+
464+
for ix, expectedResponse := range expectedResponses {
465+
b, _ := json.Marshal(expectedResponse)
466+
t.Logf("%d: %s", ix, string(b))
467+
468+
receivedResponse, streamErr := stream.Recv()
469+
checks.NoError(t, streamErr, "stream.Recv() failed")
470+
if !compareChatResponses(expectedResponse, receivedResponse) {
471+
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
472+
}
473+
}
474+
475+
_, streamErr := stream.Recv()
476+
if !errors.Is(streamErr, io.EOF) {
477+
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
478+
}
479+
}
480+
481+
func TestCreateChatCompletionStreamWithLogprobs(t *testing.T) {
482+
client, server, teardown := setupOpenAITestServer()
483+
defer teardown()
484+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
485+
w.Header().Set("Content-Type", "text/event-stream")
486+
487+
// Send test responses
488+
dataBytes := []byte{}
489+
490+
//nolint:lll
491+
dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":{"content":[],"refusal":null},"finish_reason":null}]}`)...)
492+
dataBytes = append(dataBytes, []byte("\n\n")...)
493+
494+
//nolint:lll
495+
dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":{"content":[{"token":"Hello","logprob":-0.000020458236,"bytes":[72,101,108,108,111],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...)
496+
dataBytes = append(dataBytes, []byte("\n\n")...)
497+
498+
//nolint:lll
499+
dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":" World"},"logprobs":{"content":[{"token":" World","logprob":-0.00055303273,"bytes":[32,87,111,114,108,100],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...)
500+
dataBytes = append(dataBytes, []byte("\n\n")...)
501+
502+
//nolint:lll
503+
dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`)...)
504+
dataBytes = append(dataBytes, []byte("\n\n")...)
505+
506+
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
507+
508+
_, err := w.Write(dataBytes)
509+
checks.NoError(t, err, "Write error")
510+
})
511+
512+
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
513+
MaxTokens: 2000,
514+
Model: openai.GPT3Dot5Turbo,
515+
Messages: []openai.ChatCompletionMessage{
516+
{
517+
Role: openai.ChatMessageRoleUser,
518+
Content: "Hello!",
519+
},
520+
},
521+
Stream: true,
522+
})
523+
checks.NoError(t, err, "CreateCompletionStream returned error")
524+
defer stream.Close()
525+
526+
expectedResponses := []openai.ChatCompletionStreamResponse{
527+
{
528+
ID: "1",
529+
Object: "chat.completion.chunk",
530+
Created: 1729585728,
531+
Model: openai.GPT4oMini20240718,
532+
SystemFingerprint: "fp_d9767fc5b9",
533+
Choices: []openai.ChatCompletionStreamChoice{
534+
{
535+
Index: 0,
536+
Delta: openai.ChatCompletionStreamChoiceDelta{},
537+
Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{
538+
Content: []openai.ChatCompletionTokenLogprob{},
539+
},
540+
},
541+
},
542+
},
543+
{
544+
ID: "2",
545+
Object: "chat.completion.chunk",
546+
Created: 1729585728,
547+
Model: openai.GPT4oMini20240718,
548+
SystemFingerprint: "fp_d9767fc5b9",
549+
Choices: []openai.ChatCompletionStreamChoice{
550+
{
551+
Index: 0,
552+
Delta: openai.ChatCompletionStreamChoiceDelta{
553+
Content: "Hello",
554+
},
555+
Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{
556+
Content: []openai.ChatCompletionTokenLogprob{
557+
{
558+
Token: "Hello",
559+
Logprob: -0.000020458236,
560+
Bytes: []int64{72, 101, 108, 108, 111},
561+
TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{},
562+
},
563+
},
564+
},
565+
},
566+
},
567+
},
568+
{
569+
ID: "3",
570+
Object: "chat.completion.chunk",
571+
Created: 1729585728,
572+
Model: openai.GPT4oMini20240718,
573+
SystemFingerprint: "fp_d9767fc5b9",
574+
Choices: []openai.ChatCompletionStreamChoice{
575+
{
576+
Index: 0,
577+
Delta: openai.ChatCompletionStreamChoiceDelta{
578+
Content: " World",
579+
},
580+
Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{
581+
Content: []openai.ChatCompletionTokenLogprob{
582+
{
583+
Token: " World",
584+
Logprob: -0.00055303273,
585+
Bytes: []int64{32, 87, 111, 114, 108, 100},
586+
TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{},
587+
},
588+
},
589+
},
590+
},
591+
},
592+
},
593+
{
594+
ID: "4",
595+
Object: "chat.completion.chunk",
596+
Created: 1729585728,
597+
Model: openai.GPT4oMini20240718,
598+
SystemFingerprint: "fp_d9767fc5b9",
599+
Choices: []openai.ChatCompletionStreamChoice{
600+
{
601+
Index: 0,
602+
Delta: openai.ChatCompletionStreamChoiceDelta{},
603+
FinishReason: "stop",
604+
},
605+
},
606+
},
607+
}
608+
609+
for ix, expectedResponse := range expectedResponses {
610+
b, _ := json.Marshal(expectedResponse)
611+
t.Logf("%d: %s", ix, string(b))
612+
613+
receivedResponse, streamErr := stream.Recv()
614+
checks.NoError(t, streamErr, "stream.Recv() failed")
615+
if !compareChatResponses(expectedResponse, receivedResponse) {
616+
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
617+
}
618+
}
619+
620+
_, streamErr := stream.Recv()
621+
if !errors.Is(streamErr, io.EOF) {
622+
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
623+
}
624+
}
625+
361626
func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) {
362627
wantCode := "429"
363628
wantMessage := "Requests to the Creates a completion for the chat message Operation under Azure OpenAI API " +

0 commit comments

Comments
 (0)