-
Notifications
You must be signed in to change notification settings - Fork 21.5k
Expand file tree
/
Copy pathtest_tool_node.py
More file actions
110 lines (82 loc) · 3.23 KB
/
test_tool_node.py
File metadata and controls
110 lines (82 loc) · 3.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Tests for the langchain ToolNode subclass (NotRequired state field handling)."""
from __future__ import annotations
from typing import Annotated
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, ToolCall
from langchain_core.tools import tool
from langgraph.prebuilt import InjectedState, ToolRuntime
from typing_extensions import NotRequired, TypedDict
from langchain.tools.tool_node import ToolNode
# -- helpers ----------------------------------------------------------------
class StateWithOptional(TypedDict):
messages: list[AIMessage]
city: NotRequired[str]
@tool
def get_weather(city: Annotated[str, InjectedState("city")]) -> str:
"""Get weather for a given city."""
return f"Sunny in {city}"
@tool
def get_full_state(state: Annotated[dict[str, object], InjectedState()]) -> str:
"""Tool that receives the full state."""
return str(state)
# -- tests ------------------------------------------------------------------
def test_inject_state_field_present() -> None:
"""InjectedState works normally when the referenced field IS in state."""
node = ToolNode(tools=[get_weather])
tc: ToolCall = {
"name": "get_weather",
"args": {},
"id": "call_1",
"type": "tool_call",
}
runtime = MagicMock(spec=ToolRuntime)
runtime.state = {"messages": [], "city": "Rome"}
result = node._inject_tool_args(tc, runtime)
assert result["args"]["city"] == "Rome"
def test_inject_state_not_required_field_absent() -> None:
"""InjectedState must not raise KeyError when a NotRequired field is absent.
This is the core regression test for
https://github.com/langchain-ai/langchain/issues/35585
"""
node = ToolNode(tools=[get_weather])
tc: ToolCall = {
"name": "get_weather",
"args": {},
"id": "call_2",
"type": "tool_call",
}
runtime = MagicMock(spec=ToolRuntime)
runtime.state = {"messages": []} # "city" is absent
# Before the fix this raised KeyError: 'city'
result = node._inject_tool_args(tc, runtime)
assert result["args"]["city"] is None
def test_inject_full_state_when_field_is_none() -> None:
"""When InjectedState() has no field, the entire state dict is injected."""
node = ToolNode(tools=[get_full_state])
tc: ToolCall = {
"name": "get_full_state",
"args": {},
"id": "call_3",
"type": "tool_call",
}
state_dict = {"messages": [AIMessage(content="hi", tool_calls=[])]}
runtime = MagicMock(spec=ToolRuntime)
runtime.state = state_dict
result = node._inject_tool_args(tc, runtime)
assert result["args"]["state"] is state_dict
def test_inject_state_object_attr_missing() -> None:
"""Handles missing attributes on non-dict state objects gracefully."""
class ObjState:
def __init__(self) -> None:
self.messages: list[AIMessage] = []
node = ToolNode(tools=[get_weather])
tc: ToolCall = {
"name": "get_weather",
"args": {},
"id": "call_4",
"type": "tool_call",
}
runtime = MagicMock(spec=ToolRuntime)
runtime.state = ObjState() # no 'city' attribute
result = node._inject_tool_args(tc, runtime)
assert result["args"]["city"] is None