Skip to content

Commit ddf8b1c

Browse files
committed
Add Bedrock auth support for OpenAI providers
1 parent dbedfe3 commit ddf8b1c

12 files changed

Lines changed: 660 additions & 57 deletions

File tree

docs/providers.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ model := provider.ChatModel("gpt-4o-mini")
9696
| Option | Default | Description |
9797
|--------|---------|-------------|
9898
| `WithAPIKey(key)` | `""` | API key sent as `Authorization: Bearer <key>` |
99+
| `WithBedrockRegion(region)` | disabled | Use AWS SigV4 with the default AWS credential chain for Bedrock's OpenAI-compatible endpoint |
100+
| `WithBedrockCredentials(region, accessKeyID, secretAccessKey, sessionToken)` | disabled | Use AWS SigV4 with static AWS credentials for Bedrock |
99101
| `WithBaseURL(url)` | `https://api.openai.com/v1` | Base URL for API requests |
100102
| `WithHTTPClient(client)` | `&http.Client{}` | Custom HTTP client (for proxies, timeouts, etc.) |
101103

@@ -130,6 +132,12 @@ provider := completions.New(
130132
completions.WithBaseURL("https://your-resource.openai.azure.com/openai/deployments/gpt-4o"),
131133
)
132134

135+
// Amazon Bedrock OpenAI-compatible endpoint (AWS credentials)
136+
provider := completions.New(
137+
completions.WithBedrockRegion("us-east-1"),
138+
completions.WithBaseURL("https://bedrock-mantle.us-east-1.api.aws/v1"),
139+
)
140+
133141
// Local (Ollama, vLLM, etc.)
134142
provider := completions.New(
135143
completions.WithBaseURL("http://localhost:11434/v1"),
@@ -198,6 +206,8 @@ model := provider.ChatModel("gpt-4o-mini")
198206
| Option | Default | Description |
199207
|--------|---------|-------------|
200208
| `WithAPIKey(key)` | `""` | API key sent as `Authorization: Bearer <key>` |
209+
| `WithBedrockRegion(region)` | disabled | Use AWS SigV4 with the default AWS credential chain for Bedrock's OpenAI-compatible endpoint |
210+
| `WithBedrockCredentials(region, accessKeyID, secretAccessKey, sessionToken)` | disabled | Use AWS SigV4 with static AWS credentials for Bedrock |
201211
| `WithBaseURL(url)` | `https://api.openai.com/v1` | Base URL for API requests |
202212
| `WithHTTPClient(client)` | `&http.Client{}` | Custom HTTP client |
203213

@@ -218,6 +228,8 @@ provider := responses.New(
218228
responses.WithAPIKey("sk-or-v1-..."),
219229
responses.WithBaseURL("https://openrouter.ai/api/v1"),
220230
)
231+
232+
Amazon Bedrock's OpenAI-compatible Responses endpoint also works with the same provider when you configure a Bedrock Mantle base URL and either `WithBedrockRegion(...)` or `WithBedrockCredentials(...)`.
221233
model := provider.ChatModel("openai/o4-mini")
222234
```
223235
@@ -777,6 +789,8 @@ vec, err := sdk.Embed(ctx, "Hello world",
777789
| Option | Default | Description |
778790
|--------|---------|-------------|
779791
| `WithAPIKey(key)` | `""` | API key sent as `Authorization: Bearer <key>` |
792+
| `WithBedrockRegion(region)` | disabled | Use AWS SigV4 with the default AWS credential chain for Bedrock's OpenAI-compatible endpoint |
793+
| `WithBedrockCredentials(region, accessKeyID, secretAccessKey, sessionToken)` | disabled | Use AWS SigV4 with static AWS credentials for Bedrock |
780794
| `WithBaseURL(url)` | `https://api.openai.com/v1` | Base URL for API requests |
781795
| `WithHTTPClient(client)` | `&http.Client{}` | Custom HTTP client |
782796

go.mod

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,27 @@ module github.com/memohai/twilight-ai
33
go 1.25.7
44

55
require (
6+
github.com/aws/aws-sdk-go-v2 v1.41.5
7+
github.com/aws/aws-sdk-go-v2/config v1.32.14
8+
github.com/aws/aws-sdk-go-v2/credentials v1.19.14
69
github.com/google/jsonschema-go v0.4.2
710
github.com/google/uuid v1.6.0
811
github.com/gorilla/websocket v1.5.3
912
github.com/modelcontextprotocol/go-sdk v1.5.0
1013
)
1114

1215
require (
16+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
17+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
18+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
19+
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
20+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
21+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
22+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
23+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
24+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
25+
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
26+
github.com/aws/smithy-go v1.24.2 // indirect
1327
github.com/segmentio/asm v1.2.1 // indirect
1428
github.com/segmentio/encoding v0.5.4 // indirect
1529
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect

go.sum

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,31 @@
1+
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
2+
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
3+
github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI=
4+
github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo=
5+
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
6+
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
7+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
8+
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
9+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
10+
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c=
11+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A=
12+
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
13+
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
14+
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
15+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
16+
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
17+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=
18+
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
19+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
20+
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
21+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
22+
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
23+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
24+
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
25+
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
26+
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
27+
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
28+
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
129
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
230
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
331
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=

internal/utils/aws.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package utils
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"crypto/sha256"
7+
"encoding/hex"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"sync"
12+
"time"
13+
14+
"github.com/aws/aws-sdk-go-v2/aws"
15+
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
16+
"github.com/aws/aws-sdk-go-v2/config"
17+
"github.com/aws/aws-sdk-go-v2/credentials"
18+
)
19+
20+
const (
21+
emptySHA256 = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
22+
23+
bedrockService = "bedrock"
24+
)
25+
26+
type lazyCredentialsProvider struct {
27+
region string
28+
29+
once sync.Once
30+
cp aws.CredentialsProvider
31+
err error
32+
}
33+
34+
func NewBedrockDefaultCredentialsPreparer(region string) func(*http.Request) error {
35+
return NewSigV4Preparer(bedrockService, region, &lazyCredentialsProvider{region: region})
36+
}
37+
38+
func NewBedrockStaticCredentialsPreparer(region, accessKeyID, secretAccessKey, sessionToken string) func(*http.Request) error {
39+
cp := aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, sessionToken))
40+
return NewSigV4Preparer(bedrockService, region, cp)
41+
}
42+
43+
func NewSigV4Preparer(service, region string, cp aws.CredentialsProvider) func(*http.Request) error {
44+
signer := v4.NewSigner()
45+
46+
return func(req *http.Request) error {
47+
payloadHash, err := requestPayloadHash(req)
48+
if err != nil {
49+
return err
50+
}
51+
52+
creds, err := cp.Retrieve(req.Context())
53+
if err != nil {
54+
return fmt.Errorf("retrieve AWS credentials: %w", err)
55+
}
56+
57+
if err := signer.SignHTTP(req.Context(), creds, req, payloadHash, service, region, time.Now().UTC()); err != nil {
58+
return fmt.Errorf("sign AWS request: %w", err)
59+
}
60+
61+
return nil
62+
}
63+
}
64+
65+
func (p *lazyCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
66+
p.once.Do(func() {
67+
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(p.region))
68+
if err != nil {
69+
p.err = err
70+
return
71+
}
72+
p.cp = cfg.Credentials
73+
})
74+
if p.err != nil {
75+
return aws.Credentials{}, p.err
76+
}
77+
return p.cp.Retrieve(ctx)
78+
}
79+
80+
func requestPayloadHash(req *http.Request) (string, error) {
81+
if req.Body == nil {
82+
return emptySHA256, nil
83+
}
84+
85+
if req.GetBody != nil {
86+
body, err := req.GetBody()
87+
if err != nil {
88+
return "", fmt.Errorf("clone request body: %w", err)
89+
}
90+
defer body.Close()
91+
return hashReader(body)
92+
}
93+
94+
data, err := io.ReadAll(req.Body)
95+
if err != nil {
96+
return "", fmt.Errorf("read request body: %w", err)
97+
}
98+
99+
req.Body = io.NopCloser(bytes.NewReader(data))
100+
req.GetBody = func() (io.ReadCloser, error) {
101+
return io.NopCloser(bytes.NewReader(data)), nil
102+
}
103+
104+
return hashBytes(data), nil
105+
}
106+
107+
func hashReader(r io.Reader) (string, error) {
108+
h := sha256.New()
109+
if _, err := io.Copy(h, r); err != nil {
110+
return "", fmt.Errorf("hash request body: %w", err)
111+
}
112+
return hex.EncodeToString(h.Sum(nil)), nil
113+
}
114+
115+
func hashBytes(data []byte) string {
116+
sum := sha256.Sum256(data)
117+
return hex.EncodeToString(sum[:])
118+
}

