diff --git a/py/core/agent/base.py b/py/core/agent/base.py index 84aae3f23..b197c4b80 100644 --- a/py/core/agent/base.py +++ b/py/core/agent/base.py @@ -557,17 +557,68 @@ async def sse_generator() -> AsyncGenerator[str, None]: calls_list, partial_text_buffer ) - # (c) Execute each tool call in parallel - await asyncio.gather( - *[ - self.handle_function_or_tool_call( + # (c) Execute each tool call and emit tool result events + for c in calls_list: + try: + tool_result = await self.handle_function_or_tool_call( c["name"], c["arguments"], tool_id=c["tool_call_id"], + save_messages=True, # Set to True to add tool results to conversation ) - for c in calls_list - ] - ) + result_content = tool_result.llm_formatted_result + + # Extract the ids from the tool tool_result.content + ids = self.BRACKET_PATTERN.findall(tool_result.llm_formatted_result) + # Get the raw result from source_collector + raw_results = [] + for id in ids: + raw_result = self.search_results_collector.find_by_short_id( + id + ) + if raw_result: + raw_results.append(raw_result) + + # Create the result data with raw results + result_data = { + "tool_call_id": c["tool_call_id"], + "role": "tool", + "content": json.dumps(raw_results), + } + + # Emit SSE tool result event + async for line in SSEFormatter.yield_tool_result_event( + result_data + ): + yield line + except Exception as e: + error_content = f"Error in tool '{c['name']}': {str(e)}" + logger.error(error_content) + + # Add error message to conversation + await self.conversation.add_message( + Message( + role="tool", + content=error_content, + name=c["name"], + tool_call_id=c["tool_call_id"], + ) + ) + + # Emit error as tool result + result_data = { + "tool_call_id": c["tool_call_id"], + "role": "tool", + "content": json.dumps( + convert_nonserializable_objects( + error_content + ) + ), + } + async for line in SSEFormatter.yield_tool_result_event( + result_data + ): + yield line # Reset buffer & calls pending_tool_calls.clear()