Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
61d172f
Allow hooks to raise `ModelRetry` for retry control flow
DouweM Mar 26, 2026
3516f8d
Address review feedback: extract helper, fix streaming bug, improve t…
DouweM Mar 26, 2026
559169c
Fix Python 3.14 SyntaxError: avoid 'return' in 'finally' block
DouweM Mar 26, 2026
894242b
Add streaming tests for wrap_model_request ModelRetry coverage
DouweM Mar 26, 2026
b14a821
Fix coverage: remove dead branch, pragma no cover on not-called handlers
DouweM Mar 26, 2026
e3c1a59
Fix remaining coverage: exercise hook return paths after retry
DouweM Mar 26, 2026
44e040d
Remove incorrect pragma: no cover on stream_fn that IS executed
DouweM Mar 26, 2026
0b65a08
Merge remote-tracking branch 'origin/main' into hook-model-retry
DouweM Mar 26, 2026
054c774
Preserve model response in history when wrap_model_request raises Mod…
DouweM Mar 26, 2026
71556db
Fix coverage: assert handler_response in normal streaming path
DouweM Mar 27, 2026
4d434cd
Fix double usage.requests increment, add snapshot assertions
DouweM Mar 27, 2026
89175ed
Fix short-circuit usage.requests, add snapshots to all tests
DouweM Mar 27, 2026
758f7ee
Fix non-streaming short-circuit usage.requests, add docs section
DouweM Mar 27, 2026
8291df5
Remove unused import in docs example
DouweM Mar 27, 2026
4820998
Fix docs example: add type annotations, title, correct output
DouweM Mar 27, 2026
a753eef
Fix misleading docs: clarify not all hooks support ModelRetry
DouweM Mar 27, 2026
5c04a67
Simplify docs: drop explicit before_model_request exclusion
DouweM Mar 27, 2026
8508c9d
Merge remote-tracking branch 'origin/main' into hook-model-retry
DouweM Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/hooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,52 @@ Error hooks (`*_error` in the `hooks.on` namespace, `on_*_error` on `AbstractCap

See [Error hooks](capabilities.md#error-hooks) for the full pattern and recovery types.

## Triggering retries with `ModelRetry`

Hooks can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with a custom message — the same exception used in [tool functions](tools.md#model-retry) and output validators.

**Model request hooks** (`after_model_request`, `wrap_model_request`, `on_model_request_error`):

- The retry message is sent back to the model as a [`RetryPromptPart`][pydantic_ai.messages.RetryPromptPart]
- `after_model_request`: the original response is preserved in message history so the model can see what it said
- `wrap_model_request`: the response is preserved only if the handler was called
- Retries count against the agent's `output_retries` limit

**Tool hooks** (`before/after_tool_validate`, `before/after_tool_execute`, `wrap_tool_execute`, `on_tool_execute_error`):

- Converted to tool retry prompts, same as when a tool function raises `ModelRetry`
- Retries count against the tool's `max_retries` limit

`ModelRetry` from `wrap_model_request` and `wrap_tool_execute` is treated as control flow — it bypasses `on_model_request_error` and `on_tool_execute_error` respectively.

```python {title="hooks_model_retry.py"}
from pydantic_ai import Agent, RunContext
from pydantic_ai.capabilities import Hooks
from pydantic_ai.exceptions import ModelRetry
from pydantic_ai.messages import ModelResponse
from pydantic_ai.models import ModelRequestContext

hooks = Hooks()


@hooks.on.after_model_request
async def check_response(
ctx: RunContext[None],
*,
request_context: ModelRequestContext,
response: ModelResponse,
) -> ModelResponse:
if 'PLACEHOLDER' in str(response.parts):
raise ModelRetry('Response contains placeholder text. Please provide real data.')
return response


agent = Agent('test', capabilities=[hooks])
result = agent.run_sync('Hello')
print(result.output)
#> success (no tool calls)
```

## When to use `Hooks` vs `AbstractCapability`

| Use [`Hooks`][pydantic_ai.capabilities.Hooks] | Use [`AbstractCapability`][pydantic_ai.capabilities.AbstractCapability] |
Expand Down
176 changes: 136 additions & 40 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,15 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
request: _messages.ModelRequest
is_resuming_without_prompt: bool = False

_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None)
_result: CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT] | None = field(
repr=False, init=False, default=None
)
_did_stream: bool = field(repr=False, init=False, default=False)
last_request_context: ModelRequestContext | None = field(repr=False, init=False, default=None)

async def run(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> CallToolsNode[DepsT, NodeRunEndT]:
) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]:
if self._result is not None:
return self._result

Expand Down Expand Up @@ -530,9 +532,12 @@ async def stream(
stream_done = asyncio.Event()
agent_stream_holder: list[result.AgentStream[DepsT, T]] = []

_handler_response: _messages.ModelResponse | None = None

async def _streaming_handler(
req_ctx: ModelRequestContext,
) -> _messages.ModelResponse:
nonlocal _handler_response
with set_current_run_context(run_context):
async with req_ctx.model.request_stream(
req_ctx.messages, req_ctx.model_settings, req_ctx.model_request_parameters, run_context
Expand All @@ -543,7 +548,9 @@ async def _streaming_handler(
agent_stream_holder.append(agent_stream)
stream_ready.set()
await stream_done.wait()
return sr.get()
response = sr.get()
_handler_response = response
return response

wrap_request_context = ModelRequestContext(
model=model,
Expand All @@ -567,13 +574,24 @@ async def _streaming_handler(
if wrap_task.done() and not stream_ready.is_set():
# wrap_model_request completed without calling handler — short-circuited or raised SkipModelRequest
try:
model_response = wrap_task.result()
except exceptions.SkipModelRequest as e:
model_response = e.response
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=wrap_request_context, error=e
result_or_exc: _messages.ModelResponse | Exception
try:
result_or_exc = wrap_task.result()
except Exception as e:
result_or_exc = e
model_response = await self._resolve_wrap_result(ctx, run_context, wrap_request_context, result_or_exc)
except exceptions.ModelRetry as e:
self._did_stream = True
# Don't increment usage.requests — handler was never called (short-circuit)
run_context = build_run_context(ctx)
await self._build_retry_node(ctx, run_context, e)
# Must still yield from @asynccontextmanager — yield an empty stream
dummy_sr = _SkipStreamedResponse(
model_request_parameters=model_request_parameters,
_response=_messages.ModelResponse(parts=[]),
)
yield self._build_agent_stream(ctx, dummy_sr, model_request_parameters)
return
self._did_stream = True
ctx.state.usage.requests += 1
skip_sr = _SkipStreamedResponse(model_request_parameters=model_request_parameters, _response=model_response)
Expand Down Expand Up @@ -605,14 +623,25 @@ async def _streaming_handler(
pass
else:
try:
model_response = await wrap_task
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=wrap_request_context, error=e
)
self.last_request_context = wrap_request_context
await self._finish_handling(ctx, model_response)
assert self._result is not None
try:
model_response = await wrap_task
except exceptions.ModelRetry:
raise # Propagate to outer handler
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=wrap_request_context, error=e
)
except exceptions.ModelRetry as e:
# Don't increment usage.requests — _streaming_handler already did
# In the normal streaming path the handler was always called (that's
# how the stream was created), so _handler_response is always set.
assert _handler_response is not None
self._append_response(ctx, _handler_response)
await self._build_retry_node(ctx, run_context, e)
else:
self.last_request_context = wrap_request_context
await self._finish_handling(ctx, model_response)
assert self._result is not None
Comment on lines +641 to +644
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Streaming _build_retry_node in finally block raises or silently ignores retry when FinalResultEvent was already yielded

When using run_stream, if the model returns text that triggers a FinalResultEvent, the agent loop yields StreamedRunResult to the user and breaks out of its while loop (at agent/abstract.py:709). However, the async with node.stream(graph_ctx) context is still active, so ModelRequestNode.stream()'s finally block runs afterward. Inside the finally block, _finish_handling calls after_model_request, which may raise ModelRetry. When it does, _build_retry_node calls ctx.state.increment_retries() — which can raise UnexpectedModelBehavior if max_result_retries is exceeded, crashing the user's async with agent.run_stream(...) exit even though they already received a valid response. If retries are NOT exceeded, the retry ModelRequestNode is stored in self._result but never processed (the while loop already broke at agent/abstract.py:709, so _wrap_and_advance at line 722 is never reached), meaning the ModelRetry is silently ignored — the user gets the response the hook intended to reject.

Trace of the silent-ignore path
  1. run_streamnode.stream(graph_ctx) yields stream with text → FinalResultEvent emitted
  2. Agent loop yields StreamedRunResult at abstract.py:699, sets yielded=True, break at line 709
  3. async with node.stream() exits → finally block runs at _agent_graph.py:615
  4. _finish_handling at line 643 → after_model_request raises ModelRetry
  5. _append_response appends response, _build_retry_node creates retry node at line 640 or 826
  6. Retry node stored in self._result but agent loop already exited — retry never executed
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


@staticmethod
def _build_agent_stream(
Expand All @@ -634,7 +663,7 @@ def _build_agent_stream(

async def _make_request(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> CallToolsNode[DepsT, NodeRunEndT]:
) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]:
if self._result is not None:
return self._result # pragma: no cover

Expand All @@ -650,11 +679,16 @@ async def _make_request(
ctx.state.usage.requests += 1
return await self._finish_handling(ctx, e.response)

_handler_response: _messages.ModelResponse | None = None

async def model_handler(req_ctx: ModelRequestContext) -> _messages.ModelResponse:
nonlocal _handler_response
with set_current_run_context(run_context):
return await req_ctx.model.request(
response = await req_ctx.model.request(
req_ctx.messages, req_ctx.model_settings, req_ctx.model_request_parameters
)
_handler_response = response
return response

request_context = ModelRequestContext(
model=model,
Expand All @@ -663,17 +697,27 @@ async def model_handler(req_ctx: ModelRequestContext) -> _messages.ModelResponse
model_request_parameters=model_request_parameters,
)
try:
model_response = await ctx.deps.root_capability.wrap_model_request(
run_context,
request_context=request_context,
handler=model_handler,
)
except exceptions.SkipModelRequest as e:
model_response = e.response
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=request_context, error=e
)
try:
model_response = await ctx.deps.root_capability.wrap_model_request(
run_context,
request_context=request_context,
handler=model_handler,
)
except exceptions.SkipModelRequest as e:
model_response = e.response
except exceptions.ModelRetry:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The except ModelRetry: raise followed by except Exception pattern is repeated in three places (_make_request, and twice in stream). The re-raise is necessary to prevent the broad except Exception from swallowing ModelRetry, but it's worth noting that with the helper extraction suggested in the other comment, you may be able to structure this more cleanly — e.g. by checking isinstance(e, ModelRetry) inside the except Exception block and re-raising if so, which would eliminate the inner try/except.

raise # Propagate to outer handler
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=request_context, error=e
)
except exceptions.ModelRetry as e:
# ModelRetry from wrap_model_request or on_model_request_error — retry the model request.
# If the handler was called, preserve the response in history for context.
if _handler_response is not None:
ctx.state.usage.requests += 1
self._append_response(ctx, _handler_response)
return await self._build_retry_node(ctx, run_context, e)
self.last_request_context = request_context
ctx.state.usage.requests += 1

Expand Down Expand Up @@ -765,30 +809,82 @@ async def _finish_handling(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
response: _messages.ModelResponse,
) -> CallToolsNode[DepsT, NodeRunEndT]:
) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]:
response.run_id = response.run_id or ctx.state.run_id

run_context = build_run_context(ctx)
assert self.last_request_context is not None, 'last_request_context must be set before _finish_handling'
request_context = self.last_request_context
run_context.model_settings = request_context.model_settings
response = await ctx.deps.root_capability.after_model_request(
run_context, request_context=request_context, response=response
)

# Update usage
ctx.state.usage.incr(response.usage)
if ctx.deps.usage_limits: # pragma: no branch
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
try:
response = await ctx.deps.root_capability.after_model_request(
run_context, request_context=request_context, response=response
)
except exceptions.ModelRetry as e:
# Hook rejected the response — append it to history (model DID respond) and retry
self._append_response(ctx, response)
return await self._build_retry_node(ctx, run_context, e)

# Append the model response to state.message_history
ctx.state.message_history.append(response)
self._append_response(ctx, response)

# Set the `_result` attribute since we can't use `return` in an async iterator
self._result = CallToolsNode(response)

return self._result

async def _resolve_wrap_result(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
run_context: RunContext[DepsT],
request_context: ModelRequestContext,
result_or_exc: _messages.ModelResponse | Exception,
) -> _messages.ModelResponse:
"""Resolve a wrap_model_request result, handling SkipModelRequest and errors.

Returns ModelResponse on success.
Raises ModelRetry if the result or on_model_request_error raises it.
"""
if isinstance(result_or_exc, Exception):
exc = result_or_exc
if isinstance(exc, exceptions.SkipModelRequest):
return exc.response
if isinstance(exc, exceptions.ModelRetry):
raise exc
return await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=request_context, error=exc
)
return result_or_exc

@staticmethod
def _append_response(
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[Any, Any]],
response: _messages.ModelResponse,
) -> None:
"""Append a model response to history, updating usage tracking."""
response.run_id = response.run_id or ctx.state.run_id
ctx.state.usage.incr(response.usage)
if ctx.deps.usage_limits: # pragma: no branch
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
ctx.state.message_history.append(response)

async def _build_retry_node(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
run_context: RunContext[DepsT],
error: exceptions.ModelRetry,
) -> ModelRequestNode[DepsT, NodeRunEndT]:
"""Build a retry ModelRequestNode from a ModelRetry exception.

Increments the retry counter and creates a new request with a RetryPromptPart.
"""
ctx.state.increment_retries(ctx.deps.max_result_retries, error=error)
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 Info: ModelRetry retries count against max_result_retries, sharing the budget with output validation retries

_build_retry_node calls ctx.state.increment_retries(ctx.deps.max_result_retries, error=error) at line 825, which shares the same global retry counter (ctx.state.retries) and budget (max_result_retries) as output validation retries in CallToolsNode. This means a ModelRetry from after_model_request consumes one of the retries that would otherwise be available for output validation. This is a design choice rather than a bug, but users setting output_retries=2 might be surprised that a hook retry reduces their remaining output validation retries.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

m = _messages.RetryPromptPart(content=error.message)
instructions = await ctx.deps.get_instructions(run_context)
retry_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[m], instructions=instructions))
self._result = retry_node
return retry_node

__repr__ = dataclasses_no_defaults_repr


Expand Down
37 changes: 24 additions & 13 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,21 +297,32 @@ async def do_execute(args: dict[str, Any]) -> Any:
if cap is not None:
tool_def = validated.tool.tool_def

# before_tool_execute
args = await cap.before_tool_execute(ctx, call=call, tool_def=tool_def, args=validated.validated_args)

# wrap_tool_execute wraps the execution; on_tool_execute_error on failure
try:
tool_result = await cap.wrap_tool_execute(
ctx, call=call, tool_def=tool_def, args=args, handler=do_execute
# before_tool_execute
args = await cap.before_tool_execute(ctx, call=call, tool_def=tool_def, args=validated.validated_args)

# wrap_tool_execute wraps the execution; on_tool_execute_error on failure
try:
tool_result = await cap.wrap_tool_execute(
ctx, call=call, tool_def=tool_def, args=args, handler=do_execute
)
except (SkipToolExecution, CallDeferred, ApprovalRequired, ToolRetryError):
raise # Control flow, not errors
except ModelRetry:
raise # Propagate to outer handler
except Exception as e:
tool_result = await cap.on_tool_execute_error(ctx, call=call, tool_def=tool_def, args=args, error=e)
Comment on lines +309 to +314
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚩 Behavioral change: ModelRetry from wrap_tool_execute no longer reaches on_tool_execute_error

Previously, if wrap_tool_execute raised ModelRetry, it would be caught by except Exception as e: and passed to on_tool_execute_error. With the new except ModelRetry: raise clause at pydantic_ai_slim/pydantic_ai/_tool_manager.py:311-312, ModelRetry now bypasses on_tool_execute_error entirely and is converted directly to ToolRetryError by the outer handler. This is documented as intentional in the updated docstring for on_tool_execute_error at pydantic_ai_slim/pydantic_ai/capabilities/abstract.py:491-492, but it is a semantic change to the hook contract that could affect existing capability implementations relying on on_tool_execute_error seeing ModelRetry exceptions.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


# after_tool_execute
tool_result = await cap.after_tool_execute(
ctx, call=call, tool_def=tool_def, args=args, result=tool_result
)
except (SkipToolExecution, CallDeferred, ApprovalRequired, ToolRetryError):
raise # Control flow, not errors
except Exception as e:
tool_result = await cap.on_tool_execute_error(ctx, call=call, tool_def=tool_def, args=args, error=e)

# after_tool_execute
tool_result = await cap.after_tool_execute(ctx, call=call, tool_def=tool_def, args=args, result=tool_result)
except ModelRetry as e:
# Hook raised ModelRetry — convert to ToolRetryError for retry handling
name = call.tool_name
self._check_max_retries(name, validated.tool.max_retries, e)
self.failed_tools.add(name)
raise self._wrap_error_as_retry(name, call, e) from e
else:
tool_result = await do_execute(validated.validated_args)

Expand Down
Loading
Loading