Skip to content

Commit 4ee7a2e

Browse files
Revert "refactor: reimplement tool patching via reverse traversal (#1050)"
This reverts commit 7491283.
1 parent 7491283 commit 4ee7a2e

1 file changed

Lines changed: 78 additions & 55 deletions

File tree

adk/middlewares/patchtoolcalls/patchtoolcalls.go

Lines changed: 78 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ type typedMiddleware[M adk.MessageType] struct {
6969
}
7070

7171
func (m *typedMiddleware[M]) BeforeModelRewriteState(ctx context.Context, state *adk.TypedChatModelAgentState[M],
72-
mc *adk.TypedModelContext[M],
73-
) (context.Context, *adk.TypedChatModelAgentState[M], error) {
72+
mc *adk.TypedModelContext[M]) (context.Context, *adk.TypedChatModelAgentState[M], error) {
73+
7474
if len(state.Messages) == 0 {
7575
return ctx, state, nil
7676
}
@@ -89,82 +89,112 @@ func (m *typedMiddleware[M]) BeforeModelRewriteState(ctx context.Context, state
8989
func patchToolCallsForMessage[M adk.MessageType](ctx context.Context,
9090
gen func(ctx context.Context, toolName, toolCallID string) (string, error),
9191
state *adk.TypedChatModelAgentState[*schema.Message],
92-
_ *adk.TypedModelContext[M],
93-
) (context.Context, *adk.TypedChatModelAgentState[M], error) {
94-
// seenIDs stores unique tool call IDs collected by reverse traversal
95-
seenIDs := make(map[string]struct{})
92+
_ *adk.TypedModelContext[M]) (context.Context, *adk.TypedChatModelAgentState[M], error) {
93+
9694
patched := make([]*schema.Message, 0, len(state.Messages))
9795

98-
// Iterate messages in reverse order to track existing tool call IDs
99-
for i := len(state.Messages) - 1; i >= 0; i-- {
100-
msg := state.Messages[i]
96+
for i, msg := range state.Messages {
97+
patched = append(patched, msg)
10198

102-
if msg.Role == schema.Tool {
103-
seenIDs[msg.ToolCallID] = struct{}{}
99+
if msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 {
100+
continue
104101
}
105102

106-
if msg.Role == schema.Assistant && len(msg.ToolCalls) > 0 {
107-
for _, tc := range msg.ToolCalls {
108-
if _, exists := seenIDs[tc.ID]; !exists {
109-
toolMsg, err := createPatchedToolMessage(ctx, gen, tc)
110-
if err != nil {
111-
return ctx, nil, err
112-
}
113-
patched = append(patched, toolMsg)
114-
}
103+
for _, tc := range msg.ToolCalls {
104+
if hasCorrespondingToolMessage(state.Messages[i+1:], tc.ID) {
105+
continue
115106
}
116-
}
117107

118-
patched = append(patched, msg)
108+
toolMsg, err := createPatchedToolMessage(ctx, gen, tc)
109+
if err != nil {
110+
return ctx, nil, err
111+
}
112+
patched = append(patched, toolMsg)
113+
}
119114
}
120115

121116
nState := *state
122-
nState.Messages = reverse(patched)
117+
nState.Messages = patched
123118
return ctx, any(&nState).(*adk.TypedChatModelAgentState[M]), nil
124119
}
125120

126121
func patchToolCallsForAgenticMessage[M adk.MessageType](ctx context.Context,
127122
gen func(ctx context.Context, toolName, toolCallID string) (string, error),
128123
state *adk.TypedChatModelAgentState[*schema.AgenticMessage],
129-
_ *adk.TypedModelContext[M],
130-
) (context.Context, *adk.TypedChatModelAgentState[M], error) {
131-
// seenIDs stores unique tool call IDs collected by reverse traversal
132-
seenIDs := make(map[string]struct{})
124+
_ *adk.TypedModelContext[M]) (context.Context, *adk.TypedChatModelAgentState[M], error) {
125+
133126
patched := make([]*schema.AgenticMessage, 0, len(state.Messages))
134127

135-
// Iterate messages in reverse order to track existing tool call IDs
136-
for i := len(state.Messages) - 1; i >= 0; i-- {
137-
msg := state.Messages[i]
128+
for i, msg := range state.Messages {
129+
patched = append(patched, msg)
130+
131+
if msg.Role != schema.AgenticRoleTypeAssistant {
132+
continue
133+
}
138134

135+
// Collect tool call IDs from this assistant message.
136+
var toolCalls []struct {
137+
callID string
138+
name string
139+
}
139140
for _, block := range msg.ContentBlocks {
140-
if block == nil {
141-
continue
141+
if block != nil && block.Type == schema.ContentBlockTypeFunctionToolCall && block.FunctionToolCall != nil {
142+
toolCalls = append(toolCalls, struct {
143+
callID string
144+
name string
145+
}{callID: block.FunctionToolCall.CallID, name: block.FunctionToolCall.Name})
142146
}
143-
if block.Type == schema.ContentBlockTypeFunctionToolResult && block.FunctionToolResult != nil {
144-
seenIDs[block.FunctionToolResult.CallID] = struct{}{}
145-
}
146-
if block.Type == schema.ContentBlockTypeToolSearchResult && block.ToolSearchFunctionToolResult != nil {
147-
seenIDs[block.ToolSearchFunctionToolResult.CallID] = struct{}{}
147+
}
148+
if len(toolCalls) == 0 {
149+
continue
150+
}
151+
152+
for _, tc := range toolCalls {
153+
if hasCorrespondingAgenticToolResult(state.Messages[i+1:], tc.callID) {
154+
continue
148155
}
149-
if block.Type == schema.ContentBlockTypeFunctionToolCall && block.FunctionToolCall != nil {
150-
if _, exists := seenIDs[block.FunctionToolCall.CallID]; !exists {
151-
toolMsg, err := createPatchedAgenticToolMessage(ctx, gen, block.FunctionToolCall.Name, block.FunctionToolCall.CallID)
152-
if err != nil {
153-
return ctx, nil, err
154-
}
155-
patched = append(patched, toolMsg)
156-
}
156+
157+
toolMsg, err := createPatchedAgenticToolMessage(ctx, gen, tc.name, tc.callID)
158+
if err != nil {
159+
return ctx, nil, err
157160
}
161+
patched = append(patched, toolMsg)
158162
}
159-
160-
patched = append(patched, msg)
161163
}
162164

163165
nState := *state
164-
nState.Messages = reverse(patched)
166+
nState.Messages = patched
165167
return ctx, any(&nState).(*adk.TypedChatModelAgentState[M]), nil
166168
}
167169

170+
func hasCorrespondingToolMessage(messages []*schema.Message, toolCallID string) bool {
171+
for _, msg := range messages {
172+
if msg.Role == schema.Tool && msg.ToolCallID == toolCallID {
173+
return true
174+
}
175+
}
176+
return false
177+
}
178+
179+
func hasCorrespondingAgenticToolResult(messages []*schema.AgenticMessage, toolCallID string) bool {
180+
for _, msg := range messages {
181+
for _, block := range msg.ContentBlocks {
182+
if block == nil {
183+
continue
184+
}
185+
if block.Type == schema.ContentBlockTypeFunctionToolResult &&
186+
block.FunctionToolResult != nil && block.FunctionToolResult.CallID == toolCallID {
187+
return true
188+
}
189+
if block.Type == schema.ContentBlockTypeToolSearchResult &&
190+
block.ToolSearchFunctionToolResult != nil && block.ToolSearchFunctionToolResult.CallID == toolCallID {
191+
return true
192+
}
193+
}
194+
}
195+
return false
196+
}
197+
168198
func createPatchedToolMessage(ctx context.Context, gen func(ctx context.Context, toolName, toolCallID string) (string, error), tc schema.ToolCall) (*schema.Message, error) {
169199
if gen != nil {
170200
content, err := gen(ctx, tc.Function.Name, tc.ID)
@@ -211,13 +241,6 @@ func createPatchedAgenticToolMessage(ctx context.Context, gen func(ctx context.C
211241
}, nil
212242
}
213243

214-
func reverse[M adk.MessageType](s []M) []M {
215-
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
216-
s[i], s[j] = s[j], s[i]
217-
}
218-
return s
219-
}
220-
221244
const (
222245
defaultPatchedToolMessageTemplate = "Tool call %s with id %s was canceled - another message came in before it could be completed."
223246
defaultPatchedToolMessageTemplateChinese = "工具调用 %s(ID 为 %s)已被取消——在其完成之前收到了另一条消息。"

0 commit comments

Comments
 (0)