1
1
import logging
2
- import inspect
3
- from copy import deepcopy
4
- from typing import Any , Callable , Literal
2
+ from typing import Any , Callable , Optional
5
3
6
4
from litellm import ContextWindowExceededError
7
5
@@ -22,8 +20,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
22
20
self .signature = signature = ensure_signature (signature )
23
21
self .max_iters = max_iters
24
22
25
- tools = [t if isinstance (t , Tool ) else Tool (t ) for t in tools ]
26
- tools = {tool .name : tool for tool in tools }
23
+ tools = self ._convert_tools (tools )
27
24
28
25
inputs = ", " .join ([f"`{ k } `" for k in signature .input_fields .keys ()])
29
26
outputs = ", " .join ([f"`{ k } `" for k in signature .output_fields .keys ()])
@@ -36,7 +33,6 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
36
33
"To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task." ,
37
34
"After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n " ,
38
35
"When writing next_thought, you may reason about the current situation and plan for future steps." ,
39
- "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n " ,
40
36
]
41
37
)
42
38
@@ -47,14 +43,12 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
47
43
args = {},
48
44
)
49
45
50
- for idx , tool in enumerate (tools .values ()):
51
- instr .append (f"({ idx + 1 } ) { tool } " )
52
-
53
46
react_signature = (
54
47
dspy .Signature ({** signature .input_fields }, "\n " .join (instr ))
55
48
.append ("trajectory" , dspy .InputField (), type_ = str )
49
+ .append ("tools" , dspy .InputField (desc = "Tools you select from when selecting the next_tool_name and its next_tool_args" ), type_ = list [str ])
56
50
.append ("next_thought" , dspy .OutputField (), type_ = str )
57
- .append ("next_tool_name" , dspy .OutputField (), type_ = Literal [ tuple ( tools . keys ())] )
51
+ .append ("next_tool_name" , dspy .OutputField (), type_ = str )
58
52
.append ("next_tool_args" , dspy .OutputField (), type_ = dict [str , Any ])
59
53
)
60
54
@@ -67,18 +61,18 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
67
61
self .react = dspy .Predict (react_signature )
68
62
self .extract = dspy .ChainOfThought (fallback_signature )
69
63
70
- def _format_trajectory (self , trajectory : dict [str , Any ]):
71
- adapter = dspy .settings .adapter or dspy .ChatAdapter ()
72
- trajectory_signature = dspy .Signature (f"{ ', ' .join (trajectory .keys ())} -> x" )
73
- return adapter .format_user_message_content (trajectory_signature , trajectory )
74
-
75
- def forward (self , ** input_args ):
64
+ def forward (self , additional_tools : Optional [list [Callable ]] = None , ** input_args ):
76
65
trajectory = {}
77
66
max_iters = input_args .pop ("max_iters" , self .max_iters )
78
- tools = self ._copy_tools ( self .tools )
67
+ tools = self .tools | self ._convert_tools ( additional_tools )
79
68
for idx in range (max_iters ):
80
69
try :
81
- pred = self ._call_with_potential_trajectory_truncation (self .react , trajectory , ** input_args )
70
+ pred = self ._call_with_potential_trajectory_truncation (
71
+ self .react ,
72
+ trajectory ,
73
+ tools = self ._format_tools_string (tools ),
74
+ ** input_args
75
+ )
82
76
except ValueError as err :
83
77
logger .warning (f"Ending the trajectory: Agent failed to select a valid tool: { _fmt_exc (err )} " )
84
78
break
@@ -98,13 +92,18 @@ def forward(self, **input_args):
98
92
extract = self ._call_with_potential_trajectory_truncation (self .extract , trajectory , ** input_args )
99
93
return dspy .Prediction (trajectory = trajectory , ** extract )
100
94
101
- async def aforward (self , ** input_args ):
95
+ async def aforward (self , additional_tools : Optional [ list [ Callable ]] = None , ** input_args ):
102
96
trajectory = {}
103
97
max_iters = input_args .pop ("max_iters" , self .max_iters )
104
- tools = self ._copy_tools ( self .tools )
98
+ tools = self .tools | self ._convert_tools ( additional_tools )
105
99
for idx in range (max_iters ):
106
100
try :
107
- pred = await self ._async_call_with_potential_trajectory_truncation (self .react , trajectory , ** input_args )
101
+ pred = await self ._async_call_with_potential_trajectory_truncation (
102
+ self .react ,
103
+ trajectory ,
104
+ tools = self ._format_tools_string (tools ),
105
+ ** input_args
106
+ )
108
107
except ValueError as err :
109
108
logger .warning (f"Ending the trajectory: Agent failed to select a valid tool: { _fmt_exc (err )} " )
110
109
break
@@ -164,18 +163,19 @@ def truncate_trajectory(self, trajectory):
164
163
165
164
return trajectory
166
165
167
- def _copy_tools (self , tools ):
168
- results = tools .copy ()
169
- for tool_name , tool in tools .items ():
170
- if inspect .isfunction (tool .func ):
171
- results [tool_name ] = tool
172
- else :
173
- try :
174
- results [tool_name ] = deepcopy (tool )
175
- except Exception :
176
- logger .warning (f"Failed to deepcopy tool: { tool !r} . Consider making your tool deep-copyable "
177
- "if it needs to manage internal state. Error: {e}." )
178
- return results
166
+ def _format_trajectory (self , trajectory : dict [str , Any ]):
167
+ adapter = dspy .settings .adapter or dspy .ChatAdapter ()
168
+ trajectory_signature = dspy .Signature (f"{ ', ' .join (trajectory .keys ())} -> x" )
169
+ return adapter .format_user_message_content (trajectory_signature , trajectory )
170
+
171
+ def _convert_tools (self , tools : Optional [list [Callable ]]) -> dict [str , Tool ]:
172
+ """Convert the tools to a dictionary of name -> tool."""
173
+ tools = [t if isinstance (t , Tool ) else Tool (t ) for t in tools or []]
174
+ return {tool .name : tool for tool in tools }
175
+
176
+ def _format_tools_string (self , tools : dict [str , Tool ]) -> list [str ]:
177
+ """Format the tools into a list of string."""
178
+ return [f"({ idx + 1 } ) { tool } " for idx , tool in enumerate (tools .values ())]
179
179
180
180
181
181
def _fmt_exc (err : BaseException , * , limit : int = 5 ) -> str :
0 commit comments