Skip to content

Commit a5a945a

Browse files
authored
fix: stream return EOF when openai return error (#184)
* fix: stream return EOF when openai return error * perf: add error accumulator * fix: golangci-lint * fix: unmarshal error possibly null * fix: error accumulator * test: error accumulator use interface and add test code * test: error accumulator add test code * refactor: use stream reader to re-use stream code * refactor: stream reader use generics
1 parent aa149c1 commit a5a945a

8 files changed

+372
-107
lines changed

chat_stream.go

+8-53
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@ package openai
22

33
import (
44
"bufio"
5-
"bytes"
65
"context"
7-
"encoding/json"
8-
"io"
9-
"net/http"
106
)
117

128
type ChatCompletionStreamChoiceDelta struct {
@@ -30,52 +26,7 @@ type ChatCompletionStreamResponse struct {
3026
// ChatCompletionStream
3127
// Note: Perhaps it is more elegant to abstract Stream using generics.
3228
type ChatCompletionStream struct {
33-
emptyMessagesLimit uint
34-
isFinished bool
35-
36-
reader *bufio.Reader
37-
response *http.Response
38-
}
39-
40-
func (stream *ChatCompletionStream) Recv() (response ChatCompletionStreamResponse, err error) {
41-
if stream.isFinished {
42-
err = io.EOF
43-
return
44-
}
45-
46-
var emptyMessagesCount uint
47-
48-
waitForData:
49-
line, err := stream.reader.ReadBytes('\n')
50-
if err != nil {
51-
return
52-
}
53-
54-
var headerData = []byte("data: ")
55-
line = bytes.TrimSpace(line)
56-
if !bytes.HasPrefix(line, headerData) {
57-
emptyMessagesCount++
58-
if emptyMessagesCount > stream.emptyMessagesLimit {
59-
err = ErrTooManyEmptyStreamMessages
60-
return
61-
}
62-
63-
goto waitForData
64-
}
65-
66-
line = bytes.TrimPrefix(line, headerData)
67-
if string(line) == "[DONE]" {
68-
stream.isFinished = true
69-
err = io.EOF
70-
return
71-
}
72-
73-
err = json.Unmarshal(line, &response)
74-
return
75-
}
76-
77-
func (stream *ChatCompletionStream) Close() {
78-
stream.response.Body.Close()
29+
*streamReader[ChatCompletionStreamResponse]
7930
}
8031

8132
// CreateChatCompletionStream — API call to create a chat completion w/ streaming
@@ -98,9 +49,13 @@ func (c *Client) CreateChatCompletionStream(
9849
}
9950

10051
stream = &ChatCompletionStream{
101-
emptyMessagesLimit: c.config.EmptyMessagesLimit,
102-
reader: bufio.NewReader(resp.Body),
103-
response: resp,
52+
streamReader: &streamReader[ChatCompletionStreamResponse]{
53+
emptyMessagesLimit: c.config.EmptyMessagesLimit,
54+
reader: bufio.NewReader(resp.Body),
55+
response: resp,
56+
errAccumulator: newErrorAccumulator(),
57+
unmarshaler: &jsonUnmarshaler{},
58+
},
10459
}
10560
return
10661
}

chat_stream_test.go

+67
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,73 @@ func TestCreateChatCompletionStream(t *testing.T) {
123123
}
124124
}
125125

126+
func TestCreateChatCompletionStreamError(t *testing.T) {
127+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
128+
w.Header().Set("Content-Type", "text/event-stream")
129+
130+
// Send test responses
131+
dataBytes := []byte{}
132+
dataStr := []string{
133+
`{`,
134+
`"error": {`,
135+
`"message": "Incorrect API key provided: sk-***************************************",`,
136+
`"type": "invalid_request_error",`,
137+
`"param": null,`,
138+
`"code": "invalid_api_key"`,
139+
`}`,
140+
`}`,
141+
}
142+
for _, str := range dataStr {
143+
dataBytes = append(dataBytes, []byte(str+"\n")...)
144+
}
145+
146+
_, err := w.Write(dataBytes)
147+
if err != nil {
148+
t.Errorf("Write error: %s", err)
149+
}
150+
}))
151+
defer server.Close()
152+
153+
// Client portion of the test
154+
config := DefaultConfig(test.GetTestToken())
155+
config.BaseURL = server.URL + "/v1"
156+
config.HTTPClient.Transport = &tokenRoundTripper{
157+
test.GetTestToken(),
158+
http.DefaultTransport,
159+
}
160+
161+
client := NewClientWithConfig(config)
162+
ctx := context.Background()
163+
164+
request := ChatCompletionRequest{
165+
MaxTokens: 5,
166+
Model: GPT3Dot5Turbo,
167+
Messages: []ChatCompletionMessage{
168+
{
169+
Role: ChatMessageRoleUser,
170+
Content: "Hello!",
171+
},
172+
},
173+
Stream: true,
174+
}
175+
176+
stream, err := client.CreateChatCompletionStream(ctx, request)
177+
if err != nil {
178+
t.Errorf("CreateCompletionStream returned error: %v", err)
179+
}
180+
defer stream.Close()
181+
182+
_, streamErr := stream.Recv()
183+
if streamErr == nil {
184+
t.Errorf("stream.Recv() did not return error")
185+
}
186+
var apiErr *APIError
187+
if !errors.As(streamErr, &apiErr) {
188+
t.Errorf("stream.Recv() did not return APIError")
189+
}
190+
t.Logf("%+v\n", apiErr)
191+
}
192+
126193
// Helper funcs.
127194
func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool {
128195
if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model {

error_accumulator.go

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"io"
7+
)
8+
9+
type errorAccumulator interface {
10+
write(p []byte) error
11+
unmarshalError() (*ErrorResponse, error)
12+
}
13+
14+
type errorBuffer interface {
15+
io.Writer
16+
Len() int
17+
Bytes() []byte
18+
}
19+
20+
type errorAccumulate struct {
21+
buffer errorBuffer
22+
unmarshaler unmarshaler
23+
}
24+
25+
func newErrorAccumulator() errorAccumulator {
26+
return &errorAccumulate{
27+
buffer: &bytes.Buffer{},
28+
unmarshaler: &jsonUnmarshaler{},
29+
}
30+
}
31+
32+
func (e *errorAccumulate) write(p []byte) error {
33+
_, err := e.buffer.Write(p)
34+
if err != nil {
35+
return fmt.Errorf("error accumulator write error, %w", err)
36+
}
37+
return nil
38+
}
39+
40+
func (e *errorAccumulate) unmarshalError() (*ErrorResponse, error) {
41+
var err error
42+
if e.buffer.Len() > 0 {
43+
var errRes ErrorResponse
44+
err = e.unmarshaler.unmarshal(e.buffer.Bytes(), &errRes)
45+
if err != nil {
46+
return nil, err
47+
}
48+
return &errRes, nil
49+
}
50+
return nil, err
51+
}

error_accumulator_test.go

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package openai //nolint:testpackage // testing private field
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"errors"
7+
"testing"
8+
9+
"github.com/sashabaranov/go-openai/internal/test"
10+
)
11+
12+
var (
13+
errTestUnmarshalerFailed = errors.New("test unmarshaler failed")
14+
errTestErrorAccumulatorWriteFailed = errors.New("test error accumulator failed")
15+
)
16+
17+
type (
18+
failingUnMarshaller struct{}
19+
failingErrorBuffer struct{}
20+
)
21+
22+
func (b *failingErrorBuffer) Write(_ []byte) (n int, err error) {
23+
return 0, errTestErrorAccumulatorWriteFailed
24+
}
25+
26+
func (b *failingErrorBuffer) Len() int {
27+
return 0
28+
}
29+
30+
func (b *failingErrorBuffer) Bytes() []byte {
31+
return []byte{}
32+
}
33+
34+
func (*failingUnMarshaller) unmarshal(_ []byte, _ any) error {
35+
return errTestUnmarshalerFailed
36+
}
37+
38+
func TestErrorAccumulatorReturnsUnmarshalerErrors(t *testing.T) {
39+
accumulator := &errorAccumulate{
40+
buffer: &bytes.Buffer{},
41+
unmarshaler: &failingUnMarshaller{},
42+
}
43+
44+
err := accumulator.write([]byte("{"))
45+
if err != nil {
46+
t.Fatalf("%+v", err)
47+
}
48+
_, err = accumulator.unmarshalError()
49+
if !errors.Is(err, errTestUnmarshalerFailed) {
50+
t.Fatalf("Did not return error when unmarshaler failed: %v", err)
51+
}
52+
}
53+
54+
func TestErrorByteWriteErrors(t *testing.T) {
55+
accumulator := &errorAccumulate{
56+
buffer: &failingErrorBuffer{},
57+
unmarshaler: &jsonUnmarshaler{},
58+
}
59+
err := accumulator.write([]byte("{"))
60+
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
61+
t.Fatalf("Did not return error when write failed: %v", err)
62+
}
63+
}
64+
65+
func TestErrorAccumulatorWriteErrors(t *testing.T) {
66+
var err error
67+
ts := test.NewTestServer().OpenAITestServer()
68+
ts.Start()
69+
defer ts.Close()
70+
71+
config := DefaultConfig(test.GetTestToken())
72+
config.BaseURL = ts.URL + "/v1"
73+
client := NewClientWithConfig(config)
74+
75+
ctx := context.Background()
76+
77+
stream, err := client.CreateChatCompletionStream(ctx, ChatCompletionRequest{})
78+
if err != nil {
79+
t.Fatal(err)
80+
}
81+
stream.errAccumulator = &errorAccumulate{
82+
buffer: &failingErrorBuffer{},
83+
unmarshaler: &jsonUnmarshaler{},
84+
}
85+
86+
_, err = stream.Recv()
87+
if !errors.Is(err, errTestErrorAccumulatorWriteFailed) {
88+
t.Fatalf("Did not return error when write failed: %v", err)
89+
}
90+
}

