diff --git a/bedrock.go b/bedrock.go index b940a20..1cbb3f3 100644 --- a/bedrock.go +++ b/bedrock.go @@ -133,9 +133,10 @@ func (b *Bedrock) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai. // Create model metadata meta := &ai.ModelOptions{ - Label: provider + "-" + model.Name, - Supports: info.Supports, - Versions: info.Versions, + Label: provider + "-" + model.Name, + Supports: info.Supports, + Versions: info.Versions, + ConfigSchema: configSchema(), } // Create the model function based on model type diff --git a/bedrock_live_test.go b/bedrock_live_test.go new file mode 100644 index 0000000..a32e98c --- /dev/null +++ b/bedrock_live_test.go @@ -0,0 +1,198 @@ +// Copyright 2025 Xavier Portilla Edo +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package bedrock + +// Live tests exercise reasoning ("thinking") against a real Bedrock endpoint. +// They are skipped by default and only run when the required model flags are +// passed, e.g.: +// +// go test -run TestBedrockLive_ClaudeReasoning \ +// -test-bedrock-region=us-east-1 \ +// -test-bedrock-model-claude=us.anthropic.claude-haiku-4-5-20251001-v1:0 +// +// They require AWS credentials in the environment and model access granted in +// the target region. Reasoning support is region- and model-scoped on Bedrock; +// these tests validate that the plugin's request/response shape round-trips, +// not that any particular model is granted. + +import ( + "context" + "flag" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" +) + +var ( + testRegion = flag.String("test-bedrock-region", "", "AWS region for Bedrock live tests (e.g. us-east-1)") + testModelClaude = flag.String("test-bedrock-model-claude", "", "Thinking-capable Claude model ID (e.g. us.anthropic.claude-haiku-4-5-20251001-v1:0)") +) + +// reasoningBudgetTokens is the extended-thinking budget. Bedrock requires it to +// be at least 1024, and MaxTokens must exceed it. +const reasoningBudgetTokens = 1024 + +// requireLiveClaude asserts the live-test prerequisites and skips otherwise. It +// returns a Genkit instance with the Bedrock plugin and a defined Claude model. +func requireLiveClaude(t *testing.T) (context.Context, *genkit.Genkit, ai.Model) { + t.Helper() + if *testRegion == "" { + t.Skip("bedrock live tests skipped; pass -test-bedrock-region=") + } + if *testModelClaude == "" { + t.Skip("pass -test-bedrock-model-claude= to run") + } + ctx := context.Background() + pb := &Bedrock{Region: *testRegion} + g := genkit.Init(ctx, genkit.WithPlugins(pb)) + m := pb.DefineModel(g, ModelDefinition{ + Name: *testModelClaude, + Type: "chat", + }, nil) + return ctx, g, m +} + +// thinkingConfig enables Claude extended thinking via AdditionalModelRequestFields. +// Temperature is intentionally left unset — Bedrock rejects thinking requests +// that also set a custom temperature. +func thinkingConfig() *Config { + return &Config{ + MaxTokens: reasoningBudgetTokens + 1024, + AdditionalModelRequestFields: map[string]any{ + "thinking": map[string]any{ + "type": "enabled", + "budget_tokens": reasoningBudgetTokens, + }, + }, + } +} + +// firstReasoning returns the first reasoning part in a message, or nil. +func firstReasoning(msg *ai.Message) *ai.Part { + if msg == nil { + return nil + } + for _, p := range msg.Content { + if p.IsReasoning() { + return p + } + } + return nil +} + +// TestBedrockLive_ClaudeReasoningSync confirms a thinking-enabled request comes +// back with a signed reasoning part, and that the plain text answer is still +// surfaced via Text() (i.e. reasoning doesn't leak into normal output). +func TestBedrockLive_ClaudeReasoningSync(t *testing.T) { + ctx, g, m := requireLiveClaude(t) + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithPrompt("What is 17 * 24? Think it through step by step, then give the answer."), + ai.WithConfig(thinkingConfig()), + ) + if err != nil { + t.Fatal(err) + } + + reasoning := firstReasoning(resp.Message) + if reasoning == nil { + t.Fatal("expected a reasoning part in the response; got none") + } + if sig := metadataBytes(reasoning.Metadata, reasoningSignatureMetadataKey); len(sig) == 0 { + t.Error("reasoning part is missing its Bedrock signature") + } + if resp.Text() == "" { + t.Error("final response text is empty") + } +} + +// TestBedrockLive_ClaudeReasoningRoundTrip is the real proof of the feature: it +// feeds a thinking response back as conversation history and confirms the +// follow-up turn is accepted. If the signed/redacted reasoning weren't +// round-tripped verbatim, Bedrock rejects the request. +func TestBedrockLive_ClaudeReasoningRoundTrip(t *testing.T) { + ctx, g, m := requireLiveClaude(t) + + turn1 := ai.NewUserTextMessage("What is 17 * 24? Show your reasoning, then state the result.") + resp1, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithMessages(turn1), + ai.WithConfig(thinkingConfig()), + ) + if err != nil { + t.Fatal(err) + } + if firstReasoning(resp1.Message) == nil { + t.Fatal("first turn produced no reasoning part; cannot exercise round-trip") + } + + // Replay the assistant turn (reasoning included) plus a follow-up question. + resp2, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithMessages( + turn1, + resp1.Message, + ai.NewUserTextMessage("Now multiply that result by 2."), + ), + ai.WithConfig(thinkingConfig()), + ) + if err != nil { + t.Fatalf("follow-up turn rejected (reasoning round-trip likely broken): %v", err) + } + if resp2.Text() == "" { + t.Error("follow-up response text is empty") + } +} + +// TestBedrockLive_ClaudeReasoningStream confirms reasoning deltas stream through +// to the callback and the final response carries an assembled reasoning part. +func TestBedrockLive_ClaudeReasoningStream(t *testing.T) { + ctx, g, m := requireLiveClaude(t) + + var reasoningChunks, textChunks int + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithPrompt("What is 17 * 24? Think it through, then answer."), + ai.WithConfig(thinkingConfig()), + ai.WithStreaming(func(ctx context.Context, c *ai.ModelResponseChunk) error { + for _, p := range c.Content { + switch { + case p.IsReasoning(): + reasoningChunks++ + case p.IsText(): + textChunks++ + } + } + return nil + }), + ) + if err != nil { + t.Fatal(err) + } + if reasoningChunks == 0 { + t.Error("expected at least one reasoning chunk") + } + if firstReasoning(resp.Message) == nil { + t.Error("final response is missing the assembled reasoning part") + } + if resp.Text() == "" { + t.Error("final response text is empty") + } +} diff --git a/bedrock_plugin_test.go b/bedrock_plugin_test.go index 1a7e163..1afc529 100644 --- a/bedrock_plugin_test.go +++ b/bedrock_plugin_test.go @@ -18,11 +18,15 @@ package bedrock import ( + "context" "encoding/base64" "testing" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" ) func TestInferModelCapabilities_WithInferenceProfiles(t *testing.T) { @@ -186,6 +190,50 @@ func TestInferenceProfilePrefixes_Coverage(t *testing.T) { } } +func TestDefineModelRequiresInitializedPluginInstance(t *testing.T) { + ctx := context.Background() + b := &Bedrock{ + Region: "us-east-1", + AWSConfig: &aws.Config{ + Region: "us-east-1", + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider( + "test-access-key", + "test-secret-key", + "", + )), + }, + } + g := genkit.Init(ctx, genkit.WithPlugins(b)) + + if got := b.DefineModel(g, ModelDefinition{ + Name: "anthropic.claude-3-haiku-20240307-v1:0", + Type: "chat", + }, nil); got == nil { + t.Fatal("DefineModel returned nil for initialized plugin") + } + + assertPanicsWith(t, "bedrock: Init not called", func() { + (&Bedrock{Region: "us-east-1"}).DefineModel(g, ModelDefinition{ + Name: "anthropic.claude-3-haiku-20240307-v1:0", + Type: "chat", + }, nil) + }) +} + +func assertPanicsWith(t *testing.T, want string, fn func()) { + t.Helper() + defer func() { + got := recover() + if got == nil { + t.Fatalf("expected panic %q, got none", want) + } + if got != want { + t.Fatalf("panic = %v, want %q", got, want) + } + }() + fn() +} + func TestModelCapabilities_KnownModels(t *testing.T) { // Verify some known models are in the capability map with correct values tests := []struct { diff --git a/generate.go b/generate.go index 98b54dc..fdcf65a 100644 --- a/generate.go +++ b/generate.go @@ -52,7 +52,7 @@ func (b *Bedrock) buildConverseInput(modelName string, input *ai.ModelRequest) ( if input == nil { return nil, fmt.Errorf("model request is nil") } - + converseInput := &bedrockruntime.ConverseInput{ ModelId: aws.String(modelName), } @@ -244,6 +244,12 @@ func (b *Bedrock) buildConverseInput(modelName string, input *ai.ModelRequest) ( }, }) } + } else if part.Kind == ai.PartReasoning { + // Round-trip Bedrock reasoning (thinking) content so the + // signed text and any redacted block survive into the + // follow-up request. Reasoning parts without Bedrock + // metadata produce no blocks (see reasoningPartToContentBlocks). + contentBlocks = append(contentBlocks, reasoningPartToContentBlocks(part)...) } } @@ -279,31 +285,18 @@ func (b *Bedrock) buildConverseInput(modelName string, input *ai.ModelRequest) ( } } - // Set inference configuration - if input.Config != nil { - if configMap, ok := input.Config.(map[string]interface{}); ok { - inferenceConfig := &types.InferenceConfiguration{} - - if maxTokens, ok := configMap["maxOutputTokens"].(int); ok { - inferenceConfig.MaxTokens = aws.Int32(int32(maxTokens)) - } else if maxTokens, ok := configMap["max_tokens"].(int); ok { - inferenceConfig.MaxTokens = aws.Int32(int32(maxTokens)) - } - - if temp, ok := configMap["temperature"].(float64); ok { - inferenceConfig.Temperature = aws.Float32(float32(temp)) - } - - if topP, ok := configMap["topP"].(float64); ok { - inferenceConfig.TopP = aws.Float32(float32(topP)) - } - - if stopSequences, ok := configMap["stopSequences"].([]string); ok { - inferenceConfig.StopSequences = stopSequences - } - + // Set inference configuration and any model-specific request fields. + cfg, err := configFromRequest(input) + if err != nil { + return nil, err + } + if cfg != nil { + if inferenceConfig := buildInferenceConfig(cfg); inferenceConfig != nil { converseInput.InferenceConfig = inferenceConfig } + if len(cfg.AdditionalModelRequestFields) > 0 { + converseInput.AdditionalModelRequestFields = document.NewLazyDocument(cfg.AdditionalModelRequestFields) + } } // Handle tools @@ -406,6 +399,13 @@ func (b *Bedrock) convertResponse(response *bedrockruntime.ConverseOutput, origi modelResponse.Message.Content = append(modelResponse.Message.Content, ai.NewToolRequestPart(toolRequest)) + + case *types.ContentBlockMemberReasoningContent: + // Reasoning ("thinking") content: carry the signed text and + // any redacted block so it can be replayed on the next turn. + if part, err := reasoningBlockToPart(block.Value); err == nil && part != nil { + modelResponse.Message.Content = append(modelResponse.Message.Content, part) + } } } } @@ -434,6 +434,158 @@ func (b *Bedrock) convertResponse(response *bedrockruntime.ConverseOutput, origi return modelResponse } +// configFromRequest decodes input.Config into a *Config. It accepts the typed +// *Config/Config, *ai.GenerationCommonConfig/ai.GenerationCommonConfig, and the +// historical map[string]any shape (used on resumed/serialized flows). It +// returns (nil, nil) when no config is provided. +func configFromRequest(input *ai.ModelRequest) (*Config, error) { + if input == nil || input.Config == nil { + return nil, nil + } + switch v := input.Config.(type) { + case *Config: + return v, nil + case Config: + return &v, nil + case *ai.GenerationCommonConfig: + return configFromGenerationCommonConfig(v), nil + case ai.GenerationCommonConfig: + return configFromGenerationCommonConfig(&v), nil + case map[string]interface{}: + b, err := json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("bedrock: marshal config: %w", err) + } + var c Config + if err := json.Unmarshal(b, &c); err != nil { + return nil, fmt.Errorf("bedrock: decode config: %w", err) + } + // Preserve the historical max-token keys, which differ from Config's + // json tag ("maxTokens"). The map values may decode as float64 (JSON) + // or int (a directly-constructed Go map). + if c.MaxTokens == 0 { + if mt, ok := mapInt(v, "maxOutputTokens"); ok { + c.MaxTokens = mt + } else if mt, ok := mapInt(v, "max_tokens"); ok { + c.MaxTokens = mt + } + } + return &c, nil + default: + return nil, fmt.Errorf("bedrock: unexpected config type %T, want *bedrock.Config, *ai.GenerationCommonConfig, or map[string]any", input.Config) + } +} + +// mapInt reads an integer-valued key from a config map, tolerating the float64 +// that JSON-decoded numbers arrive as alongside a plain int. +func mapInt(m map[string]interface{}, key string) (int, bool) { + switch v := m[key].(type) { + case int: + return v, true + case int32: + return int(v), true + case int64: + return int(v), true + case float64: + return int(v), true + default: + return 0, false + } +} + +func configFromGenerationCommonConfig(v *ai.GenerationCommonConfig) *Config { + if v == nil { + return nil + } + cfg := &Config{ + MaxTokens: v.MaxOutputTokens, + StopSequences: v.StopSequences, + } + if v.Temperature != 0 { + t := float32(v.Temperature) + cfg.Temperature = &t + } + if v.TopP != 0 { + p := float32(v.TopP) + cfg.TopP = &p + } + return cfg +} + +// buildInferenceConfig maps a *Config onto Bedrock's InferenceConfiguration. It +// returns nil when nothing is set, leaving Bedrock to apply its own defaults +// (MaxTokens is only sent when explicitly provided). +func buildInferenceConfig(cfg *Config) *types.InferenceConfiguration { + if cfg == nil { + return nil + } + ic := &types.InferenceConfiguration{} + set := false + if cfg.MaxTokens > 0 { + ic.MaxTokens = aws.Int32(int32(cfg.MaxTokens)) + set = true + } + if cfg.Temperature != nil { + ic.Temperature = cfg.Temperature + set = true + } + if cfg.TopP != nil { + ic.TopP = cfg.TopP + set = true + } + if len(cfg.StopSequences) > 0 { + ic.StopSequences = cfg.StopSequences + set = true + } + if !set { + return nil + } + return ic +} + +// reasoningPartToContentBlocks converts a reasoning ai.Part back into Bedrock +// reasoning content blocks. Only Bedrock-originated reasoning (carrying the +// signature and/or redacted metadata) is emitted; a generic reasoning part +// produces no blocks so it cannot corrupt the follow-up request. +func reasoningPartToContentBlocks(p *ai.Part) []types.ContentBlock { + var blocks []types.ContentBlock + if redacted := metadataBytes(p.Metadata, redactedReasoningMetadataKey); len(redacted) > 0 { + blocks = append(blocks, &types.ContentBlockMemberReasoningContent{ + Value: &types.ReasoningContentBlockMemberRedactedContent{Value: redacted}, + }) + } + if signature := metadataBytes(p.Metadata, reasoningSignatureMetadataKey); p.Text != "" && len(signature) > 0 { + blocks = append(blocks, &types.ContentBlockMemberReasoningContent{ + Value: &types.ReasoningContentBlockMemberReasoningText{ + Value: types.ReasoningTextBlock{ + Text: aws.String(p.Text), + Signature: aws.String(string(signature)), + }, + }, + }) + } + return blocks +} + +// reasoningBlockToPart converts a Bedrock reasoning content block into an ai +// reasoning Part, or (nil, nil) when the block is empty. +func reasoningBlockToPart(block types.ReasoningContentBlock) (*ai.Part, error) { + switch rc := block.(type) { + case *types.ReasoningContentBlockMemberReasoningText: + if rc.Value.Text == nil && rc.Value.Signature == nil { + return nil, nil + } + return newBedrockReasoningPart(aws.ToString(rc.Value.Text), aws.ToString(rc.Value.Signature), nil), nil + case *types.ReasoningContentBlockMemberRedactedContent: + if len(rc.Value) == 0 { + return nil, nil + } + return newBedrockReasoningPart("", "", rc.Value), nil + default: + return nil, fmt.Errorf("bedrock: unhandled reasoning content variant %T", block) + } +} + // convertToolInputTypes converts tool input parameters to the correct types based on the tool schema func (b *Bedrock) convertToolInputTypes(inputMap map[string]interface{}, toolName string, tools []*ai.ToolDefinition) interface{} { // Find the tool definition for this tool call diff --git a/image.go b/image.go index 06cb93b..ef8649d 100644 --- a/image.go +++ b/image.go @@ -33,7 +33,7 @@ func (b *Bedrock) generateImage(ctx context.Context, modelName string, input *ai if input == nil { return nil, fmt.Errorf("model request is nil") } - + // Extract prompt from the first message var prompt string if len(input.Messages) > 0 && len(input.Messages[0].Content) > 0 { diff --git a/reasoning_test.go b/reasoning_test.go new file mode 100644 index 0000000..ac9b96f --- /dev/null +++ b/reasoning_test.go @@ -0,0 +1,311 @@ +// Copyright 2025 Xavier Portilla Edo +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package bedrock + +import ( + "encoding/base64" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/firebase/genkit/go/ai" +) + +// --- Request round-trip ----------------------------------------------------- + +func TestReasoningPartToContentBlocks_RoundTripsBedrockReasoning(t *testing.T) { + // Signed reasoning text → ReasoningContentBlockMemberReasoningText. + signed := newBedrockReasoningPart("signed thought", "sig", nil) + blocks := reasoningPartToContentBlocks(signed) + if len(blocks) != 1 { + t.Fatalf("len(blocks) = %d, want 1", len(blocks)) + } + rc, ok := blocks[0].(*types.ContentBlockMemberReasoningContent) + if !ok { + t.Fatalf("blocks[0] = %T, want reasoning content", blocks[0]) + } + text, ok := rc.Value.(*types.ReasoningContentBlockMemberReasoningText) + if !ok { + t.Fatalf("blocks[0].Value = %T, want reasoning text", rc.Value) + } + if aws.ToString(text.Value.Text) != "signed thought" { + t.Errorf("text = %q, want signed thought", aws.ToString(text.Value.Text)) + } + if aws.ToString(text.Value.Signature) != "sig" { + t.Errorf("signature = %q, want sig", aws.ToString(text.Value.Signature)) + } + + // Redacted reasoning stored as a base64 string (JSON round-trip shape) → + // ReasoningContentBlockMemberRedactedContent with the decoded bytes. + redacted := ai.NewReasoningPart("", nil) + redacted.Metadata[redactedReasoningMetadataKey] = base64.StdEncoding.EncodeToString([]byte("encrypted")) + blocks = reasoningPartToContentBlocks(redacted) + if len(blocks) != 1 { + t.Fatalf("redacted len(blocks) = %d, want 1", len(blocks)) + } + rc, ok = blocks[0].(*types.ContentBlockMemberReasoningContent) + if !ok { + t.Fatalf("redacted blocks[0] = %T, want reasoning content", blocks[0]) + } + red, ok := rc.Value.(*types.ReasoningContentBlockMemberRedactedContent) + if !ok { + t.Fatalf("redacted blocks[0].Value = %T, want redacted content", rc.Value) + } + if string(red.Value) != "encrypted" { + t.Errorf("redacted = %q, want encrypted", string(red.Value)) + } +} + +func TestReasoningPartToContentBlocks_SkipsGenericReasoning(t *testing.T) { + // A generic reasoning part (signature under the framework "signature" key, + // not the Bedrock-specific key) must not round-trip into request blocks. + p := ai.NewReasoningPart("signed elsewhere", []byte("foreign-sig")) + if blocks := reasoningPartToContentBlocks(p); len(blocks) != 0 { + t.Fatalf("len(blocks) = %d, want 0", len(blocks)) + } +} + +func TestBuildConverseInput_SkipsGenericReasoning(t *testing.T) { + b := &Bedrock{} + input := &ai.ModelRequest{ + Messages: []*ai.Message{ + { + Role: ai.RoleModel, + Content: []*ai.Part{ + ai.NewTextPart("question"), + ai.NewReasoningPart("internal monologue from a prior turn", nil), + ai.NewTextPart("more question"), + }, + }, + }, + } + + out, err := b.buildConverseInput("anthropic.claude-3-sonnet", input) + if err != nil { + t.Fatal(err) + } + if len(out.Messages) != 1 { + t.Fatalf("len(messages) = %d, want 1", len(out.Messages)) + } + blocks := out.Messages[0].Content + if len(blocks) != 2 { + t.Fatalf("len(blocks) = %d, want 2 (no reasoning leakage)", len(blocks)) + } + for i, blk := range blocks { + text, ok := blk.(*types.ContentBlockMemberText) + if !ok { + t.Fatalf("blocks[%d] = %T, want text", i, blk) + } + if text.Value == "internal monologue from a prior turn" { + t.Errorf("reasoning text leaked into block %d", i) + } + } +} + +// --- Response parse --------------------------------------------------------- + +func TestConvertResponse_ReasoningSignatureAndRedacted(t *testing.T) { + b := &Bedrock{} + redacted := []byte("encrypted") + resp := &bedrockruntime.ConverseOutput{ + Output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{ + &types.ContentBlockMemberReasoningContent{ + Value: &types.ReasoningContentBlockMemberReasoningText{ + Value: types.ReasoningTextBlock{ + Text: aws.String("thinking"), + Signature: aws.String("sig"), + }, + }, + }, + &types.ContentBlockMemberReasoningContent{ + Value: &types.ReasoningContentBlockMemberRedactedContent{Value: redacted}, + }, + }, + }, + }, + } + + got := b.convertResponse(resp, &ai.ModelRequest{}) + parts := got.Message.Content + if len(parts) != 2 { + t.Fatalf("len(parts) = %d, want 2", len(parts)) + } + + if !parts[0].IsReasoning() || parts[0].Text != "thinking" { + t.Fatalf("parts[0] = %+v, want reasoning text", parts[0]) + } + if sig, ok := parts[0].Metadata["signature"].([]byte); !ok || string(sig) != "sig" { + t.Errorf("generic signature = %v, want sig", parts[0].Metadata["signature"]) + } + if sig, ok := parts[0].Metadata[reasoningSignatureMetadataKey].([]byte); !ok || string(sig) != "sig" { + t.Errorf("bedrock signature = %v, want sig", parts[0].Metadata[reasoningSignatureMetadataKey]) + } + + if !parts[1].IsReasoning() { + t.Fatalf("parts[1] kind = %v, want reasoning", parts[1].Kind) + } + if red, ok := parts[1].Metadata[redactedReasoningMetadataKey].([]byte); !ok || string(red) != string(redacted) { + t.Errorf("redacted = %v, want %q", parts[1].Metadata[redactedReasoningMetadataKey], string(redacted)) + } +} + +// TestConvertResponse_TextSkipsReasoning is the proposal's sanity check: a +// response carrying both reasoning and text must expose only the text via +// Text(), so existing callers don't suddenly see thinking content. +func TestConvertResponse_TextSkipsReasoning(t *testing.T) { + b := &Bedrock{} + resp := &bedrockruntime.ConverseOutput{ + Output: &types.ConverseOutputMemberMessage{ + Value: types.Message{ + Content: []types.ContentBlock{ + &types.ContentBlockMemberReasoningContent{ + Value: &types.ReasoningContentBlockMemberReasoningText{ + Value: types.ReasoningTextBlock{ + Text: aws.String("let me think about this"), + Signature: aws.String("sig"), + }, + }, + }, + &types.ContentBlockMemberText{Value: "the answer is 42"}, + }, + }, + }, + } + + got := b.convertResponse(resp, &ai.ModelRequest{}) + if text := got.Text(); text != "the answer is 42" { + t.Errorf("Text() = %q, want %q (reasoning leaked)", text, "the answer is 42") + } +} + +// --- Streaming -------------------------------------------------------------- + +func TestAppendReasoningDelta(t *testing.T) { + acc := &streamAccumulator{} + + part, err := appendReasoningDelta(acc, &types.ReasoningContentBlockDeltaMemberText{Value: "thinking"}) + if err != nil { + t.Fatal(err) + } + if part == nil || !part.IsReasoning() || part.Text != "thinking" { + t.Fatalf("text delta part = %+v, want reasoning %q", part, "thinking") + } + if got := acc.reasoning.String(); got != "thinking" { + t.Errorf("accumulated reasoning = %q, want thinking", got) + } + + part, err = appendReasoningDelta(acc, &types.ReasoningContentBlockDeltaMemberSignature{Value: "sig"}) + if err != nil { + t.Fatal(err) + } + if part != nil { + t.Fatalf("signature delta returned part %v, want nil", part) + } + if acc.reasoningSignature != "sig" { + t.Errorf("signature = %q, want sig", acc.reasoningSignature) + } + + part, err = appendReasoningDelta(acc, &types.ReasoningContentBlockDeltaMemberRedactedContent{Value: []byte("encrypted")}) + if err != nil { + t.Fatal(err) + } + if part != nil { + t.Fatalf("redacted delta returned part %v, want nil", part) + } + if string(acc.redactedReasoning) != "encrypted" { + t.Errorf("redacted = %q, want encrypted", string(acc.redactedReasoning)) + } +} + +func TestAppendReasoningDelta_UnknownErrors(t *testing.T) { + if _, err := appendReasoningDelta(&streamAccumulator{}, &types.UnknownUnionMember{Tag: "future_reasoning_delta"}); err == nil { + t.Fatal("expected error for unknown reasoning delta") + } +} + +func TestStreamFinalContent_ReasoningReassembly(t *testing.T) { + acc := &streamAccumulator{reasoningSignature: "sig", redactedReasoning: []byte("encrypted")} + acc.reasoning.WriteString("First thought. ") + acc.reasoning.WriteString("Second thought.") + acc.text.WriteString("Final answer.") + + parts := acc.finalContent() + if len(parts) != 2 { + t.Fatalf("len(parts) = %d, want 2", len(parts)) + } + + // Reasoning precedes text so the assistant turn replays in order. + if !parts[0].IsReasoning() || parts[0].Text != "First thought. Second thought." { + t.Fatalf("parts[0] = %+v, want assembled reasoning", parts[0]) + } + if sig, ok := parts[0].Metadata["signature"].([]byte); !ok || string(sig) != "sig" { + t.Errorf("generic signature = %v, want sig", parts[0].Metadata["signature"]) + } + if sig, ok := parts[0].Metadata[reasoningSignatureMetadataKey].([]byte); !ok || string(sig) != "sig" { + t.Errorf("bedrock signature = %v, want sig", parts[0].Metadata[reasoningSignatureMetadataKey]) + } + if red, ok := parts[0].Metadata[redactedReasoningMetadataKey].([]byte); !ok || string(red) != "encrypted" { + t.Errorf("redacted = %v, want encrypted", parts[0].Metadata[redactedReasoningMetadataKey]) + } + if !parts[1].IsText() || parts[1].Text != "Final answer." { + t.Errorf("parts[1] = %+v, want text Final answer.", parts[1]) + } +} + +// --- Config decode ---------------------------------------------------------- + +func TestConfigFromRequest_TypedAndAdditionalFields(t *testing.T) { + thinking := map[string]any{"type": "enabled", "budget_tokens": 5000} + input := &ai.ModelRequest{Config: &Config{ + MaxTokens: 8000, + AdditionalModelRequestFields: map[string]any{"thinking": thinking}, + }} + cfg, err := configFromRequest(input) + if err != nil { + t.Fatal(err) + } + if cfg.MaxTokens != 8000 { + t.Errorf("MaxTokens = %d, want 8000", cfg.MaxTokens) + } + if cfg.AdditionalModelRequestFields["thinking"] == nil { + t.Error("thinking field dropped") + } +} + +// TestConfigFromRequest_LegacyMapKeys guards backward compatibility: the +// historical map config (with maxOutputTokens) must still drive MaxTokens. +func TestConfigFromRequest_LegacyMapKeys(t *testing.T) { + cases := []map[string]interface{}{ + {"maxOutputTokens": 1024, "temperature": 0.5}, + {"max_tokens": 1024}, + {"maxOutputTokens": float64(1024)}, // JSON-decoded shape + } + for _, m := range cases { + cfg, err := configFromRequest(&ai.ModelRequest{Config: m}) + if err != nil { + t.Fatalf("config %v: %v", m, err) + } + ic := buildInferenceConfig(cfg) + if ic == nil || ic.MaxTokens == nil || *ic.MaxTokens != 1024 { + t.Errorf("config %v: MaxTokens = %v, want 1024", m, ic.MaxTokens) + } + } +} diff --git a/rerank_test.go b/rerank_test.go index 787d842..a466ef1 100644 --- a/rerank_test.go +++ b/rerank_test.go @@ -58,7 +58,9 @@ func TestRerankInvokesCohereRerankAndMapsScores(t *testing.T) { } w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"results":[{"index":1,"relevance_score":0.94},{"index":0,"relevance_score":0.42}]}`) + if _, err := fmt.Fprint(w, `{"results":[{"index":1,"relevance_score":0.94},{"index":0,"relevance_score":0.42}]}`); err != nil { + t.Errorf("failed to write mock rerank response: %v", err) + } })) defer server.Close() @@ -139,7 +141,9 @@ func TestRerankDefaultsTopNToDocumentCount(t *testing.T) { return } w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"results":[]}`) + if _, err := fmt.Fprint(w, `{"results":[]}`); err != nil { + t.Errorf("failed to write mock rerank response: %v", err) + } })) defer server.Close() diff --git a/stream.go b/stream.go index 7188911..a6b310b 100644 --- a/stream.go +++ b/stream.go @@ -58,8 +58,10 @@ func (b *Bedrock) generateTextStream(ctx context.Context, input *bedrockruntime. } }() - // Build final response - var fullText strings.Builder + // Accumulate streamed content. Reasoning ("thinking") deltas are tracked + // separately from plain text so the signed/redacted reasoning can be + // attached to the final response and replayed on the next turn. + acc := &streamAccumulator{} var finalResponse *ai.ModelResponse var stopReason types.StopReason @@ -68,19 +70,30 @@ func (b *Bedrock) generateTextStream(ctx context.Context, input *bedrockruntime. switch e := event.(type) { case *types.ConverseStreamOutputMemberContentBlockDelta: - // Text delta received deltaEvent := e.Value - if deltaEvent.Delta != nil { - if textDelta, ok := deltaEvent.Delta.(*types.ContentBlockDeltaMemberText); ok { - text := textDelta.Value - fullText.WriteString(text) + if deltaEvent.Delta == nil { + continue + } + switch delta := deltaEvent.Delta.(type) { + case *types.ContentBlockDeltaMemberText: + acc.text.WriteString(delta.Value) + chunk := &ai.ModelResponseChunk{ + Index: 0, + Content: []*ai.Part{ai.NewTextPart(delta.Value)}, + } + if err := cb(ctx, chunk); err != nil { + return nil, fmt.Errorf("callback error: %w", err) + } - // Send chunk to callback + case *types.ContentBlockDeltaMemberReasoningContent: + part, err := appendReasoningDelta(acc, delta.Value) + if err != nil { + return nil, err + } + if part != nil { chunk := &ai.ModelResponseChunk{ - Index: 0, - Content: []*ai.Part{ - ai.NewTextPart(text), - }, + Index: 0, + Content: []*ai.Part{part}, } if err := cb(ctx, chunk); err != nil { return nil, fmt.Errorf("callback error: %w", err) @@ -95,10 +108,8 @@ func (b *Bedrock) generateTextStream(ctx context.Context, input *bedrockruntime. finalResponse = &ai.ModelResponse{ Message: &ai.Message{ - Role: ai.RoleModel, - Content: []*ai.Part{ - ai.NewTextPart(fullText.String()), - }, + Role: ai.RoleModel, + Content: acc.finalContent(), }, FinishReason: convertStopReasonToGenkit(stopReason), } @@ -113,10 +124,8 @@ func (b *Bedrock) generateTextStream(ctx context.Context, input *bedrockruntime. if finalResponse == nil { finalResponse = &ai.ModelResponse{ Message: &ai.Message{ - Role: ai.RoleModel, - Content: []*ai.Part{ - ai.NewTextPart(fullText.String()), - }, + Role: ai.RoleModel, + Content: acc.finalContent(), }, FinishReason: ai.FinishReasonStop, } @@ -124,3 +133,45 @@ func (b *Bedrock) generateTextStream(ctx context.Context, input *bedrockruntime. return finalResponse, nil } + +// streamAccumulator collects streamed content across delta events. The current +// streaming path reconstructs a single text block plus any reasoning; +// block-indexed reassembly (e.g. for streamed tool-use) is a separate concern. +type streamAccumulator struct { + text strings.Builder + reasoning strings.Builder + reasoningSignature string + redactedReasoning []byte +} + +// appendReasoningDelta folds a reasoning delta into the accumulator. Text deltas +// are accumulated and returned as an emittable chunk; signature and redacted +// deltas are accumulated silently (nil part) since they only matter for the +// final, replayable reasoning part. +func appendReasoningDelta(acc *streamAccumulator, delta types.ReasoningContentBlockDelta) (*ai.Part, error) { + switch d := delta.(type) { + case *types.ReasoningContentBlockDeltaMemberText: + acc.reasoning.WriteString(d.Value) + return newBedrockReasoningPart(d.Value, "", nil), nil + case *types.ReasoningContentBlockDeltaMemberSignature: + acc.reasoningSignature = d.Value + return nil, nil + case *types.ReasoningContentBlockDeltaMemberRedactedContent: + acc.redactedReasoning = append(acc.redactedReasoning, d.Value...) + return nil, nil + default: + return nil, fmt.Errorf("bedrock: unhandled stream reasoning delta variant %T", delta) + } +} + +// finalContent assembles the accumulated stream state into response parts. Any +// reasoning precedes the text so the assistant turn replays in the order +// thinking models require. +func (acc *streamAccumulator) finalContent() []*ai.Part { + var parts []*ai.Part + if acc.reasoning.Len() > 0 || len(acc.redactedReasoning) > 0 { + parts = append(parts, newBedrockReasoningPart(acc.reasoning.String(), acc.reasoningSignature, acc.redactedReasoning)) + } + parts = append(parts, ai.NewTextPart(acc.text.String())) + return parts +} diff --git a/types.go b/types.go index 008e353..66a65c9 100644 --- a/types.go +++ b/types.go @@ -18,8 +18,11 @@ package bedrock import ( + "encoding/base64" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" ) // Type aliases for better readability @@ -65,6 +68,101 @@ const ( const bedrockCachePointTypeKey = "bedrockCachePointType" +// Metadata keys used to round-trip Bedrock reasoning ("thinking") content back +// into a follow-up request. Bedrock returns signed and sometimes redacted +// reasoning that must be replayed verbatim on the next turn or the model +// rejects it, so the signature and redacted bytes are stashed on the +// ai.Part metadata. These keys are Bedrock-specific: a generic reasoning part +// created via ai.NewReasoningPart (without these) is intentionally NOT +// round-tripped, so foreign reasoning can't corrupt a Bedrock conversation. +const ( + reasoningSignatureMetadataKey = "bedrockReasoningSignature" + redactedReasoningMetadataKey = "bedrockRedactedContent" +) + +// Config is the per-call configuration for Bedrock Converse models. Pass it +// via [ai.WithConfig]. +// +// It is fully optional and additive: callers may still pass configuration as a +// map[string]any (the historical shape) or as *ai.GenerationCommonConfig; see +// configFromRequest. The typed form exists mainly so model-specific knobs like +// Claude extended thinking can be enabled through AdditionalModelRequestFields. +type Config struct { + // MaxTokens is the upper bound on the generated response length. When 0 the + // plugin leaves it unset and Bedrock applies its own per-model default. + MaxTokens int `json:"maxTokens,omitempty"` + + // Temperature controls sampling randomness. nil leaves it to the model default. + Temperature *float32 `json:"temperature,omitempty"` + + // TopP is the nucleus-sampling cutoff. nil leaves it to the model default. + TopP *float32 `json:"topP,omitempty"` + + // StopSequences are strings that, when generated, halt generation. + StopSequences []string `json:"stopSequences,omitempty"` + + // ToolChoice selects how the model should pick tools. It is accepted here + // for forward compatibility; wiring it through the Converse request is a + // separate change and it is currently not applied. + ToolChoice string `json:"toolChoice,omitempty"` + + // AdditionalModelRequestFields is forwarded verbatim as the Converse API's + // AdditionalModelRequestFields document. Use it for model-specific knobs not + // covered by the inference-config surface, e.g. Claude extended thinking: + // + // &bedrock.Config{ + // MaxTokens: 8000, + // AdditionalModelRequestFields: map[string]any{ + // "thinking": map[string]any{"type": "enabled", "budget_tokens": 5000}, + // }, + // } + AdditionalModelRequestFields map[string]any `json:"additionalModelRequestFields,omitempty"` +} + +// configSchema returns the JSON schema for [Config], used as the per-call +// ConfigSchema on every defined Converse model. +func configSchema() map[string]any { return core.InferSchemaMap(Config{}) } + +// newBedrockReasoningPart builds an ai reasoning part carrying the Bedrock +// signature and/or redacted bytes needed to replay it on the next turn. The +// signature is also stored under the generic "signature" key (via +// ai.NewReasoningPart) so framework-level consumers see it too. +func newBedrockReasoningPart(text, signature string, redacted []byte) *ai.Part { + var sig []byte + if signature != "" { + sig = []byte(signature) + } + p := ai.NewReasoningPart(text, sig) + if len(sig) > 0 { + p.Metadata[reasoningSignatureMetadataKey] = sig + } + if len(redacted) > 0 { + p.Metadata[redactedReasoningMetadataKey] = redacted + } + return p +} + +// metadataBytes reads a []byte value from part metadata, also accepting a +// base64-encoded string (which is how []byte survives a JSON round-trip on +// resumed/serialized flows). +func metadataBytes(metadata map[string]any, key string) []byte { + if metadata == nil { + return nil + } + switch v := metadata[key].(type) { + case []byte: + return v + case string: + b, err := base64.StdEncoding.DecodeString(v) + if err != nil { + return nil + } + return b + default: + return nil + } +} + // ModelDefinition represents a model with its name and type. type ModelDefinition struct { Name string // Model ID as used in AWS Bedrock