Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def _consolidate_calls(items: Iterable[dict[str, Any]]) -> Iterator[dict[str, An
pass
collapsed["type"] = "web_search_call"

if current.get("name") == "file_search":
elif current.get("name") == "file_search":
collapsed = {"id": current["id"]}
if "args" in current and "queries" in current["args"]:
collapsed["queries"] = current["args"]["queries"]
Expand Down Expand Up @@ -392,7 +392,10 @@ def _consolidate_calls(items: Iterable[dict[str, Any]]) -> Iterator[dict[str, An
if k not in ("server_label", "error"):
collapsed[k] = v
else:
pass
# Unrecognized server tool name — emit both items unchanged
yield current
yield nxt
continue

yield collapsed

Expand Down
79 changes: 79 additions & 0 deletions libs/partners/openai/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models._compat import (
_FUNCTION_CALL_IDS_MAP_KEY,
_consolidate_calls,
_convert_from_v1_to_chat_completions,
_convert_from_v1_to_responses,
_convert_to_v03_ai_message,
Expand Down Expand Up @@ -2882,6 +2883,84 @@ def test_convert_from_v1_to_responses(
assert message_v1 != result


class TestConsolidateCalls:
"""Tests for _consolidate_calls in _compat.py."""

def test_web_search_collapses(self) -> None:
items = [
{"type": "server_tool_call", "name": "web_search", "id": "call_1"},
{
"type": "server_tool_result",
"tool_call_id": "call_1",
"status": "success",
},
]
result = list(_consolidate_calls(items))
assert len(result) == 1
assert result[0]["type"] == "web_search_call"
assert result[0]["status"] == "completed"
assert result[0]["id"] == "call_1"

def test_file_search_collapses(self) -> None:
items: list[dict[str, Any]] = [
{
"type": "server_tool_call",
"name": "file_search",
"id": "call_2",
"args": {"queries": ["test"]},
},
{
"type": "server_tool_result",
"tool_call_id": "call_2",
"status": "success",
"output": [{"text": "found"}],
},
]
result = list(_consolidate_calls(items))
assert len(result) == 1
assert result[0]["type"] == "file_search_call"
assert result[0]["queries"] == ["test"]
assert result[0]["results"] == [{"text": "found"}]

def test_unrecognized_tool_name_passes_through(self) -> None:
"""Unrecognized server tool names should not raise UnboundLocalError."""
items = [
{"type": "server_tool_call", "name": "computer", "id": "call_3"},
{
"type": "server_tool_result",
"tool_call_id": "call_3",
"output": "result",
},
]
result = list(_consolidate_calls(items))
assert len(result) == 2
assert result[0]["type"] == "server_tool_call"
assert result[1]["type"] == "server_tool_result"

def test_non_server_tool_passes_through(self) -> None:
items = [
{"type": "text", "text": "hello"},
{"type": "function_call", "name": "foo", "call_id": "c1"},
]
result = list(_consolidate_calls(items))
assert len(result) == 2

def test_unmatched_pair_emits_both(self) -> None:
"""server_tool_call followed by non-matching result emits both."""
items = [
{"type": "server_tool_call", "name": "web_search", "id": "call_a"},
{
"type": "server_tool_result",
"tool_call_id": "call_b",
"status": "success",
},
]
result = list(_consolidate_calls(items))
assert len(result) == 2
assert result[0]["type"] == "server_tool_call"
assert result[1]["type"] == "server_tool_result"


def test_get_last_messages() -> None:
messages: list[BaseMessage] = [HumanMessage("Hello")]
last_messages, previous_response_id = _get_last_messages(messages)
Expand Down
Loading