Skip to content

Commit f644ee7

Browse files
committed
fix: dedupe tool outputs before auth rotation
1 parent e512a67 commit f644ee7

6 files changed

Lines changed: 310 additions & 87 deletions

File tree

internal/runtime/executor/codex_executor.go

Lines changed: 2 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
186186
// 防御性去重:翻译链中可能因多 Key / 多层处理导致 input 数组里
187187
// 同一个 call_id 的 function_call_output 或 tool_search_output 被重复写入。
188188
// 在所有请求体变换完成后做一次最终去重,避免模型收到重复工具结果。
189-
body = dedupeToolOutputs(body)
189+
body = cliproxyexecutor.DedupeToolOutputs(body)
190190

191191
url := strings.TrimSuffix(baseURL, "/") + "/responses"
192192
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -432,7 +432,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
432432
body = ensureImageGenerationTool(body, baseModel, auth)
433433

434434
// 防御性去重:同 Execute 方法,防止 input 中工具输出重复。
435-
body = dedupeToolOutputs(body)
435+
body = cliproxyexecutor.DedupeToolOutputs(body)
436436

437437
url := strings.TrimSuffix(baseURL, "/") + "/responses"
438438
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -1048,85 +1048,3 @@ func codexConfigLookupAttrs(auth *cliproxyauth.Auth) (apiKey, baseURL string) {
10481048
}
10491049
return strings.TrimSpace(auth.Attributes["api_key"]), strings.TrimSpace(auth.Attributes["base_url"])
10501050
}
1051-
1052-
// dedupeToolOutputs 移除工具结果数组中 call_id / tool_call_id 重复的项(保留最后一次出现)。
1053-
// 同时兼容两种上游格式:
1054-
// - Responses 格式:input[].{type, call_id} (type 为 function_call_output / tool_search_output)
1055-
// - Chat 格式: messages[].{role, tool_call_id}(role 为 tool)
1056-
func dedupeToolOutputs(body []byte) []byte {
1057-
// ---- Responses 格式 ----
1058-
if inputItems := gjson.GetBytes(body, "input"); inputItems.IsArray() {
1059-
if deduped, changed := dedupeInputArray(inputItems, "type", "call_id", "function_call_output", "tool_search_output"); changed {
1060-
body, _ = sjson.SetRawBytes(body, "input", deduped)
1061-
}
1062-
}
1063-
1064-
// ---- Chat Completions 格式 ----
1065-
if messages := gjson.GetBytes(body, "messages"); messages.IsArray() {
1066-
if deduped, changed := dedupeInputArray(messages, "role", "tool_call_id", "tool"); changed {
1067-
body, _ = sjson.SetRawBytes(body, "messages", deduped)
1068-
}
1069-
}
1070-
1071-
return body
1072-
}
1073-
1074-
// dedupeInputArray 对数组中匹配指定 keyField 和 matchTypes 的项按 idField 去重,
1075-
// 保留最后一次出现。返回重建后的 JSON 数组字节和是否有变更。
1076-
func dedupeInputArray(arr gjson.Result, typeField, idField string, matchTypes ...string) ([]byte, bool) {
1077-
items := arr.Array()
1078-
lastIdxByID := make(map[string]int, len(items))
1079-
outputIdx := make([]int, 0, len(items))
1080-
1081-
matchSet := make(map[string]struct{}, len(matchTypes))
1082-
for _, t := range matchTypes {
1083-
matchSet[t] = struct{}{}
1084-
}
1085-
1086-
for i, item := range items {
1087-
typ := item.Get(typeField).String()
1088-
if _, ok := matchSet[typ]; !ok {
1089-
continue
1090-
}
1091-
id := strings.TrimSpace(item.Get(idField).String())
1092-
if id == "" {
1093-
continue
1094-
}
1095-
outputIdx = append(outputIdx, i)
1096-
lastIdxByID[id] = i // 最后出现覆盖前面的
1097-
}
1098-
1099-
keep := make(map[int]bool, len(lastIdxByID))
1100-
for _, idx := range lastIdxByID {
1101-
keep[idx] = true
1102-
}
1103-
1104-
dupes := make(map[int]bool)
1105-
for _, idx := range outputIdx {
1106-
if !keep[idx] {
1107-
dupes[idx] = true
1108-
}
1109-
}
1110-
1111-
if len(dupes) == 0 {
1112-
return nil, false
1113-
}
1114-
1115-
// 重建数组,跳过重复项
1116-
filtered := make([]byte, 0, len(arr.Raw))
1117-
filtered = append(filtered, '[')
1118-
first := true
1119-
for i, item := range items {
1120-
if dupes[i] {
1121-
continue
1122-
}
1123-
if !first {
1124-
filtered = append(filtered, ',')
1125-
}
1126-
filtered = append(filtered, []byte(item.Raw)...)
1127-
first = false
1128-
}
1129-
filtered = append(filtered, ']')
1130-
1131-
return filtered, true
1132-
}

internal/runtime/executor/openai_compat_executor.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
334334
}
335335

