Skip to content

Commit 8dc3b57

Browse files
authored
feat: add Bedrock reasoning support (#142)
* feat(reasoning): add Bedrock reasoning support * fix(test): check fmt.Fprint error in rerank mock handlers * fix(bedrock): fix live Claude test Bedrock plugin reuse
1 parent 2865e1b commit 8dc3b57

9 files changed

Lines changed: 913 additions & 50 deletions

File tree

bedrock.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,10 @@ func (b *Bedrock) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.
133133

134134
// Create model metadata
135135
meta := &ai.ModelOptions{
136-
Label: provider + "-" + model.Name,
137-
Supports: info.Supports,
138-
Versions: info.Versions,
136+
Label: provider + "-" + model.Name,
137+
Supports: info.Supports,
138+
Versions: info.Versions,
139+
ConfigSchema: configSchema(),
139140
}
140141

141142
// Create the model function based on model type

bedrock_live_test.go

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
// Copyright 2025 Xavier Portilla Edo
2+
// Copyright 2025 Google LLC
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//
16+
// SPDX-License-Identifier: Apache-2.0
17+
18+
package bedrock
19+
20+
// Live tests exercise reasoning ("thinking") against a real Bedrock endpoint.
21+
// They are skipped by default and only run when the required model flags are
22+
// passed, e.g.:
23+
//
24+
// go test -run TestBedrockLive_ClaudeReasoning \
25+
// -test-bedrock-region=us-east-1 \
26+
// -test-bedrock-model-claude=us.anthropic.claude-haiku-4-5-20251001-v1:0
27+
//
28+
// They require AWS credentials in the environment and model access granted in
29+
// the target region. Reasoning support is region- and model-scoped on Bedrock;
30+
// these tests validate that the plugin's request/response shape round-trips,
31+
// not that any particular model is granted.
32+
33+
import (
34+
"context"
35+
"flag"
36+
"testing"
37+
38+
"github.com/firebase/genkit/go/ai"
39+
"github.com/firebase/genkit/go/genkit"
40+
)
41+
42+
var (
43+
testRegion = flag.String("test-bedrock-region", "", "AWS region for Bedrock live tests (e.g. us-east-1)")
44+
testModelClaude = flag.String("test-bedrock-model-claude", "", "Thinking-capable Claude model ID (e.g. us.anthropic.claude-haiku-4-5-20251001-v1:0)")
45+
)
46+
47+
// reasoningBudgetTokens is the extended-thinking budget. Bedrock requires it to
48+
// be at least 1024, and MaxTokens must exceed it.
49+
const reasoningBudgetTokens = 1024
50+
51+
// requireLiveClaude asserts the live-test prerequisites and skips otherwise. It
52+
// returns a Genkit instance with the Bedrock plugin and a defined Claude model.
53+
func requireLiveClaude(t *testing.T) (context.Context, *genkit.Genkit, ai.Model) {
54+
t.Helper()
55+
if *testRegion == "" {
56+
t.Skip("bedrock live tests skipped; pass -test-bedrock-region=<region>")
57+
}
58+
if *testModelClaude == "" {
59+
t.Skip("pass -test-bedrock-model-claude=<thinking-capable-model-id> to run")
60+
}
61+
ctx := context.Background()
62+
pb := &Bedrock{Region: *testRegion}
63+
g := genkit.Init(ctx, genkit.WithPlugins(pb))
64+
m := pb.DefineModel(g, ModelDefinition{
65+
Name: *testModelClaude,
66+
Type: "chat",
67+
}, nil)
68+
return ctx, g, m
69+
}
70+
71+
// thinkingConfig enables Claude extended thinking via AdditionalModelRequestFields.
72+
// Temperature is intentionally left unset — Bedrock rejects thinking requests
73+
// that also set a custom temperature.
74+
func thinkingConfig() *Config {
75+
return &Config{
76+
MaxTokens: reasoningBudgetTokens + 1024,
77+
AdditionalModelRequestFields: map[string]any{
78+
"thinking": map[string]any{
79+
"type": "enabled",
80+
"budget_tokens": reasoningBudgetTokens,
81+
},
82+
},
83+
}
84+
}
85+
86+
// firstReasoning returns the first reasoning part in a message, or nil.
87+
func firstReasoning(msg *ai.Message) *ai.Part {
88+
if msg == nil {
89+
return nil
90+
}
91+
for _, p := range msg.Content {
92+
if p.IsReasoning() {
93+
return p
94+
}
95+
}
96+
return nil
97+
}
98+
99+
// TestBedrockLive_ClaudeReasoningSync confirms a thinking-enabled request comes
100+
// back with a signed reasoning part, and that the plain text answer is still
101+
// surfaced via Text() (i.e. reasoning doesn't leak into normal output).
102+
func TestBedrockLive_ClaudeReasoningSync(t *testing.T) {
103+
ctx, g, m := requireLiveClaude(t)
104+
105+
resp, err := genkit.Generate(ctx, g,
106+
ai.WithModel(m),
107+
ai.WithPrompt("What is 17 * 24? Think it through step by step, then give the answer."),
108+
ai.WithConfig(thinkingConfig()),
109+
)
110+
if err != nil {
111+
t.Fatal(err)
112+
}
113+
114+
reasoning := firstReasoning(resp.Message)
115+
if reasoning == nil {
116+
t.Fatal("expected a reasoning part in the response; got none")
117+
}
118+
if sig := metadataBytes(reasoning.Metadata, reasoningSignatureMetadataKey); len(sig) == 0 {
119+
t.Error("reasoning part is missing its Bedrock signature")
120+
}
121+
if resp.Text() == "" {
122+
t.Error("final response text is empty")
123+
}
124+
}
125+
126+
// TestBedrockLive_ClaudeReasoningRoundTrip is the real proof of the feature: it
127+
// feeds a thinking response back as conversation history and confirms the
128+
// follow-up turn is accepted. If the signed/redacted reasoning weren't
129+
// round-tripped verbatim, Bedrock rejects the request.
130+
func TestBedrockLive_ClaudeReasoningRoundTrip(t *testing.T) {
131+
ctx, g, m := requireLiveClaude(t)
132+
133+
turn1 := ai.NewUserTextMessage("What is 17 * 24? Show your reasoning, then state the result.")
134+
resp1, err := genkit.Generate(ctx, g,
135+
ai.WithModel(m),
136+
ai.WithMessages(turn1),
137+
ai.WithConfig(thinkingConfig()),
138+
)
139+
if err != nil {
140+
t.Fatal(err)
141+
}
142+
if firstReasoning(resp1.Message) == nil {
143+
t.Fatal("first turn produced no reasoning part; cannot exercise round-trip")
144+
}
145+
146+
// Replay the assistant turn (reasoning included) plus a follow-up question.
147+
resp2, err := genkit.Generate(ctx, g,
148+
ai.WithModel(m),
149+
ai.WithMessages(
150+
turn1,
151+
resp1.Message,
152+
ai.NewUserTextMessage("Now multiply that result by 2."),
153+
),
154+
ai.WithConfig(thinkingConfig()),
155+
)
156+
if err != nil {
157+
t.Fatalf("follow-up turn rejected (reasoning round-trip likely broken): %v", err)
158+
}
159+
if resp2.Text() == "" {
160+
t.Error("follow-up response text is empty")
161+
}
162+
}
163+
164+
// TestBedrockLive_ClaudeReasoningStream confirms reasoning deltas stream through
165+
// to the callback and the final response carries an assembled reasoning part.
166+
func TestBedrockLive_ClaudeReasoningStream(t *testing.T) {
167+
ctx, g, m := requireLiveClaude(t)
168+
169+
var reasoningChunks, textChunks int
170+
resp, err := genkit.Generate(ctx, g,
171+
ai.WithModel(m),
172+
ai.WithPrompt("What is 17 * 24? Think it through, then answer."),
173+
ai.WithConfig(thinkingConfig()),
174+
ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error {
175+
for _, p := range c.Content {
176+
switch {
177+
case p.IsReasoning():
178+
reasoningChunks++
179+
case p.IsText():
180+
textChunks++
181+
}
182+
}
183+
return nil
184+
}),
185+
)
186+
if err != nil {
187+
t.Fatal(err)
188+
}
189+
if reasoningChunks == 0 {
190+
t.Error("expected at least one reasoning chunk")
191+
}
192+
if firstReasoning(resp.Message) == nil {
193+
t.Error("final response is missing the assembled reasoning part")
194+
}
195+
if resp.Text() == "" {
196+
t.Error("final response text is empty")
197+
}
198+
}

