@@ -165,6 +165,94 @@ def test_filter_message(filters: dict) -> None:
165
165
assert messages == messages_model_copy
166
166
167
167
168
+ def test_filter_message_exclude_tool_calls () -> None :
169
+ tool_calls = [
170
+ {"name" : "foo" , "id" : "1" , "args" : {}, "type" : "tool_call" },
171
+ {"name" : "bar" , "id" : "2" , "args" : {}, "type" : "tool_call" },
172
+ ]
173
+ messages = [
174
+ HumanMessage ("foo" , name = "blah" , id = "1" ),
175
+ AIMessage ("foo-response" , name = "blah" , id = "2" ),
176
+ HumanMessage ("bar" , name = "blur" , id = "3" ),
177
+ AIMessage (
178
+ "bar-response" ,
179
+ tool_calls = tool_calls ,
180
+ id = "4" ,
181
+ ),
182
+ ToolMessage ("baz" , tool_call_id = "1" , id = "5" ),
183
+ ToolMessage ("qux" , tool_call_id = "2" , id = "6" ),
184
+ ]
185
+ messages_model_copy = [m .model_copy (deep = True ) for m in messages ]
186
+ expected = messages [:3 ]
187
+
188
+ # test excluding all tool calls
189
+ actual = filter_messages (messages , exclude_tool_calls = True )
190
+ assert expected == actual
191
+
192
+ # test explicitly excluding all tool calls
193
+ actual = filter_messages (messages , exclude_tool_calls = {"1" , "2" })
194
+ assert expected == actual
195
+
196
+ # test excluding a specific tool call
197
+ expected = messages [:5 ]
198
+ expected [3 ] = expected [3 ].model_copy (update = {"tool_calls" : [tool_calls [0 ]]})
199
+ actual = filter_messages (messages , exclude_tool_calls = ["2" ])
200
+ assert expected == actual
201
+
202
+ # assert that we didn't mutate the original messages
203
+ assert messages == messages_model_copy
204
+
205
+
206
+ def test_filter_message_exclude_tool_calls_content_blocks () -> None :
207
+ tool_calls = [
208
+ {"name" : "foo" , "id" : "1" , "args" : {}, "type" : "tool_call" },
209
+ {"name" : "bar" , "id" : "2" , "args" : {}, "type" : "tool_call" },
210
+ ]
211
+ messages = [
212
+ HumanMessage ("foo" , name = "blah" , id = "1" ),
213
+ AIMessage ("foo-response" , name = "blah" , id = "2" ),
214
+ HumanMessage ("bar" , name = "blur" , id = "3" ),
215
+ AIMessage (
216
+ [
217
+ {"text" : "bar-response" , "type" : "text" },
218
+ {"name" : "foo" , "type" : "tool_use" , "id" : "1" },
219
+ {"name" : "bar" , "type" : "tool_use" , "id" : "2" },
220
+ ],
221
+ tool_calls = tool_calls ,
222
+ id = "4" ,
223
+ ),
224
+ ToolMessage ("baz" , tool_call_id = "1" , id = "5" ),
225
+ ToolMessage ("qux" , tool_call_id = "2" , id = "6" ),
226
+ ]
227
+ messages_model_copy = [m .model_copy (deep = True ) for m in messages ]
228
+ expected = messages [:3 ]
229
+
230
+ # test excluding all tool calls
231
+ actual = filter_messages (messages , exclude_tool_calls = True )
232
+ assert expected == actual
233
+
234
+ # test explicitly excluding all tool calls
235
+ actual = filter_messages (messages , exclude_tool_calls = {"1" , "2" })
236
+ assert expected == actual
237
+
238
+ # test excluding a specific tool call
239
+ expected = messages [:4 ] + messages [- 1 :]
240
+ expected [3 ] = expected [3 ].model_copy (
241
+ update = {
242
+ "tool_calls" : [tool_calls [1 ]],
243
+ "content" : [
244
+ {"text" : "bar-response" , "type" : "text" },
245
+ {"name" : "bar" , "type" : "tool_use" , "id" : "2" },
246
+ ],
247
+ }
248
+ )
249
+ actual = filter_messages (messages , exclude_tool_calls = ["1" ])
250
+ assert expected == actual
251
+
252
+ # assert that we didn't mutate the original messages
253
+ assert messages == messages_model_copy
254
+
255
+
168
256
_MESSAGES_TO_TRIM = [
169
257
SystemMessage ("This is a 4 token text." ),
170
258
HumanMessage ("This is a 4 token text." , id = "first" ),
0 commit comments