Skip to content

Commit a3e4acb

Browse files
authored
Merge pull request #2182 from dgageot/modelerrs
Improve modelsdev package
2 parents f8a2243 + 2040013 commit a3e4acb

File tree

10 files changed

+409
-220
lines changed

10 files changed

+409
-220
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ require (
1919
github.com/aws/aws-sdk-go-v2/credentials v1.19.12
2020
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.2
2121
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9
22+
github.com/aws/smithy-go v1.24.2
2223
github.com/aymanbagabas/go-udiff v0.4.1
2324
github.com/blevesearch/bleve/v2 v2.5.7
2425
github.com/bmatcuk/doublestar/v4 v4.10.0
@@ -90,7 +91,6 @@ require (
9091
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect
9192
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect
9293
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect
93-
github.com/aws/smithy-go v1.24.2 // indirect
9494
github.com/aymerick/douceur v0.2.0 // indirect
9595
github.com/bahlo/generic-list-go v0.2.0 // indirect
9696
github.com/bits-and-blooms/bitset v1.24.4 // indirect

pkg/backoff/backoff.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Package backoff provides exponential backoff calculation and
2+
// context-aware sleep utilities.
3+
package backoff
4+
5+
import (
6+
"context"
7+
"math/rand/v2"
8+
"time"
9+
)
10+
11+
// Configuration constants for exponential backoff.
12+
const (
13+
baseDelay = 200 * time.Millisecond
14+
maxDelay = 2 * time.Second
15+
factor = 2.0
16+
jitter = 0.1
17+
18+
// MaxRetryAfterWait caps how long we'll honor a Retry-After header to prevent
19+
// a misbehaving server from blocking the agent for an unreasonable amount of time.
20+
MaxRetryAfterWait = 60 * time.Second
21+
)
22+
23+
// Calculate returns the backoff duration for a given attempt (0-indexed).
24+
// Uses exponential backoff with jitter.
25+
func Calculate(attempt int) time.Duration {
26+
if attempt < 0 {
27+
attempt = 0
28+
}
29+
30+
// Calculate exponential delay
31+
delay := float64(baseDelay)
32+
for range attempt {
33+
delay *= factor
34+
}
35+
36+
// Cap at max delay
37+
if delay > float64(maxDelay) {
38+
delay = float64(maxDelay)
39+
}
40+
41+
// Add jitter (±10%)
42+
j := delay * jitter * (2*rand.Float64() - 1)
43+
delay += j
44+
45+
return time.Duration(delay)
46+
}
47+
48+
// SleepWithContext sleeps for the specified duration, returning early if context is cancelled.
49+
// Returns true if the sleep completed, false if it was interrupted by context cancellation.
50+
func SleepWithContext(ctx context.Context, d time.Duration) bool {
51+
timer := time.NewTimer(d)
52+
defer timer.Stop()
53+
54+
select {
55+
case <-timer.C:
56+
return true
57+
case <-ctx.Done():
58+
return false
59+
}
60+
}

pkg/backoff/backoff_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package backoff
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestCalculate(t *testing.T) {
13+
t.Parallel()
14+
15+
tests := []struct {
16+
attempt int
17+
minExpected time.Duration
18+
maxExpected time.Duration
19+
}{
20+
{attempt: 0, minExpected: 180 * time.Millisecond, maxExpected: 220 * time.Millisecond},
21+
{attempt: 1, minExpected: 360 * time.Millisecond, maxExpected: 440 * time.Millisecond},
22+
{attempt: 2, minExpected: 720 * time.Millisecond, maxExpected: 880 * time.Millisecond},
23+
{attempt: 3, minExpected: 1440 * time.Millisecond, maxExpected: 1760 * time.Millisecond},
24+
{attempt: 10, minExpected: 1800 * time.Millisecond, maxExpected: 2200 * time.Millisecond}, // capped at 2s
25+
}
26+
27+
for _, tt := range tests {
28+
t.Run(fmt.Sprintf("attempt_%d", tt.attempt), func(t *testing.T) {
29+
t.Parallel()
30+
b := Calculate(tt.attempt)
31+
assert.GreaterOrEqual(t, b, tt.minExpected, "backoff should be at least %v", tt.minExpected)
32+
assert.LessOrEqual(t, b, tt.maxExpected, "backoff should be at most %v", tt.maxExpected)
33+
})
34+
}
35+
36+
t.Run("negative attempt treated as 0", func(t *testing.T) {
37+
t.Parallel()
38+
b := Calculate(-1)
39+
assert.GreaterOrEqual(t, b, 180*time.Millisecond)
40+
assert.LessOrEqual(t, b, 220*time.Millisecond)
41+
})
42+
}
43+
44+
func TestSleepWithContext(t *testing.T) {
45+
t.Parallel()
46+
47+
t.Run("completes normally", func(t *testing.T) {
48+
t.Parallel()
49+
ctx := t.Context()
50+
start := time.Now()
51+
completed := SleepWithContext(ctx, 10*time.Millisecond)
52+
elapsed := time.Since(start)
53+
54+
assert.True(t, completed, "should complete normally")
55+
assert.GreaterOrEqual(t, elapsed, 10*time.Millisecond)
56+
})
57+
58+
t.Run("interrupted by context", func(t *testing.T) {
59+
t.Parallel()
60+
ctx, cancel := context.WithCancel(t.Context())
61+
time.AfterFunc(10*time.Millisecond, cancel)
62+
63+
start := time.Now()
64+
completed := SleepWithContext(ctx, 1*time.Second)
65+
elapsed := time.Since(start)
66+
67+
assert.False(t, completed, "should be interrupted")
68+
assert.Less(t, elapsed, 100*time.Millisecond, "should return quickly after cancel")
69+
})
70+
}

pkg/model/provider/bedrock/adapter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
7171
// Check for errors
7272
if err := a.stream.Err(); err != nil {
7373
slog.Debug("Bedrock stream: error on channel close", "error", err)
74-
return chat.MessageStreamResponse{}, err
74+
return chat.MessageStreamResponse{}, wrapBedrockError(err)
7575
}
7676
// If we have a pending finish reason but never got metadata, emit it now
7777
if a.pendingFinishReason != "" {

pkg/model/provider/bedrock/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ func (c *Client) CreateChatCompletionStream(
219219
output, err := c.bedrockClient.ConverseStream(ctx, input)
220220
if err != nil {
221221
slog.Error("Bedrock ConverseStream failed", "error", err)
222-
return nil, fmt.Errorf("bedrock converse stream failed: %w", err)
222+
return nil, wrapBedrockError(fmt.Errorf("bedrock converse stream failed: %w", err))
223223
}
224224

225225
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage

pkg/model/provider/bedrock/wrap.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package bedrock
2+
3+
import (
4+
"errors"
5+
6+
smithyhttp "github.com/aws/smithy-go/transport/http"
7+
8+
"github.com/docker/docker-agent/pkg/modelerrors"
9+
)
10+
11+
// wrapBedrockError wraps an AWS Bedrock SDK error in a *modelerrors.StatusError
12+
// to carry HTTP status code metadata for the retry loop.
13+
// The AWS SDK v2 exposes HTTP status via smithyhttp.ResponseError.
14+
// Non-AWS errors (e.g., io.EOF, network errors) pass through unchanged.
15+
func wrapBedrockError(err error) error {
16+
if err == nil {
17+
return nil
18+
}
19+
20+
var respErr *smithyhttp.ResponseError
21+
if !errors.As(err, &respErr) {
22+
return err
23+
}
24+
25+
var resp *smithyhttp.Response
26+
if respErr.HTTPResponse() != nil {
27+
resp = respErr.HTTPResponse()
28+
}
29+
30+
statusCode := respErr.HTTPStatusCode()
31+
if resp != nil {
32+
return modelerrors.WrapHTTPError(statusCode, resp.Response, err)
33+
}
34+
return modelerrors.WrapHTTPError(statusCode, nil, err)
35+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package bedrock
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"net/http"
7+
"testing"
8+
"time"
9+
10+
smithy "github.com/aws/smithy-go"
11+
smithyhttp "github.com/aws/smithy-go/transport/http"
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
14+
15+
"github.com/docker/docker-agent/pkg/modelerrors"
16+
)
17+
18+
func makeTestBedrockError(statusCode int, retryAfterValue string) error {
19+
header := http.Header{}
20+
if retryAfterValue != "" {
21+
header.Set("Retry-After", retryAfterValue)
22+
}
23+
24+
httpResp := &http.Response{
25+
StatusCode: statusCode,
26+
Header: header,
27+
}
28+
resp := &smithyhttp.Response{Response: httpResp}
29+
30+
return &smithy.OperationError{
31+
ServiceID: "BedrockRuntime",
32+
OperationName: "ConverseStream",
33+
Err: &smithyhttp.ResponseError{
34+
Response: resp,
35+
Err: &smithy.GenericAPIError{
36+
Code: "ThrottlingException",
37+
Message: "Rate exceeded",
38+
},
39+
},
40+
}
41+
}
42+
43+
func TestWrapBedrockError(t *testing.T) {
44+
t.Parallel()
45+
46+
t.Run("nil returns nil", func(t *testing.T) {
47+
t.Parallel()
48+
assert.NoError(t, wrapBedrockError(nil))
49+
})
50+
51+
t.Run("non-AWS error passes through unchanged", func(t *testing.T) {
52+
t.Parallel()
53+
orig := errors.New("some network error")
54+
result := wrapBedrockError(orig)
55+
assert.Equal(t, orig, result)
56+
var se *modelerrors.StatusError
57+
assert.NotErrorAs(t, result, &se)
58+
})
59+
60+
t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) {
61+
t.Parallel()
62+
awsErr := makeTestBedrockError(429, "")
63+
result := wrapBedrockError(awsErr)
64+
var se *modelerrors.StatusError
65+
require.ErrorAs(t, result, &se)
66+
assert.Equal(t, 429, se.StatusCode)
67+
assert.Equal(t, time.Duration(0), se.RetryAfter)
68+
// Original error still accessible
69+
assert.ErrorIs(t, result, awsErr)
70+
})
71+
72+
t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) {
73+
t.Parallel()
74+
awsErr := makeTestBedrockError(429, "20")
75+
result := wrapBedrockError(awsErr)
76+
var se *modelerrors.StatusError
77+
require.ErrorAs(t, result, &se)
78+
assert.Equal(t, 429, se.StatusCode)
79+
assert.Equal(t, 20*time.Second, se.RetryAfter)
80+
})
81+
82+
t.Run("500 wraps with correct status code", func(t *testing.T) {
83+
t.Parallel()
84+
awsErr := makeTestBedrockError(500, "")
85+
result := wrapBedrockError(awsErr)
86+
var se *modelerrors.StatusError
87+
require.ErrorAs(t, result, &se)
88+
assert.Equal(t, 500, se.StatusCode)
89+
assert.Equal(t, time.Duration(0), se.RetryAfter)
90+
})
91+
92+
t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) {
93+
t.Parallel()
94+
awsErr := makeTestBedrockError(429, "15")
95+
result := wrapBedrockError(awsErr)
96+
retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result)
97+
assert.False(t, retryable)
98+
assert.True(t, rateLimited)
99+
assert.Equal(t, 15*time.Second, retryAfter)
100+
})
101+
102+
t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) {
103+
t.Parallel()
104+
awsErr := makeTestBedrockError(500, "")
105+
wrapped := fmt.Errorf("bedrock converse stream failed: %w", wrapBedrockError(awsErr))
106+
retryable, rateLimited, _ := modelerrors.ClassifyModelError(wrapped)
107+
assert.True(t, retryable)
108+
assert.False(t, rateLimited)
109+
})
110+
}

0 commit comments

Comments
 (0)