Skip to content

Commit e512a67

Browse files
committed
fix: extend dedupeToolOutputs to handle both Responses (input[].call_id) and Chat (messages[].tool_call_id) formats for multi-vendor compatibility
1 parent 71c0096 commit e512a67

1 file changed

Lines changed: 45 additions & 30 deletions

File tree

internal/runtime/executor/codex_executor.go

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,55 +1049,74 @@ func codexConfigLookupAttrs(auth *cliproxyauth.Auth) (apiKey, baseURL string) {
10491049
return strings.TrimSpace(auth.Attributes["api_key"]), strings.TrimSpace(auth.Attributes["base_url"])
10501050
}
10511051

1052-
// dedupeToolOutputs 移除 input 数组中 call_id 重复的 function_call_output
1053-
// 和 tool_search_output 项(保留首次出现),防止上游翻译链或重试逻辑
1054-
// 导致模型收到重复工具结果。
1052+
// dedupeToolOutputs 移除工具结果数组中 call_id / tool_call_id 重复的项(保留最后一次出现)。
1053+
// 同时兼容两种上游格式:
1054+
// - Responses 格式:input[].{type, call_id} (type 为 function_call_output / tool_search_output)
1055+
// - Chat 格式: messages[].{role, tool_call_id}(role 为 tool)
10551056
func dedupeToolOutputs(body []byte) []byte {
1056-
inputItems := gjson.GetBytes(body, "input")
1057-
if !inputItems.IsArray() {
1058-
return body
1057+
// ---- Responses 格式 ----
1058+
if inputItems := gjson.GetBytes(body, "input"); inputItems.IsArray() {
1059+
if deduped, changed := dedupeInputArray(inputItems, "type", "call_id", "function_call_output", "tool_search_output"); changed {
1060+
body, _ = sjson.SetRawBytes(body, "input", deduped)
1061+
}
1062+
}
1063+
1064+
// ---- Chat Completions 格式 ----
1065+
if messages := gjson.GetBytes(body, "messages"); messages.IsArray() {
1066+
if deduped, changed := dedupeInputArray(messages, "role", "tool_call_id", "tool"); changed {
1067+
body, _ = sjson.SetRawBytes(body, "messages", deduped)
1068+
}
10591069
}
10601070

1061-
arr := inputItems.Array()
1062-
lastIdxByCallID := make(map[string]int, len(arr))
1063-
toolOutputIdx := make([]int, 0, len(arr))
1071+
return body
1072+
}
1073+
1074+
// dedupeInputArray 对数组中匹配指定 keyField 和 matchTypes 的项按 idField 去重,
1075+
// 保留最后一次出现。返回重建后的 JSON 数组字节和是否有变更。
1076+
func dedupeInputArray(arr gjson.Result, typeField, idField string, matchTypes ...string) ([]byte, bool) {
1077+
items := arr.Array()
1078+
lastIdxByID := make(map[string]int, len(items))
1079+
outputIdx := make([]int, 0, len(items))
10641080

1065-
for i, item := range arr {
1066-
typ := item.Get("type").String()
1067-
if typ != "function_call_output" && typ != "tool_search_output" {
1081+
matchSet := make(map[string]struct{}, len(matchTypes))
1082+
for _, t := range matchTypes {
1083+
matchSet[t] = struct{}{}
1084+
}
1085+
1086+
for i, item := range items {
1087+
typ := item.Get(typeField).String()
1088+
if _, ok := matchSet[typ]; !ok {
10681089
continue
10691090
}
1070-
callID := strings.TrimSpace(item.Get("call_id").String())
1071-
if callID == "" {
1091+
id := strings.TrimSpace(item.Get(idField).String())
1092+
if id == "" {
10721093
continue
10731094
}
1074-
toolOutputIdx = append(toolOutputIdx, i)
1075-
lastIdxByCallID[callID] = i // 最后一次出现覆盖前面的
1095+
outputIdx = append(outputIdx, i)
1096+
lastIdxByID[id] = i // 最后出现覆盖前面的
10761097
}
10771098

1078-
// 构建保留集合:每个 call_id 只保留最后出现的那一项
1079-
keep := make(map[int]bool, len(lastIdxByCallID))
1080-
for _, idx := range lastIdxByCallID {
1099+
keep := make(map[int]bool, len(lastIdxByID))
1100+
for _, idx := range lastIdxByID {
10811101
keep[idx] = true
10821102
}
10831103

1084-
// 统计需要移除的项
10851104
dupes := make(map[int]bool)
1086-
for _, idx := range toolOutputIdx {
1105+
for _, idx := range outputIdx {
10871106
if !keep[idx] {
10881107
dupes[idx] = true
10891108
}
10901109
}
10911110

10921111
if len(dupes) == 0 {
1093-
return body
1112+
return nil, false
10941113
}
10951114

1096-
// 重建 input 数组,跳过标记为重复的索引
1097-
filtered := make([]byte, 0, len(inputItems.Raw))
1115+
// 重建数组,跳过重复项
1116+
filtered := make([]byte, 0, len(arr.Raw))
10981117
filtered = append(filtered, '[')
10991118
first := true
1100-
for i, item := range arr {
1119+
for i, item := range items {
11011120
if dupes[i] {
11021121
continue
11031122
}
@@ -1109,9 +1128,5 @@ func dedupeToolOutputs(body []byte) []byte {
11091128
}
11101129
filtered = append(filtered, ']')
11111130

1112-
out, err := sjson.SetRawBytes(body, "input", filtered)
1113-
if err != nil {
1114-
return body
1115-
}
1116-
return out
1131+
return filtered, true
11171132
}

0 commit comments

Comments
 (0)