Skip to content

Commit 1d335e2

Browse files
committed
add additional_tools arg
1 parent 3e8a9c7 commit 1d335e2

File tree

2 files changed

+36
-39
lines changed

2 files changed

+36
-39
lines changed

dspy/predict/react.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import logging
2-
import inspect
3-
from copy import deepcopy
4-
from typing import Any, Callable, Literal
2+
from typing import Any, Callable, Optional
53

64
from litellm import ContextWindowExceededError
75

@@ -22,8 +20,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
2220
self.signature = signature = ensure_signature(signature)
2321
self.max_iters = max_iters
2422

25-
tools = [t if isinstance(t, Tool) else Tool(t) for t in tools]
26-
tools = {tool.name: tool for tool in tools}
23+
tools = self._convert_tools(tools)
2724

2825
inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()])
2926
outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()])
@@ -36,7 +33,6 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
3633
"To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.",
3734
"After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n",
3835
"When writing next_thought, you may reason about the current situation and plan for future steps.",
39-
"When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n",
4036
]
4137
)
4238

@@ -47,14 +43,12 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
4743
args={},
4844
)
4945

50-
for idx, tool in enumerate(tools.values()):
51-
instr.append(f"({idx + 1}) {tool}")
52-
5346
react_signature = (
5447
dspy.Signature({**signature.input_fields}, "\n".join(instr))
5548
.append("trajectory", dspy.InputField(), type_=str)
49+
.append("tools", dspy.InputField(desc="Tools you select from when selecting the next_tool_name and its next_tool_args"), type_=list[str])
5650
.append("next_thought", dspy.OutputField(), type_=str)
57-
.append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())])
51+
.append("next_tool_name", dspy.OutputField(), type_=str)
5852
.append("next_tool_args", dspy.OutputField(), type_=dict[str, Any])
5953
)
6054

@@ -67,18 +61,18 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
6761
self.react = dspy.Predict(react_signature)
6862
self.extract = dspy.ChainOfThought(fallback_signature)
6963

70-
def _format_trajectory(self, trajectory: dict[str, Any]):
71-
adapter = dspy.settings.adapter or dspy.ChatAdapter()
72-
trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x")
73-
return adapter.format_user_message_content(trajectory_signature, trajectory)
74-
75-
def forward(self, **input_args):
64+
def forward(self, additional_tools: Optional[list[Callable]] = None, **input_args):
7665
trajectory = {}
7766
max_iters = input_args.pop("max_iters", self.max_iters)
78-
tools = self._copy_tools(self.tools)
67+
tools = self.tools | self._convert_tools(additional_tools)
7968
for idx in range(max_iters):
8069
try:
81-
pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
70+
pred = self._call_with_potential_trajectory_truncation(
71+
self.react,
72+
trajectory,
73+
tools=self._format_tools_string(tools),
74+
**input_args
75+
)
8276
except ValueError as err:
8377
logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}")
8478
break
@@ -98,13 +92,18 @@ def forward(self, **input_args):
9892
extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
9993
return dspy.Prediction(trajectory=trajectory, **extract)
10094

101-
async def aforward(self, **input_args):
95+
async def aforward(self, additional_tools: Optional[list[Callable]] = None, **input_args):
10296
trajectory = {}
10397
max_iters = input_args.pop("max_iters", self.max_iters)
104-
tools = self._copy_tools(self.tools)
98+
tools = self.tools | self._convert_tools(additional_tools)
10599
for idx in range(max_iters):
106100
try:
107-
pred = await self._async_call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
101+
pred = await self._async_call_with_potential_trajectory_truncation(
102+
self.react,
103+
trajectory,
104+
tools=self._format_tools_string(tools),
105+
**input_args
106+
)
108107
except ValueError as err:
109108
logger.warning(f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}")
110109
break
@@ -164,18 +163,19 @@ def truncate_trajectory(self, trajectory):
164163

165164
return trajectory
166165

167-
def _copy_tools(self, tools):
168-
results = tools.copy()
169-
for tool_name, tool in tools.items():
170-
if inspect.isfunction(tool.func):
171-
results[tool_name] = tool
172-
else:
173-
try:
174-
results[tool_name] = deepcopy(tool)
175-
except Exception:
176-
logger.warning(f"Failed to deepcopy tool: {tool!r}. Consider making your tool deep-copyable "
177-
"if it needs to manage internal state. Error: {e}.")
178-
return results
166+
def _format_trajectory(self, trajectory: dict[str, Any]):
167+
adapter = dspy.settings.adapter or dspy.ChatAdapter()
168+
trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x")
169+
return adapter.format_user_message_content(trajectory_signature, trajectory)
170+
171+
def _convert_tools(self, tools: Optional[list[Callable]]) -> dict[str, Tool]:
172+
"""Convert the tools to a dictionary of name -> tool."""
173+
tools = [t if isinstance(t, Tool) else Tool(t) for t in tools or []]
174+
return {tool.name: tool for tool in tools}
175+
176+
def _format_tools_string(self, tools: dict[str, Tool]) -> list[str]:
177+
"""Format the tools into a list of string."""
178+
return [f"({idx + 1}) {tool}" for idx, tool in enumerate(tools.values())]
179179

180180

181181
def _fmt_exc(err: BaseException, *, limit: int = 5) -> str:

tests/predict/test_react.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,8 @@ def __init__(self):
343343
def __call__(self, i: int):
344344
self.count += i
345345
return self.count
346-
347-
counter = Counter()
348346

349-
react = dspy.ReAct("max -> sum:int", tools=[counter])
347+
react = dspy.ReAct("max -> sum:int", tools=[])
350348
lm = DummyLM(
351349
[
352350
{"next_thought": "I need to add 1", "next_tool_name": "Counter", "next_tool_args": {"i": 1}},
@@ -358,11 +356,10 @@ def __call__(self, i: int):
358356
)
359357
dspy.settings.configure(lm=lm)
360358

361-
outputs = react(call_count=3)
359+
# Pass the tool object as an additional tool
360+
outputs = react(call_count=3, additional_tools=[Counter()])
362361

363362
assert outputs.sum == 6
364-
# Verify the state is not changed after the forward call
365-
assert counter.count == 0
366363

367364
# Check the state is managed during the forward call
368365
expected_trajectory = {

0 commit comments

Comments
 (0)