1111from agentscope .model import ChatModelBase
1212from agentscope .formatter import FormatterBase
1313from agentscope .memory import MemoryBase , LongTermMemoryBase
14- from agentscope .message import Msg , TextBlock , ToolUseBlock , ToolResultBlock
14+ from agentscope .message import Msg , ToolUseBlock , ToolResultBlock
1515
1616from alias .agent .tools import AliasToolkit
1717from alias .agent .utils .constants import DEFAULT_PLANNER_NAME
2323from alias .agent .utils .constants import MODEL_MAX_RETRIES
2424
2525
26- def alias_agent_post_reply_hook (
27- self : "AliasAgentBase" ,
28- kwargs : dict [str , Any ], # pylint: disable=unused-argument
29- output : Any ,
30- ):
31- """
32- This is a monkey patch to ensure that when the agent is interrupted in
33- a tool call, the control returns to user
34- """
35- if (
36- self .tool_call_interrupt_return
37- and isinstance (output , Msg )
38- and output .metadata
39- and output .metadata .get ("is_interrupted" , False )
40- ):
41- raise asyncio .CancelledError ()
42-
43-
4426class AliasAgentBase (ReActAgent ):
4527 def __init__ (
4628 self ,
@@ -53,7 +35,6 @@ def __init__(
5335 state_saving_dir : Optional [str ] = None ,
5436 sys_prompt : Optional [str ] = None ,
5537 max_iters : int = 10 ,
56- tool_call_interrupt_return : bool = True ,
5738 long_term_memory : Optional [LongTermMemoryBase ] = None ,
5839 long_term_memory_mode : Literal [
5940 "agent_control" ,
@@ -77,22 +58,24 @@ def __init__(
7758 self .message_sending_mapping = {}
7859 self .state_saving_dir = state_saving_dir
7960 self .agent_stop_function_names = [self .finish_function_name ]
80- self .tool_call_interrupt_return = tool_call_interrupt_return
8161
82- # interrupted if the
83- self .register_instance_hook (
84- "post_reply" ,
85- "alias_agent_post_reply_hook" ,
86- alias_agent_post_reply_hook ,
87- )
8862 # for message output to backend
8963 self .register_instance_hook (
9064 "post_print" ,
9165 "alias_post_print_hook" ,
9266 alias_post_print_hook ,
9367 )
9468
95- async def _reasoning (self ):
69+ # register finish_function_name
70+ if self .finish_function_name not in self .toolkit .tools :
71+ self .toolkit .register_tool_function (
72+ getattr (self , self .finish_function_name ),
73+ )
74+
75+ async def _reasoning (
76+ self ,
77+ tool_choice : Literal ["auto" , "none" , "required" ] | None = None ,
78+ ):
9679 """Override _reasoning to add retry logic."""
9780
9881 # Call the parent class's _reasoning method directly to
@@ -109,7 +92,8 @@ async def call_parent_reasoning():
10992 if hasattr (original_method , "__wrapped__" ):
11093 # This is the wrapped version, get the original
11194 original_method = original_method .__wrapped__
112- return await original_method (self )
95+
96+ return await original_method (self , tool_choice = tool_choice )
11397
11498 for i in range (MODEL_MAX_RETRIES - 1 ):
11599 try :
@@ -132,21 +116,18 @@ async def call_parent_reasoning():
132116 # final attempt
133117 await call_parent_reasoning ()
134118
135- async def _acting (self , tool_call : ToolUseBlock ) -> Msg | None :
136- """Perform the acting process.
137-
138- TODO: (part 2)
139- this is just a monkey patch for AS when not support interruption
140- during tool call; can be remove when AS framework updated
119+ async def _acting (self , tool_call : ToolUseBlock ) -> dict | None :
120+ """Perform the acting process, and return the structured output if
121+ it's generated and verified in the finish function call.
141122
142123 Args:
143124 tool_call (`ToolUseBlock`):
144125 The tool use block to be executed.
145126
146127 Returns:
147- `Union[Msg , None]`:
148- Return a message to the user if the `_finish_function` is
149- called, otherwise return `None` .
128+ `Union[dict , None]`:
129+ Return the structured output if it's verified in the finish
130+ function call .
150131 """
151132
152133 tool_res_msg = Msg (
@@ -165,7 +146,6 @@ async def _acting(self, tool_call: ToolUseBlock) -> Msg | None:
165146 # Execute the tool call
166147 tool_res = await self .toolkit .call_tool_function (tool_call )
167148
168- response_msg = None
169149 # Async generator handling
170150 async for chunk in tool_res :
171151 # Turn into a tool result block
@@ -191,28 +171,26 @@ async def _acting(self, tool_call: ToolUseBlock) -> Msg | None:
191171 tool_call ["name" ] != self .finish_function_name
192172 or (
193173 tool_call ["name" ] == self .finish_function_name
174+ and chunk .metadata
194175 and not chunk .metadata .get ("success" )
195176 )
196177 ):
197178 await self .print (tool_res_msg , chunk .is_last )
198179
199180 # Return message if generate_response is called successfully
200- if tool_call [
201- "name"
202- ] in self .agent_stop_function_names and chunk .metadata .get (
203- "success" ,
204- True ,
181+ if (
182+ tool_call ["name" ] in self .agent_stop_function_names
183+ and chunk .metadata
184+ and chunk .metadata .get (
185+ "success" ,
186+ True ,
187+ )
205188 ):
206- response_msg = chunk .metadata .get ("response_msg " )
189+ return chunk .metadata .get ("structured_output " )
207190 elif chunk .is_interrupted :
208- # TODO: monkey patch happens here
209- response_msg = tool_res_msg
210- if response_msg .metadata is None :
211- response_msg .metadata = {"is_interrupted" : True }
212- else :
213- response_msg .metadata ["is_interrupted" ] = True
191+ raise asyncio .CancelledError
214192
215- return response_msg
193+ return None
216194 finally :
217195 # Record the tool result message in the memory
218196 await self .memory .add (tool_res_msg )
@@ -228,16 +206,16 @@ async def handle_interrupt(
228206 """
229207 response_msg = Msg (
230208 self .name ,
231- content = [
232- TextBlock (
233- type = "text" ,
234- text = "I got interrupted by the user. "
235- "Pivot to handle the user's new request." ,
236- ),
237- ],
238- role = "assistant" ,
239- metadata = {},
209+ "I noticed that you have interrupted me. What can I "
210+ "do for you?" ,
211+ "assistant" ,
212+ metadata = {
213+ # Expose this field to indicate the interruption
214+ "_is_interrupted" : True ,
215+ },
240216 )
217+
218+ await self .print (response_msg , True )
241219 await self .memory .add (response_msg )
242220
243221 # update and save agent states
0 commit comments