336336
// 防御性去重:多 Key / 重试时翻译链可能重复写入 tool output。
337-
translated = dedupeToolOutputs(translated)
337+
translated = cliproxyexecutor.DedupeToolOutputs(translated)
338338

339339
upstreamURL := strings.TrimSuffix(baseURL, "/") + endpoint
340340
parsedURL, err := url.Parse(upstreamURL)
@@ -447,7 +447,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
447447
}
448448

449449
// 防御性去重:多 Key / 重试时翻译链可能重复写入 tool output。
450-
translated = dedupeToolOutputs(translated)
450+
translated = cliproxyexecutor.DedupeToolOutputs(translated)
451451

452452
upstreamURL := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
453453
parsedURL, err := url.Parse(upstreamURL)

sdk/api/handlers/openai/openai_responses_handlers_stream_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
6060
t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1)
6161
}
6262

63-
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
63+
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"function_call\",\"arguments\":\"{}\"}]}}"
6464
if parts[1] != expectedPart2 {
6565
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
6666
}

sdk/cliproxy/auth/conductor.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,7 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
12401240
if len(normalized) == 0 {
12411241
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
12421242
}
1243+
req, opts = cliproxyexecutor.DedupeRequestToolOutputs(req, opts)
12431244

12441245
_, maxRetryCredentials, maxWait := m.retrySettings()
12451246

@@ -1271,6 +1272,7 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
12711272
if len(normalized) == 0 {
12721273
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
12731274
}
1275+
req, opts = cliproxyexecutor.DedupeRequestToolOutputs(req, opts)
12741276

12751277
_, maxRetryCredentials, maxWait := m.retrySettings()
12761278

@@ -1302,6 +1304,7 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
13021304
if len(normalized) == 0 {
13031305
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
13041306
}
1307+
req, opts = cliproxyexecutor.DedupeRequestToolOutputs(req, opts)
13051308

13061309
_, maxRetryCredentials, maxWait := m.retrySettings()
13071310