stream.go

+8-54
Original file line numberDiff line numberDiff line change
@@ -2,65 +2,16 @@ package openai
22

33
import (
44
"bufio"
5-
"bytes"
65
"context"
7-
"encoding/json"
86
"errors"
9-
"io"
10-
"net/http"
117
)
128

139
var (
1410
ErrTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages")
1511
)
1612

1713
type CompletionStream struct {
18-
emptyMessagesLimit uint
19-
isFinished bool
20-
21-
reader *bufio.Reader
22-
response *http.Response
23-
}
24-
25-
func (stream *CompletionStream) Recv() (response CompletionResponse, err error) {
26-
if stream.isFinished {
27-
err = io.EOF
28-
return
29-
}
30-
31-
var emptyMessagesCount uint
32-
33-
waitForData:
34-
line, err := stream.reader.ReadBytes('\n')
35-
if err != nil {
36-
return
37-
}
38-
39-
var headerData = []byte("data: ")
40-
line = bytes.TrimSpace(line)
41-
if !bytes.HasPrefix(line, headerData) {
42-
emptyMessagesCount++
43-
if emptyMessagesCount > stream.emptyMessagesLimit {
44-
err = ErrTooManyEmptyStreamMessages
45-
return
46-
}
47-
48-
goto waitForData
49-
}
50-
51-
line = bytes.TrimPrefix(line, headerData)
52-
if string(line) == "[DONE]" {
53-
stream.isFinished = true
54-
err = io.EOF
55-
return
56-
}
57-
58-
err = json.Unmarshal(line, &response)
59-
return
60-
}
61-
62-
func (stream *CompletionStream) Close() {
63-
stream.response.Body.Close()
14+
*streamReader[CompletionResponse]
6415
}
6516

6617
// CreateCompletionStream — API call to create a completion w/ streaming
@@ -83,10 +34,13 @@ func (c *Client) CreateCompletionStream(
8334
}
8435

8536
stream = &CompletionStream{
86-
emptyMessagesLimit: c.config.EmptyMessagesLimit,
87-
88-
reader: bufio.NewReader(resp.Body),
89-
response: resp,
37+
streamReader: &streamReader[CompletionResponse]{
38+
emptyMessagesLimit: c.config.EmptyMessagesLimit,
39+
reader: bufio.NewReader(resp.Body),
40+
response: resp,
41+
errAccumulator: newErrorAccumulator(),
42+
unmarshaler: &jsonUnmarshaler{},
43+
},
9044
}
9145
return
9246
}

0 commit comments

Comments
 (0)