From d631608af81fd9d064fa621acdbf8ee83f619ca9 Mon Sep 17 00:00:00 2001 From: Giulio Leone <6887247+giulio-leone@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:37:36 +0100 Subject: [PATCH 1/6] fix(langchain): handle NotRequired fields in InjectedState without KeyError When InjectedState references a TypedDict field marked as NotRequired, the upstream langgraph ToolNode accesses state[key] which raises KeyError if the field is absent from the runtime state. This introduces a ToolNode subclass in langchain that overrides _inject_tool_args to use .get() (dict state) and getattr with a default (object state) so that missing optional fields resolve to None instead of crashing. The factory's create_agent now uses this patched ToolNode automatically. Fixes #35585 --- libs/langchain_v1/langchain/agents/factory.py | 4 +- .../langchain_v1/langchain/tools/tool_node.py | 105 ++++++++++++++++- .../tests/unit_tests/tools/test_tool_node.py | 110 ++++++++++++++++++ 3 files changed, 217 insertions(+), 2 deletions(-) create mode 100644 libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 1d7e8d279be4c..38fa34551c58b 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -21,7 +21,9 @@ from langgraph._internal._runnable import RunnableCallable from langgraph.constants import END, START from langgraph.graph.state import StateGraph -from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode +from langgraph.prebuilt.tool_node import ToolCallWithContext + +from langchain.tools.tool_node import ToolNode from langgraph.types import Command, Send from langsmith import traceable from typing_extensions import NotRequired, Required, TypedDict diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 4474c8ba935fb..c9d509af432db 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -1,20 +1,123 @@ """Utils file included for backwards compat imports.""" +from __future__ import annotations + +from copy import copy +from typing import TYPE_CHECKING + from langgraph.prebuilt import InjectedState, InjectedStore, ToolRuntime from langgraph.prebuilt.tool_node import ( ToolCallRequest, ToolCallWithContext, ToolCallWrapper, + _get_all_injected_args, ) from langgraph.prebuilt.tool_node import ( - ToolNode as _ToolNode, # noqa: F401 + ToolNode as _ToolNode, ) +if TYPE_CHECKING: + from langchain_core.messages import ToolCall + from langchain_core.tools import BaseTool + + +class ToolNode(_ToolNode): + """ToolNode subclass that gracefully handles ``NotRequired`` state fields. + + Upstream ``langgraph.prebuilt.ToolNode._inject_tool_args`` accesses state + fields with ``state[field]`` which raises ``KeyError`` when the field is + declared as ``NotRequired`` in the state schema and is absent at runtime. + + This subclass overrides ``_inject_tool_args`` to use ``.get()`` (for dict + state) and ``getattr(…, None)`` (for object state) so that missing optional + fields resolve to ``None`` instead of crashing. + + See: https://github.com/langchain-ai/langchain/issues/35585 + """ + + def _inject_tool_args( + self, + tool_call: ToolCall, + tool_runtime: ToolRuntime, + tool: BaseTool | None = None, + ) -> ToolCall: + injected = self._injected_args.get(tool_call["name"]) + if not injected and tool is not None: + injected = _get_all_injected_args(tool) + if not injected: + return tool_call + + tool_call_copy: ToolCall = copy(tool_call) + injected_args: dict = {} + + # Inject state + if injected.state: + state = tool_runtime.state + # Handle list state by converting to dict + if isinstance(state, list): + required_fields = list(injected.state.values()) + if ( + len(required_fields) == 1 + and required_fields[0] == self._messages_key + ) or required_fields[0] is None: + state = {self._messages_key: state} + else: + err_msg = ( + f"Invalid input to ToolNode. " + f"Tool {tool_call['name']} requires " + f"graph state dict as input." + ) + if any( + state_field for state_field in injected.state.values() + ): + required_fields_str = ", ".join( + f for f in required_fields if f + ) + err_msg += ( + f" State should contain fields " + f"{required_fields_str}." + ) + raise ValueError(err_msg) + + # Extract state values — use .get() / getattr default so that + # NotRequired fields that are absent resolve to None (#35585). + if isinstance(state, dict): + for tool_arg, state_field in injected.state.items(): + injected_args[tool_arg] = ( + state.get(state_field) if state_field else state + ) + else: + for tool_arg, state_field in injected.state.items(): + injected_args[tool_arg] = ( + getattr(state, state_field, None) + if state_field + else state + ) + + # Inject store + if injected.store: + if tool_runtime.store is None: + msg = ( + "Cannot inject store into tools with InjectedStore " + "annotations - please compile your graph with a store." + ) + raise ValueError(msg) + injected_args[injected.store] = tool_runtime.store + + # Inject runtime + if injected.runtime: + injected_args[injected.runtime] = tool_runtime + + tool_call_copy["args"] = {**tool_call_copy["args"], **injected_args} + return tool_call_copy + + __all__ = [ "InjectedState", "InjectedStore", "ToolCallRequest", "ToolCallWithContext", "ToolCallWrapper", + "ToolNode", "ToolRuntime", ] diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py b/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py new file mode 100644 index 0000000000000..4bc5b104dffcb --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py @@ -0,0 +1,110 @@ +"""Tests for the langchain ToolNode subclass (NotRequired state field handling).""" + +from __future__ import annotations + +from typing import Annotated +from unittest.mock import MagicMock + +from langchain_core.messages import AIMessage, ToolCall +from langchain_core.tools import tool +from langgraph.prebuilt import InjectedState, ToolRuntime +from typing_extensions import NotRequired, TypedDict + +from langchain.tools.tool_node import ToolNode + + +# -- helpers ---------------------------------------------------------------- + + +class StateWithOptional(TypedDict): + messages: list + city: NotRequired[str] + + +@tool +def get_weather(city: Annotated[str, InjectedState("city")]) -> str: + """Get weather for a given city.""" + return f"Sunny in {city}" + + +@tool +def get_full_state(state: Annotated[dict, InjectedState()]) -> str: + """Tool that receives the full state.""" + return str(state) + + +# -- tests ------------------------------------------------------------------ + + +def test_inject_state_field_present() -> None: + """InjectedState works normally when the referenced field IS in state.""" + node = ToolNode(tools=[get_weather]) + tc: ToolCall = { + "name": "get_weather", + "args": {}, + "id": "call_1", + "type": "tool_call", + } + runtime = MagicMock(spec=ToolRuntime) + runtime.state = {"messages": [], "city": "Rome"} + + result = node._inject_tool_args(tc, runtime) + assert result["args"]["city"] == "Rome" + + +def test_inject_state_not_required_field_absent() -> None: + """InjectedState must not raise KeyError when a NotRequired field is absent. + + This is the core regression test for + https://github.com/langchain-ai/langchain/issues/35585 + """ + node = ToolNode(tools=[get_weather]) + tc: ToolCall = { + "name": "get_weather", + "args": {}, + "id": "call_2", + "type": "tool_call", + } + runtime = MagicMock(spec=ToolRuntime) + runtime.state = {"messages": []} # "city" is absent + + # Before the fix this raised KeyError: 'city' + result = node._inject_tool_args(tc, runtime) + assert result["args"]["city"] is None + + +def test_inject_full_state_when_field_is_none() -> None: + """When InjectedState() has no field, the entire state dict is injected.""" + node = ToolNode(tools=[get_full_state]) + tc: ToolCall = { + "name": "get_full_state", + "args": {}, + "id": "call_3", + "type": "tool_call", + } + state_dict = {"messages": [AIMessage(content="hi", tool_calls=[])]} + runtime = MagicMock(spec=ToolRuntime) + runtime.state = state_dict + + result = node._inject_tool_args(tc, runtime) + assert result["args"]["state"] is state_dict + + +def test_inject_state_object_attr_missing() -> None: + """Handles missing attributes on non-dict state objects gracefully.""" + + class ObjState: + messages: list = [] + + node = ToolNode(tools=[get_weather]) + tc: ToolCall = { + "name": "get_weather", + "args": {}, + "id": "call_4", + "type": "tool_call", + } + runtime = MagicMock(spec=ToolRuntime) + runtime.state = ObjState() # no 'city' attribute + + result = node._inject_tool_args(tc, runtime) + assert result["args"]["city"] is None From f8e4854935c9c64481b6d698ab824669332e6adc Mon Sep 17 00:00:00 2001 From: Giulio Leone <6887247+giulio-leone@users.noreply.github.com> Date: Mon, 9 Mar 2026 15:07:16 +0100 Subject: [PATCH 2/6] fix(langchain): sort imports in factory.py to fix ruff I001 lint --- libs/langchain_v1/langchain/agents/factory.py | 3 +-- .../langchain_v1/langchain/tools/tool_node.py | 24 +++++-------------- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 38fa34551c58b..132b8e9d56810 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -22,8 +22,6 @@ from langgraph.constants import END, START from langgraph.graph.state import StateGraph from langgraph.prebuilt.tool_node import ToolCallWithContext - -from langchain.tools.tool_node import ToolNode from langgraph.types import Command, Send from langsmith import traceable from typing_extensions import NotRequired, Required, TypedDict @@ -55,6 +53,7 @@ ToolStrategy, ) from langchain.chat_models import init_chat_model +from langchain.tools.tool_node import ToolNode @dataclass diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index c9d509af432db..89f87a9fc959a 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -57,8 +57,7 @@ def _inject_tool_args( if isinstance(state, list): required_fields = list(injected.state.values()) if ( - len(required_fields) == 1 - and required_fields[0] == self._messages_key + len(required_fields) == 1 and required_fields[0] == self._messages_key ) or required_fields[0] is None: state = {self._messages_key: state} else: @@ -67,31 +66,20 @@ def _inject_tool_args( f"Tool {tool_call['name']} requires " f"graph state dict as input." ) - if any( - state_field for state_field in injected.state.values() - ): - required_fields_str = ", ".join( - f for f in required_fields if f - ) - err_msg += ( - f" State should contain fields " - f"{required_fields_str}." - ) + if any(state_field for state_field in injected.state.values()): + required_fields_str = ", ".join(f for f in required_fields if f) + err_msg += f" State should contain fields {required_fields_str}." raise ValueError(err_msg) # Extract state values — use .get() / getattr default so that # NotRequired fields that are absent resolve to None (#35585). if isinstance(state, dict): for tool_arg, state_field in injected.state.items(): - injected_args[tool_arg] = ( - state.get(state_field) if state_field else state - ) + injected_args[tool_arg] = state.get(state_field) if state_field else state else: for tool_arg, state_field in injected.state.items(): injected_args[tool_arg] = ( - getattr(state, state_field, None) - if state_field - else state + getattr(state, state_field, None) if state_field else state ) # Inject store From 3a45240c867957707f76f46efdf2bba08f44f5ed Mon Sep 17 00:00:00 2001 From: Giulio Leone <6887247+giulio-leone@users.noreply.github.com> Date: Mon, 9 Mar 2026 17:12:01 +0100 Subject: [PATCH 3/6] fix: resolve mypy errors in tool_node.py - Add type parameters to generic dict (dict[str, Any]) - Annotate state as Any to prevent unreachable-code warnings from isinstance checks (StateT defaults to dict, hiding list/object branches) --- libs/langchain_v1/langchain/tools/tool_node.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 89f87a9fc959a..259a3203503ad 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -3,7 +3,7 @@ from __future__ import annotations from copy import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from langgraph.prebuilt import InjectedState, InjectedStore, ToolRuntime from langgraph.prebuilt.tool_node import ( @@ -48,11 +48,11 @@ def _inject_tool_args( return tool_call tool_call_copy: ToolCall = copy(tool_call) - injected_args: dict = {} + injected_args: dict[str, Any] = {} # Inject state if injected.state: - state = tool_runtime.state + state: Any = tool_runtime.state # Handle list state by converting to dict if isinstance(state, list): required_fields = list(injected.state.values()) From 6aea9d92db370a6a5b7a335051472e81c804415f Mon Sep 17 00:00:00 2001 From: giulio-leone Date: Tue, 10 Mar 2026 01:35:13 +0100 Subject: [PATCH 4/6] test: fix tool_node lint regression Resolve the failing langchain_v1 lint job on PR #35684 by auto-formatting the test import block and replacing the mutable class attribute used in ObjState with an instance attribute. Targeted validation in libs/langchain_v1/.venv now passes: ruff check, ruff format --check, and pytest tests/unit_tests/tools/test_tool_node.py -q. --- libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py b/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py index 4bc5b104dffcb..306fbadd1d132 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py @@ -12,7 +12,6 @@ from langchain.tools.tool_node import ToolNode - # -- helpers ---------------------------------------------------------------- @@ -94,7 +93,8 @@ def test_inject_state_object_attr_missing() -> None: """Handles missing attributes on non-dict state objects gracefully.""" class ObjState: - messages: list = [] + def __init__(self) -> None: + self.messages = [] node = ToolNode(tools=[get_weather]) tc: ToolCall = { From 67e9685d67915277e07b623210081d2bc7411c2f Mon Sep 17 00:00:00 2001 From: giulio-leone Date: Tue, 10 Mar 2026 02:12:55 +0100 Subject: [PATCH 5/6] Fix mypy errors in tool node tests --- libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py b/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py index 306fbadd1d132..5093dee5c06fe 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py @@ -16,7 +16,7 @@ class StateWithOptional(TypedDict): - messages: list + messages: list[AIMessage] city: NotRequired[str] @@ -27,7 +27,7 @@ def get_weather(city: Annotated[str, InjectedState("city")]) -> str: @tool -def get_full_state(state: Annotated[dict, InjectedState()]) -> str: +def get_full_state(state: Annotated[dict[str, object], InjectedState()]) -> str: """Tool that receives the full state.""" return str(state) @@ -94,7 +94,7 @@ def test_inject_state_object_attr_missing() -> None: class ObjState: def __init__(self) -> None: - self.messages = [] + self.messages: list[AIMessage] = [] node = ToolNode(tools=[get_weather]) tc: ToolCall = { From b9c7503d1c8b0344db4dd398603147397d7ac514 Mon Sep 17 00:00:00 2001 From: giulio-leone Date: Sat, 21 Mar 2026 16:03:48 +0100 Subject: [PATCH 6/6] refactor(tools): narrow tool node override Delegate to the upstream ToolNode implementation by default and only\nrecover the missing optional-state case from #35585. This keeps\nLangChain aligned with future langgraph changes while preserving\nthe NotRequired-state fix on the real create_agent path.\n\nCo-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../langchain_v1/langchain/tools/tool_node.py | 104 ++++++++---------- 1 file changed, 47 insertions(+), 57 deletions(-) diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 259a3203503ad..b3170751e7dc8 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -24,13 +24,11 @@ class ToolNode(_ToolNode): """ToolNode subclass that gracefully handles ``NotRequired`` state fields. - Upstream ``langgraph.prebuilt.ToolNode._inject_tool_args`` accesses state - fields with ``state[field]`` which raises ``KeyError`` when the field is - declared as ``NotRequired`` in the state schema and is absent at runtime. - - This subclass overrides ``_inject_tool_args`` to use ``.get()`` (for dict - state) and ``getattr(…, None)`` (for object state) so that missing optional - fields resolve to ``None`` instead of crashing. + Keep the override as narrow as possible: delegate to upstream + ``langgraph.prebuilt.ToolNode._inject_tool_args`` by default, and only + recover the specific case where an injected state field is optional and + absent at runtime. That keeps LangChain aligned with upstream ToolNode + changes while still fixing ``#35585``. See: https://github.com/langchain-ai/langchain/issues/35585 """ @@ -41,63 +39,55 @@ def _inject_tool_args( tool_runtime: ToolRuntime, tool: BaseTool | None = None, ) -> ToolCall: - injected = self._injected_args.get(tool_call["name"]) - if not injected and tool is not None: - injected = _get_all_injected_args(tool) - if not injected: - return tool_call - - tool_call_copy: ToolCall = copy(tool_call) - injected_args: dict[str, Any] = {} + try: + return super()._inject_tool_args(tool_call, tool_runtime, tool=tool) + except (KeyError, AttributeError) as err: + injected = self._injected_args.get(tool_call["name"]) + if not injected and tool is not None: + injected = _get_all_injected_args(tool) + if not injected or not injected.state: + raise - # Inject state - if injected.state: state: Any = tool_runtime.state - # Handle list state by converting to dict - if isinstance(state, list): - required_fields = list(injected.state.values()) - if ( - len(required_fields) == 1 and required_fields[0] == self._messages_key - ) or required_fields[0] is None: - state = {self._messages_key: state} - else: - err_msg = ( - f"Invalid input to ToolNode. " - f"Tool {tool_call['name']} requires " - f"graph state dict as input." - ) - if any(state_field for state_field in injected.state.values()): - required_fields_str = ", ".join(f for f in required_fields if f) - err_msg += f" State should contain fields {required_fields_str}." - raise ValueError(err_msg) - - # Extract state values — use .get() / getattr default so that - # NotRequired fields that are absent resolve to None (#35585). if isinstance(state, dict): - for tool_arg, state_field in injected.state.items(): - injected_args[tool_arg] = state.get(state_field) if state_field else state + missing_optional_state = any( + state_field and state_field not in state + for state_field in injected.state.values() + ) + if not missing_optional_state: + raise + injected_args: dict[str, Any] = { + tool_arg: state.get(state_field) if state_field else state + for tool_arg, state_field in injected.state.items() + } else: - for tool_arg, state_field in injected.state.items(): - injected_args[tool_arg] = ( - getattr(state, state_field, None) if state_field else state - ) - - # Inject store - if injected.store: - if tool_runtime.store is None: - msg = ( - "Cannot inject store into tools with InjectedStore " - "annotations - please compile your graph with a store." + missing_optional_state = any( + state_field and not hasattr(state, state_field) + for state_field in injected.state.values() ) - raise ValueError(msg) - injected_args[injected.store] = tool_runtime.store + if not missing_optional_state: + raise + injected_args = { + tool_arg: getattr(state, state_field, None) if state_field else state + for tool_arg, state_field in injected.state.items() + } + + tool_call_copy: ToolCall = copy(tool_call) + + if injected.store: + if tool_runtime.store is None: + msg = ( + "Cannot inject store into tools with InjectedStore " + "annotations - please compile your graph with a store." + ) + raise ValueError(msg) from err + injected_args[injected.store] = tool_runtime.store - # Inject runtime - if injected.runtime: - injected_args[injected.runtime] = tool_runtime + if injected.runtime: + injected_args[injected.runtime] = tool_runtime - tool_call_copy["args"] = {**tool_call_copy["args"], **injected_args} - return tool_call_copy + tool_call_copy["args"] = {**tool_call_copy["args"], **injected_args} + return tool_call_copy __all__ = [