forked from strands-agents/evals
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtools_use_extractor.py
More file actions
176 lines (145 loc) · 6.75 KB
/
tools_use_extractor.py
File metadata and controls
176 lines (145 loc) · 6.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from typing import Union, cast
from strands import Agent
from ..types.trace import Session, ToolLevelInput
from .trace_extractor import TraceExtractor
def extract_agent_tools_used_from_messages(agent_messages):
"""
Extract tool usage information from agent message history.
Args:
agent_messages: List of message dictionaries from agent conversation
Returns:
list: Tool usage information with name, input, and tool_result
[{name: str, input: dict, tool_result: str}, ...]
"""
tools_used = []
for i, message in enumerate(agent_messages):
if message.get("role") == "assistant":
message_info = message.get("content")
if message_info:
# Collect tool uses from this message
tools = [cb.get("toolUse") for cb in message_info if cb.get("toolUse")]
if not tools:
continue
# Build lookup dict of tool results from subsequent user messages
tool_ids_needed = {tool.get("toolUseId") for tool in tools}
tool_results_by_id: dict[str, dict] = {}
for next_message in agent_messages[i + 1 :]:
if next_message.get("role") == "user":
for content_block in next_message.get("content") or []:
tool_result_dict = content_block.get("toolResult")
if tool_result_dict:
tool_id = tool_result_dict.get("toolUseId")
if tool_id in tool_ids_needed and tool_id not in tool_results_by_id:
tool_results_by_id[tool_id] = tool_result_dict
if len(tool_results_by_id) == len(tool_ids_needed):
break
for tool in tools:
tool_name = tool.get("name")
tool_input = tool.get("input")
tool_id = tool.get("toolUseId")
tool_result = None
is_error = False
# Find the matching tool result block
tool_result_dict = tool_results_by_id.get(tool_id)
if tool_result_dict:
tool_result_content = tool_result_dict.get("content", [])
for result_item in tool_result_content:
if isinstance(result_item, dict) and "text" in result_item:
tool_result = result_item.get("text")
break
is_error = tool_result_dict.get("status") == "error"
tools_used.append(
{"name": tool_name, "input": tool_input, "tool_result": tool_result, "is_error": is_error}
)
return tools_used
def extract_agent_tools_used_from_metrics(agent_result):
"""
Extract tool usage metrics from agent execution result.
Args:
agent_result: Agent result object containing metrics
Returns:
list: Tool metrics with name, input, counts, and timing
[{
name: str,
input: dict,
call_count: int,
success_count: int,
total_time: float
}, ...]
"""
tool_metrics = agent_result.metrics.tool_metrics
tools_used = []
for tool_name, tool_info in tool_metrics.items():
tools_used.append(
{
"name": tool_name,
"input": tool_info.tool.get("input"),
"call_count": tool_info.call_count,
"success_count": tool_info.success_count,
"total_time": tool_info.total_time,
}
)
return tools_used
def extract_agent_tools_used_from_trace(session: Session) -> list[dict]:
"""
Extract tool usage information from trace data (Session object).
This function uses TraceExtractor to parse the session at TOOL_LEVEL,
then transforms the ToolLevelInput objects into the same format as
extract_agent_tools_used_from_messages for consistency.
Args:
session: Session object containing trace data
Returns:
list: Tool usage information with name, input, and tool_result
[{name: str, input: dict, tool_result: str}, ...]
"""
from ..types.trace import EvaluationLevel
# Use TraceExtractor to get tool-level inputs
extractor = TraceExtractor(evaluation_level=EvaluationLevel.TOOL_LEVEL)
tool_inputs = cast(list[ToolLevelInput], extractor.extract(session))
# Transform to the same format as message-based extraction
tools_used = []
for tool_input in tool_inputs:
tool_execution = tool_input.tool_execution_details
tool_name = tool_execution.tool_call.name
tool_input_args = tool_execution.tool_call.arguments
tool_result = tool_execution.tool_result.content if tool_execution.tool_result else None
tools_used.append({"name": tool_name, "input": tool_input_args, "tool_result": tool_result})
return tools_used
def extract_agent_tools_used(source: Union[list, Session]) -> list[dict]:
"""
Extract tool usage information from either agent messages or trace data.
This is a unified interface that automatically detects the input type and uses
the appropriate extraction method:
- If source is a Session object, uses trace-based extraction
- If source is a list, uses message-based extraction
Args:
source: Either agent_messages (list) or Session object
Returns:
list: Tool usage information with name, input, and tool_result
[{name: str, input: dict, tool_result: str}, ...]
Raises:
TypeError: If source is neither a list nor a Session object
"""
if isinstance(source, Session):
return extract_agent_tools_used_from_trace(source)
elif isinstance(source, list):
return extract_agent_tools_used_from_messages(source)
else:
raise TypeError(f"source must be either a list (agent messages) or Session object, got {type(source).__name__}")
def extract_tools_description(agent: Agent, is_short: bool = True):
"""
Extract a dictionary of all tools used in a given agent.
Args:
agent (Agent): Target agent to extract tool registry from
is_short (bool, optional): Whether to return only the description of the tools or everything. Defaults to True.
Returns:
dict: Tool name and its corresponding description
{<tool_name>: <tool_description>, ...}
"""
description = agent.tool_registry.get_all_tools_config()
if is_short:
shorten_descrip = {}
for tool_name, tool_info in description.items():
shorten_descrip[tool_name] = tool_info["description"]
return shorten_descrip
return description