Skip to content

Commit b865103

Browse files
authored
refactor(bedrock): split bedrock plugin implementation by concern (#140)
* refactor(bedrock): split bedrock plugin implementation by concern * fix(image.go, generate.go, stream.go): address code review
1 parent 95dc2b9 commit b865103

10 files changed

Lines changed: 1901 additions & 1711 deletions

File tree

SECURITY.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ Our security response team consists of:
100100

101101
This security policy applies to:
102102

103-
- The core plugin code (`bedrock_plugin.go`)
103+
- The core plugin code (`bedrock.go`)
104104
- All example applications
105105
- Documentation that could impact security
106106
- Dependencies and their configurations

bedrock.go

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
// Copyright 2025 Xavier Portilla Edo
2+
// Copyright 2025 Google LLC
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//
16+
// SPDX-License-Identifier: Apache-2.0
17+
18+
// Package bedrock provides a comprehensive AWS Bedrock plugin for Genkit Go.
19+
// This plugin supports text generation, image generation, and embedding capabilities
20+
// using AWS Bedrock foundation models via the Converse API.
21+
//
22+
// This implementation follows the same patterns as the existing Genkit plugins:
23+
// - ollama: https://github.com/firebase/genkit/blob/main/go/plugins/ollama/ollama.go
24+
// - gemini: https://github.com/firebase/genkit/blob/main/go/plugins/googlegenai/gemini.go
25+
package bedrock
26+
27+
import (
28+
"context"
29+
"fmt"
30+
"sync"
31+
"time"
32+
33+
"github.com/aws/aws-sdk-go-v2/aws"
34+
"github.com/aws/aws-sdk-go-v2/config"
35+
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
36+
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
37+
"github.com/firebase/genkit/go/ai"
38+
"github.com/firebase/genkit/go/core/api"
39+
"github.com/firebase/genkit/go/genkit"
40+
)
41+
42+
// Bedrock provides configuration options for the AWS Bedrock plugin.
43+
type Bedrock struct {
44+
Region string // AWS region (optional, uses AWS_REGION or us-east-1)
45+
MaxRetries int // Maximum number of retries (default: 3)
46+
RequestTimeout time.Duration // Request timeout (default: 30s)
47+
AWSConfig *aws.Config // Custom AWS config (optional)
48+
49+
mu sync.Mutex // Mutex to control access
50+
client BedrockClient
51+
initted bool // Whether the plugin has been initialized
52+
}
53+
54+
// Name returns the provider name.
55+
func (b *Bedrock) Name() string {
56+
return provider
57+
}
58+
59+
// Init initializes the AWS Bedrock plugin.
60+
// This method follows the same pattern as the Ollama plugin.
61+
func (b *Bedrock) Init(ctx context.Context) []api.Action {
62+
b.mu.Lock()
63+
64+
if b.initted {
65+
b.mu.Unlock()
66+
panic("bedrock: Init already called")
67+
}
68+
69+
// Set defaults
70+
if b.Region == "" {
71+
b.Region = "us-east-1" // Default region
72+
}
73+
if b.MaxRetries == 0 {
74+
b.MaxRetries = 3
75+
}
76+
if b.RequestTimeout == 0 {
77+
b.RequestTimeout = 30 * time.Second
78+
}
79+
80+
// Load AWS configuration
81+
var awsConfig aws.Config
82+
var err error
83+
84+
if b.AWSConfig != nil {
85+
awsConfig = *b.AWSConfig
86+
} else {
87+
// Load default AWS configuration
88+
awsConfig, err = config.LoadDefaultConfig(ctx,
89+
config.WithRegion(b.Region),
90+
config.WithRetryMaxAttempts(b.MaxRetries),
91+
)
92+
if err != nil {
93+
panic(fmt.Sprintf("bedrock: failed to load AWS config: %v", err))
94+
}
95+
}
96+
97+
// Create Bedrock Runtime client
98+
b.client = bedrockruntime.NewFromConfig(awsConfig)
99+
100+
b.initted = true
101+
102+
// Release the mutex
103+
b.mu.Unlock()
104+
105+
// Don't defer unlock since we already unlocked manually
106+
return []api.Action{}
107+
}
108+
109+
// DefineModel defines a model in the registry.
110+
// This follows the same pattern as the Anthropic plugin's DefineModel method.
111+
func (b *Bedrock) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai.Model {
112+
b.mu.Lock()
113+
defer b.mu.Unlock()
114+
115+
if !b.initted {
116+
panic("bedrock: Init not called")
117+
}
118+
119+
// Auto-detect model capabilities if not provided
120+
if info == nil {
121+
info = b.inferModelCapabilities(model.Name, model.Type)
122+
}
123+
124+
// Create model metadata
125+
meta := &ai.ModelOptions{
126+
Label: provider + "-" + model.Name,
127+
Supports: info.Supports,
128+
Versions: info.Versions,
129+
}
130+
131+
// Create the model function based on model type
132+
switch model.Type {
133+
case "image":
134+
return genkit.DefineModel(g, api.NewName(provider, model.Name), meta, func(
135+
ctx context.Context,
136+
input *ai.ModelRequest,
137+
cb func(context.Context, *ai.ModelResponseChunk) error,
138+
) (*ai.ModelResponse, error) {
139+
return b.generateImage(ctx, model.Name, input, cb)
140+
})
141+
default:
142+
return genkit.DefineModel(g, api.NewName(provider, model.Name), meta, func(
143+
ctx context.Context,
144+
input *ai.ModelRequest,
145+
cb func(context.Context, *ai.ModelResponseChunk) error,
146+
) (*ai.ModelResponse, error) {
147+
return b.generateText(ctx, model.Name, input, cb)
148+
})
149+
}
150+
}
151+
152+
// DefineEmbedder defines an embedder in the registry.
153+
func (b *Bedrock) DefineEmbedder(g *genkit.Genkit, modelName string) ai.Embedder {
154+
b.mu.Lock()
155+
defer b.mu.Unlock()
156+
157+
if !b.initted {
158+
panic("bedrock: Init not called")
159+
}
160+
161+
return genkit.DefineEmbedder(g, api.NewName(provider, modelName), nil, func(
162+
ctx context.Context,
163+
req *ai.EmbedRequest,
164+
) (*ai.EmbedResponse, error) {
165+
return b.embed(ctx, modelName, req)
166+
})
167+
}
168+
169+
// IsDefinedModel reports whether a model is defined.
170+
func IsDefinedModel(g *genkit.Genkit, name string) bool {
171+
return genkit.LookupModel(g, api.NewName(provider, name)) != nil
172+
}
173+
174+
// Model returns the Model with the given name.
175+
func Model(g *genkit.Genkit, name string) ai.Model {
176+
return genkit.LookupModel(g, api.NewName(provider, name))
177+
}
178+
179+
// DefineCommonModels is a helper to define commonly used models
180+
func DefineCommonModels(b *Bedrock, g *genkit.Genkit) map[string]ai.Model {
181+
models := make(map[string]ai.Model)
182+
183+
// Text generation models
184+
claudeHaiku := b.DefineModel(g, ModelDefinition{
185+
Name: "anthropic.claude-3-haiku-20240307-v1:0",
186+
Type: "chat",
187+
}, nil)
188+
models["claude-haiku"] = claudeHaiku
189+
190+
claudeSonnet := b.DefineModel(g, ModelDefinition{
191+
Name: "anthropic.claude-3-5-sonnet-20241022-v2:0",
192+
Type: "chat",
193+
}, nil)
194+
models["claude-sonnet"] = claudeSonnet
195+
196+
// Claude 4 models
197+
claudeOpus4 := b.DefineModel(g, ModelDefinition{
198+
Name: "anthropic.claude-opus-4-20250514-v1:0",
199+
Type: "chat",
200+
}, nil)
201+
models["claude-opus-4"] = claudeOpus4
202+
203+
claudeSonnet4 := b.DefineModel(g, ModelDefinition{
204+
Name: "anthropic.claude-sonnet-4-20250514-v1:0",
205+
Type: "chat",
206+
}, nil)
207+
models["claude-sonnet-4"] = claudeSonnet4
208+
209+
// Claude 3.7 Sonnet
210+
claude37Sonnet := b.DefineModel(g, ModelDefinition{
211+
Name: "anthropic.claude-3-7-sonnet-20250219-v1:0",
212+
Type: "chat",
213+
}, nil)
214+
models["claude-3-7-sonnet"] = claude37Sonnet
215+
216+
// Amazon Nova models
217+
novaMicro := b.DefineModel(g, ModelDefinition{
218+
Name: "amazon.nova-micro-v1:0",
219+
Type: "chat",
220+
}, nil)
221+
models["nova-micro"] = novaMicro
222+
223+
novaLite := b.DefineModel(g, ModelDefinition{
224+
Name: "amazon.nova-lite-v1:0",
225+
Type: "chat",
226+
}, nil)
227+
models["nova-lite"] = novaLite
228+
229+
novaPro := b.DefineModel(g, ModelDefinition{
230+
Name: "amazon.nova-pro-v1:0",
231+
Type: "chat",
232+
}, nil)
233+
models["nova-pro"] = novaPro
234+
235+
// Legacy models for backward compatibility
236+
titanText := b.DefineModel(g, ModelDefinition{
237+
Name: "amazon.titan-text-premier-v1:0",
238+
Type: "chat",
239+
}, nil)
240+
models["titan-text"] = titanText
241+
242+
// Meta Llama models
243+
llama3_8b := b.DefineModel(g, ModelDefinition{
244+
Name: "meta.llama3-8b-instruct-v1:0",
245+
Type: "chat",
246+
}, nil)
247+
models["llama3-8b"] = llama3_8b
248+
249+
llama3_1_8b := b.DefineModel(g, ModelDefinition{
250+
Name: "meta.llama3-1-8b-instruct-v1:0",
251+
Type: "chat",
252+
}, nil)
253+
models["llama3-1-8b"] = llama3_1_8b
254+
255+
llama3_2_3b := b.DefineModel(g, ModelDefinition{
256+
Name: "meta.llama3-2-3b-instruct-v1:0",
257+
Type: "chat",
258+
}, nil)
259+
models["llama3-2-3b"] = llama3_2_3b
260+
261+
// New Llama 4 models
262+
llama4Maverick := b.DefineModel(g, ModelDefinition{
263+
Name: "meta.llama4-maverick-17b-instruct-v1:0",
264+
Type: "chat",
265+
}, nil)
266+
models["llama4-maverick"] = llama4Maverick
267+
268+
llama4Scout := b.DefineModel(g, ModelDefinition{
269+
Name: "meta.llama4-scout-17b-instruct-v1:0",
270+
Type: "chat",
271+
}, nil)
272+
models["llama4-scout"] = llama4Scout
273+
274+
// DeepSeek R1 model
275+
deepseekR1 := b.DefineModel(g, ModelDefinition{
276+
Name: "deepseek.r1-v1:0",
277+
Type: "chat",
278+
}, nil)
279+
models["deepseek-r1"] = deepseekR1
280+
281+
// Image generation models
282+
titanImage := b.DefineModel(g, ModelDefinition{
283+
Name: "amazon.titan-image-generator-v1",
284+
Type: "image",
285+
}, nil)
286+
models["titan-image"] = titanImage
287+
288+
novaCanvas := b.DefineModel(g, ModelDefinition{
289+
Name: "amazon.nova-canvas-v1:0",
290+
Type: "image",
291+
}, nil)
292+
models["nova-canvas"] = novaCanvas
293+
294+
return models
295+
}
296+
297+
// DefineCommonEmbedders is a helper to define commonly used embedders
298+
func DefineCommonEmbedders(b *Bedrock, g *genkit.Genkit) map[string]ai.Embedder {
299+
embedders := make(map[string]ai.Embedder)
300+
301+
// Amazon Titan Embeddings
302+
titanEmbed := b.DefineEmbedder(g, "amazon.titan-embed-text-v1")
303+
embedders["titan-embed"] = titanEmbed
304+
305+
titanEmbedV2 := b.DefineEmbedder(g, "amazon.titan-embed-text-v2:0")
306+
embedders["titan-embed-v2"] = titanEmbedV2
307+
308+
titanMultimodal := b.DefineEmbedder(g, "amazon.titan-embed-image-v1")
309+
embedders["titan-multimodal"] = titanMultimodal
310+
311+
// Cohere Embeddings
312+
cohereEmbed := b.DefineEmbedder(g, "cohere.embed-english-v3")
313+
embedders["cohere-embed"] = cohereEmbed
314+
315+
cohereMultilingual := b.DefineEmbedder(g, "cohere.embed-multilingual-v3")
316+
embedders["cohere-multilingual"] = cohereMultilingual
317+
318+
return embedders
319+
}
320+
321+
// NewCachePointPart creates and returns a new ai.Part instance representing a cache point part
322+
// with the default cache point type. A cache point should be inserted after a big static prompt
323+
// that is reused across multiple requests to optimize token usage.
324+
func NewCachePointPart() *ai.Part {
325+
return ai.NewCustomPart(map[string]any{
326+
bedrockCachePointTypeKey: types.CachePointTypeDefault,
327+
})
328+
}
329+
330+
// CachePointType retrieves the CachePointType value from the Custom field of the given ai.Part.
331+
// It returns the CachePointType and a boolean indicating whether the value was found and successfully asserted.
332+
func CachePointType(part *ai.Part) (types.CachePointType, bool) {
333+
cachePointTypeVal, ok := part.Custom[bedrockCachePointTypeKey]
334+
if !ok {
335+
return "", false
336+
}
337+
cpt, ok := cachePointTypeVal.(types.CachePointType)
338+
return cpt, ok
339+
}

0 commit comments

Comments
 (0)