@@ -15,9 +15,10 @@ import (
1515const defaultBaseURL = "https://api.openai.com/v1"
1616
1717type Provider struct {
18- apiKey string
19- baseURL string
20- httpClient * http.Client
18+ apiKey string
19+ baseURL string
20+ httpClient * http.Client
21+ prepareRequest func (* http.Request ) error
2122}
2223
2324type Option func (* Provider )
@@ -28,6 +29,22 @@ func WithAPIKey(apiKey string) Option {
2829 }
2930}
3031
32+ // WithBedrockRegion enables AWS SigV4 authentication for Amazon Bedrock's
33+ // OpenAI-compatible endpoint using the default AWS credential chain.
34+ func WithBedrockRegion (region string ) Option {
35+ return func (p * Provider ) {
36+ p .prepareRequest = utils .NewBedrockDefaultCredentialsPreparer (region )
37+ }
38+ }
39+
40+ // WithBedrockCredentials enables AWS SigV4 authentication for Amazon Bedrock's
41+ // OpenAI-compatible endpoint using static credentials.
42+ func WithBedrockCredentials (region , accessKeyID , secretAccessKey , sessionToken string ) Option {
43+ return func (p * Provider ) {
44+ p .prepareRequest = utils .NewBedrockStaticCredentialsPreparer (region , accessKeyID , secretAccessKey , sessionToken )
45+ }
46+ }
47+
3148func WithBaseURL (baseURL string ) Option {
3249 return func (p * Provider ) {
3350 p .baseURL = baseURL
@@ -60,7 +77,8 @@ func (p *Provider) ListModels(ctx context.Context) ([]sdk.Model, error) {
6077 Method : http .MethodGet ,
6178 BaseURL : p .baseURL ,
6279 Path : "/models" ,
63- Headers : utils .AuthHeader (p .apiKey ),
80+ Headers : p .authHeaders (),
81+ Prepare : p .prepareRequest ,
6482 })
6583 if err != nil {
6684 return nil , fmt .Errorf ("openai: list models request failed: %w" , err )
@@ -83,7 +101,8 @@ func (p *Provider) Test(ctx context.Context) *sdk.ProviderTestResult {
83101 BaseURL : p .baseURL ,
84102 Path : "/models" ,
85103 Query : map [string ]string {"limit" : "1" },
86- Headers : utils .AuthHeader (p .apiKey ),
104+ Headers : p .authHeaders (),
105+ Prepare : p .prepareRequest ,
87106 })
88107 if err != nil {
89108 return classifyError (err )
@@ -96,7 +115,8 @@ func (p *Provider) TestModel(ctx context.Context, modelID string) (*sdk.ModelTes
96115 Method : http .MethodGet ,
97116 BaseURL : p .baseURL ,
98117 Path : "/models/" + modelID ,
99- Headers : utils .AuthHeader (p .apiKey ),
118+ Headers : p .authHeaders (),
119+ Prepare : p .prepareRequest ,
100120 })
101121 if err == nil {
102122 return & sdk.ModelTestResult {Supported : true , Message : "supported" }, nil
@@ -112,7 +132,8 @@ func (p *Provider) TestModel(ctx context.Context, modelID string) (*sdk.ModelTes
112132 Method : http .MethodPost ,
113133 BaseURL : p .baseURL ,
114134 Path : "/chat/completions" ,
115- Headers : utils .AuthHeader (p .apiKey ),
135+ Headers : p .authHeaders (),
136+ Prepare : p .prepareRequest ,
116137 Body : map [string ]any {
117138 "model" : modelID ,
118139 "messages" : []map [string ]string {{"role" : "user" , "content" : "hi" }},
@@ -147,7 +168,8 @@ func (p *Provider) DoGenerate(ctx context.Context, params sdk.GenerateParams) (*
147168 Method : http .MethodPost ,
148169 BaseURL : p .baseURL ,
149170 Path : "/chat/completions" ,
150- Headers : utils .AuthHeader (p .apiKey ),
171+ Headers : p .authHeaders (),
172+ Prepare : p .prepareRequest ,
151173 Body : req ,
152174 })
153175 if err != nil {
@@ -407,7 +429,8 @@ func (p *Provider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sd
407429 Method : http .MethodPost ,
408430 BaseURL : p .baseURL ,
409431 Path : "/chat/completions" ,
410- Headers : utils .AuthHeader (p .apiKey ),
432+ Headers : p .authHeaders (),
433+ Prepare : p .prepareRequest ,
411434 Body : req ,
412435 }, func (ev * utils.SSEEvent ) error {
413436 if ev .Data == "[DONE]" {
@@ -444,6 +467,16 @@ func (p *Provider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sd
444467 return & sdk.StreamResult {Stream : ch }, nil
445468}
446469
470+ func (p * Provider ) authHeaders () map [string ]string {
471+ if p .prepareRequest != nil {
472+ return nil
473+ }
474+ if p .apiKey == "" {
475+ return nil
476+ }
477+ return utils .AuthHeader (p .apiKey )
478+ }
479+
447480type streamingToolCall struct {
448481 id string
449482 name string
0 commit comments