1- import inspect
21import json
32import logging
43import warnings
109from strands import Agent
1110from strands .agent import AgentResult
1211from strands .models .model import Model
12+ from strands .tools .decorator import DecoratedFunctionTool
1313
1414from strands_evals .types .simulation .tool import RegisteredTool
1515
@@ -122,8 +122,6 @@ def cache_tool_call(
122122
123123 # Append to deque with automatic FIFO eviction when cache is full
124124 state ["previous_calls" ].append (call_record )
125-
126- # Return converted state for external use
127125 return self .get_state (state_key )
128126
129127 def clear_state (self , state_key : str ) -> None :
@@ -145,6 +143,15 @@ class ToolSimulator:
145143 registered tools. It can be configured to override tool behavior for simulation purposes,
146144 enabling controlled testing scenarios.
147145
146+ IMPORTANT: This simulator expects functions to be decorated with Strands' @tool decorator first.
147+
148+ Example usage:
149+ @simulator.tool(share_state_id="room_environment")
150+ @tool
151+ def my_tool(param: str) -> dict:
152+ '''Tool description'''
153+ pass
154+
148155 The simulator automatically maintains a bounded cache of tool calls for context.
149156 The maximum number of tool calls stored per state key is configurable via
150157 max_tool_call_cache_size parameter (default: 20).
@@ -153,7 +160,6 @@ class ToolSimulator:
153160 state_registry: Registry for maintaining tool state across calls.
154161 function_tool_prompt: Custom prompt template for tool response generation.
155162 model: Provider for running inference or model identifier for Bedrock.
156- framework: Agent framework to use (default: "strands").
157163 max_tool_call_cache_size: Maximum number of tool calls to store per state key.
158164 """
159165
@@ -165,7 +171,6 @@ def __init__(
165171 state_registry : StateRegistry | None = None ,
166172 function_tool_prompt : str | None = None ,
167173 model : Model | str | None = None ,
168- framework : str = "strands" ,
169174 max_tool_call_cache_size : int = 20 ,
170175 ):
171176 """
@@ -176,34 +181,21 @@ def __init__(
176181 a new StateRegistry will be created with max_tool_call_cache_size.
177182 function_tool_prompt: Optional custom prompt for tool response generation
178183 model: Provider for running inference or a string representing the model-id for Bedrock to use
179- framework: Agent framework to use (default: "strands")
180184 max_tool_call_cache_size: Maximum number of tool calls to store per state key.
181185 Only used when creating a new StateRegistry (ignored if state_registry
182186 is provided). Older calls are automatically evicted when limit is exceeded.
183187 Default is 20.
184188 """
185- # Store framework selection
186- self .framework = framework
187- # Store model configuration for creating internal agents
188189 self .model = model
189-
190- # Set custom prompt or use default
191190 self .function_tool_prompt = function_tool_prompt or FUNCTION_TOOL_RESPONSE_GENERATION_PROMPT
192-
193- # Set up state registry
194191 self .state_registry = state_registry or StateRegistry (max_tool_call_cache_size = max_tool_call_cache_size )
195-
196- # Initialize shared states from registered tools
197192 self ._initialize_shared_states ()
198193
199194 def _initialize_shared_states (self ):
200195 """Initialize shared states from registered tools' initial descriptions."""
201196 for tool_name , registered_tool in self ._registered_tools .items ():
202197 if registered_tool .initial_state_description :
203- # Determine state key from share_state_id or tool name
204198 state_key = registered_tool .share_state_id or registered_tool .name
205-
206- # Initialize state with description
207199 self .state_registry .initialize_state_via_description (
208200 registered_tool .initial_state_description , state_key
209201 )
@@ -213,10 +205,8 @@ def _create_tool_wrapper(self, registered_tool: RegisteredTool):
213205 """Create a framework-compatible tool wrapper."""
214206
215207 def wrapper (* args , ** kwargs ):
216- # Determine state key
217208 state_key = registered_tool .share_state_id or registered_tool .name
218209
219- # Build parameters string for tool
220210 parameters_string = (
221211 json .dumps ({"args" : args , "kwargs" : kwargs }, indent = 2 ) if args else json .dumps (kwargs , indent = 2 )
222212 )
@@ -236,61 +226,36 @@ def wrapper(*args, **kwargs):
236226 wrapper .__name__ = registered_tool .name
237227 wrapper .__doc__ = f"Simulated { registered_tool .name } tool"
238228
239- # Use framework-specific method to create the tool wrapper
240- if self .framework == "strands" :
241- return self ._create_strands_tool_wrapper (registered_tool , wrapper )
242- else :
243- raise ValueError (f"Framework '{ self .framework } ' is not supported. Only 'strands' is currently supported." )
229+ return self ._create_strands_tool_wrapper (registered_tool , wrapper )
244230
245231 def _create_strands_tool_wrapper (self , registered_tool : RegisteredTool , wrapper : Callable ):
246- """Create a Strands-specific DecoratedFunctionTool wrapper."""
247- from strands . tools . decorator import DecoratedFunctionTool , FunctionToolMetadata
232+ """
233+ Create a Strands-specific DecoratedFunctionTool wrapper.
248234
249- # Create tool spec based on function signature and docstring
250- tool_description = wrapper .__doc__ or f"Simulated { registered_tool .name } tool"
235+ Since the registered function is already a DecoratedFunctionTool (from @tool decorator),
236+ we reuse its existing metadata and spec, but replace the tool_func with our simulation wrapper.
237+ """
238+ original_tool = registered_tool .function
251239
252- # Build input schema from function signature
253- input_schema : dict [str , Any ] = {"type" : "object" , "properties" : {}}
254- if registered_tool .function :
255- try :
256- sig = inspect .signature (registered_tool .function )
257- for param_name , param in sig .parameters .items ():
258- if param .annotation != inspect .Parameter .empty :
259- param_type = (
260- str (param .annotation ).replace ("<class '" , "" ).replace ("'>" , "" ).replace ("typing." , "" )
261- )
262- if "str" in param_type .lower ():
263- input_schema ["properties" ][param_name ] = {"type" : "string" }
264- elif "int" in param_type .lower ():
265- input_schema ["properties" ][param_name ] = {"type" : "integer" }
266- elif "float" in param_type .lower ():
267- input_schema ["properties" ][param_name ] = {"type" : "number" }
268- elif "bool" in param_type .lower ():
269- input_schema ["properties" ][param_name ] = {"type" : "boolean" }
270- else :
271- input_schema ["properties" ][param_name ] = {"type" : "object" }
272- else :
273- input_schema ["properties" ][param_name ] = {"type" : "string" } # default
274- except Exception :
275- pass # fallback to empty schema
276-
277- # Create Strands tool's FunctionToolMetadata object and DecoratedFunctionTool instance
278- metadata = FunctionToolMetadata (registered_tool .function or wrapper )
279-
280- # Extract tool_spec from metadata; override with our custom description if needed
281- extracted_tool_spec = metadata .extract_metadata ()
282- if tool_description != extracted_tool_spec .get ("description" ):
283- extracted_tool_spec ["description" ] = tool_description
284- extracted_tool_spec ["name" ] = registered_tool .name
285-
286- decorated_tool = DecoratedFunctionTool (
240+ if not isinstance (original_tool , DecoratedFunctionTool ):
241+ raise TypeError (
242+ f"Expected DecoratedFunctionTool, got { type (original_tool ).__name__ } . "
243+ f"Ensure your function is decorated with @tool first."
244+ )
245+
246+ # Reuse existing tool spec and metadata, but override name if specified
247+ tool_spec = original_tool .tool_spec .copy ()
248+ tool_spec ["name" ] = registered_tool .name
249+
250+ # Create new DecoratedFunctionTool with simulation wrapper as the function
251+ simulated_tool = DecoratedFunctionTool (
287252 tool_name = registered_tool .name ,
288- tool_spec = extracted_tool_spec ,
289- tool_func = wrapper , # Always use wrapper to ensure simulation logic is executed
290- metadata = metadata ,
253+ tool_spec = tool_spec ,
254+ tool_func = wrapper , # Use our simulation wrapper instead of original function
255+ metadata = original_tool . _metadata , # Reuse existing metadata
291256 )
292257
293- return decorated_tool
258+ return simulated_tool
294259
295260 def _simulate_tool_call (self , prompt : str , structured_output_model = None ) -> Any :
296261 """Tool simulation agent creation and response generation."""
@@ -348,9 +313,14 @@ def tool(
348313 """
349314 Decorator for registering tools with flexible output schemas.
350315
316+ IMPORTANT: This decorator expects the function to already be decorated with @tool
317+ from strands.tools.decorator. When output_schema is not provided, the input_model
318+ from the DecoratedFunctionTool's metadata will be automatically used as the output_schema.
319+
351320 Args:
352- name: Optional name for the tool. If None, uses function.__name__
353- output_schema: Optional Pydantic BaseModel for output schema
321+ name: Optional name for the tool. If None, uses DecoratedFunctionTool.tool_name
322+ output_schema: Optional Pydantic BaseModel for output schema. If None, uses the
323+ input_model from the DecoratedFunctionTool's metadata.
354324 share_state_id: Optional shared state ID for sharing state between tools
355325 initial_state_description: Optional initial state description for the tool's context
356326
@@ -360,19 +330,34 @@ def tool(
360330
361331 def decorator (func : Callable ) -> Callable :
362332 try :
363- tool_name = name or func .__name__
333+ if not isinstance (func , DecoratedFunctionTool ):
334+ raise TypeError (
335+ f"Expected DecoratedFunctionTool (from @tool decorator), got { type (func ).__name__ } . "
336+ f"Please ensure your function is decorated with @tool first, then @simulator.tool()."
337+ )
338+
339+ tool_name = name or func .tool_name
340+
341+ final_output_schema = output_schema
342+ if (
343+ final_output_schema is None
344+ and hasattr (func , "_metadata" )
345+ and hasattr (func ._metadata , "input_model" )
346+ ):
347+ final_output_schema = func ._metadata .input_model
348+ logger .info (
349+ f"Using input_model from DecoratedFunctionTool metadata as output_schema for tool '{ tool_name } '"
350+ )
364351
365- # Register tool
366352 registered_tool = RegisteredTool (
367353 name = tool_name ,
368354 function = func ,
369- output_schema = output_schema ,
355+ output_schema = final_output_schema ,
370356 initial_state_description = initial_state_description ,
371357 share_state_id = share_state_id ,
372358 )
373359 self ._registered_tools [tool_name ] = registered_tool
374360
375- # Initialize state if initial_state_description is provided
376361 if initial_state_description :
377362 state_key = share_state_id or tool_name
378363 self .state_registry .initialize_state_via_description (initial_state_description , state_key )
@@ -381,7 +366,7 @@ def decorator(func: Callable) -> Callable:
381366 logger .info (f"Registered tool: { tool_name } " )
382367
383368 except Exception as e :
384- raise RuntimeError (f"Error registering tool { name or func . __name__ } : { e } " ) from e
369+ raise RuntimeError (f"Error registering tool { name or getattr ( func , ' __name__' , 'unknown' ) } : { e } " ) from e
385370
386371 return func
387372
0 commit comments