Skip to content

Commit 97a81c4

Browse files
committed
Harden toolcall leak interception for function-style payloads
1 parent b0a09df commit 97a81c4

20 files changed

+336
-222
lines changed

internal/adapter/claude/handler_stream_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,40 @@ func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.
358358
}
359359
}
360360

361+
func TestHandleClaudeStreamRealtimeDetectsToolUseWithLeadingProse(t *testing.T) {
362+
h := &Handler{}
363+
payload := "I'll call a tool now.\\n<tool_use><tool_name>write_file</tool_name><parameters>{\\\"path\\\":\\\"/tmp/a.txt\\\",\\\"content\\\":\\\"abc\\\"}</parameters></tool_use>"
364+
resp := makeClaudeSSEHTTPResponse(
365+
`data: {"p":"response/content","v":"`+payload+`"}`,
366+
`data: [DONE]`,
367+
)
368+
rec := httptest.NewRecorder()
369+
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
370+
371+
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"write_file"})
372+
373+
frames := parseClaudeFrames(t, rec.Body.String())
374+
foundToolUse := false
375+
for _, f := range findClaudeFrames(frames, "content_block_start") {
376+
contentBlock, _ := f.Payload["content_block"].(map[string]any)
377+
if contentBlock["type"] == "tool_use" && contentBlock["name"] == "write_file" {
378+
foundToolUse = true
379+
break
380+
}
381+
}
382+
if !foundToolUse {
383+
t.Fatalf("expected tool_use block with leading prose payload, body=%s", rec.Body.String())
384+
}
385+
386+
for _, f := range findClaudeFrames(frames, "message_delta") {
387+
delta, _ := f.Payload["delta"].(map[string]any)
388+
if delta["stop_reason"] == "tool_use" {
389+
return
390+
}
391+
}
392+
t.Fatalf("expected stop_reason=tool_use, body=%s", rec.Body.String())
393+
}
394+
361395
func TestHandleClaudeStreamRealtimeIgnoresUnclosedFencedToolExample(t *testing.T) {
362396
h := &Handler{}
363397
resp := makeClaudeSSEHTTPResponse(

internal/adapter/claude/standard_request.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
3838
}
3939
finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
4040
toolNames := extractClaudeToolNames(toolsRequested)
41+
if len(toolNames) == 0 && len(toolsRequested) > 0 {
42+
toolNames = []string{"__any_tool__"}
43+
}
4144

4245
return claudeNormalizedRequest{
4346
Standard: util.StandardRequest{

internal/adapter/claude/stream_runtime_core.go

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88

99
"ds2api/internal/sse"
1010
streamengine "ds2api/internal/stream"
11-
"ds2api/internal/util"
1211
)
1312

1413
type claudeStreamRuntime struct {
@@ -120,15 +119,6 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
120119
if hasUnclosedCodeFence(s.text.String()) {
121120
continue
122121
}
123-
detected := util.ParseToolCalls(s.text.String(), s.toolNames)
124-
if len(detected) > 0 {
125-
s.finalize("tool_use")
126-
return streamengine.ParsedDecision{
127-
ContentSeen: true,
128-
Stop: true,
129-
StopReason: streamengine.StopReason("tool_use_detected"),
130-
}
131-
}
132122
continue
133123
}
134124
s.closeThinkingBlock()

internal/adapter/claude/stream_runtime_finalize.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
4545
finalText := s.text.String()
4646

4747
if s.bufferToolContent {
48-
detected := util.ParseToolCalls(finalText, s.toolNames)
48+
detected := util.ParseStandaloneToolCalls(finalText, s.toolNames)
4949
if len(detected) == 0 && finalText == "" && finalThinking != "" {
50-
detected = util.ParseToolCalls(finalThinking, s.toolNames)
50+
detected = util.ParseStandaloneToolCalls(finalThinking, s.toolNames)
5151
}
5252
if len(detected) > 0 {
5353
stopReason = "tool_use"

internal/adapter/openai/handler_toolcall_format.go

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,28 +111,21 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNam
111111
if len(deltas) == 0 {
112112
return nil
113113
}
114-
allowed := namesToSet(allowedNames)
115-
if len(allowed) == 0 {
116-
for _, d := range deltas {
117-
if d.Name != "" {
118-
seenNames[d.Index] = "__blocked__"
119-
}
120-
}
121-
return nil
122-
}
123114
out := make([]toolCallDelta, 0, len(deltas))
124115
for _, d := range deltas {
125116
if d.Name != "" {
126-
if _, ok := allowed[d.Name]; !ok {
127-
seenNames[d.Index] = "__blocked__"
128-
continue
117+
if seenNames != nil {
118+
seenNames[d.Index] = d.Name
129119
}
130-
seenNames[d.Index] = d.Name
120+
out = append(out, d)
121+
continue
122+
}
123+
if seenNames == nil {
131124
out = append(out, d)
132125
continue
133126
}
134127
name := strings.TrimSpace(seenNames[d.Index])
135-
if name == "" || name == "__blocked__" {
128+
if name == "" {
136129
continue
137130
}
138131
out = append(out, d)

internal/adapter/openai/handler_toolcall_test.go

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
182182
}
183183
}
184184

185-
func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
185+
func TestHandleNonStreamUnknownToolIntercepted(t *testing.T) {
186186
h := &Handler{}
187187
resp := makeSSEHTTPResponse(
188188
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
@@ -198,16 +198,13 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
198198
out := decodeJSONBody(t, rec.Body.String())
199199
choices, _ := out["choices"].([]any)
200200
choice, _ := choices[0].(map[string]any)
201-
if choice["finish_reason"] != "stop" {
202-
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
201+
if choice["finish_reason"] != "tool_calls" {
202+
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
203203
}
204204
msg, _ := choice["message"].(map[string]any)
205-
if _, ok := msg["tool_calls"]; ok {
206-
t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"])
207-
}
208-
content, _ := msg["content"].(string)
209-
if !strings.Contains(content, `"tool_calls"`) {
210-
t.Fatalf("expected unknown tool json to pass through as text, got %#v", content)
205+
toolCalls, _ := msg["tool_calls"].([]any)
206+
if len(toolCalls) != 1 {
207+
t.Fatalf("expected tool_calls for unknown schema name, got %#v", msg["tool_calls"])
211208
}
212209
}
213210

@@ -413,7 +410,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.
413410
}
414411
}
415412

416-
func TestHandleStreamUnknownToolDoesNotLeakRawPayload(t *testing.T) {
413+
func TestHandleStreamUnknownToolEmitsToolCall(t *testing.T) {
417414
h := &Handler{}
418415
resp := makeSSEHTTPResponse(
419416
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
@@ -428,18 +425,18 @@ func TestHandleStreamUnknownToolDoesNotLeakRawPayload(t *testing.T) {
428425
if !done {
429426
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
430427
}
431-
if streamHasToolCallsDelta(frames) {
432-
t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String())
428+
if !streamHasToolCallsDelta(frames) {
429+
t.Fatalf("expected tool_calls delta for unknown schema name, body=%s", rec.Body.String())
433430
}
434431
if streamHasRawToolJSONContent(frames) {
435432
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name: %s", rec.Body.String())
436433
}
437-
if streamFinishReason(frames) != "stop" {
438-
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
434+
if streamFinishReason(frames) != "tool_calls" {
435+
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
439436
}
440437
}
441438

442-
func TestHandleStreamUnknownToolNoArgsDoesNotLeakRawPayload(t *testing.T) {
439+
func TestHandleStreamUnknownToolNoArgsEmitsToolCall(t *testing.T) {
443440
h := &Handler{}
444441
resp := makeSSEHTTPResponse(
445442
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\"}]}"}`,
@@ -454,14 +451,14 @@ func TestHandleStreamUnknownToolNoArgsDoesNotLeakRawPayload(t *testing.T) {
454451
if !done {
455452
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
456453
}
457-
if streamHasToolCallsDelta(frames) {
458-
t.Fatalf("did not expect tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String())
454+
if !streamHasToolCallsDelta(frames) {
455+
t.Fatalf("expected tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String())
459456
}
460457
if streamHasRawToolJSONContent(frames) {
461458
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name (no args): %s", rec.Body.String())
462459
}
463-
if streamFinishReason(frames) != "stop" {
464-
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
460+
if streamFinishReason(frames) != "tool_calls" {
461+
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
465462
}
466463
}
467464

internal/adapter/openai/responses_stream_test.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *te
354354
}
355355
}
356356

357-
func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
357+
func TestHandleResponsesStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
358358
h := &Handler{}
359359
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
360360
rec := httptest.NewRecorder()
@@ -376,8 +376,8 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
376376

377377
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "")
378378
body := rec.Body.String()
379-
if strings.Contains(body, "event: response.function_call_arguments.done") {
380-
t.Fatalf("did not expect function_call events for tool_choice=none, body=%s", body)
379+
if !strings.Contains(body, "event: response.function_call_arguments.done") {
380+
t.Fatalf("expected function_call events for tool_choice=none, body=%s", body)
381381
}
382382
}
383383

@@ -518,7 +518,7 @@ func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
518518
}
519519
}
520520

521-
func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
521+
func TestHandleResponsesStreamAllowsUnknownToolName(t *testing.T) {
522522
h := &Handler{}
523523
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
524524
rec := httptest.NewRecorder()
@@ -539,8 +539,8 @@ func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
539539

540540
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
541541
body := rec.Body.String()
542-
if strings.Contains(body, "event: response.function_call_arguments.done") {
543-
t.Fatalf("did not expect function_call events for unknown tool, body=%s", body)
542+
if !strings.Contains(body, "event: response.function_call_arguments.done") {
543+
t.Fatalf("expected function_call events for unknown tool, body=%s", body)
544544
}
545545
}
546546

@@ -597,7 +597,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t
597597
}
598598
}
599599

600-
func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
600+
func TestHandleResponsesNonStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
601601
h := &Handler{}
602602
rec := httptest.NewRecorder()
603603
resp := &http.Response{
@@ -611,16 +611,20 @@ func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T)
611611

612612
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "")
613613
if rec.Code != http.StatusOK {
614-
t.Fatalf("expected 200 for tool_choice=none passthrough text, got %d body=%s", rec.Code, rec.Body.String())
614+
t.Fatalf("expected 200 for tool_choice=none handling, got %d body=%s", rec.Code, rec.Body.String())
615615
}
616616
out := decodeJSONBody(t, rec.Body.String())
617617
output, _ := out["output"].([]any)
618+
foundFunctionCall := false
618619
for _, item := range output {
619620
m, _ := item.(map[string]any)
620621
if m != nil && m["type"] == "function_call" {
621-
t.Fatalf("did not expect function_call output item for tool_choice=none, got %#v", output)
622+
foundFunctionCall = true
622623
}
623624
}
625+
if !foundFunctionCall {
626+
t.Fatalf("expected function_call output item for tool_choice=none, got %#v", output)
627+
}
624628
}
625629

626630
func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {

internal/adapter/openai/standard_request.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
2525
}
2626
toolPolicy := util.DefaultToolChoicePolicy()
2727
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
28+
toolNames = ensureToolDetectionEnabled(toolNames, req["tools"])
2829
passThrough := collectOpenAIChatPassThrough(req)
2930

3031
return util.StandardRequest{
@@ -74,10 +75,8 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
7475
return util.StandardRequest{}, err
7576
}
7677
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
77-
if toolPolicy.IsNone() {
78-
toolNames = nil
79-
toolPolicy.Allowed = nil
80-
} else {
78+
toolNames = ensureToolDetectionEnabled(toolNames, req["tools"])
79+
if !toolPolicy.IsNone() {
8180
toolPolicy.Allowed = namesToSet(toolNames)
8281
}
8382
passThrough := collectOpenAIChatPassThrough(req)
@@ -98,6 +97,20 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
9897
}, nil
9998
}
10099

100+
func ensureToolDetectionEnabled(toolNames []string, toolsRaw any) []string {
101+
if len(toolNames) > 0 {
102+
return toolNames
103+
}
104+
tools, _ := toolsRaw.([]any)
105+
if len(tools) == 0 {
106+
return toolNames
107+
}
108+
// Keep stream sieve/tool buffering enabled even when client tool schemas
109+
// are malformed or lack explicit names; parsed tool payload names are no
110+
// longer filtered by this list.
111+
return []string{"__any_tool__"}
112+
}
113+
101114
func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
102115
out := map[string]any{}
103116
for _, k := range []string{

internal/adapter/openai/standard_request_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testi
152152
}
153153
}
154154

155-
func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T) {
155+
func TestNormalizeOpenAIResponsesRequestToolChoiceNoneKeepsToolDetectionEnabled(t *testing.T) {
156156
store := newEmptyStoreForNormalizeTest(t)
157157
req := map[string]any{
158158
"model": "gpt-4o",
@@ -174,7 +174,7 @@ func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T
174174
if n.ToolChoice.Mode != util.ToolChoiceNone {
175175
t.Fatalf("expected tool choice mode none, got %q", n.ToolChoice.Mode)
176176
}
177-
if len(n.ToolNames) != 0 {
178-
t.Fatalf("expected no tool names when tool_choice=none, got %#v", n.ToolNames)
177+
if len(n.ToolNames) == 0 {
178+
t.Fatalf("expected tool detection sentinel when tool_choice=none, got %#v", n.ToolNames)
179179
}
180180
}

internal/adapter/openai/tool_sieve_core.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func findToolSegmentStart(s string) int {
167167
return -1
168168
}
169169
lower := strings.ToLower(s)
170-
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]", "[tool_result_history]"}
170+
keywords := []string{"tool_calls", "\"function\"", "function.name:", "[tool_call_history]", "[tool_result_history]"}
171171
bestKeyIdx := -1
172172
for _, kw := range keywords {
173173
idx := strings.Index(lower, kw)
@@ -195,7 +195,7 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
195195
}
196196
lower := strings.ToLower(captured)
197197
keyIdx := -1
198-
keywords := []string{"tool_calls", "function.name:", "[tool_call_history]", "[tool_result_history]"}
198+
keywords := []string{"tool_calls", "\"function\"", "function.name:", "[tool_call_history]", "[tool_result_history]"}
199199
for _, kw := range keywords {
200200
idx := strings.Index(lower, kw)
201201
if idx >= 0 && (keyIdx < 0 || idx < keyIdx) {

0 commit comments

Comments
 (0)