Skip to content

Commit 57fbc6b

Browse files
committed
fix(agents): support parallel tool calls and improve recovery prompt
1 parent ad6e3cc commit 57fbc6b

3 files changed

Lines changed: 335 additions & 14 deletions

File tree

src/gaia/agents/base/agent.py

Lines changed: 165 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -976,10 +976,45 @@ def _parse_llm_response(self, response: str) -> Dict[str, Any]:
976976
f"finishing the call — increase AgentConfig.max_tokens."
977977
)
978978
if len(tool_calls) > 1:
979-
raise NotImplementedError(
980-
"Parallel tool calls (multiple tool_calls in one response) are not yet supported. "
981-
f"Received {len(tool_calls)} tool calls."
979+
# Support multiple native tool_calls by returning a list of
980+
# parsed tool call dicts. Each element has the same shape
981+
# as the single-call return value so callers can either
982+
# handle a list (preferred for tool-calling models) or
983+
# fall back to the old single-dict behaviour.
984+
parsed_calls = []
985+
for tc in tool_calls:
986+
name = tc["function"]["name"]
987+
arguments_raw = tc["function"].get("arguments")
988+
989+
if arguments_raw is None or arguments_raw == "":
990+
tool_args = {}
991+
elif isinstance(arguments_raw, dict):
992+
tool_args = arguments_raw
993+
elif isinstance(arguments_raw, (str, bytes, bytearray)):
994+
try:
995+
tool_args = json.loads(arguments_raw)
996+
except json.JSONDecodeError as exc:
997+
raise ValueError(
998+
f"Malformed tool_call arguments for '{name}': {exc}. "
999+
f"Raw arguments: {str(arguments_raw)[:200]}"
1000+
) from exc
1001+
else:
1002+
raise ValueError(
1003+
f"Malformed tool_call arguments for '{name}': expected "
1004+
f"str or dict, got {type(arguments_raw).__name__}"
1005+
)
1006+
1007+
parsed_calls.append({
1008+
"thought": "",
1009+
"goal": "",
1010+
"tool": name,
1011+
"tool_args": tool_args,
1012+
})
1013+
logger.debug(
1014+
"[PARSE] Native tool_calls: returning %d parsed calls",
1015+
len(parsed_calls),
9821016
)
1017+
return parsed_calls
9831018
tc = tool_calls[0]
9841019
name = tc["function"]["name"]
9851020
arguments_raw = tc["function"].get("arguments")
@@ -2514,33 +2549,149 @@ def process_query(
25142549
"rephrase or break the request into smaller pieces?"
25152550
)
25162551
break
2552+
25172553
# Push a synthetic assistant turn + recovery user message so the
25182554
# next LLM call has context. Don't include the raw envelope to
25192555
# keep noise out of the conversation history.
2520-
messages.append(
2521-
{
2522-
"role": "assistant",
2556+
recovery_assistant = {
2557+
"role": "assistant",
2558+
"content": "[I tried to call a tool but my arguments were malformed.]",
2559+
}
2560+
messages.append(recovery_assistant)
2561+
conversation.append(recovery_assistant)
2562+
2563+
# Provide different guidance depending on the parse failure type.
2564+
if isinstance(parse_exc, NotImplementedError):
2565+
# NotImplementedError historically meant "multiple tool_calls"
2566+
# when native tool-calling models returned parallel calls.
2567+
# Give the model a clear instruction to either emit a single
2568+
# tool call or a JSON `plan` describing multiple steps.
2569+
recovery_user = {
2570+
"role": "user",
25232571
"content": (
2524-
"[I tried to call a tool but my arguments were "
2525-
"malformed.]"
2572+
"Your last response contained MULTIPLE tool calls in a single reply. "
2573+
"This agent prefers either a single tool call per response, "
2574+
"or a structured JSON 'plan' containing an ordered array of steps. "
2575+
"Please either: (A) output a single tool call JSON object, "
2576+
"or (B) output a JSON plan in the format: {\"plan\": [{\"tool\": \"name\", \"tool_args\": {...}}]}. "
2577+
"If you don't need to call a tool, answer in plain text."
25262578
),
25272579
}
2528-
)
2529-
messages.append(
2530-
{
2580+
messages.append(recovery_user)
2581+
conversation.append(recovery_user)
2582+
else:
2583+
# ValueError or other parse errors usually mean malformed args.
2584+
recovery_user = {
25312585
"role": "user",
25322586
"content": (
25332587
"Your last tool call had malformed arguments. "
25342588
"Please try again. Use ONLY the documented enum "
2535-
"values for each argument (e.g. 'brief', "
2536-
"'detailed', 'bullets' — never a long sentence). "
2589+
"values for each argument (e.g. 'brief', 'detailed', 'bullets'). "
25372590
"If you don't need a tool, answer in plain text."
25382591
),
25392592
}
2540-
)
2593+
messages.append(recovery_user)
2594+
conversation.append(recovery_user)
2595+
25412596
steps_taken += 1
25422597
continue
25432598
logger.debug(f"Parsed response: {parsed}")
2599+
2600+
# If the parser returned multiple native tool calls, execute them
2601+
# sequentially in this same LLM turn (one LLM turn -> N tool turns).
2602+
if isinstance(parsed, list):
2603+
# Record assistant turn containing multiple tool_calls
2604+
conversation.append({"role": "assistant", "content": {"tool_calls": parsed}})
2605+
# Preserve raw assistant response for history
2606+
messages.append({"role": "assistant", "content": response})
2607+
2608+
for call in parsed:
2609+
if not call.get("tool") or "tool_args" not in call:
2610+
continue
2611+
2612+
tool_name = call["tool"]
2613+
tool_args = call["tool_args"]
2614+
logger.debug(f"Sequential native tool call: {tool_name} {tool_args}")
2615+
2616+
# Display the tool call in real-time
2617+
self.console.print_tool_usage(tool_name)
2618+
if tool_args:
2619+
self.console.pretty_print_json(tool_args, "Arguments")
2620+
2621+
# Start progress indicator for tool execution
2622+
self.console.start_progress(f"Executing {tool_name}")
2623+
2624+
# Track call history and detect repeats
2625+
current_call = (tool_name, str(tool_args))
2626+
tool_call_history.append(current_call)
2627+
tool_call_log.append(current_call)
2628+
if len(tool_call_history) > 5:
2629+
tool_call_history.pop(0)
2630+
2631+
consecutive_count = 0
2632+
for c in reversed(tool_call_history):
2633+
if c == current_call:
2634+
consecutive_count += 1
2635+
else:
2636+
break
2637+
if consecutive_count >= self.max_consecutive_repeats:
2638+
self.console.stop_progress()
2639+
final_answer = f"Task completed with {tool_name}. No further action needed."
2640+
self.console.print_repeated_tool_warning()
2641+
break
2642+
2643+
# Execute the tool
2644+
tool_result = self._execute_tool(tool_name, tool_args)
2645+
2646+
# Stop progress indicator
2647+
self.console.stop_progress()
2648+
2649+
# Domain-specific post-processing
2650+
self._post_process_tool_result(tool_name, tool_args, tool_result)
2651+
2652+
# Handle and append large tool results
2653+
truncated_result = self._handle_large_tool_result(
2654+
tool_name, tool_result, conversation, tool_args
2655+
)
2656+
2657+
# Display the tool result
2658+
self.console.pretty_print_json(tool_result, "Result")
2659+
self.console.print_tool_complete()
2660+
2661+
previous_outputs.append({"tool": tool_name, "args": tool_args, "result": truncated_result})
2662+
step_results.append(tool_result)
2663+
2664+
# Share tool output with subsequent LLM calls
2665+
messages.append(self._create_tool_message(tool_name, truncated_result))
2666+
2667+
# Error handling
2668+
is_error = isinstance(tool_result, dict) and (
2669+
tool_result.get("status") == "error"
2670+
or tool_result.get("success") is False
2671+
or tool_result.get("has_errors") is True
2672+
or tool_result.get("return_code", 0) != 0
2673+
)
2674+
if is_error:
2675+
error_count += 1
2676+
last_error = (
2677+
tool_result.get("error_brief")
2678+
or tool_result.get("error")
2679+
or tool_result.get("stderr")
2680+
or tool_result.get("hint")
2681+
or tool_result.get("suggested_fix")
2682+
or f"Command failed with return code {tool_result.get('return_code')}"
2683+
)
2684+
logger.warning(f"Tool execution error in sequential calls (count: {error_count}): {last_error}")
2685+
if not tool_result.get("error_displayed"):
2686+
self.console.print_error(last_error)
2687+
self.execution_state = self.STATE_ERROR_RECOVERY
2688+
# Continue processing remaining calls (or break?) — prefer to continue
2689+
2690+
# After executing all sequential native calls, continue the main loop
2691+
# so the LLM can process the combined tool results.
2692+
continue
2693+
2694+
# Single parsed response — append as before
25442695
conversation.append({"role": "assistant", "content": parsed})
25452696

25462697
# Add assistant response to messages for chat history
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import json
2+
import pytest
3+
4+
from gaia.agents.base.agent import Agent
5+
from gaia.agents.base.tools import _TOOL_REGISTRY
6+
7+
8+
def test_process_query_executes_multiple_native_tool_calls(monkeypatch):
9+
# Register two simple tools for the test
10+
def tool_one(a=""):
11+
return {"status": "success", "value": f"one:{a}"}
12+
13+
def tool_two(b=""):
14+
return {"status": "success", "value": f"two:{b}"}
15+
16+
_TOOL_REGISTRY["tool_one"] = {
17+
"function": tool_one,
18+
"parameters": {"a": {"type": "str", "required": False}},
19+
"description": "Test tool one",
20+
}
21+
_TOOL_REGISTRY["tool_two"] = {
22+
"function": tool_two,
23+
"parameters": {"b": {"type": "str", "required": False}},
24+
"description": "Test tool two",
25+
}
26+
27+
class DummyAgent(Agent):
28+
def _register_tools(self):
29+
# No-op; tests inject tools directly into registry
30+
return None
31+
32+
agent = DummyAgent(skip_lemonade=True, silent_mode=True)
33+
34+
# Prepare a native envelope with two tool_calls (as Lemonade encodes them)
35+
envelope = {
36+
"__tool_calls__": [
37+
{"function": {"name": "tool_one", "arguments": json.dumps({"a": "X"})}},
38+
{"function": {"name": "tool_two", "arguments": json.dumps({"b": "Y"})}},
39+
],
40+
"finish_reason": "",
41+
}
42+
43+
# Monkeypatch send_messages to return our envelope as the LLM response
44+
# AgentSDK.send_messages returns an object with .text and .stats attributes
45+
monkeypatch.setattr(
46+
agent.chat,
47+
"send_messages",
48+
lambda messages, system_prompt, tools: type(
49+
"R", (), {"text": json.dumps(envelope), "stats": {}}
50+
)(),
51+
)
52+
53+
result = agent.process_query("execute both tools", max_steps=6)
54+
55+
# Verify both tool results were appended to conversation
56+
tool_names = [m.get("name") for m in result["conversation"] if m.get("role") == "tool"]
57+
assert "tool_one" in tool_names
58+
assert "tool_two" in tool_names
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import json
2+
import pytest
3+
4+
from gaia.agents.base.agent import Agent
5+
from gaia.agents.base.tools import _TOOL_REGISTRY
6+
7+
8+
def _make_agent(monkeypatch):
9+
class DummyAgent(Agent):
10+
def _register_tools(self):
11+
return None
12+
13+
agent = DummyAgent(skip_lemonade=True, silent_mode=True)
14+
return agent
15+
16+
17+
def test_parallel_calls_with_error(monkeypatch):
18+
# Tools: two success, one error
19+
def t_ok1(x=""):
20+
return {"status": "success", "value": f"ok1:{x}"}
21+
22+
def t_err(y=""):
23+
return {"status": "error", "error": "boom"}
24+
25+
def t_ok2(z=""):
26+
return {"status": "success", "value": f"ok2:{z}"}
27+
28+
_TOOL_REGISTRY["ok1"] = {"function": t_ok1, "parameters": {}, "description": ""}
29+
_TOOL_REGISTRY["errtool"] = {"function": t_err, "parameters": {}, "description": ""}
30+
_TOOL_REGISTRY["ok2"] = {"function": t_ok2, "parameters": {}, "description": ""}
31+
32+
agent = _make_agent(monkeypatch)
33+
34+
envelope = {
35+
"__tool_calls__": [
36+
{"function": {"name": "ok1", "arguments": json.dumps({"x": "A"})}},
37+
{"function": {"name": "errtool", "arguments": json.dumps({"y": "B"})}},
38+
{"function": {"name": "ok2", "arguments": json.dumps({"z": "C"})}},
39+
],
40+
"finish_reason": "",
41+
}
42+
43+
# make send_messages return envelope
44+
responses = [type("R", (), {"text": json.dumps(envelope), "stats": {}})()]
45+
46+
monkeypatch.setattr(agent.chat, "send_messages", lambda messages, system_prompt, tools: responses.pop(0))
47+
48+
result = agent.process_query("run three tools", max_steps=10)
49+
50+
# Ensure we got three tool entries in conversation
51+
tool_entries = [m for m in result["conversation"] if m.get("role") == "tool"]
52+
names = [t.get("name") for t in tool_entries]
53+
assert "ok1" in names and "errtool" in names and "ok2" in names
54+
55+
# Find the errtool result and ensure it's an error
56+
err_entry = next((t for t in tool_entries if t.get("name") == "errtool"), None)
57+
assert err_entry is not None
58+
assert isinstance(err_entry.get("content"), dict) and err_entry["content"].get("status") == "error"
59+
60+
61+
def test_plan_then_native_tool_calls(monkeypatch):
62+
# Tools
63+
def q(a=""):
64+
return {"status": "success", "value": f"q:{a}"}
65+
66+
_TOOL_REGISTRY["q"] = {"function": q, "parameters": {}, "description": ""}
67+
68+
agent = _make_agent(monkeypatch)
69+
70+
envelope = {
71+
"__tool_calls__": [
72+
{"function": {"name": "q", "arguments": json.dumps({"a": "1"})}},
73+
{"function": {"name": "q", "arguments": json.dumps({"a": "2"})}},
74+
],
75+
"finish_reason": "",
76+
}
77+
78+
# Second LLM response will be a final answer
79+
final_answer = {"answer": "All done"}
80+
81+
responses = [
82+
type("R", (), {"text": json.dumps(envelope), "stats": {}})(),
83+
type("R", (), {"text": json.dumps(final_answer), "stats": {}})(),
84+
]
85+
86+
def fake_send(messages, system_prompt, tools):
87+
return responses.pop(0)
88+
89+
monkeypatch.setattr(agent.chat, "send_messages", fake_send)
90+
91+
result = agent.process_query("do q twice and answer", max_steps=10)
92+
93+
# Should have run two q tool calls and then returned the final answer
94+
tool_entries = [m for m in result["conversation"] if m.get("role") == "tool"]
95+
assert len([t for t in tool_entries if t.get("name") == "q"]) == 2
96+
assert result.get("result") and "All done" in result.get("result")
97+
98+
99+
def test_notimplementederror_recovery_message(monkeypatch):
100+
agent = _make_agent(monkeypatch)
101+
102+
# Make the parser raise NotImplementedError
103+
monkeypatch.setattr(agent, "_parse_llm_response", lambda r: (_ for _ in ()).throw(NotImplementedError("multiple")))
104+
105+
# Make send_messages return something (will be ignored by parser)
106+
monkeypatch.setattr(agent.chat, "send_messages", lambda messages, system_prompt, tools: type("R", (), {"text": "{\"bad\":1}", "stats": {}})())
107+
108+
result = agent.process_query("trigger parse error", max_steps=3)
109+
110+
# Last user message in conversation should instruct about multiple tool calls
111+
user_msgs = [m for m in result["conversation"] if m.get("role") == "user"]
112+
assert any("MULTIPLE tool calls" in str(m.get("content")) or "single tool call" in str(m.get("content")) for m in user_msgs)

0 commit comments

Comments
 (0)