bedrock_plugin_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@
1818
package bedrock
1919

2020
import (
21+
"context"
2122
"encoding/base64"
2223
"testing"
2324

25+
"github.com/aws/aws-sdk-go-v2/aws"
26+
"github.com/aws/aws-sdk-go-v2/credentials"
2427
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
2528
"github.com/firebase/genkit/go/ai"
29+
"github.com/firebase/genkit/go/genkit"
2630
)
2731

2832
func TestInferModelCapabilities_WithInferenceProfiles(t *testing.T) {
@@ -186,6 +190,50 @@ func TestInferenceProfilePrefixes_Coverage(t *testing.T) {
186190
}
187191
}
188192

193+
func TestDefineModelRequiresInitializedPluginInstance(t *testing.T) {
194+
ctx := context.Background()
195+
b := &Bedrock{
196+
Region: "us-east-1",
197+
AWSConfig: &aws.Config{
198+
Region: "us-east-1",
199+
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(
200+
"test-access-key",
201+
"test-secret-key",
202+
"",
203+
)),
204+
},
205+
}
206+
g := genkit.Init(ctx, genkit.WithPlugins(b))
207+
208+
if got := b.DefineModel(g, ModelDefinition{
209+
Name: "anthropic.claude-3-haiku-20240307-v1:0",
210+
Type: "chat",
211+
}, nil); got == nil {
212+
t.Fatal("DefineModel returned nil for initialized plugin")
213+
}
214+
215+
assertPanicsWith(t, "bedrock: Init not called", func() {
216+
(&Bedrock{Region: "us-east-1"}).DefineModel(g, ModelDefinition{
217+
Name: "anthropic.claude-3-haiku-20240307-v1:0",
218+
Type: "chat",
219+
}, nil)
220+
})
221+
}
222+
223+
func assertPanicsWith(t *testing.T, want string, fn func()) {
224+
t.Helper()
225+
defer func() {
226+
got := recover()
227+
if got == nil {
228+
t.Fatalf("expected panic %q, got none", want)
229+
}
230+
if got != want {
231+
t.Fatalf("panic = %v, want %q", got, want)
232+
}
233+
}()
234+
fn()
235+
}
236+
189237
func TestModelCapabilities_KnownModels(t *testing.T) {
190238
// Verify some known models are in the capability map with correct values
191239
tests := []struct {

0 commit comments

Comments
 (0)