Skip to content

Commit cfe2303

Browse files
committed
fix(ci): golang lint
1 parent 9e4d5d1 commit cfe2303

2 files changed

Lines changed: 64 additions & 31 deletions

File tree

provider/anthropic/messages/messages_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,7 @@ func TestDoGenerate_CacheControl_DetailedUsage(t *testing.T) {
10681068
}
10691069
}
10701070

1071-
1071+
func TestDoGenerate_CacheControl_BasicUsage(t *testing.T) {
10721072
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
10731073
w.Header().Set("Content-Type", "application/json")
10741074
json.NewEncoder(w).Encode(map[string]any{

provider/openai/responses/responses.go

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ const (
2222
)
2323

2424
type Provider struct {
25-
apiKey string
26-
baseURL string
27-
httpClient *http.Client
25+
apiKey string
26+
baseURL string
27+
httpClient *http.Client
28+
prepareRequest func(*http.Request) error
2829
}
2930

3031
type Option func(*Provider)
@@ -41,6 +42,22 @@ func WithHTTPClient(client *http.Client) Option {
4142
return func(p *Provider) { p.httpClient = client }
4243
}
4344

45+
// WithBedrockRegion enables AWS SigV4 authentication for Amazon Bedrock's
46+
// OpenAI-compatible Responses endpoint using the default AWS credential chain.
47+
func WithBedrockRegion(region string) Option {
48+
return func(p *Provider) {
49+
p.prepareRequest = utils.NewBedrockDefaultCredentialsPreparer(region)
50+
}
51+
}
52+
53+
// WithBedrockCredentials enables AWS SigV4 authentication for Amazon Bedrock's
54+
// OpenAI-compatible Responses endpoint using static credentials.
55+
func WithBedrockCredentials(region, accessKeyID, secretAccessKey, sessionToken string) Option {
56+
return func(p *Provider) {
57+
p.prepareRequest = utils.NewBedrockStaticCredentialsPreparer(region, accessKeyID, secretAccessKey, sessionToken)
58+
}
59+
}
60+
4461
func New(options ...Option) *Provider {
4562
p := &Provider{
4663
baseURL: defaultBaseURL,
@@ -56,10 +73,11 @@ func (p *Provider) Name() string { return "openai-responses" }
5673

5774
func (p *Provider) ListModels(ctx context.Context) ([]sdk.Model, error) {
5875
resp, err := utils.FetchJSON[modelsListResponse](ctx, p.httpClient, &utils.RequestOptions{
59-
Method: http.MethodGet,
60-
BaseURL: p.baseURL,
61-
Path: "/models",
62-
Headers: utils.AuthHeader(p.apiKey),
76+
Method: http.MethodGet,
77+
BaseURL: p.baseURL,
78+
Path: "/models",
79+
Headers: p.authHeaders(),
80+
Prepare: p.prepareRequest,
6381
})
6482
if err != nil {
6583
return nil, fmt.Errorf("openai-responses: list models request failed: %w", err)
@@ -78,11 +96,12 @@ func (p *Provider) ListModels(ctx context.Context) ([]sdk.Model, error) {
7896

7997
func (p *Provider) Test(ctx context.Context) *sdk.ProviderTestResult {
8098
_, err := utils.FetchJSON[modelsListResponse](ctx, p.httpClient, &utils.RequestOptions{
81-
Method: http.MethodGet,
82-
BaseURL: p.baseURL,
83-
Path: "/models",
84-
Query: map[string]string{"limit": "1"},
85-
Headers: utils.AuthHeader(p.apiKey),
99+
Method: http.MethodGet,
100+
BaseURL: p.baseURL,
101+
Path: "/models",
102+
Query: map[string]string{"limit": "1"},
103+
Headers: p.authHeaders(),
104+
Prepare: p.prepareRequest,
86105
})
87106
if err != nil {
88107
return classifyError(err)
@@ -92,10 +111,11 @@ func (p *Provider) Test(ctx context.Context) *sdk.ProviderTestResult {
92111

93112
func (p *Provider) TestModel(ctx context.Context, modelID string) (*sdk.ModelTestResult, error) {
94113
_, err := utils.FetchJSON[modelObject](ctx, p.httpClient, &utils.RequestOptions{
95-
Method: http.MethodGet,
96-
BaseURL: p.baseURL,
97-
Path: "/models/" + modelID,
98-
Headers: utils.AuthHeader(p.apiKey),
114+
Method: http.MethodGet,
115+
BaseURL: p.baseURL,
116+
Path: "/models/" + modelID,
117+
Headers: p.authHeaders(),
118+
Prepare: p.prepareRequest,
99119
})
100120
if err == nil {
101121
return &sdk.ModelTestResult{Supported: true, Message: "supported"}, nil
@@ -106,10 +126,11 @@ func (p *Provider) TestModel(ctx context.Context, modelID string) (*sdk.ModelTes
106126
}
107127

108128
status, probeErr := utils.ProbeStatus(ctx, p.httpClient, &utils.RequestOptions{
109-
Method: http.MethodPost,
110-
BaseURL: p.baseURL,
111-
Path: "/responses",
112-
Headers: utils.AuthHeader(p.apiKey),
129+
Method: http.MethodPost,
130+
BaseURL: p.baseURL,
131+
Path: "/responses",
132+
Headers: p.authHeaders(),
133+
Prepare: p.prepareRequest,
113134
Body: map[string]any{
114135
"model": modelID,
115136
"input": "hi",
@@ -130,6 +151,16 @@ func (p *Provider) ChatModel(id string) *sdk.Model {
130151
}
131152
}
132153

154+
func (p *Provider) authHeaders() map[string]string {
155+
if p.prepareRequest != nil {
156+
return nil
157+
}
158+
if p.apiKey == "" {
159+
return nil
160+
}
161+
return utils.AuthHeader(p.apiKey)
162+
}
163+
133164
// ---------- DoGenerate ----------
134165

135166
func (p *Provider) DoGenerate(ctx context.Context, params sdk.GenerateParams) (*sdk.GenerateResult, error) { //nolint:gocritic // interface method
@@ -140,11 +171,12 @@ func (p *Provider) DoGenerate(ctx context.Context, params sdk.GenerateParams) (*
140171
req := p.buildRequest(&params)
141172

142173
resp, err := utils.FetchJSON[responsesResponse](ctx, p.httpClient, &utils.RequestOptions{
143-
Method: http.MethodPost,
144-
BaseURL: p.baseURL,
145-
Path: "/responses",
146-
Headers: utils.AuthHeader(p.apiKey),
147-
Body: req,
174+
Method: http.MethodPost,
175+
BaseURL: p.baseURL,
176+
Path: "/responses",
177+
Headers: p.authHeaders(),
178+
Prepare: p.prepareRequest,
179+
Body: req,
148180
})
149181
if err != nil {
150182
var apiErr *utils.APIError
@@ -488,11 +520,12 @@ func (p *Provider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sd
488520
}
489521

490522
err := utils.FetchSSE(ctx, p.httpClient, &utils.RequestOptions{
491-
Method: http.MethodPost,
492-
BaseURL: p.baseURL,
493-
Path: "/responses",
494-
Headers: utils.AuthHeader(p.apiKey),
495-
Body: req,
523+
Method: http.MethodPost,
524+
BaseURL: p.baseURL,
525+
Path: "/responses",
526+
Headers: p.authHeaders(),
527+
Prepare: p.prepareRequest,
528+
Body: req,
496529
}, func(ev *utils.SSEEvent) error {
497530
eventType := ev.Event
498531
if eventType == "" {

0 commit comments

Comments
 (0)