@@ -69,8 +69,8 @@ type typedMiddleware[M adk.MessageType] struct {
6969}
7070
7171func (m * typedMiddleware [M ]) BeforeModelRewriteState (ctx context.Context , state * adk.TypedChatModelAgentState [M ],
72- mc * adk.TypedModelContext [M ],
73- ) (context. Context , * adk. TypedChatModelAgentState [ M ], error ) {
72+ mc * adk.TypedModelContext [M ]) (context. Context , * adk. TypedChatModelAgentState [ M ], error ) {
73+
7474 if len (state .Messages ) == 0 {
7575 return ctx , state , nil
7676 }
@@ -89,82 +89,112 @@ func (m *typedMiddleware[M]) BeforeModelRewriteState(ctx context.Context, state
8989func patchToolCallsForMessage [M adk.MessageType ](ctx context.Context ,
9090 gen func (ctx context.Context , toolName , toolCallID string ) (string , error ),
9191 state * adk.TypedChatModelAgentState [* schema.Message ],
92- _ * adk.TypedModelContext [M ],
93- ) (context.Context , * adk.TypedChatModelAgentState [M ], error ) {
94- // seenIDs stores unique tool call IDs collected by reverse traversal
95- seenIDs := make (map [string ]struct {})
92+ _ * adk.TypedModelContext [M ]) (context.Context , * adk.TypedChatModelAgentState [M ], error ) {
93+
9694 patched := make ([]* schema.Message , 0 , len (state .Messages ))
9795
98- // Iterate messages in reverse order to track existing tool call IDs
99- for i := len (state .Messages ) - 1 ; i >= 0 ; i -- {
100- msg := state .Messages [i ]
96+ for i , msg := range state .Messages {
97+ patched = append (patched , msg )
10198
102- if msg .Role == schema .Tool {
103- seenIDs [ msg . ToolCallID ] = struct {}{}
99+ if msg .Role != schema .Assistant || len ( msg . ToolCalls ) == 0 {
100+ continue
104101 }
105102
106- if msg .Role == schema .Assistant && len (msg .ToolCalls ) > 0 {
107- for _ , tc := range msg .ToolCalls {
108- if _ , exists := seenIDs [tc .ID ]; ! exists {
109- toolMsg , err := createPatchedToolMessage (ctx , gen , tc )
110- if err != nil {
111- return ctx , nil , err
112- }
113- patched = append (patched , toolMsg )
114- }
103+ for _ , tc := range msg .ToolCalls {
104+ if hasCorrespondingToolMessage (state .Messages [i + 1 :], tc .ID ) {
105+ continue
115106 }
116- }
117107
118- patched = append (patched , msg )
108+ toolMsg , err := createPatchedToolMessage (ctx , gen , tc )
109+ if err != nil {
110+ return ctx , nil , err
111+ }
112+ patched = append (patched , toolMsg )
113+ }
119114 }
120115
121116 nState := * state
122- nState .Messages = reverse ( patched )
117+ nState .Messages = patched
123118 return ctx , any (& nState ).(* adk.TypedChatModelAgentState [M ]), nil
124119}
125120
126121func patchToolCallsForAgenticMessage [M adk.MessageType ](ctx context.Context ,
127122 gen func (ctx context.Context , toolName , toolCallID string ) (string , error ),
128123 state * adk.TypedChatModelAgentState [* schema.AgenticMessage ],
129- _ * adk.TypedModelContext [M ],
130- ) (context.Context , * adk.TypedChatModelAgentState [M ], error ) {
131- // seenIDs stores unique tool call IDs collected by reverse traversal
132- seenIDs := make (map [string ]struct {})
124+ _ * adk.TypedModelContext [M ]) (context.Context , * adk.TypedChatModelAgentState [M ], error ) {
125+
133126 patched := make ([]* schema.AgenticMessage , 0 , len (state .Messages ))
134127
135- // Iterate messages in reverse order to track existing tool call IDs
136- for i := len (state .Messages ) - 1 ; i >= 0 ; i -- {
137- msg := state .Messages [i ]
128+ for i , msg := range state .Messages {
129+ patched = append (patched , msg )
130+
131+ if msg .Role != schema .AgenticRoleTypeAssistant {
132+ continue
133+ }
138134
135+ // Collect tool call IDs from this assistant message.
136+ var toolCalls []struct {
137+ callID string
138+ name string
139+ }
139140 for _ , block := range msg .ContentBlocks {
140- if block == nil {
141- continue
141+ if block != nil && block .Type == schema .ContentBlockTypeFunctionToolCall && block .FunctionToolCall != nil {
142+ toolCalls = append (toolCalls , struct {
143+ callID string
144+ name string
145+ }{callID : block .FunctionToolCall .CallID , name : block .FunctionToolCall .Name })
142146 }
143- if block .Type == schema .ContentBlockTypeFunctionToolResult && block .FunctionToolResult != nil {
144- seenIDs [block .FunctionToolResult .CallID ] = struct {}{}
145- }
146- if block .Type == schema .ContentBlockTypeToolSearchResult && block .ToolSearchFunctionToolResult != nil {
147- seenIDs [block .ToolSearchFunctionToolResult .CallID ] = struct {}{}
147+ }
148+ if len (toolCalls ) == 0 {
149+ continue
150+ }
151+
152+ for _ , tc := range toolCalls {
153+ if hasCorrespondingAgenticToolResult (state .Messages [i + 1 :], tc .callID ) {
154+ continue
148155 }
149- if block .Type == schema .ContentBlockTypeFunctionToolCall && block .FunctionToolCall != nil {
150- if _ , exists := seenIDs [block .FunctionToolCall .CallID ]; ! exists {
151- toolMsg , err := createPatchedAgenticToolMessage (ctx , gen , block .FunctionToolCall .Name , block .FunctionToolCall .CallID )
152- if err != nil {
153- return ctx , nil , err
154- }
155- patched = append (patched , toolMsg )
156- }
156+
157+ toolMsg , err := createPatchedAgenticToolMessage (ctx , gen , tc .name , tc .callID )
158+ if err != nil {
159+ return ctx , nil , err
157160 }
161+ patched = append (patched , toolMsg )
158162 }
159-
160- patched = append (patched , msg )
161163 }
162164
163165 nState := * state
164- nState .Messages = reverse ( patched )
166+ nState .Messages = patched
165167 return ctx , any (& nState ).(* adk.TypedChatModelAgentState [M ]), nil
166168}
167169
170+ func hasCorrespondingToolMessage (messages []* schema.Message , toolCallID string ) bool {
171+ for _ , msg := range messages {
172+ if msg .Role == schema .Tool && msg .ToolCallID == toolCallID {
173+ return true
174+ }
175+ }
176+ return false
177+ }
178+
179+ func hasCorrespondingAgenticToolResult (messages []* schema.AgenticMessage , toolCallID string ) bool {
180+ for _ , msg := range messages {
181+ for _ , block := range msg .ContentBlocks {
182+ if block == nil {
183+ continue
184+ }
185+ if block .Type == schema .ContentBlockTypeFunctionToolResult &&
186+ block .FunctionToolResult != nil && block .FunctionToolResult .CallID == toolCallID {
187+ return true
188+ }
189+ if block .Type == schema .ContentBlockTypeToolSearchResult &&
190+ block .ToolSearchFunctionToolResult != nil && block .ToolSearchFunctionToolResult .CallID == toolCallID {
191+ return true
192+ }
193+ }
194+ }
195+ return false
196+ }
197+
168198func createPatchedToolMessage (ctx context.Context , gen func (ctx context.Context , toolName , toolCallID string ) (string , error ), tc schema.ToolCall ) (* schema.Message , error ) {
169199 if gen != nil {
170200 content , err := gen (ctx , tc .Function .Name , tc .ID )
@@ -211,13 +241,6 @@ func createPatchedAgenticToolMessage(ctx context.Context, gen func(ctx context.C
211241 }, nil
212242}
213243
214- func reverse [M adk.MessageType ](s []M ) []M {
215- for i , j := 0 , len (s )- 1 ; i < j ; i , j = i + 1 , j - 1 {
216- s [i ], s [j ] = s [j ], s [i ]
217- }
218- return s
219- }
220-
221244const (
222245 defaultPatchedToolMessageTemplate = "Tool call %s with id %s was canceled - another message came in before it could be completed."
223246 defaultPatchedToolMessageTemplateChinese = "工具调用 %s(ID 为 %s)已被取消——在其完成之前收到了另一条消息。"
0 commit comments