Skip to content

Commit c203ca0

Browse files
authored
feat: add RecvRaw (#896)
1 parent 21fa42c commit c203ca0

File tree

2 files changed

+35
-17
lines changed

2 files changed

+35
-17
lines changed

stream_reader.go

+22-17
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,28 @@ type streamReader[T streamable] struct {
3232
}
3333

3434
func (stream *streamReader[T]) Recv() (response T, err error) {
35-
if stream.isFinished {
36-
err = io.EOF
35+
rawLine, err := stream.RecvRaw()
36+
if err != nil {
3737
return
3838
}
3939

40-
response, err = stream.processLines()
41-
return
40+
err = stream.unmarshaler.Unmarshal(rawLine, &response)
41+
if err != nil {
42+
return
43+
}
44+
return response, nil
45+
}
46+
47+
func (stream *streamReader[T]) RecvRaw() ([]byte, error) {
48+
if stream.isFinished {
49+
return nil, io.EOF
50+
}
51+
52+
return stream.processLines()
4253
}
4354

4455
//nolint:gocognit
45-
func (stream *streamReader[T]) processLines() (T, error) {
56+
func (stream *streamReader[T]) processLines() ([]byte, error) {
4657
var (
4758
emptyMessagesCount uint
4859
hasErrorPrefix bool
@@ -53,9 +64,9 @@ func (stream *streamReader[T]) processLines() (T, error) {
5364
if readErr != nil || hasErrorPrefix {
5465
respErr := stream.unmarshalError()
5566
if respErr != nil {
56-
return *new(T), fmt.Errorf("error, %w", respErr.Error)
67+
return nil, fmt.Errorf("error, %w", respErr.Error)
5768
}
58-
return *new(T), readErr
69+
return nil, readErr
5970
}
6071

6172
noSpaceLine := bytes.TrimSpace(rawLine)
@@ -68,11 +79,11 @@ func (stream *streamReader[T]) processLines() (T, error) {
6879
}
6980
writeErr := stream.errAccumulator.Write(noSpaceLine)
7081
if writeErr != nil {
71-
return *new(T), writeErr
82+
return nil, writeErr
7283
}
7384
emptyMessagesCount++
7485
if emptyMessagesCount > stream.emptyMessagesLimit {
75-
return *new(T), ErrTooManyEmptyStreamMessages
86+
return nil, ErrTooManyEmptyStreamMessages
7687
}
7788

7889
continue
@@ -81,16 +92,10 @@ func (stream *streamReader[T]) processLines() (T, error) {
8192
noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData)
8293
if string(noPrefixLine) == "[DONE]" {
8394
stream.isFinished = true
84-
return *new(T), io.EOF
85-
}
86-
87-
var response T
88-
unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response)
89-
if unmarshalErr != nil {
90-
return *new(T), unmarshalErr
95+
return nil, io.EOF
9196
}
9297

93-
return response, nil
98+
return noPrefixLine, nil
9499
}
95100
}
96101

stream_reader_test.go

+13
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
6363
_, err := stream.Recv()
6464
checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error())
6565
}
66+
67+
func TestStreamReaderRecvRaw(t *testing.T) {
68+
stream := &streamReader[ChatCompletionStreamResponse]{
69+
reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))),
70+
}
71+
rawLine, err := stream.RecvRaw()
72+
if err != nil {
73+
t.Fatalf("Did not return raw line: %v", err)
74+
}
75+
if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) {
76+
t.Fatalf("Did not return raw line: %v", string(rawLine))
77+
}
78+
}

0 commit comments

Comments
 (0)