diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 1d7e8d279be4c..132b8e9d56810 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -21,7 +21,7 @@ from langgraph._internal._runnable import RunnableCallable from langgraph.constants import END, START from langgraph.graph.state import StateGraph -from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode +from langgraph.prebuilt.tool_node import ToolCallWithContext from langgraph.types import Command, Send from langsmith import traceable from typing_extensions import NotRequired, Required, TypedDict @@ -53,6 +53,7 @@ ToolStrategy, ) from langchain.chat_models import init_chat_model +from langchain.tools.tool_node import ToolNode @dataclass diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 4474c8ba935fb..b3170751e7dc8 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -1,20 +1,101 @@ """Utils file included for backwards compat imports.""" +from __future__ import annotations + +from copy import copy +from typing import TYPE_CHECKING, Any + from langgraph.prebuilt import InjectedState, InjectedStore, ToolRuntime from langgraph.prebuilt.tool_node import ( ToolCallRequest, ToolCallWithContext, ToolCallWrapper, + _get_all_injected_args, ) from langgraph.prebuilt.tool_node import ( - ToolNode as _ToolNode, # noqa: F401 + ToolNode as _ToolNode, ) +if TYPE_CHECKING: + from langchain_core.messages import ToolCall + from langchain_core.tools import BaseTool + + +class ToolNode(_ToolNode): + """ToolNode subclass that gracefully handles ``NotRequired`` state fields. + + Keep the override as narrow as possible: delegate to upstream + ``langgraph.prebuilt.ToolNode._inject_tool_args`` by default, and only + recover the specific case where an injected state field is optional and + absent at runtime. That keeps LangChain aligned with upstream ToolNode + changes while still fixing ``#35585``. + + See: https://github.com/langchain-ai/langchain/issues/35585 + """ + + def _inject_tool_args( + self, + tool_call: ToolCall, + tool_runtime: ToolRuntime, + tool: BaseTool | None = None, + ) -> ToolCall: + try: + return super()._inject_tool_args(tool_call, tool_runtime, tool=tool) + except (KeyError, AttributeError) as err: + injected = self._injected_args.get(tool_call["name"]) + if not injected and tool is not None: + injected = _get_all_injected_args(tool) + if not injected or not injected.state: + raise + + state: Any = tool_runtime.state + if isinstance(state, dict): + missing_optional_state = any( + state_field and state_field not in state + for state_field in injected.state.values() + ) + if not missing_optional_state: + raise + injected_args: dict[str, Any] = { + tool_arg: state.get(state_field) if state_field else state + for tool_arg, state_field in injected.state.items() + } + else: + missing_optional_state = any( + state_field and not hasattr(state, state_field) + for state_field in injected.state.values() + ) + if not missing_optional_state: + raise + injected_args = { + tool_arg: getattr(state, state_field, None) if state_field else state + for tool_arg, state_field in injected.state.items() + } + + tool_call_copy: ToolCall = copy(tool_call) + + if injected.store: + if tool_runtime.store is None: + msg = ( + "Cannot inject store into tools with InjectedStore " + "annotations - please compile your graph with a store." + ) + raise ValueError(msg) from err + injected_args[injected.store] = tool_runtime.store + + if injected.runtime: + injected_args[injected.runtime] = tool_runtime + + tool_call_copy["args"] = {**tool_call_copy["args"], **injected_args} + return tool_call_copy + + __all__ = [ "InjectedState", "InjectedStore", "ToolCallRequest", "ToolCallWithContext", "ToolCallWrapper", + "ToolNode", "ToolRuntime", ] diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py b/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py new file mode 100644 index 0000000000000..5093dee5c06fe --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py @@ -0,0 +1,110 @@ +"""Tests for the langchain ToolNode subclass (NotRequired state field handling).""" + +from __future__ import annotations + +from typing import Annotated +from unittest.mock import MagicMock + +from langchain_core.messages import AIMessage, ToolCall +from langchain_core.tools import tool +from langgraph.prebuilt import InjectedState, ToolRuntime +from typing_extensions import NotRequired, TypedDict + +from langchain.tools.tool_node import ToolNode + +# -- helpers ---------------------------------------------------------------- + + +class StateWithOptional(TypedDict): + messages: list[AIMessage] + city: NotRequired[str] + + +@tool +def get_weather(city: Annotated[str, InjectedState("city")]) -> str: + """Get weather for a given city.""" + return f"Sunny in {city}" + + +@tool +def get_full_state(state: Annotated[dict[str, object], InjectedState()]) -> str: + """Tool that receives the full state.""" + return str(state) + + +# -- tests ------------------------------------------------------------------ + + +def test_inject_state_field_present() -> None: + """InjectedState works normally when the referenced field IS in state.""" + node = ToolNode(tools=[get_weather]) + tc: ToolCall = { + "name": "get_weather", + "args": {}, + "id": "call_1", + "type": "tool_call", + } + runtime = MagicMock(spec=ToolRuntime) + runtime.state = {"messages": [], "city": "Rome"} + + result = node._inject_tool_args(tc, runtime) + assert result["args"]["city"] == "Rome" + + +def test_inject_state_not_required_field_absent() -> None: + """InjectedState must not raise KeyError when a NotRequired field is absent. + + This is the core regression test for + https://github.com/langchain-ai/langchain/issues/35585 + """ + node = ToolNode(tools=[get_weather]) + tc: ToolCall = { + "name": "get_weather", + "args": {}, + "id": "call_2", + "type": "tool_call", + } + runtime = MagicMock(spec=ToolRuntime) + runtime.state = {"messages": []} # "city" is absent + + # Before the fix this raised KeyError: 'city' + result = node._inject_tool_args(tc, runtime) + assert result["args"]["city"] is None + + +def test_inject_full_state_when_field_is_none() -> None: + """When InjectedState() has no field, the entire state dict is injected.""" + node = ToolNode(tools=[get_full_state]) + tc: ToolCall = { + "name": "get_full_state", + "args": {}, + "id": "call_3", + "type": "tool_call", + } + state_dict = {"messages": [AIMessage(content="hi", tool_calls=[])]} + runtime = MagicMock(spec=ToolRuntime) + runtime.state = state_dict + + result = node._inject_tool_args(tc, runtime) + assert result["args"]["state"] is state_dict + + +def test_inject_state_object_attr_missing() -> None: + """Handles missing attributes on non-dict state objects gracefully.""" + + class ObjState: + def __init__(self) -> None: + self.messages: list[AIMessage] = [] + + node = ToolNode(tools=[get_weather]) + tc: ToolCall = { + "name": "get_weather", + "args": {}, + "id": "call_4", + "type": "tool_call", + } + runtime = MagicMock(spec=ToolRuntime) + runtime.state = ObjState() # no 'city' attribute + + result = node._inject_tool_args(tc, runtime) + assert result["args"]["city"] is None