Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,19 @@ func sendStreamRequest[T responseStream[R], R any](ctx context.Context, ac *apiC
var cancel context.CancelFunc
if timeout != nil && *timeout > 0*time.Second && isTimeoutBeforeDeadline(ctx, *timeout) {
requestContext, cancel = context.WithTimeout(ctx, *timeout)
defer cancel()
}
req = req.WithContext(requestContext)

resp, err := doRequest(ac, req)
if err != nil {
if cancel != nil {
cancel()
}
return err
}

// Transfer cancel ownership to the stream; the iterator will call it when done.
output.cancel = cancel
// resp.Body will be closed by the iterator
return deserializeStreamResponse(resp, output)
}
Expand Down Expand Up @@ -375,15 +379,19 @@ func deserializeUnaryResponse(resp *http.Response) (map[string]any, error) {
}

type responseStream[R any] struct {
r *bufio.Scanner
rc io.ReadCloser
h http.Header
r *bufio.Scanner
rc io.ReadCloser
h http.Header
cancel context.CancelFunc // cancels the request timeout context; called when iterator completes
}

func iterateResponseStream[R any](rs *responseStream[R], responseConverter func(responseMap map[string]any) (*R, error)) iter.Seq2[*R, error] {
return func(yield func(*R, error) bool) {
defer func() {
// Close the response body range over function is done.
// Cancel the request timeout context first, then close the response body.
if rs.cancel != nil {
rs.cancel()
}
if err := rs.rc.Close(); err != nil {
log.Printf("Error closing response body: %v", err)
}
Expand Down
43 changes: 43 additions & 0 deletions api_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,49 @@ func TestSendStreamRequest(t *testing.T) {
}
}

func TestSendStreamRequestTimeoutNotCancelledEarly(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher := w.(http.Flusher)
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
for i := 1; i <= 4; i++ {
fmt.Fprintf(w, "data:{\"chunk\":%d}\n\n", i)
flusher.Flush()
time.Sleep(10 * time.Millisecond)
}
}))
defer ts.Close()

ac := &apiClient{clientConfig: &ClientConfig{
Backend: BackendGeminiAPI,
HTTPOptions: HTTPOptions{
BaseURL: ts.URL,
APIVersion: "v0",
Headers: http.Header{"User-Agent": {"test-user-agent"}, "X-Goog-Api-Key": {"test-api-key"}},
},
HTTPClient: ts.Client(),
}}

requestTimeout := 5 * time.Second
var output responseStream[map[string]any]
if err := sendStreamRequest(context.Background(), ac, "test", "POST", map[string]any{"key": "value"}, &HTTPOptions{Timeout: &requestTimeout, BaseURL: ac.clientConfig.HTTPOptions.BaseURL}, &output); err != nil {
t.Fatalf("sendStreamRequest() error = %v", err)
}

var got []map[string]any
for resp, err := range iterateResponseStream(&output, func(m map[string]any) (*map[string]any, error) { return &m, nil }) {
if err != nil {
t.Fatalf("iterateResponseStream() error = %v", err)
}
got = append(got, *resp)
}

want := []map[string]any{{"chunk": float64(1)}, {"chunk": float64(2)}, {"chunk": float64(3)}, {"chunk": float64(4)}}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("stream response mismatch (-want +got):\n%s", diff)
}
}

func TestMapToStruct(t *testing.T) {
testCases := []struct {
name string
Expand Down