internal/utils/fetch.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ type RequestOptions struct {
1717
Headers map[string]string
1818
Query map[string]string
1919
Body any
20+
Prepare func(*http.Request) error
2021
}
2122

2223
type APIError struct {
@@ -86,6 +87,12 @@ func BuildRequest(ctx context.Context, opts *RequestOptions) (*http.Request, err
8687
req.Header.Set(k, v)
8788
}
8889

90+
if opts.Prepare != nil {
91+
if err := opts.Prepare(req); err != nil {
92+
return nil, fmt.Errorf("prepare request: %w", err)
93+
}
94+
}
95+
8996
return req, nil
9097
}
9198

provider/openai/completions/completions.go

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ import (
1515
const defaultBaseURL = "https://api.openai.com/v1"
1616

1717
type Provider struct {
18-
apiKey string
19-
baseURL string
20-
httpClient *http.Client
18+
apiKey string
19+
baseURL string
20+
httpClient *http.Client
21+
prepareRequest func(*http.Request) error
2122
}
2223

2324
type Option func(*Provider)
@@ -28,6 +29,22 @@ func WithAPIKey(apiKey string) Option {
2829
}
2930
}
3031

32+
// WithBedrockRegion enables AWS SigV4 authentication for Amazon Bedrock's
33+
// OpenAI-compatible endpoint using the default AWS credential chain.
34+
func WithBedrockRegion(region string) Option {
35+
return func(p *Provider) {
36+
p.prepareRequest = utils.NewBedrockDefaultCredentialsPreparer(region)
37+
}
38+
}
39+
40+
// WithBedrockCredentials enables AWS SigV4 authentication for Amazon Bedrock's
41+
// OpenAI-compatible endpoint using static credentials.
42+
func WithBedrockCredentials(region, accessKeyID, secretAccessKey, sessionToken string) Option {
43+
return func(p *Provider) {
44+
p.prepareRequest = utils.NewBedrockStaticCredentialsPreparer(region, accessKeyID, secretAccessKey, sessionToken)
45+
}
46+
}
47+
3148
func WithBaseURL(baseURL string) Option {
3249
return func(p *Provider) {
3350
p.baseURL = baseURL
@@ -60,7 +77,8 @@ func (p *Provider) ListModels(ctx context.Context) ([]sdk.Model, error) {
6077
Method: http.MethodGet,
6178
BaseURL: p.baseURL,
6279
Path: "/models",
63-
Headers: utils.AuthHeader(p.apiKey),
80+
Headers: p.authHeaders(),
81+
Prepare: p.prepareRequest,
6482
})
6583
if err != nil {
6684
return nil, fmt.Errorf("openai: list models request failed: %w", err)
@@ -83,7 +101,8 @@ func (p *Provider) Test(ctx context.Context) *sdk.ProviderTestResult {
83101
BaseURL: p.baseURL,
84102
Path: "/models",
85103
Query: map[string]string{"limit": "1"},
86-
Headers: utils.AuthHeader(p.apiKey),
104+
Headers: p.authHeaders(),
105+
Prepare: p.prepareRequest,
87106
})
88107
if err != nil {
89108
return classifyError(err)
@@ -96,7 +115,8 @@ func (p *Provider) TestModel(ctx context.Context, modelID string) (*sdk.ModelTes
96115
Method: http.MethodGet,
97116
BaseURL: p.baseURL,
98117
Path: "/models/" + modelID,
99-
Headers: utils.AuthHeader(p.apiKey),
118+
Headers: p.authHeaders(),
119+
Prepare: p.prepareRequest,
100120
})
101121
if err == nil {
102122
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
@@ -112,7 +132,8 @@ func (p *Provider) TestModel(ctx context.Context, modelID string) (*sdk.ModelTes
112132
Method: http.MethodPost,
113133
BaseURL: p.baseURL,
114134
Path: "/chat/completions",
115-
Headers: utils.AuthHeader(p.apiKey),
135+
Headers: p.authHeaders(),
136+
Prepare: p.prepareRequest,
116137
Body: map[string]any{
117138
"model": modelID,
118139
"messages": []map[string]string{{"role": "user", "content": "hi"}},
@@ -147,7 +168,8 @@ func (p *Provider) DoGenerate(ctx context.Context, params sdk.GenerateParams) (*
147168
Method: http.MethodPost,
148169
BaseURL: p.baseURL,
149170
Path: "/chat/completions",
150-
Headers: utils.AuthHeader(p.apiKey),
171+
Headers: p.authHeaders(),
172+
Prepare: p.prepareRequest,
151173
Body: req,
152174
})
153175
if err != nil {
@@ -407,7 +429,8 @@ func (p *Provider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sd
407429
Method: http.MethodPost,
408430
BaseURL: p.baseURL,
409431
Path: "/chat/completions",
410-
Headers: utils.AuthHeader(p.apiKey),
432+
Headers: p.authHeaders(),
433+
Prepare: p.prepareRequest,
411434
Body: req,
412435
}, func(ev *utils.SSEEvent) error {
413436
if ev.Data == "[DONE]" {
@@ -444,6 +467,16 @@ func (p *Provider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sd
444467
return &sdk.StreamResult{Stream: ch}, nil
445468
}
446469

470+
func (p *Provider) authHeaders() map[string]string {
471+
if p.prepareRequest != nil {
472+
return nil
473+
}
474+
if p.apiKey == "" {
475+
return nil
476+
}
477+
return utils.AuthHeader(p.apiKey)
478+
}
479+
447480
type streamingToolCall struct {
448481
id string
449482
name string

0 commit comments

Comments
 (0)