Skip to content

Commit b75573e

Browse files
ianchivbarda
andauthored
core: add tool_call exclusion in filter_message (#30289)
Extend functionallity to allow to filter pairs of tool calls (ai + tool). --------- Co-authored-by: vbarda <[email protected]>
1 parent 673ec00 commit b75573e

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

libs/core/langchain_core/messages/utils.py

+45
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def filter_messages(
404404
exclude_types: Optional[Sequence[Union[str, type[BaseMessage]]]] = None,
405405
include_ids: Optional[Sequence[str]] = None,
406406
exclude_ids: Optional[Sequence[str]] = None,
407+
exclude_tool_calls: Optional[Sequence[str] | bool] = None,
407408
) -> list[BaseMessage]:
408409
"""Filter messages based on name, type or id.
409410
@@ -419,6 +420,13 @@ def filter_messages(
419420
SystemMessage, HumanMessage, AIMessage, ...). Default is None.
420421
include_ids: Message IDs to include. Default is None.
421422
exclude_ids: Message IDs to exclude. Default is None.
423+
exclude_tool_calls: Tool call IDs to exclude. Default is None.
424+
Can be one of the following:
425+
- `True`: all AIMessages with tool calls and all ToolMessages will be excluded.
426+
- a sequence of tool call IDs to exclude:
427+
- ToolMessages with the corresponding tool call ID will be excluded.
428+
- The `tool_calls` in the AIMessage will be updated to exclude matching tool calls.
429+
If all tool_calls are filtered from an AIMessage, the whole message is excluded.
422430
423431
Returns:
424432
A list of Messages that meets at least one of the incl_* conditions and none
@@ -467,6 +475,43 @@ def filter_messages(
467475
else:
468476
pass
469477

478+
if exclude_tool_calls is True and (
479+
(isinstance(msg, AIMessage) and msg.tool_calls)
480+
or isinstance(msg, ToolMessage)
481+
):
482+
continue
483+
484+
if isinstance(exclude_tool_calls, (list, tuple, set)):
485+
if isinstance(msg, AIMessage) and msg.tool_calls:
486+
tool_calls = [
487+
tool_call
488+
for tool_call in msg.tool_calls
489+
if tool_call["id"] not in exclude_tool_calls
490+
]
491+
if not tool_calls:
492+
continue
493+
494+
content = msg.content
495+
# handle Anthropic content blocks
496+
if isinstance(msg.content, list):
497+
content = [
498+
content_block
499+
for content_block in msg.content
500+
if (
501+
not isinstance(content_block, dict)
502+
or content_block.get("type") != "tool_use"
503+
or content_block.get("id") not in exclude_tool_calls
504+
)
505+
]
506+
507+
msg = msg.model_copy(
508+
update={"tool_calls": tool_calls, "content": content}
509+
)
510+
elif (
511+
isinstance(msg, ToolMessage) and msg.tool_call_id in exclude_tool_calls
512+
):
513+
continue
514+
470515
# default to inclusion when no inclusion criteria given.
471516
if (
472517
not (include_types or include_ids or include_names)

libs/core/tests/unit_tests/messages/test_utils.py

+88
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,94 @@ def test_filter_message(filters: dict) -> None:
165165
assert messages == messages_model_copy
166166

167167

168+
def test_filter_message_exclude_tool_calls() -> None:
169+
tool_calls = [
170+
{"name": "foo", "id": "1", "args": {}, "type": "tool_call"},
171+
{"name": "bar", "id": "2", "args": {}, "type": "tool_call"},
172+
]
173+
messages = [
174+
HumanMessage("foo", name="blah", id="1"),
175+
AIMessage("foo-response", name="blah", id="2"),
176+
HumanMessage("bar", name="blur", id="3"),
177+
AIMessage(
178+
"bar-response",
179+
tool_calls=tool_calls,
180+
id="4",
181+
),
182+
ToolMessage("baz", tool_call_id="1", id="5"),
183+
ToolMessage("qux", tool_call_id="2", id="6"),
184+
]
185+
messages_model_copy = [m.model_copy(deep=True) for m in messages]
186+
expected = messages[:3]
187+
188+
# test excluding all tool calls
189+
actual = filter_messages(messages, exclude_tool_calls=True)
190+
assert expected == actual
191+
192+
# test explicitly excluding all tool calls
193+
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
194+
assert expected == actual
195+
196+
# test excluding a specific tool call
197+
expected = messages[:5]
198+
expected[3] = expected[3].model_copy(update={"tool_calls": [tool_calls[0]]})
199+
actual = filter_messages(messages, exclude_tool_calls=["2"])
200+
assert expected == actual
201+
202+
# assert that we didn't mutate the original messages
203+
assert messages == messages_model_copy
204+
205+
206+
def test_filter_message_exclude_tool_calls_content_blocks() -> None:
207+
tool_calls = [
208+
{"name": "foo", "id": "1", "args": {}, "type": "tool_call"},
209+
{"name": "bar", "id": "2", "args": {}, "type": "tool_call"},
210+
]
211+
messages = [
212+
HumanMessage("foo", name="blah", id="1"),
213+
AIMessage("foo-response", name="blah", id="2"),
214+
HumanMessage("bar", name="blur", id="3"),
215+
AIMessage(
216+
[
217+
{"text": "bar-response", "type": "text"},
218+
{"name": "foo", "type": "tool_use", "id": "1"},
219+
{"name": "bar", "type": "tool_use", "id": "2"},
220+
],
221+
tool_calls=tool_calls,
222+
id="4",
223+
),
224+
ToolMessage("baz", tool_call_id="1", id="5"),
225+
ToolMessage("qux", tool_call_id="2", id="6"),
226+
]
227+
messages_model_copy = [m.model_copy(deep=True) for m in messages]
228+
expected = messages[:3]
229+
230+
# test excluding all tool calls
231+
actual = filter_messages(messages, exclude_tool_calls=True)
232+
assert expected == actual
233+
234+
# test explicitly excluding all tool calls
235+
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
236+
assert expected == actual
237+
238+
# test excluding a specific tool call
239+
expected = messages[:4] + messages[-1:]
240+
expected[3] = expected[3].model_copy(
241+
update={
242+
"tool_calls": [tool_calls[1]],
243+
"content": [
244+
{"text": "bar-response", "type": "text"},
245+
{"name": "bar", "type": "tool_use", "id": "2"},
246+
],
247+
}
248+
)
249+
actual = filter_messages(messages, exclude_tool_calls=["1"])
250+
assert expected == actual
251+
252+
# assert that we didn't mutate the original messages
253+
assert messages == messages_model_copy
254+
255+
168256
_MESSAGES_TO_TRIM = [
169257
SystemMessage("This is a 4 token text."),
170258
HumanMessage("This is a 4 token text.", id="first"),

0 commit comments

Comments
 (0)