diff --git a/docs/hooks.md b/docs/hooks.md index 50f30394e3..7a40848daa 100644 --- a/docs/hooks.md +++ b/docs/hooks.md @@ -295,6 +295,52 @@ Error hooks (`*_error` in the `hooks.on` namespace, `on_*_error` on `AbstractCap See [Error hooks](capabilities.md#error-hooks) for the full pattern and recovery types. +## Triggering retries with `ModelRetry` + +Hooks can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with a custom message — the same exception used in [tool functions](tools.md#model-retry) and output validators. + +**Model request hooks** (`after_model_request`, `wrap_model_request`, `on_model_request_error`): + +- The retry message is sent back to the model as a [`RetryPromptPart`][pydantic_ai.messages.RetryPromptPart] +- `after_model_request`: the original response is preserved in message history so the model can see what it said +- `wrap_model_request`: the response is preserved only if the handler was called +- Retries count against the agent's `output_retries` limit + +**Tool hooks** (`before/after_tool_validate`, `before/after_tool_execute`, `wrap_tool_execute`, `on_tool_execute_error`): + +- Converted to tool retry prompts, same as when a tool function raises `ModelRetry` +- Retries count against the tool's `max_retries` limit + +`ModelRetry` from `wrap_model_request` and `wrap_tool_execute` is treated as control flow — it bypasses `on_model_request_error` and `on_tool_execute_error` respectively. + +```python {title="hooks_model_retry.py"} +from pydantic_ai import Agent, RunContext +from pydantic_ai.capabilities import Hooks +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.messages import ModelResponse +from pydantic_ai.models import ModelRequestContext + +hooks = Hooks() + + +@hooks.on.after_model_request +async def check_response( + ctx: RunContext[None], + *, + request_context: ModelRequestContext, + response: ModelResponse, +) -> ModelResponse: + if 'PLACEHOLDER' in str(response.parts): + raise ModelRetry('Response contains placeholder text. Please provide real data.') + return response + + +agent = Agent('test', capabilities=[hooks]) +result = agent.run_sync('Hello') +print(result.output) +#> success (no tool calls) +``` + ## When to use `Hooks` vs `AbstractCapability` | Use [`Hooks`][pydantic_ai.capabilities.Hooks] | Use [`AbstractCapability`][pydantic_ai.capabilities.AbstractCapability] | diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 1002f3135a..37d3eb71f4 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -476,13 +476,15 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]): request: _messages.ModelRequest is_resuming_without_prompt: bool = False - _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None) + _result: CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT] | None = field( + repr=False, init=False, default=None + ) _did_stream: bool = field(repr=False, init=False, default=False) last_request_context: ModelRequestContext | None = field(repr=False, init=False, default=None) async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> CallToolsNode[DepsT, NodeRunEndT]: + ) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]: if self._result is not None: return self._result @@ -530,9 +532,12 @@ async def stream( stream_done = asyncio.Event() agent_stream_holder: list[result.AgentStream[DepsT, T]] = [] + _handler_response: _messages.ModelResponse | None = None + async def _streaming_handler( req_ctx: ModelRequestContext, ) -> _messages.ModelResponse: + nonlocal _handler_response with set_current_run_context(run_context): async with req_ctx.model.request_stream( req_ctx.messages, req_ctx.model_settings, req_ctx.model_request_parameters, run_context @@ -543,7 +548,9 @@ async def _streaming_handler( agent_stream_holder.append(agent_stream) stream_ready.set() await stream_done.wait() - return sr.get() + response = sr.get() + _handler_response = response + return response wrap_request_context = ModelRequestContext( model=model, @@ -567,13 +574,24 @@ async def _streaming_handler( if wrap_task.done() and not stream_ready.is_set(): # wrap_model_request completed without calling handler — short-circuited or raised SkipModelRequest try: - model_response = wrap_task.result() - except exceptions.SkipModelRequest as e: - model_response = e.response - except Exception as e: - model_response = await ctx.deps.root_capability.on_model_request_error( - run_context, request_context=wrap_request_context, error=e + result_or_exc: _messages.ModelResponse | Exception + try: + result_or_exc = wrap_task.result() + except Exception as e: + result_or_exc = e + model_response = await self._resolve_wrap_result(ctx, run_context, wrap_request_context, result_or_exc) + except exceptions.ModelRetry as e: + self._did_stream = True + # Don't increment usage.requests — handler was never called (short-circuit) + run_context = build_run_context(ctx) + await self._build_retry_node(ctx, run_context, e) + # Must still yield from @asynccontextmanager — yield an empty stream + dummy_sr = _SkipStreamedResponse( + model_request_parameters=model_request_parameters, + _response=_messages.ModelResponse(parts=[]), ) + yield self._build_agent_stream(ctx, dummy_sr, model_request_parameters) + return self._did_stream = True ctx.state.usage.requests += 1 skip_sr = _SkipStreamedResponse(model_request_parameters=model_request_parameters, _response=model_response) @@ -605,14 +623,25 @@ async def _streaming_handler( pass else: try: - model_response = await wrap_task - except Exception as e: - model_response = await ctx.deps.root_capability.on_model_request_error( - run_context, request_context=wrap_request_context, error=e - ) - self.last_request_context = wrap_request_context - await self._finish_handling(ctx, model_response) - assert self._result is not None + try: + model_response = await wrap_task + except exceptions.ModelRetry: + raise # Propagate to outer handler + except Exception as e: + model_response = await ctx.deps.root_capability.on_model_request_error( + run_context, request_context=wrap_request_context, error=e + ) + except exceptions.ModelRetry as e: + # Don't increment usage.requests — _streaming_handler already did + # In the normal streaming path the handler was always called (that's + # how the stream was created), so _handler_response is always set. + assert _handler_response is not None + self._append_response(ctx, _handler_response) + await self._build_retry_node(ctx, run_context, e) + else: + self.last_request_context = wrap_request_context + await self._finish_handling(ctx, model_response) + assert self._result is not None @staticmethod def _build_agent_stream( @@ -634,7 +663,7 @@ def _build_agent_stream( async def _make_request( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> CallToolsNode[DepsT, NodeRunEndT]: + ) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]: if self._result is not None: return self._result # pragma: no cover @@ -650,11 +679,16 @@ async def _make_request( ctx.state.usage.requests += 1 return await self._finish_handling(ctx, e.response) + _handler_response: _messages.ModelResponse | None = None + async def model_handler(req_ctx: ModelRequestContext) -> _messages.ModelResponse: + nonlocal _handler_response with set_current_run_context(run_context): - return await req_ctx.model.request( + response = await req_ctx.model.request( req_ctx.messages, req_ctx.model_settings, req_ctx.model_request_parameters ) + _handler_response = response + return response request_context = ModelRequestContext( model=model, @@ -663,17 +697,27 @@ async def model_handler(req_ctx: ModelRequestContext) -> _messages.ModelResponse model_request_parameters=model_request_parameters, ) try: - model_response = await ctx.deps.root_capability.wrap_model_request( - run_context, - request_context=request_context, - handler=model_handler, - ) - except exceptions.SkipModelRequest as e: - model_response = e.response - except Exception as e: - model_response = await ctx.deps.root_capability.on_model_request_error( - run_context, request_context=request_context, error=e - ) + try: + model_response = await ctx.deps.root_capability.wrap_model_request( + run_context, + request_context=request_context, + handler=model_handler, + ) + except exceptions.SkipModelRequest as e: + model_response = e.response + except exceptions.ModelRetry: + raise # Propagate to outer handler + except Exception as e: + model_response = await ctx.deps.root_capability.on_model_request_error( + run_context, request_context=request_context, error=e + ) + except exceptions.ModelRetry as e: + # ModelRetry from wrap_model_request or on_model_request_error — retry the model request. + # If the handler was called, preserve the response in history for context. + if _handler_response is not None: + ctx.state.usage.requests += 1 + self._append_response(ctx, _handler_response) + return await self._build_retry_node(ctx, run_context, e) self.last_request_context = request_context ctx.state.usage.requests += 1 @@ -765,30 +809,82 @@ async def _finish_handling( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], response: _messages.ModelResponse, - ) -> CallToolsNode[DepsT, NodeRunEndT]: + ) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]: response.run_id = response.run_id or ctx.state.run_id run_context = build_run_context(ctx) assert self.last_request_context is not None, 'last_request_context must be set before _finish_handling' request_context = self.last_request_context run_context.model_settings = request_context.model_settings - response = await ctx.deps.root_capability.after_model_request( - run_context, request_context=request_context, response=response - ) - - # Update usage - ctx.state.usage.incr(response.usage) - if ctx.deps.usage_limits: # pragma: no branch - ctx.deps.usage_limits.check_tokens(ctx.state.usage) + try: + response = await ctx.deps.root_capability.after_model_request( + run_context, request_context=request_context, response=response + ) + except exceptions.ModelRetry as e: + # Hook rejected the response — append it to history (model DID respond) and retry + self._append_response(ctx, response) + return await self._build_retry_node(ctx, run_context, e) # Append the model response to state.message_history - ctx.state.message_history.append(response) + self._append_response(ctx, response) # Set the `_result` attribute since we can't use `return` in an async iterator self._result = CallToolsNode(response) return self._result + async def _resolve_wrap_result( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + run_context: RunContext[DepsT], + request_context: ModelRequestContext, + result_or_exc: _messages.ModelResponse | Exception, + ) -> _messages.ModelResponse: + """Resolve a wrap_model_request result, handling SkipModelRequest and errors. + + Returns ModelResponse on success. + Raises ModelRetry if the result or on_model_request_error raises it. + """ + if isinstance(result_or_exc, Exception): + exc = result_or_exc + if isinstance(exc, exceptions.SkipModelRequest): + return exc.response + if isinstance(exc, exceptions.ModelRetry): + raise exc + return await ctx.deps.root_capability.on_model_request_error( + run_context, request_context=request_context, error=exc + ) + return result_or_exc + + @staticmethod + def _append_response( + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[Any, Any]], + response: _messages.ModelResponse, + ) -> None: + """Append a model response to history, updating usage tracking.""" + response.run_id = response.run_id or ctx.state.run_id + ctx.state.usage.incr(response.usage) + if ctx.deps.usage_limits: # pragma: no branch + ctx.deps.usage_limits.check_tokens(ctx.state.usage) + ctx.state.message_history.append(response) + + async def _build_retry_node( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + run_context: RunContext[DepsT], + error: exceptions.ModelRetry, + ) -> ModelRequestNode[DepsT, NodeRunEndT]: + """Build a retry ModelRequestNode from a ModelRetry exception. + + Increments the retry counter and creates a new request with a RetryPromptPart. + """ + ctx.state.increment_retries(ctx.deps.max_result_retries, error=error) + m = _messages.RetryPromptPart(content=error.message) + instructions = await ctx.deps.get_instructions(run_context) + retry_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[m], instructions=instructions)) + self._result = retry_node + return retry_node + __repr__ = dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index ce78c4ef98..10572aa9cb 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -297,21 +297,32 @@ async def do_execute(args: dict[str, Any]) -> Any: if cap is not None: tool_def = validated.tool.tool_def - # before_tool_execute - args = await cap.before_tool_execute(ctx, call=call, tool_def=tool_def, args=validated.validated_args) - - # wrap_tool_execute wraps the execution; on_tool_execute_error on failure try: - tool_result = await cap.wrap_tool_execute( - ctx, call=call, tool_def=tool_def, args=args, handler=do_execute + # before_tool_execute + args = await cap.before_tool_execute(ctx, call=call, tool_def=tool_def, args=validated.validated_args) + + # wrap_tool_execute wraps the execution; on_tool_execute_error on failure + try: + tool_result = await cap.wrap_tool_execute( + ctx, call=call, tool_def=tool_def, args=args, handler=do_execute + ) + except (SkipToolExecution, CallDeferred, ApprovalRequired, ToolRetryError): + raise # Control flow, not errors + except ModelRetry: + raise # Propagate to outer handler + except Exception as e: + tool_result = await cap.on_tool_execute_error(ctx, call=call, tool_def=tool_def, args=args, error=e) + + # after_tool_execute + tool_result = await cap.after_tool_execute( + ctx, call=call, tool_def=tool_def, args=args, result=tool_result ) - except (SkipToolExecution, CallDeferred, ApprovalRequired, ToolRetryError): - raise # Control flow, not errors - except Exception as e: - tool_result = await cap.on_tool_execute_error(ctx, call=call, tool_def=tool_def, args=args, error=e) - - # after_tool_execute - tool_result = await cap.after_tool_execute(ctx, call=call, tool_def=tool_def, args=args, result=tool_result) + except ModelRetry as e: + # Hook raised ModelRetry — convert to ToolRetryError for retry handling + name = call.tool_name + self._check_max_retries(name, validated.tool.max_retries, e) + self.failed_tools.add(name) + raise self._wrap_error_as_retry(name, call, e) from e else: tool_result = await do_execute(validated.validated_args) diff --git a/pydantic_ai_slim/pydantic_ai/capabilities/abstract.py b/pydantic_ai_slim/pydantic_ai/capabilities/abstract.py index 4e9f3cd36b..42e9bb9b53 100644 --- a/pydantic_ai_slim/pydantic_ai/capabilities/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/capabilities/abstract.py @@ -337,7 +337,12 @@ async def after_model_request( request_context: ModelRequestContext, response: ModelResponse, ) -> ModelResponse: - """Called after each model response. Can modify the response before further processing.""" + """Called after each model response. Can modify the response before further processing. + + Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the response and + ask the model to try again. The original response is still appended to message history + so the model can see what it said. Retries count against `max_result_retries`. + """ return response async def wrap_model_request( @@ -347,7 +352,12 @@ async def wrap_model_request( request_context: ModelRequestContext, handler: WrapModelRequestHandler, ) -> ModelResponse: - """Wraps the model request. handler() calls the model.""" + """Wraps the model request. handler() calls the model. + + Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip `on_model_request_error` + and directly retry the model request with a retry prompt. If the handler was called, + the model response is preserved in history for context (same as `after_model_request`). + """ return await handler(request_context) async def on_model_request_error( @@ -365,8 +375,11 @@ async def on_model_request_error( **Raise** the original `error` (or a different exception) to propagate it. **Return** a [`ModelResponse`][pydantic_ai.messages.ModelResponse] to suppress the error and use the response as if the model call succeeded. + **Raise** [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to retry the model request + with a retry prompt instead of recovering or propagating. - Not called for [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest]. + Not called for [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest] + or [`ModelRetry`][pydantic_ai.exceptions.ModelRetry]. """ raise error @@ -380,7 +393,11 @@ async def before_tool_validate( tool_def: ToolDefinition, args: RawToolArgs, ) -> RawToolArgs: - """Modify raw args before validation.""" + """Modify raw args before validation. + + Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip validation and + ask the model to redo the tool call. + """ return args async def after_tool_validate( @@ -391,7 +408,11 @@ async def after_tool_validate( tool_def: ToolDefinition, args: ValidatedToolArgs, ) -> ValidatedToolArgs: - """Modify validated args. Called only on successful validation.""" + """Modify validated args. Called only on successful validation. + + Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the validated args + and ask the model to redo the tool call. + """ return args async def wrap_tool_validate( @@ -439,7 +460,11 @@ async def before_tool_execute( tool_def: ToolDefinition, args: ValidatedToolArgs, ) -> ValidatedToolArgs: - """Modify validated args before execution.""" + """Modify validated args before execution. + + Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip execution and + ask the model to redo the tool call. + """ return args async def after_tool_execute( @@ -451,7 +476,11 @@ async def after_tool_execute( args: ValidatedToolArgs, result: Any, ) -> Any: - """Modify result after execution.""" + """Modify result after execution. + + Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the tool result + and ask the model to redo the tool call. + """ return result async def wrap_tool_execute( @@ -482,6 +511,8 @@ async def on_tool_execute_error( **Raise** the original `error` (or a different exception) to propagate it. **Return** any value to suppress the error and use it as the tool result. + **Raise** [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to + redo the tool call instead of recovering or propagating. Not called for control flow exceptions ([`SkipToolExecution`][pydantic_ai.exceptions.SkipToolExecution], diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index d472ca733c..f73a3884e9 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -37,9 +37,11 @@ class ModelRetry(Exception): - """Exception to raise when a tool function should be retried. + """Exception to raise to request a model retry. - The agent will return the message to the model and ask it to try calling the function/tool again. + Can be raised from tool functions, output validators, and capability hooks + (such as `after_model_request`, `after_tool_execute`, etc.) to send + a retry prompt back to the model asking it to try again. """ message: str diff --git a/tests/test_capabilities.py b/tests/test_capabilities.py index 287044b21d..e2861b6f7e 100644 --- a/tests/test_capabilities.py +++ b/tests/test_capabilities.py @@ -30,12 +30,20 @@ from pydantic_ai.capabilities.builtin_tool import BuiltinTool as BuiltinToolCap from pydantic_ai.capabilities.combined import CombinedCapability from pydantic_ai.capabilities.hooks import Hooks, HookTimeoutError -from pydantic_ai.exceptions import SkipModelRequest, SkipToolExecution, SkipToolValidation, UserError +from pydantic_ai.exceptions import ( + ModelRetry, + SkipModelRequest, + SkipToolExecution, + SkipToolValidation, + UnexpectedModelBehavior, + UserError, +) from pydantic_ai.messages import ( AgentStreamEvent, ModelMessage, ModelRequest, ModelResponse, + RetryPromptPart, TextPart, ToolCallPart, ToolReturnPart, @@ -6184,3 +6192,955 @@ def my_tool() -> str: pass assert error_log == ['CallToolsNode'] + + +# --- ModelRetry from hooks tests --- + + +class TestModelRetryFromHooks: + """Tests for raising ModelRetry from capability hooks.""" + + async def test_after_model_request_model_retry(self): + """after_model_request raises ModelRetry — model is called again with retry prompt.""" + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + return make_text_response('bad response') + return make_text_response('good response') + + @dataclass + class RetryCap(AbstractCapability[Any]): + retried: bool = False + + async def after_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + response: ModelResponse, + ) -> ModelResponse: + if not self.retried: + self.retried = True + raise ModelRetry('Response was bad, please try again') + return response + + cap = RetryCap() + agent = Agent(FunctionModel(model_fn), capabilities=[cap]) + result = await agent.run('hello') + assert result.output == 'good response' + assert call_count == 2 + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='bad response')], + usage=RequestUsage(input_tokens=51, output_tokens=2), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Response was bad, please try again', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='good response')], + usage=RequestUsage(input_tokens=66, output_tokens=4), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_after_model_request_model_retry_max_retries(self): + """after_model_request raises ModelRetry repeatedly — hits max_result_retries.""" + + @dataclass + class AlwaysRetryCap(AbstractCapability[Any]): + async def after_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + response: ModelResponse, + ) -> ModelResponse: + raise ModelRetry('always bad') + + agent = Agent( + FunctionModel(simple_model_function), + capabilities=[AlwaysRetryCap()], + output_retries=2, + ) + with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + await agent.run('hello') + + async def test_after_model_request_model_retry_streaming(self): + """after_model_request raises ModelRetry during streaming with tool calls — model is called again.""" + call_count = 0 + + async def stream_fn(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: return a tool call that after_model_request will reject + yield {0: DeltaToolCall(name='my_tool', json_args='{}', tool_call_id='call-1')} + elif call_count == 2: + # Second call (after retry): return text + yield 'good response' + else: + yield 'unexpected' # pragma: no cover + + @dataclass + class RetryCap(AbstractCapability[Any]): + retried: bool = False + + async def after_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + response: ModelResponse, + ) -> ModelResponse: + if not self.retried: + self.retried = True + raise ModelRetry('Response was bad, please try again') + return response + + cap = RetryCap() + agent = Agent( + FunctionModel(simple_model_function, stream_function=stream_fn), + capabilities=[cap], + ) + + @agent.tool_plain + def my_tool() -> str: + return 'tool result' # pragma: no cover + + async with agent.run_stream('hello') as streamed: + result = await streamed.get_output() + assert result == 'good response' + assert call_count == 2 + assert streamed.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='my_tool', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=50, output_tokens=1), + model_name='function:simple_model_function:stream_fn', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Response was bad, please try again', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='good response')], + usage=RequestUsage(input_tokens=50, output_tokens=2), + model_name='function:simple_model_function:stream_fn', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_wrap_model_request_model_retry_streaming_short_circuit(self): + """wrap_model_request raises ModelRetry without calling handler during streaming.""" + + async def stream_fn(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str]: + yield 'good response' + + @dataclass + class ShortCircuitRetryCap(AbstractCapability[Any]): + call_count: int = 0 + + async def wrap_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + handler: Any, + ) -> ModelResponse: + self.call_count += 1 + if self.call_count == 1: + # Short-circuit: don't call handler, raise ModelRetry + raise ModelRetry('Short-circuit retry') + return await handler(request_context) + + cap = ShortCircuitRetryCap() + agent = Agent(FunctionModel(simple_model_function, stream_function=stream_fn), capabilities=[cap]) + async with agent.run_stream('hello') as streamed: + result = await streamed.get_output() + assert result == 'good response' + assert cap.call_count == 2 + assert streamed.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Short-circuit retry', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='good response')], + usage=RequestUsage(input_tokens=50, output_tokens=2), + model_name='function:simple_model_function:stream_fn', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_wrap_model_request_model_retry_streaming_after_handler(self): + """wrap_model_request raises ModelRetry after calling handler during streaming (tool call scenario).""" + call_count = 0 + + async def stream_fn(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: tool call that wrap hook will reject + yield {0: DeltaToolCall(name='my_tool', json_args='{}', tool_call_id='call-1')} + else: + yield 'good response' + + @dataclass + class AfterHandlerRetryCap(AbstractCapability[Any]): + retried: bool = False + + async def wrap_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + handler: Any, + ) -> ModelResponse: + response = await handler(request_context) + if not self.retried: + self.retried = True + raise ModelRetry('Post-handler retry') + return response + + cap = AfterHandlerRetryCap() + agent = Agent(FunctionModel(simple_model_function, stream_function=stream_fn), capabilities=[cap]) + + @agent.tool_plain + def my_tool() -> str: + return 'tool result' # pragma: no cover + + async with agent.run_stream('hello') as streamed: + result = await streamed.get_output() + assert result == 'good response' + assert call_count == 2 + assert streamed.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='my_tool', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=50, output_tokens=1), + model_name='function:simple_model_function:stream_fn', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Post-handler retry', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='good response')], + usage=RequestUsage(input_tokens=50, output_tokens=2), + model_name='function:simple_model_function:stream_fn', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_wrap_model_request_model_retry(self): + """wrap_model_request raises ModelRetry after calling handler — triggers retry.""" + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + return make_text_response('first attempt') + return make_text_response('second attempt') + + @dataclass + class WrapRetryCap(AbstractCapability[Any]): + retried: bool = False + + async def wrap_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + handler: Any, + ) -> ModelResponse: + response = await handler(request_context) + if not self.retried: + self.retried = True + raise ModelRetry('Wrap says retry') + return response + + cap = WrapRetryCap() + agent = Agent(FunctionModel(model_fn), capabilities=[cap]) + result = await agent.run('hello') + assert result.output == 'second attempt' + assert call_count == 2 + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='first attempt')], + usage=RequestUsage(input_tokens=51, output_tokens=2), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Wrap says retry', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='second attempt')], + usage=RequestUsage(input_tokens=63, output_tokens=4), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_wrap_model_request_model_retry_skips_on_error(self): + """wrap_model_request raising ModelRetry should NOT call on_model_request_error.""" + on_error_called = False + + @dataclass + class WrapRetrySkipErrorCap(AbstractCapability[Any]): + async def wrap_model_request( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + handler: Any, + ) -> ModelResponse: + raise ModelRetry('retry please') + + async def on_model_request_error( # pragma: no cover — verifying this is NOT called + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + error: Exception, + ) -> ModelResponse: + nonlocal on_error_called + on_error_called = True + raise error + + agent = Agent(FunctionModel(simple_model_function), capabilities=[WrapRetrySkipErrorCap()], output_retries=1) + with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + await agent.run('hello') + assert not on_error_called + + async def test_on_model_request_error_model_retry(self): + """on_model_request_error raises ModelRetry to recover via retry.""" + call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError('model failed') + return make_text_response('recovered response') + + @dataclass + class ErrorRetryCap(AbstractCapability[Any]): + async def on_model_request_error( + self, + ctx: RunContext[Any], + *, + request_context: ModelRequestContext, + error: Exception, + ) -> ModelResponse: + raise ModelRetry('Model failed, please try again') + + cap = ErrorRetryCap() + agent = Agent(FunctionModel(model_fn), capabilities=[cap]) + result = await agent.run('hello') + assert result.output == 'recovered response' + assert call_count == 2 + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='hello', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Model failed, please try again', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='recovered response')], + usage=RequestUsage(input_tokens=65, output_tokens=2), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_after_tool_execute_model_retry(self): + """after_tool_execute raises ModelRetry — tool retry prompt sent to model, tool retried on success.""" + tool_call_count = 0 + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # Always call the tool — after retry, the hook won't raise again + if info.function_tools: + # Check if we already got a tool return (second call succeeded) + for msg in messages: + for part in msg.parts: + if isinstance(part, ToolReturnPart): + return make_text_response(f'got: {part.content}') + return ModelResponse( + parts=[ToolCallPart(tool_name=info.function_tools[0].name, args='{}', tool_call_id='call-1')] + ) + return make_text_response('no tools') # pragma: no cover + + @dataclass + class AfterExecRetryCap(AbstractCapability[Any]): + retried: bool = False + + async def after_tool_execute( + self, + ctx: RunContext[Any], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: dict[str, Any], + result: Any, + ) -> Any: + if not self.retried: + self.retried = True + raise ModelRetry('Tool result is bad, try again') + return result + + cap = AfterExecRetryCap() + agent = Agent(FunctionModel(model_fn), capabilities=[cap]) + + @agent.tool_plain + def my_tool() -> str: + nonlocal tool_call_count + tool_call_count += 1 + return 'tool result' + + result = await agent.run('call tool') + assert result.output == 'got: tool result' + assert tool_call_count == 2 # Tool called twice: first rejected by hook, second succeeds + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='call tool', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='my_tool', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=52, output_tokens=2), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Tool result is bad, try again', + tool_name='my_tool', + tool_call_id='call-1', + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='my_tool', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=65, output_tokens=4), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='my_tool', content='tool result', tool_call_id='call-1', timestamp=IsDatetime() + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='got: tool result')], + usage=RequestUsage(input_tokens=67, output_tokens=7), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_before_tool_execute_model_retry(self): + """before_tool_execute raises ModelRetry — tool execution is skipped, then succeeds on retry.""" + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # Always call the tool — after retry, the hook won't raise again + if info.function_tools: + for msg in messages: + for part in msg.parts: + if isinstance(part, ToolReturnPart): + return make_text_response(f'got: {part.content}') + return ModelResponse( + parts=[ToolCallPart(tool_name=info.function_tools[0].name, args='{}', tool_call_id='call-1')] + ) + return make_text_response('no tools') # pragma: no cover + + hooks = Hooks[Any]() + hook_called = False + + @hooks.on.before_tool_execute + async def reject_first( + ctx: RunContext[Any], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: dict[str, Any], + ) -> dict[str, Any]: + nonlocal hook_called + if not hook_called: + hook_called = True + raise ModelRetry('Not ready to execute, try again') + return args + + agent = Agent(FunctionModel(model_fn), capabilities=[hooks], retries=2) + + @agent.tool_plain + def my_tool() -> str: + return 'tool result' + + result = await agent.run('call tool') + assert result.output == 'got: tool result' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='call tool', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='my_tool', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=52, output_tokens=2), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Not ready to execute, try again', + tool_name='my_tool', + tool_call_id='call-1', + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='my_tool', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=65, output_tokens=4), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='my_tool', content='tool result', tool_call_id='call-1', timestamp=IsDatetime() + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='got: tool result')], + usage=RequestUsage(input_tokens=67, output_tokens=7), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_wrap_tool_execute_model_retry_skips_on_error(self): + """wrap_tool_execute raising ModelRetry should NOT call on_tool_execute_error.""" + on_error_called = False + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + for msg in messages: + for part in msg.parts: + if isinstance(part, RetryPromptPart): + return make_text_response('got retry') + if info.function_tools: + return ModelResponse( + parts=[ToolCallPart(tool_name=info.function_tools[0].name, args='{}', tool_call_id='call-1')] + ) + return make_text_response('no tools') # pragma: no cover + + @dataclass + class WrapExecRetryCap(AbstractCapability[Any]): + async def wrap_tool_execute( + self, + ctx: RunContext[Any], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: dict[str, Any], + handler: Any, + ) -> Any: + raise ModelRetry('Wrap says retry tool') + + async def on_tool_execute_error( # pragma: no cover — verifying this is NOT called + self, + ctx: RunContext[Any], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: dict[str, Any], + error: Exception, + ) -> Any: + nonlocal on_error_called + on_error_called = True + raise error + + agent = Agent(FunctionModel(model_fn), capabilities=[WrapExecRetryCap()], retries=2) + + @agent.tool_plain + def my_tool() -> str: + return 'tool result' # pragma: no cover + + result = await agent.run('call tool') + assert result.output == 'got retry' + assert not on_error_called + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='call tool', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='my_tool', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=52, output_tokens=2), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Wrap says retry tool', + tool_name='my_tool', + tool_call_id='call-1', + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='got retry')], + usage=RequestUsage(input_tokens=63, output_tokens=4), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_on_tool_execute_error_model_retry(self): + """on_tool_execute_error raises ModelRetry to recover via retry.""" + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + for msg in messages: + for part in msg.parts: + if isinstance(part, RetryPromptPart): + return make_text_response('got retry after error') + if info.function_tools: + return ModelResponse( + parts=[ToolCallPart(tool_name=info.function_tools[0].name, args='{}', tool_call_id='call-1')] + ) + return make_text_response('no tools') # pragma: no cover + + @dataclass + class ErrorRetryCap(AbstractCapability[Any]): + async def on_tool_execute_error( + self, + ctx: RunContext[Any], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: dict[str, Any], + error: Exception, + ) -> Any: + raise ModelRetry('Tool errored, please retry') + + agent = Agent(FunctionModel(model_fn), capabilities=[ErrorRetryCap()], retries=2) + + @agent.tool_plain + def my_tool() -> str: + raise ValueError('tool failed') + + result = await agent.run('call tool') + assert result.output == 'got retry after error' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='call tool', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='my_tool', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=52, output_tokens=2), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Tool errored, please retry', + tool_name='my_tool', + tool_call_id='call-1', + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='got retry after error')], + usage=RequestUsage(input_tokens=63, output_tokens=6), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_after_tool_validate_model_retry(self): + """after_tool_validate raises ModelRetry — validation retry sent to model.""" + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + for msg in messages: + for part in msg.parts: + if isinstance(part, RetryPromptPart): + return make_text_response('got validation retry') + if info.function_tools: + return ModelResponse( + parts=[ToolCallPart(tool_name=info.function_tools[0].name, args='{}', tool_call_id='call-1')] + ) + return make_text_response('no tools') # pragma: no cover + + @dataclass + class AfterValRetryCap(AbstractCapability[Any]): + async def after_tool_validate( + self, + ctx: RunContext[Any], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: dict[str, Any], + ) -> dict[str, Any]: + raise ModelRetry('Validated args are bad') + + agent = Agent(FunctionModel(model_fn), capabilities=[AfterValRetryCap()], retries=2) + + @agent.tool_plain + def my_tool() -> str: + return 'tool result' # pragma: no cover + + result = await agent.run('call tool') + assert result.output == 'got validation retry' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='call tool', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='my_tool', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=52, output_tokens=2), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Validated args are bad', + tool_name='my_tool', + tool_call_id='call-1', + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='got validation retry')], + usage=RequestUsage(input_tokens=63, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + async def test_before_tool_validate_model_retry(self): + """before_tool_validate raises ModelRetry — validation retry sent to model.""" + + def model_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + for msg in messages: + for part in msg.parts: + if isinstance(part, RetryPromptPart): + return make_text_response('got pre-validation retry') + if info.function_tools: + return ModelResponse( + parts=[ToolCallPart(tool_name=info.function_tools[0].name, args='{}', tool_call_id='call-1')] + ) + return make_text_response('no tools') # pragma: no cover + + @dataclass + class BeforeValRetryCap(AbstractCapability[Any]): + async def before_tool_validate( + self, + ctx: RunContext[Any], + *, + call: ToolCallPart, + tool_def: ToolDefinition, + args: str | dict[str, Any], + ) -> str | dict[str, Any]: + raise ModelRetry('Args look bad before validation') + + agent = Agent(FunctionModel(model_fn), capabilities=[BeforeValRetryCap()], retries=2) + + @agent.tool_plain + def my_tool() -> str: + return 'tool result' # pragma: no cover + + result = await agent.run('call tool') + assert result.output == 'got pre-validation retry' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='call tool', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ToolCallPart(tool_name='my_tool', args='{}', tool_call_id='call-1')], + usage=RequestUsage(input_tokens=52, output_tokens=2), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Args look bad before validation', + tool_name='my_tool', + tool_call_id='call-1', + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='got pre-validation retry')], + usage=RequestUsage(input_tokens=64, output_tokens=5), + model_name='function:model_fn:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + )