Skip to content

Commit f033ce0

Browse files
committed
use tool simulator with strands tool decorator
1 parent 3c3150d commit f033ce0

File tree

2 files changed

+169
-75
lines changed

2 files changed

+169
-75
lines changed

src/strands_evals/simulation/tool_simulator.py

Lines changed: 60 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
import json
32
import logging
43
import warnings
@@ -10,6 +9,7 @@
109
from strands import Agent
1110
from strands.agent import AgentResult
1211
from strands.models.model import Model
12+
from strands.tools.decorator import DecoratedFunctionTool
1313

1414
from strands_evals.types.simulation.tool import RegisteredTool
1515

@@ -122,8 +122,6 @@ def cache_tool_call(
122122

123123
# Append to deque with automatic FIFO eviction when cache is full
124124
state["previous_calls"].append(call_record)
125-
126-
# Return converted state for external use
127125
return self.get_state(state_key)
128126

129127
def clear_state(self, state_key: str) -> None:
@@ -145,6 +143,15 @@ class ToolSimulator:
145143
registered tools. It can be configured to override tool behavior for simulation purposes,
146144
enabling controlled testing scenarios.
147145
146+
IMPORTANT: This simulator expects functions to be decorated with Strands' @tool decorator first.
147+
148+
Example usage:
149+
@simulator.tool(share_state_id="room_environment")
150+
@tool
151+
def my_tool(param: str) -> dict:
152+
'''Tool description'''
153+
pass
154+
148155
The simulator automatically maintains a bounded cache of tool calls for context.
149156
The maximum number of tool calls stored per state key is configurable via
150157
max_tool_call_cache_size parameter (default: 20).
@@ -153,7 +160,6 @@ class ToolSimulator:
153160
state_registry: Registry for maintaining tool state across calls.
154161
function_tool_prompt: Custom prompt template for tool response generation.
155162
model: Provider for running inference or model identifier for Bedrock.
156-
framework: Agent framework to use (default: "strands").
157163
max_tool_call_cache_size: Maximum number of tool calls to store per state key.
158164
"""
159165

@@ -165,7 +171,6 @@ def __init__(
165171
state_registry: StateRegistry | None = None,
166172
function_tool_prompt: str | None = None,
167173
model: Model | str | None = None,
168-
framework: str = "strands",
169174
max_tool_call_cache_size: int = 20,
170175
):
171176
"""
@@ -176,34 +181,21 @@ def __init__(
176181
a new StateRegistry will be created with max_tool_call_cache_size.
177182
function_tool_prompt: Optional custom prompt for tool response generation
178183
model: Provider for running inference or a string representing the model-id for Bedrock to use
179-
framework: Agent framework to use (default: "strands")
180184
max_tool_call_cache_size: Maximum number of tool calls to store per state key.
181185
Only used when creating a new StateRegistry (ignored if state_registry
182186
is provided). Older calls are automatically evicted when limit is exceeded.
183187
Default is 20.
184188
"""
185-
# Store framework selection
186-
self.framework = framework
187-
# Store model configuration for creating internal agents
188189
self.model = model
189-
190-
# Set custom prompt or use default
191190
self.function_tool_prompt = function_tool_prompt or FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT
192-
193-
# Set up state registry
194191
self.state_registry = state_registry or StateRegistry(max_tool_call_cache_size=max_tool_call_cache_size)
195-
196-
# Initialize shared states from registered tools
197192
self._initialize_shared_states()
198193

199194
def _initialize_shared_states(self):
200195
"""Initialize shared states from registered tools' initial descriptions."""
201196
for tool_name, registered_tool in self._registered_tools.items():
202197
if registered_tool.initial_state_description:
203-
# Determine state key from share_state_id or tool name
204198
state_key = registered_tool.share_state_id or registered_tool.name
205-
206-
# Initialize state with description
207199
self.state_registry.initialize_state_via_description(
208200
registered_tool.initial_state_description, state_key
209201
)
@@ -213,10 +205,8 @@ def _create_tool_wrapper(self, registered_tool: RegisteredTool):
213205
"""Create a framework-compatible tool wrapper."""
214206

215207
def wrapper(*args, **kwargs):
216-
# Determine state key
217208
state_key = registered_tool.share_state_id or registered_tool.name
218209

219-
# Build parameters string for tool
220210
parameters_string = (
221211
json.dumps({"args": args, "kwargs": kwargs}, indent=2) if args else json.dumps(kwargs, indent=2)
222212
)
@@ -236,61 +226,36 @@ def wrapper(*args, **kwargs):
236226
wrapper.__name__ = registered_tool.name
237227
wrapper.__doc__ = f"Simulated {registered_tool.name} tool"
238228

239-
# Use framework-specific method to create the tool wrapper
240-
if self.framework == "strands":
241-
return self._create_strands_tool_wrapper(registered_tool, wrapper)
242-
else:
243-
raise ValueError(f"Framework '{self.framework}' is not supported. Only 'strands' is currently supported.")
229+
return self._create_strands_tool_wrapper(registered_tool, wrapper)
244230

245231
def _create_strands_tool_wrapper(self, registered_tool: RegisteredTool, wrapper: Callable):
246-
"""Create a Strands-specific DecoratedFunctionTool wrapper."""
247-
from strands.tools.decorator import DecoratedFunctionTool, FunctionToolMetadata
232+
"""
233+
Create a Strands-specific DecoratedFunctionTool wrapper.
248234
249-
# Create tool spec based on function signature and docstring
250-
tool_description = wrapper.__doc__ or f"Simulated {registered_tool.name} tool"
235+
Since the registered function is already a DecoratedFunctionTool (from @tool decorator),
236+
we reuse its existing metadata and spec, but replace the tool_func with our simulation wrapper.
237+
"""
238+
original_tool = registered_tool.function
251239

252-
# Build input schema from function signature
253-
input_schema: dict[str, Any] = {"type": "object", "properties": {}}
254-
if registered_tool.function:
255-
try:
256-
sig = inspect.signature(registered_tool.function)
257-
for param_name, param in sig.parameters.items():
258-
if param.annotation != inspect.Parameter.empty:
259-
param_type = (
260-
str(param.annotation).replace("<class '", "").replace("'>", "").replace("typing.", "")
261-
)
262-
if "str" in param_type.lower():
263-
input_schema["properties"][param_name] = {"type": "string"}
264-
elif "int" in param_type.lower():
265-
input_schema["properties"][param_name] = {"type": "integer"}
266-
elif "float" in param_type.lower():
267-
input_schema["properties"][param_name] = {"type": "number"}
268-
elif "bool" in param_type.lower():
269-
input_schema["properties"][param_name] = {"type": "boolean"}
270-
else:
271-
input_schema["properties"][param_name] = {"type": "object"}
272-
else:
273-
input_schema["properties"][param_name] = {"type": "string"} # default
274-
except Exception:
275-
pass # fallback to empty schema
276-
277-
# Create Strands tool's FunctionToolMetadata object and DecoratedFunctionTool instance
278-
metadata = FunctionToolMetadata(registered_tool.function or wrapper)
279-
280-
# Extract tool_spec from metadata; override with our custom description if needed
281-
extracted_tool_spec = metadata.extract_metadata()
282-
if tool_description != extracted_tool_spec.get("description"):
283-
extracted_tool_spec["description"] = tool_description
284-
extracted_tool_spec["name"] = registered_tool.name
285-
286-
decorated_tool = DecoratedFunctionTool(
240+
if not isinstance(original_tool, DecoratedFunctionTool):
241+
raise TypeError(
242+
f"Expected DecoratedFunctionTool, got {type(original_tool).__name__}. "
243+
f"Ensure your function is decorated with @tool first."
244+
)
245+
246+
# Reuse existing tool spec and metadata, but override name if specified
247+
tool_spec = original_tool.tool_spec.copy()
248+
tool_spec["name"] = registered_tool.name
249+
250+
# Create new DecoratedFunctionTool with simulation wrapper as the function
251+
simulated_tool = DecoratedFunctionTool(
287252
tool_name=registered_tool.name,
288-
tool_spec=extracted_tool_spec,
289-
tool_func=wrapper, # Always use wrapper to ensure simulation logic is executed
290-
metadata=metadata,
253+
tool_spec=tool_spec,
254+
tool_func=wrapper, # Use our simulation wrapper instead of original function
255+
metadata=original_tool._metadata, # Reuse existing metadata
291256
)
292257

293-
return decorated_tool
258+
return simulated_tool
294259

295260
def _simulate_tool_call(self, prompt: str, structured_output_model=None) -> Any:
296261
"""Tool simulation agent creation and response generation."""
@@ -348,9 +313,14 @@ def tool(
348313
"""
349314
Decorator for registering tools with flexible output schemas.
350315
316+
IMPORTANT: This decorator expects the function to already be decorated with @tool
317+
from strands.tools.decorator. When output_schema is not provided, the input_model
318+
from the DecoratedFunctionTool's metadata will be automatically used as the output_schema.
319+
351320
Args:
352-
name: Optional name for the tool. If None, uses function.__name__
353-
output_schema: Optional Pydantic BaseModel for output schema
321+
name: Optional name for the tool. If None, uses DecoratedFunctionTool.tool_name
322+
output_schema: Optional Pydantic BaseModel for output schema. If None, uses the
323+
input_model from the DecoratedFunctionTool's metadata.
354324
share_state_id: Optional shared state ID for sharing state between tools
355325
initial_state_description: Optional initial state description for the tool's context
356326
@@ -360,19 +330,34 @@ def tool(
360330

361331
def decorator(func: Callable) -> Callable:
362332
try:
363-
tool_name = name or func.__name__
333+
if not isinstance(func, DecoratedFunctionTool):
334+
raise TypeError(
335+
f"Expected DecoratedFunctionTool (from @tool decorator), got {type(func).__name__}. "
336+
f"Please ensure your function is decorated with @tool first, then @simulator.tool()."
337+
)
338+
339+
tool_name = name or func.tool_name
340+
341+
final_output_schema = output_schema
342+
if (
343+
final_output_schema is None
344+
and hasattr(func, "_metadata")
345+
and hasattr(func._metadata, "input_model")
346+
):
347+
final_output_schema = func._metadata.input_model
348+
logger.info(
349+
f"Using input_model from DecoratedFunctionTool metadata as output_schema for tool '{tool_name}'"
350+
)
364351

365-
# Register tool
366352
registered_tool = RegisteredTool(
367353
name=tool_name,
368354
function=func,
369-
output_schema=output_schema,
355+
output_schema=final_output_schema,
370356
initial_state_description=initial_state_description,
371357
share_state_id=share_state_id,
372358
)
373359
self._registered_tools[tool_name] = registered_tool
374360

375-
# Initialize state if initial_state_description is provided
376361
if initial_state_description:
377362
state_key = share_state_id or tool_name
378363
self.state_registry.initialize_state_via_description(initial_state_description, state_key)
@@ -381,7 +366,7 @@ def decorator(func: Callable) -> Callable:
381366
logger.info(f"Registered tool: {tool_name}")
382367

383368
except Exception as e:
384-
raise RuntimeError(f"Error registering tool {name or func.__name__}: {e}") from e
369+
raise RuntimeError(f"Error registering tool {name or getattr(func, '__name__', 'unknown')}: {e}") from e
385370

386371
return func
387372

0 commit comments

Comments
 (0)