Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions gollm/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,19 @@ func (cs *openAIChatSession) SendStreaming(ctx context.Context, contents ...any)
for stream.Next() {
chunk := stream.Current()

// Update the accumulator with the new chunk
acc.AddChunk(chunk)
// Update the accumulator with the new chunk, guarding against
// panics from the upstream client when tool call indices are
// missing. See https://github.com/openai/openai-go/issues/464 for more details.
if ok, err := addChunkSafely(&acc, chunk); err != nil {
klog.ErrorS(err, "Recovered from OpenAI accumulator panic while processing chunk")
yield(nil, fmt.Errorf("OpenAI streaming error while accumulating chunk: %w", err))
return
} else if !ok {
err := errors.New("OpenAI accumulator rejected stream chunk")
klog.Error(err)
yield(nil, fmt.Errorf("OpenAI streaming error while accumulating chunk: %w", err))
return
}

// Handle content completion
if _, ok := acc.JustFinishedContent(); ok {
Expand Down Expand Up @@ -415,6 +426,23 @@ func (cs *openAIChatSession) SendStreaming(ctx context.Context, contents ...any)
}, nil
}

// addChunkSafely wraps the accumulator's AddChunk method to protect against
// panics triggered by malformed tool call indices returned by the upstream
// openai-go client. Returning an error allows the caller to surface a regular
// failure instead of crashing the entire process while the upstream fix is in
// flight.
func addChunkSafely(acc *openai.ChatCompletionAccumulator, chunk openai.ChatCompletionChunk) (added bool, err error) {
defer func() {
if r := recover(); r != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets hold for now. I think we can wait for this to be fixed in the upstream.

err = fmt.Errorf("openai accumulator panic: %v", r)
added = false
}
}()

added = acc.AddChunk(chunk)
return
}

// IsRetryableError determines if an error from the OpenAI API should be retried.
func (cs *openAIChatSession) IsRetryableError(err error) bool {
if err == nil {
Expand Down
70 changes: 70 additions & 0 deletions gollm/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,73 @@ func TestConvertToolCallsToFunctionCalls(t *testing.T) {
})
}
}

func TestAddChunkSafely(t *testing.T) {
tests := []struct {
name string
chunk openai.ChatCompletionChunk
expectError bool
expectAdded bool
setupAccumulator func() *openai.ChatCompletionAccumulator
unmarshalJSON bool
jsonString string
}{
{
name: "valid chunk with valid accumulator",
chunk: openai.ChatCompletionChunk{
Choices: []openai.ChatCompletionChunkChoice{
{
Delta: openai.ChatCompletionChunkChoiceDelta{
Content: "hello",
},
},
},
},
expectError: false,
expectAdded: true,
setupAccumulator: func() *openai.ChatCompletionAccumulator {
return &openai.ChatCompletionAccumulator{}
},
},
{
name: "chunk with negative tool call index should panic and be recovered",
expectError: true,
expectAdded: false,
setupAccumulator: func() *openai.ChatCompletionAccumulator {
return &openai.ChatCompletionAccumulator{}
},
unmarshalJSON: true,
jsonString: "{\"id\":\"gen-1753952129163910909\",\"object\":\"chat.completion.chunk\",\"created\":1753952129,\"model\":\"model\",\"choices\":[{\"index\":0,\"delta\":{\"role\":null,\"content\":null,\"reasoning_content\":null,\"tool_calls\":[{\"id\":\"call_goVJu3KUQtuahQuqW6wgzQ\",\"index\":-1,\"type\":\"function\",\"function\":{\"name\":\"doc_search\",\"arguments\":\"\"}}]},\"logprobs\":null,\"finish_reason\":null,\"matched_stop\":null}],\"usage\":null}",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
acc := tt.setupAccumulator()
chunk := tt.chunk

if tt.unmarshalJSON {
err := json.Unmarshal([]byte(tt.jsonString), &chunk)
if err != nil {
t.Fatalf("Failed to unmarshal test JSON: %v", err)
}
}

added, err := addChunkSafely(acc, chunk)

if tt.expectError {
if err == nil {
t.Error("expected an error, but got none")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}

if added != tt.expectAdded {
t.Errorf("expected added to be %v, but got %v", tt.expectAdded, added)
}
})
}
}
Loading