Skip to content

Commit ccf9241

Browse files
authored
fix: Handle parallel tool calls during tool extraction (#137)
1 parent d307f40 commit ccf9241

File tree

4 files changed

+108
-46
lines changed

4 files changed

+108
-46
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ repl_state
1414
dataset_files
1515
report_files
1616
.venv
17-
*.DS_Store*
17+
*.DS_Store*
18+
uv.lock

src/strands_evals/evaluators/coherence_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ class CoherenceRating(BaseModel):
3131

3232
class CoherenceEvaluator(Evaluator[InputT, OutputT]):
3333
"""Evaluates the logical cohesion of the assistant's response.
34-
34+
3535
This evaluator assesses whether the assistant's response maintains logical consistency,
3636
flows naturally, and presents ideas in a well-organized manner. It uses an LLM-as-judge
3737
approach to provide categorical ratings that are then normalized to numeric scores.
38-
38+
3939
Scores:
4040
- NOT_AT_ALL (0.0): Response is completely incoherent or contradictory
4141
- NOT_GENERALLY (0.25): Response has significant logical gaps or inconsistencies

src/strands_evals/extractors/tools_use_extractor.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,51 +21,46 @@ def extract_agent_tools_used_from_messages(agent_messages):
2121
for i, message in enumerate(agent_messages):
2222
if message.get("role") == "assistant":
2323
message_info = message.get("content")
24-
if len(message_info) > 0:
25-
tools = []
26-
for message in message_info:
27-
if "toolUse" in message:
28-
tools.append(message.get("toolUse"))
24+
if message_info:
25+
# Collect tool uses from this message
26+
tools = [cb.get("toolUse") for cb in message_info if cb.get("toolUse")]
27+
if not tools:
28+
continue
29+
30+
# Build lookup dict of tool results from subsequent user messages
31+
tool_ids_needed = {tool.get("toolUseId") for tool in tools}
32+
tool_results_by_id: dict[str, dict] = {}
33+
for next_message in agent_messages[i + 1 :]:
34+
if next_message.get("role") == "user":
35+
for content_block in next_message.get("content") or []:
36+
tool_result_dict = content_block.get("toolResult")
37+
if tool_result_dict:
38+
tool_id = tool_result_dict.get("toolUseId")
39+
if tool_id in tool_ids_needed and tool_id not in tool_results_by_id:
40+
tool_results_by_id[tool_id] = tool_result_dict
41+
if len(tool_results_by_id) == len(tool_ids_needed):
42+
break
2943

3044
for tool in tools:
31-
if tool:
32-
tool_name = tool.get("name")
33-
tool_input = tool.get("input")
34-
tool_id = tool.get("toolUseId")
35-
# get the tool result from the next message
36-
tool_result = None
37-
is_error = False
38-
next_message_i = i + 1
39-
while next_message_i < len(agent_messages):
40-
next_message = agent_messages[next_message_i]
41-
next_message_i += 1
42-
43-
if next_message.get("role") == "user":
44-
content = next_message.get("content")
45-
if content:
46-
# Find toolResult in content blocks - may not be at index 0
47-
tool_result_dict = None
48-
for content_block in content:
49-
if "toolResult" in content_block:
50-
tool_result_dict = content_block.get("toolResult")
51-
break
52-
53-
if tool_result_dict and tool_result_dict.get("toolUseId") == tool_id:
54-
tool_result_content = tool_result_dict.get("content", [])
55-
# Find first text in tool result content - may not be at index 0
56-
tool_result = None
57-
if tool_result_content:
58-
for result_item in tool_result_content:
59-
if isinstance(result_item, dict) and "text" in result_item:
60-
tool_result = result_item.get("text")
61-
break
62-
is_error = tool_result_dict.get("status") == "error"
63-
break
64-
65-
tools_used.append(
66-
{"name": tool_name, "input": tool_input, "tool_result": tool_result, "is_error": is_error}
67-
)
68-
tool = message.get("toolUse")
45+
tool_name = tool.get("name")
46+
tool_input = tool.get("input")
47+
tool_id = tool.get("toolUseId")
48+
tool_result = None
49+
is_error = False
50+
51+
# Find the matching tool result block
52+
tool_result_dict = tool_results_by_id.get(tool_id)
53+
if tool_result_dict:
54+
tool_result_content = tool_result_dict.get("content", [])
55+
for result_item in tool_result_content:
56+
if isinstance(result_item, dict) and "text" in result_item:
57+
tool_result = result_item.get("text")
58+
break
59+
is_error = tool_result_dict.get("status") == "error"
60+
61+
tools_used.append(
62+
{"name": tool_name, "input": tool_input, "tool_result": tool_result, "is_error": is_error}
63+
)
6964
return tools_used
7065

7166

tests/strands_evals/extractors/test_tools_use_extractor.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,69 @@ def test_tools_use_extractor_extract_from_messages_user_message_without_tool_res
342342
assert result[0]["input"] == {"expression": "5+5"}
343343
assert result[0]["tool_result"] == "Result: 10"
344344
assert result[0]["is_error"] is False
345+
346+
347+
def test_tools_use_extractor_extract_from_messages_parallel_tool_calls():
348+
"""Test extracting multiple parallel tool calls with results in same user message."""
349+
messages = [
350+
{"role": "user", "content": [{"text": "Calculate 2+2 and 3+3"}]},
351+
{
352+
"role": "assistant",
353+
"content": [
354+
{"toolUse": {"toolUseId": "tool_1", "name": "calculator", "input": {"expression": "2+2"}}},
355+
{"toolUse": {"toolUseId": "tool_2", "name": "calculator", "input": {"expression": "3+3"}}},
356+
],
357+
},
358+
{
359+
"role": "user",
360+
"content": [
361+
{"toolResult": {"status": "success", "content": [{"text": "4"}], "toolUseId": "tool_1"}},
362+
{"toolResult": {"status": "success", "content": [{"text": "6"}], "toolUseId": "tool_2"}},
363+
],
364+
},
365+
]
366+
367+
result = extract_agent_tools_used_from_messages(messages)
368+
369+
assert len(result) == 2
370+
assert result[0]["name"] == "calculator"
371+
assert result[0]["input"] == {"expression": "2+2"}
372+
assert result[0]["tool_result"] == "4"
373+
assert result[1]["name"] == "calculator"
374+
assert result[1]["input"] == {"expression": "3+3"}
375+
assert result[1]["tool_result"] == "6"
376+
377+
378+
def test_tools_use_extractor_extract_from_messages_reused_tool_ids():
379+
"""Test extracting tool calls when tool IDs are reused across the session."""
380+
messages = [
381+
{"role": "user", "content": [{"text": "Calculate 2+2"}]},
382+
{
383+
"role": "assistant",
384+
"content": [{"toolUse": {"toolUseId": "call_123", "name": "calculator", "input": {"expression": "2+2"}}}],
385+
},
386+
{
387+
"role": "user",
388+
"content": [{"toolResult": {"status": "success", "content": [{"text": "4"}], "toolUseId": "call_123"}}],
389+
},
390+
{"role": "assistant", "content": [{"text": "The answer is 4"}]},
391+
{"role": "user", "content": [{"text": "Now calculate 5+5"}]},
392+
{
393+
"role": "assistant",
394+
"content": [{"toolUse": {"toolUseId": "call_123", "name": "calculator", "input": {"expression": "5+5"}}}],
395+
},
396+
{
397+
"role": "user",
398+
"content": [{"toolResult": {"status": "success", "content": [{"text": "10"}], "toolUseId": "call_123"}}],
399+
},
400+
]
401+
402+
result = extract_agent_tools_used_from_messages(messages)
403+
404+
assert len(result) == 2
405+
assert result[0]["name"] == "calculator"
406+
assert result[0]["input"] == {"expression": "2+2"}
407+
assert result[0]["tool_result"] == "4"
408+
assert result[1]["name"] == "calculator"
409+
assert result[1]["input"] == {"expression": "5+5"}
410+
assert result[1]["tool_result"] == "10"

0 commit comments

Comments
 (0)