sdk/cliproxy/executor/payload.go

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
package executor
2+
3+
import (
4+
"strings"
5+
6+
"github.com/tidwall/gjson"
7+
"github.com/tidwall/sjson"
8+
)
9+
10+
// DedupeRequestToolOutputs 在进入具体 provider executor 前清理请求历史中的重复工具结果。
11+
// 多 key / 多 auth 重试会让同一个客户端请求被多次翻译和尝试;如果上游历史里已经带了重复的
12+
// tool output,这里先按 call_id/tool_call_id/tool_use_id 保留最后一次,保证所有 provider 共用同一份干净历史。
13+
func DedupeRequestToolOutputs(req Request, opts Options) (Request, Options) {
14+
req.Payload = DedupeToolOutputs(req.Payload)
15+
if len(opts.OriginalRequest) > 0 {
16+
opts.OriginalRequest = DedupeToolOutputs(opts.OriginalRequest)
17+
}
18+
return req, opts
19+
}
20+
21+
// DedupeToolOutputs 移除常见 API 形态中的重复工具结果,保留最后一次出现的结果。
22+
// 覆盖 Responses input[].call_id、Chat messages[].tool_call_id/call_id,以及 Claude tool_result.tool_use_id。
23+
func DedupeToolOutputs(body []byte) []byte {
24+
if len(body) == 0 {
25+
return body
26+
}
27+
28+
body = dedupeRootArray(body, "input", func(item gjson.Result) (string, bool) {
29+
itemType := strings.TrimSpace(item.Get("type").String())
30+
if !isResponsesToolOutputType(itemType) {
31+
return "", false
32+
}
33+
id := strings.TrimSpace(item.Get("call_id").String())
34+
return id, id != ""
35+
})
36+
37+
body = dedupeRootArray(body, "messages", func(item gjson.Result) (string, bool) {
38+
if strings.TrimSpace(item.Get("role").String()) != "tool" {
39+
return "", false
40+
}
41+
id := strings.TrimSpace(item.Get("tool_call_id").String())
42+
if id == "" {
43+
id = strings.TrimSpace(item.Get("call_id").String())
44+
}
45+
return id, id != ""
46+
})
47+
48+
body = dedupeClaudeToolResultParts(body)
49+
return body
50+
}
51+
52+
func isResponsesToolOutputType(itemType string) bool {
53+
switch itemType {
54+
case "function_call_output", "tool_search_output", "web_search_call_output", "computer_call_output", "custom_tool_call_output", "local_shell_call_output":
55+
return true
56+
default:
57+
return strings.HasSuffix(itemType, "_call_output")
58+
}
59+
}
60+
61+
func dedupeRootArray(body []byte, path string, keyFn func(gjson.Result) (string, bool)) []byte {
62+
arr := gjson.GetBytes(body, path)
63+
if !arr.IsArray() {
64+
return body
65+
}
66+
deduped, changed := dedupeJSONResults(arr.Array(), keyFn)
67+
if !changed {
68+
return body
69+
}
70+
updated, err := sjson.SetRawBytes(body, path, deduped)
71+
if err != nil {
72+
return body
73+
}
74+
return updated
75+
}
76+
77+
func dedupeJSONResults(items []gjson.Result, keyFn func(gjson.Result) (string, bool)) ([]byte, bool) {
78+
lastIdxByKey := make(map[string]int, len(items))
79+
candidateIdx := make([]int, 0, len(items))
80+
for i, item := range items {
81+
key, ok := keyFn(item)
82+
if !ok {
83+
continue
84+
}
85+
candidateIdx = append(candidateIdx, i)
86+
lastIdxByKey[key] = i
87+
}
88+
89+
if len(lastIdxByKey) == len(candidateIdx) {
90+
return nil, false
91+
}
92+
93+
keep := make(map[int]struct{}, len(lastIdxByKey))
94+
for _, idx := range lastIdxByKey {
95+
keep[idx] = struct{}{}
96+
}
97+
candidates := make(map[int]struct{}, len(candidateIdx))
98+
for _, idx := range candidateIdx {
99+
candidates[idx] = struct{}{}
100+
}
101+
102+
filtered := make([]byte, 0)
103+
filtered = append(filtered, '[')
104+
first := true
105+
for i, item := range items {
106+
if _, isCandidate := candidates[i]; isCandidate {
107+
if _, ok := keep[i]; !ok {
108+
continue
109+
}
110+
}
111+
if !first {
112+
filtered = append(filtered, ',')
113+
}
114+
filtered = append(filtered, item.Raw...)
115+
first = false
116+
}
117+
filtered = append(filtered, ']')
118+
return filtered, true
119+
}
120+
121+
func dedupeClaudeToolResultParts(body []byte) []byte {
122+
messages := gjson.GetBytes(body, "messages")
123+
if !messages.IsArray() {
124+
return body
125+
}
126+
127+
messageItems := messages.Array()
128+
type partRef struct {
129+
messageIdx int
130+
partIdx int
131+
}
132+
lastRefByID := make(map[string]partRef)
133+
candidateRefs := make([]partRef, 0)
134+
135+
for msgIdx, msg := range messageItems {
136+
content := msg.Get("content")
137+
if !content.IsArray() {
138+
continue
139+
}
140+
for partIdx, part := range content.Array() {
141+
if strings.TrimSpace(part.Get("type").String()) != "tool_result" {
142+
continue
143+
}
144+
id := strings.TrimSpace(part.Get("tool_use_id").String())
145+
if id == "" {
146+
continue
147+
}
148+
ref := partRef{messageIdx: msgIdx, partIdx: partIdx}
149+
candidateRefs = append(candidateRefs, ref)
150+
lastRefByID[id] = ref
151+
}
152+
}
153+
154+
if len(lastRefByID) == len(candidateRefs) {
155+
return body
156+
}
157+
158+
keep := make(map[partRef]struct{}, len(lastRefByID))
159+
for _, ref := range lastRefByID {
160+
keep[ref] = struct{}{}
161+
}
162+
drop := make(map[partRef]struct{}, len(candidateRefs)-len(lastRefByID))
163+
for _, ref := range candidateRefs {
164+
if _, ok := keep[ref]; !ok {
165+
drop[ref] = struct{}{}
166+
}
167+
}
168+
169+
filteredMessages := make([]byte, 0, len(messages.Raw))
170+
filteredMessages = append(filteredMessages, '[')
171+
firstMessage := true
172+
for msgIdx, msg := range messageItems {
173+
content := msg.Get("content")
174+
messageRaw := []byte(msg.Raw)
175+
if content.IsArray() {
176+
parts := content.Array()
177+
filteredParts := make([]byte, 0, len(content.Raw))
178+
filteredParts = append(filteredParts, '[')
179+
firstPart := true
180+
for partIdx, part := range parts {
181+
if _, shouldDrop := drop[partRef{messageIdx: msgIdx, partIdx: partIdx}]; shouldDrop {
182+
continue
183+
}
184+
if !firstPart {
185+
filteredParts = append(filteredParts, ',')
186+
}
187+
filteredParts = append(filteredParts, part.Raw...)
188+
firstPart = false
189+
}
190+
filteredParts = append(filteredParts, ']')
191+
if len(gjson.ParseBytes(filteredParts).Array()) == 0 {
192+
continue
193+
}
194+
updated, err := sjson.SetRawBytes(messageRaw, "content", filteredParts)
195+
if err == nil {
196+
messageRaw = updated
197+
}
198+
}
199+
if !firstMessage {
200+
filteredMessages = append(filteredMessages, ',')
201+
}
202+
filteredMessages = append(filteredMessages, messageRaw...)
203+
firstMessage = false
204+
}
205+
filteredMessages = append(filteredMessages, ']')
206+
207+
updated, err := sjson.SetRawBytes(body, "messages", filteredMessages)
208+
if err != nil {
209+
return body
210+
}
211+
return updated
212+
}

0 commit comments

Comments
 (0)