Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from langgraph._internal._runnable import RunnableCallable
from langgraph.constants import END, START
from langgraph.graph.state import StateGraph
from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
from langgraph.prebuilt.tool_node import ToolCallWithContext
from langgraph.types import Command, Send
from typing_extensions import NotRequired, Required, TypedDict

Expand Down Expand Up @@ -51,6 +51,7 @@
ToolStrategy,
)
from langchain.chat_models import init_chat_model
from langchain.tools.tool_node import ToolNode


@dataclass
Expand Down
93 changes: 92 additions & 1 deletion libs/langchain_v1/langchain/tools/tool_node.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,111 @@
"""Utils file included for backwards compat imports."""

from __future__ import annotations

from copy import copy
from typing import TYPE_CHECKING, Any

from langgraph.prebuilt import InjectedState, InjectedStore, ToolRuntime
from langgraph.prebuilt.tool_node import (
ToolCallRequest,
ToolCallWithContext,
ToolCallWrapper,
_get_all_injected_args,
)
from langgraph.prebuilt.tool_node import (
ToolNode as _ToolNode, # noqa: F401
ToolNode as _ToolNode,
)

if TYPE_CHECKING:
from langchain_core.messages import ToolCall
from langchain_core.tools import BaseTool


class ToolNode(_ToolNode):
"""ToolNode subclass that gracefully handles ``NotRequired`` state fields.

Upstream ``langgraph.prebuilt.ToolNode._inject_tool_args`` accesses state
fields with ``state[field]`` which raises ``KeyError`` when the field is
declared as ``NotRequired`` in the state schema and is absent at runtime.

This subclass overrides ``_inject_tool_args`` to use ``.get()`` (for dict
state) and ``getattr(…, None)`` (for object state) so that missing optional
fields resolve to ``None`` instead of crashing.

See: https://github.com/langchain-ai/langchain/issues/35585
"""

def _inject_tool_args(
self,
tool_call: ToolCall,
tool_runtime: ToolRuntime,
tool: BaseTool | None = None,
) -> ToolCall:
injected = self._injected_args.get(tool_call["name"])
if not injected and tool is not None:
injected = _get_all_injected_args(tool)
if not injected:
return tool_call

tool_call_copy: ToolCall = copy(tool_call)
injected_args: dict[str, Any] = {}

# Inject state
if injected.state:
state: Any = tool_runtime.state
# Handle list state by converting to dict
if isinstance(state, list):
required_fields = list(injected.state.values())
if (
len(required_fields) == 1 and required_fields[0] == self._messages_key
) or required_fields[0] is None:
state = {self._messages_key: state}
else:
err_msg = (
f"Invalid input to ToolNode. "
f"Tool {tool_call['name']} requires "
f"graph state dict as input."
)
if any(state_field for state_field in injected.state.values()):
required_fields_str = ", ".join(f for f in required_fields if f)
err_msg += f" State should contain fields {required_fields_str}."
raise ValueError(err_msg)

# Extract state values — use .get() / getattr default so that
# NotRequired fields that are absent resolve to None (#35585).
if isinstance(state, dict):
for tool_arg, state_field in injected.state.items():
injected_args[tool_arg] = state.get(state_field) if state_field else state
else:
for tool_arg, state_field in injected.state.items():
injected_args[tool_arg] = (
getattr(state, state_field, None) if state_field else state
)

# Inject store
if injected.store:
if tool_runtime.store is None:
msg = (
"Cannot inject store into tools with InjectedStore "
"annotations - please compile your graph with a store."
)
raise ValueError(msg)
injected_args[injected.store] = tool_runtime.store

# Inject runtime
if injected.runtime:
injected_args[injected.runtime] = tool_runtime

tool_call_copy["args"] = {**tool_call_copy["args"], **injected_args}
return tool_call_copy


__all__ = [
"InjectedState",
"InjectedStore",
"ToolCallRequest",
"ToolCallWithContext",
"ToolCallWrapper",
"ToolNode",
"ToolRuntime",
]
110 changes: 110 additions & 0 deletions libs/langchain_v1/tests/unit_tests/tools/test_tool_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Tests for the langchain ToolNode subclass (NotRequired state field handling)."""

from __future__ import annotations

from typing import Annotated
from unittest.mock import MagicMock

from langchain_core.messages import AIMessage, ToolCall
from langchain_core.tools import tool
from langgraph.prebuilt import InjectedState, ToolRuntime
from typing_extensions import NotRequired, TypedDict

from langchain.tools.tool_node import ToolNode

# -- helpers ----------------------------------------------------------------


class StateWithOptional(TypedDict):
messages: list[AIMessage]
city: NotRequired[str]


@tool
def get_weather(city: Annotated[str, InjectedState("city")]) -> str:
"""Get weather for a given city."""
return f"Sunny in {city}"


@tool
def get_full_state(state: Annotated[dict[str, object], InjectedState()]) -> str:
"""Tool that receives the full state."""
return str(state)


# -- tests ------------------------------------------------------------------


def test_inject_state_field_present() -> None:
"""InjectedState works normally when the referenced field IS in state."""
node = ToolNode(tools=[get_weather])
tc: ToolCall = {
"name": "get_weather",
"args": {},
"id": "call_1",
"type": "tool_call",
}
runtime = MagicMock(spec=ToolRuntime)
runtime.state = {"messages": [], "city": "Rome"}

result = node._inject_tool_args(tc, runtime)
assert result["args"]["city"] == "Rome"


def test_inject_state_not_required_field_absent() -> None:
"""InjectedState must not raise KeyError when a NotRequired field is absent.

This is the core regression test for
https://github.com/langchain-ai/langchain/issues/35585
"""
node = ToolNode(tools=[get_weather])
tc: ToolCall = {
"name": "get_weather",
"args": {},
"id": "call_2",
"type": "tool_call",
}
runtime = MagicMock(spec=ToolRuntime)
runtime.state = {"messages": []} # "city" is absent

# Before the fix this raised KeyError: 'city'
result = node._inject_tool_args(tc, runtime)
assert result["args"]["city"] is None


def test_inject_full_state_when_field_is_none() -> None:
"""When InjectedState() has no field, the entire state dict is injected."""
node = ToolNode(tools=[get_full_state])
tc: ToolCall = {
"name": "get_full_state",
"args": {},
"id": "call_3",
"type": "tool_call",
}
state_dict = {"messages": [AIMessage(content="hi", tool_calls=[])]}
runtime = MagicMock(spec=ToolRuntime)
runtime.state = state_dict

result = node._inject_tool_args(tc, runtime)
assert result["args"]["state"] is state_dict


def test_inject_state_object_attr_missing() -> None:
"""Handles missing attributes on non-dict state objects gracefully."""

class ObjState:
def __init__(self) -> None:
self.messages: list[AIMessage] = []

node = ToolNode(tools=[get_weather])
tc: ToolCall = {
"name": "get_weather",
"args": {},
"id": "call_4",
"type": "tool_call",
}
runtime = MagicMock(spec=ToolRuntime)
runtime.state = ObjState() # no 'city' attribute

result = node._inject_tool_args(tc, runtime)
assert result["args"]["city"] is None
Loading