Skip to content

Commit 53ed766

Browse files
authored
fix: normalize LLM tool requests (#2142)
1 parent 397d5e1 commit 53ed766

7 files changed

Lines changed: 273 additions & 7 deletions

File tree

internal/llm/provider.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,11 @@ func NewProvider(providerType ProviderType, cfg Config) (Provider, error) {
173173
cfg.BaseURL = DefaultBaseURL(providerType)
174174
}
175175

176-
return factory(cfg)
176+
provider, err := factory(cfg)
177+
if err != nil {
178+
return nil, err
179+
}
180+
return normalizedProvider{Provider: provider}, nil
177181
}
178182

179183
// NewProviderWithAPIKey creates a new Provider with the given API key.

internal/llm/providers/anthropic/anthropic.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ import (
1717
)
1818

1919
const (
20-
providerName = "anthropic"
21-
defaultMessagesPath = "/v1/messages"
22-
anthropicAPIVersion = "2023-06-01"
23-
streamPrefix = "data: "
20+
providerName = "anthropic"
21+
defaultMessagesPath = "/v1/messages"
22+
anthropicAPIVersion = "2023-06-01"
23+
streamPrefix = "data: "
24+
anthropicWebSearchToolType = "web_search_20260209"
2425

2526
// Thinking budget token limits for different effort levels.
2627
// Note: Anthropic recommends budgets <= 32K to avoid timeout issues.
@@ -160,8 +161,8 @@ func (p *Provider) buildRequestBody(req *llm.ChatRequest, stream bool) ([]byte,
160161
// Append provider-native web search tool if enabled.
161162
if req.WebSearch != nil && req.WebSearch.Enabled {
162163
wsEntry := map[string]any{
163-
"type": "web_search_20260209",
164-
"name": "web_search",
164+
"type": anthropicWebSearchToolType,
165+
"name": llm.WebSearchToolName,
165166
}
166167
// Haiku 4.5 only supports allowed_callers=["direct"] for web_search.
167168
// Other models accept any value, so "direct" is safe as a universal default.

internal/llm/request.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright (C) 2026 Yota Hamada
2+
// SPDX-License-Identifier: GPL-3.0-or-later
3+
4+
package llm
5+
6+
import "context"
7+
8+
// WebSearchToolName is the canonical explicit tool name for web search.
9+
const WebSearchToolName = "web_search"
10+
11+
// NormalizeChatRequest applies provider-independent request invariants before a
12+
// request reaches a concrete provider.
13+
func NormalizeChatRequest(req *ChatRequest) *ChatRequest {
14+
if req == nil {
15+
return nil
16+
}
17+
18+
normalized := *req
19+
normalized.Tools = dedupeTools(req.Tools)
20+
normalized.WebSearch = cloneWebSearchRequest(req.WebSearch)
21+
22+
if normalized.WebSearch != nil && normalized.WebSearch.Enabled && hasToolNamed(normalized.Tools, WebSearchToolName) {
23+
normalized.WebSearch.Enabled = false
24+
}
25+
26+
return &normalized
27+
}
28+
29+
// cloneWebSearchRequest returns an independent copy of provider-native web
30+
// search settings so normalization never mutates caller-owned request state.
31+
func cloneWebSearchRequest(req *WebSearchRequest) *WebSearchRequest {
32+
if req == nil {
33+
return nil
34+
}
35+
36+
out := *req
37+
out.AllowedDomains = append([]string(nil), req.AllowedDomains...)
38+
out.BlockedDomains = append([]string(nil), req.BlockedDomains...)
39+
if req.UserLocation != nil {
40+
location := *req.UserLocation
41+
out.UserLocation = &location
42+
}
43+
return &out
44+
}
45+
46+
// dedupeTools preserves the first definition for each tool name and drops later
47+
// duplicates so providers never receive invalid same-name tool lists.
48+
func dedupeTools(tools []Tool) []Tool {
49+
if len(tools) == 0 {
50+
return nil
51+
}
52+
53+
out := make([]Tool, 0, len(tools))
54+
seen := make(map[string]struct{}, len(tools))
55+
for _, tool := range tools {
56+
name := tool.Function.Name
57+
if _, ok := seen[name]; ok {
58+
continue
59+
}
60+
seen[name] = struct{}{}
61+
out = append(out, tool)
62+
}
63+
return out
64+
}
65+
66+
// hasToolNamed reports whether tools contains a definition for name.
67+
func hasToolNamed(tools []Tool, name string) bool {
68+
for _, tool := range tools {
69+
if tool.Function.Name == name {
70+
return true
71+
}
72+
}
73+
return false
74+
}
75+
76+
// normalizedProvider enforces shared request normalization around every
77+
// concrete provider implementation returned by the factory.
78+
type normalizedProvider struct {
79+
Provider
80+
}
81+
82+
// Chat normalizes each request before delegating to the concrete provider.
83+
func (p normalizedProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
84+
return p.Provider.Chat(ctx, NormalizeChatRequest(req))
85+
}
86+
87+
// ChatStream normalizes each streaming request before delegating to the
88+
// concrete provider.
89+
func (p normalizedProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan StreamEvent, error) {
90+
return p.Provider.ChatStream(ctx, NormalizeChatRequest(req))
91+
}

internal/llm/request_test.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright (C) 2026 Yota Hamada
2+
// SPDX-License-Identifier: GPL-3.0-or-later
3+
4+
package llm
5+
6+
import (
7+
"context"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
// TestNormalizeChatRequest verifies that request normalization preserves caller
15+
// state while enforcing provider-independent tool invariants.
16+
func TestNormalizeChatRequest(t *testing.T) {
17+
t.Parallel()
18+
19+
t.Run("deduplicates tools by name and disables native web search on collision", func(t *testing.T) {
20+
t.Parallel()
21+
22+
req := &ChatRequest{
23+
Model: "test",
24+
Tools: []Tool{
25+
testTool(WebSearchToolName, "first"),
26+
testTool(WebSearchToolName, "second"),
27+
testTool("read", "read"),
28+
},
29+
WebSearch: &WebSearchRequest{Enabled: true},
30+
}
31+
32+
normalized := NormalizeChatRequest(req)
33+
34+
require.NotSame(t, req, normalized)
35+
require.Len(t, normalized.Tools, 2)
36+
assert.Equal(t, "first", normalized.Tools[0].Function.Description)
37+
assert.Equal(t, "read", normalized.Tools[1].Function.Name)
38+
require.NotNil(t, normalized.WebSearch)
39+
assert.False(t, normalized.WebSearch.Enabled)
40+
41+
require.Len(t, req.Tools, 3, "normalization must not mutate caller tools")
42+
assert.True(t, req.WebSearch.Enabled, "normalization must not mutate caller web search config")
43+
})
44+
45+
t.Run("keeps native web search when no explicit web_search tool exists", func(t *testing.T) {
46+
t.Parallel()
47+
48+
req := &ChatRequest{
49+
Model: "test",
50+
Tools: []Tool{testTool("read", "read")},
51+
WebSearch: &WebSearchRequest{Enabled: true},
52+
}
53+
54+
normalized := NormalizeChatRequest(req)
55+
56+
require.NotNil(t, normalized.WebSearch)
57+
assert.True(t, normalized.WebSearch.Enabled)
58+
})
59+
}
60+
61+
// TestNewProviderNormalizesRequests verifies that factory-created providers
62+
// normalize both synchronous and streaming chat requests.
63+
func TestNewProviderNormalizesRequests(t *testing.T) {
64+
orig := registry
65+
defer func() { registry = orig }()
66+
registry = make(map[ProviderType]ProviderFactory)
67+
68+
capturedChat := make(chan *ChatRequest, 1)
69+
capturedStream := make(chan *ChatRequest, 1)
70+
RegisterProvider(ProviderType("normalize-test"), func(_ Config) (Provider, error) {
71+
return &normalizingTestProvider{
72+
chat: capturedChat,
73+
stream: capturedStream,
74+
}, nil
75+
})
76+
77+
provider, err := NewProvider(ProviderType("normalize-test"), Config{})
78+
require.NoError(t, err)
79+
80+
req := &ChatRequest{
81+
Model: "test",
82+
Tools: []Tool{
83+
testTool(WebSearchToolName, "first"),
84+
testTool(WebSearchToolName, "second"),
85+
},
86+
WebSearch: &WebSearchRequest{Enabled: true},
87+
}
88+
89+
_, err = provider.Chat(context.Background(), req)
90+
require.NoError(t, err)
91+
chatReq := <-capturedChat
92+
require.Len(t, chatReq.Tools, 1)
93+
assert.False(t, chatReq.WebSearch.Enabled)
94+
95+
events, err := provider.ChatStream(context.Background(), req)
96+
require.NoError(t, err)
97+
for range events {
98+
}
99+
streamReq := <-capturedStream
100+
require.Len(t, streamReq.Tools, 1)
101+
assert.False(t, streamReq.WebSearch.Enabled)
102+
}
103+
104+
// testTool builds a minimal function tool definition for normalization tests.
105+
func testTool(name, description string) Tool {
106+
return Tool{
107+
Type: "function",
108+
Function: ToolFunction{
109+
Name: name,
110+
Description: description,
111+
Parameters: map[string]any{"type": "object"},
112+
},
113+
}
114+
}
115+
116+
// normalizingTestProvider captures requests after the provider wrapper has
117+
// normalized them.
118+
type normalizingTestProvider struct {
119+
chat chan<- *ChatRequest
120+
stream chan<- *ChatRequest
121+
}
122+
123+
// Chat captures the normalized request received by the test provider.
124+
func (p *normalizingTestProvider) Chat(_ context.Context, req *ChatRequest) (*ChatResponse, error) {
125+
p.chat <- req
126+
return &ChatResponse{Content: "ok"}, nil
127+
}
128+
129+
// ChatStream captures the normalized request received by the streaming path.
130+
func (p *normalizingTestProvider) ChatStream(_ context.Context, req *ChatRequest) (<-chan StreamEvent, error) {
131+
p.stream <- req
132+
ch := make(chan StreamEvent, 1)
133+
ch <- StreamEvent{Done: true}
134+
close(ch)
135+
return ch, nil
136+
}
137+
138+
// Name returns the test provider's registry name.
139+
func (p *normalizingTestProvider) Name() string {
140+
return "normalize-test"
141+
}

internal/llm/retry.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ func DefaultLogicalRetryConfig() LogicalRetryConfig {
4545
// transient request failures.
4646
func ChatWithRetry(ctx context.Context, provider Provider, req *ChatRequest, cfg LogicalRetryConfig) (*ChatResponse, error) {
4747
cfg = normalizeLogicalRetryConfig(cfg)
48+
req = NormalizeChatRequest(req)
4849

4950
var lastErr error
5051
for attempt := 1; attempt <= cfg.MaxAttempts; attempt++ {

internal/llm/retry_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,33 @@ func TestChatWithRetry(t *testing.T) {
9090
require.ErrorIs(t, err, context.Canceled)
9191
assert.Equal(t, int32(1), calls.Load())
9292
})
93+
94+
t.Run("normalizes request before provider call", func(t *testing.T) {
95+
t.Parallel()
96+
97+
var captured *ChatRequest
98+
provider := &retryTestProvider{
99+
chatFunc: func(_ context.Context, req *ChatRequest) (*ChatResponse, error) {
100+
captured = req
101+
return &ChatResponse{Content: "ok", FinishReason: "stop"}, nil
102+
},
103+
}
104+
req := &ChatRequest{
105+
Model: "test",
106+
Tools: []Tool{
107+
testTool(WebSearchToolName, "first"),
108+
testTool(WebSearchToolName, "second"),
109+
},
110+
WebSearch: &WebSearchRequest{Enabled: true},
111+
}
112+
113+
_, err := ChatWithRetry(context.Background(), provider, req, DefaultLogicalRetryConfig())
114+
require.NoError(t, err)
115+
require.NotNil(t, captured)
116+
require.Len(t, captured.Tools, 1)
117+
assert.False(t, captured.WebSearch.Enabled)
118+
assert.True(t, req.WebSearch.Enabled, "normalization must not mutate caller request")
119+
})
93120
}
94121

95122
func TestShouldRetryRequest(t *testing.T) {

internal/runtime/builtin/chat/executor.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,7 @@ func (e *Executor) executeToolStep(
719719

720720
func (e *Executor) runStreamForModel(ctx context.Context, provider llmpkg.Provider, req *llmpkg.ChatRequest) (string, *llmpkg.Usage, error) {
721721
retryCfg := llmpkg.DefaultLogicalRetryConfig()
722+
req = llmpkg.NormalizeChatRequest(req)
722723

723724
for attempt := 1; attempt <= retryCfg.MaxAttempts; attempt++ {
724725
events, err := provider.ChatStream(ctx, req)

0 commit comments

Comments
 (0)