Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
61d172f
Allow hooks to raise `ModelRetry` for retry control flow
DouweM Mar 26, 2026
3516f8d
Address review feedback: extract helper, fix streaming bug, improve t…
DouweM Mar 26, 2026
559169c
Fix Python 3.14 SyntaxError: avoid 'return' in 'finally' block
DouweM Mar 26, 2026
894242b
Add streaming tests for wrap_model_request ModelRetry coverage
DouweM Mar 26, 2026
b14a821
Fix coverage: remove dead branch, pragma no cover on not-called handlers
DouweM Mar 26, 2026
e3c1a59
Fix remaining coverage: exercise hook return paths after retry
DouweM Mar 26, 2026
44e040d
Remove incorrect pragma: no cover on stream_fn that IS executed
DouweM Mar 26, 2026
0b65a08
Merge remote-tracking branch 'origin/main' into hook-model-retry
DouweM Mar 26, 2026
054c774
Preserve model response in history when wrap_model_request raises Mod…
DouweM Mar 26, 2026
71556db
Fix coverage: assert handler_response in normal streaming path
DouweM Mar 27, 2026
4d434cd
Fix double usage.requests increment, add snapshot assertions
DouweM Mar 27, 2026
89175ed
Fix short-circuit usage.requests, add snapshots to all tests
DouweM Mar 27, 2026
758f7ee
Fix non-streaming short-circuit usage.requests, add docs section
DouweM Mar 27, 2026
8291df5
Remove unused import in docs example
DouweM Mar 27, 2026
4820998
Fix docs example: add type annotations, title, correct output
DouweM Mar 27, 2026
a753eef
Fix misleading docs: clarify not all hooks support ModelRetry
DouweM Mar 27, 2026
5c04a67
Simplify docs: drop explicit before_model_request exclusion
DouweM Mar 27, 2026
8508c9d
Merge remote-tracking branch 'origin/main' into hook-model-retry
DouweM Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 82 additions & 29 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,15 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
request: _messages.ModelRequest
is_resuming_without_prompt: bool = False

_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None)
_result: CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT] | None = field(
repr=False, init=False, default=None
)
_did_stream: bool = field(repr=False, init=False, default=False)
last_request_context: ModelRequestContext | None = field(repr=False, init=False, default=None)

async def run(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> CallToolsNode[DepsT, NodeRunEndT]:
) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]:
if self._result is not None:
return self._result

Expand Down Expand Up @@ -564,13 +566,24 @@ async def _streaming_handler(
if wrap_task.done() and not stream_ready.is_set():
# wrap_model_request completed without calling handler — short-circuited or raised SkipModelRequest
try:
model_response = wrap_task.result()
except exceptions.SkipModelRequest as e:
model_response = e.response
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=wrap_request_context, error=e
)
try:
model_response = wrap_task.result()
except exceptions.SkipModelRequest as e:
model_response = e.response
except exceptions.ModelRetry:
raise # Propagate to outer handler
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=wrap_request_context, error=e
)
except exceptions.ModelRetry as e:
self._did_stream = True
ctx.state.usage.requests += 1
ctx.state.increment_retries(ctx.deps.max_result_retries, error=e)
m = _messages.RetryPromptPart(content=e.message)
instructions = await ctx.deps.get_instructions(run_context)
self._result = ModelRequestNode(_messages.ModelRequest(parts=[m], instructions=instructions))
return
self._did_stream = True
ctx.state.usage.requests += 1
skip_sr = _SkipStreamedResponse(model_request_parameters=model_request_parameters, _response=model_response)
Expand Down Expand Up @@ -601,11 +614,21 @@ async def _streaming_handler(
pass
else:
try:
model_response = await wrap_task
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=wrap_request_context, error=e
)
try:
model_response = await wrap_task
except exceptions.ModelRetry:
raise # Propagate to outer handler
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=wrap_request_context, error=e
)
except exceptions.ModelRetry as e:
ctx.state.usage.requests += 1
ctx.state.increment_retries(ctx.deps.max_result_retries, error=e)
m = _messages.RetryPromptPart(content=e.message)
instructions = await ctx.deps.get_instructions(run_context)
self._result = ModelRequestNode(_messages.ModelRequest(parts=[m], instructions=instructions))
return
await self._finish_handling(ctx, model_response)
assert self._result is not None

Expand All @@ -629,7 +652,7 @@ def _build_agent_stream(

async def _make_request(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> CallToolsNode[DepsT, NodeRunEndT]:
) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]:
if self._result is not None:
return self._result # pragma: no cover

Expand All @@ -655,17 +678,32 @@ async def model_handler(req_ctx: ModelRequestContext) -> _messages.ModelResponse
model_request_parameters=model_request_parameters,
)
try:
model_response = await ctx.deps.root_capability.wrap_model_request(
run_context,
request_context=request_context,
handler=model_handler,
)
except exceptions.SkipModelRequest as e:
model_response = e.response
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=request_context, error=e
try:
model_response = await ctx.deps.root_capability.wrap_model_request(
run_context,
request_context=request_context,
handler=model_handler,
)
except exceptions.SkipModelRequest as e:
model_response = e.response
except exceptions.ModelRetry:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The except ModelRetry: raise followed by except Exception pattern is repeated in three places (_make_request, and twice in stream). The re-raise is necessary to prevent the broad except Exception from swallowing ModelRetry, but it's worth noting that with the helper extraction suggested in the other comment, you may be able to structure this more cleanly — e.g. by checking isinstance(e, ModelRetry) inside the except Exception block and re-raising if so, which would eliminate the inner try/except.

raise # Propagate to outer handler
except Exception as e:
model_response = await ctx.deps.root_capability.on_model_request_error(
run_context, request_context=request_context, error=e
)
except exceptions.ModelRetry as e:
# ModelRetry from wrap_model_request or on_model_request_error — retry the model request.
# No model response to append (handler may not have been called).
ctx.state.usage.requests += 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usage.requests is unconditionally incremented here, even when _handler_response is None (i.e. wrap_model_request short-circuited without calling the handler, meaning no actual model request was made). This is inconsistent with the streaming short-circuit ModelRetry path at line 585, which explicitly does not increment usage.requests because "handler was never called."

The streaming path's logic seems more correct: if the handler was never called, no model request was made, so counting it as a request is wrong. This should probably be conditional:

if _handler_response is not None:
    ctx.state.usage.requests += 1
    self._append_response(ctx, _handler_response)

Consider also adding a test that asserts on result.usage().requests to verify request counting is correct across all ModelRetry paths (short-circuit vs after-handler, streaming vs non-streaming).

ctx.state.increment_retries(ctx.deps.max_result_retries, error=e)
m = _messages.RetryPromptPart(content=e.message)
instructions = await ctx.deps.get_instructions(run_context)
retry_node = ModelRequestNode[DepsT, NodeRunEndT](
_messages.ModelRequest(parts=[m], instructions=instructions)
)
self._result = retry_node
return retry_node
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The retry node creation logic (increment retries → build RetryPromptPart → get instructions → create ModelRequestNode → set self._result) is duplicated four times across _make_request, _finish_handling, and stream (twice). The only differences are a few context-specific lines before the shared part (e.g. appending the response in _finish_handling, setting _did_stream in stream).

Consider extracting a helper like _build_retry_node(ctx, run_context, error) that handles the common retry node construction, so each call site only needs to handle its own context-specific bookkeeping before calling the helper.

ctx.state.usage.requests += 1

return await self._finish_handling(ctx, model_response)
Expand Down Expand Up @@ -748,16 +786,31 @@ async def _finish_handling(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
response: _messages.ModelResponse,
) -> CallToolsNode[DepsT, NodeRunEndT]:
) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]:
response.run_id = response.run_id or ctx.state.run_id

run_context = build_run_context(ctx)
assert self.last_request_context is not None, 'last_request_context must be set before _finish_handling'
request_context = self.last_request_context
run_context.model_settings = request_context.model_settings
response = await ctx.deps.root_capability.after_model_request(
run_context, request_context=request_context, response=response
)
try:
response = await ctx.deps.root_capability.after_model_request(
run_context, request_context=request_context, response=response
)
except exceptions.ModelRetry as e:
# Hook rejected the response — append it to history (model DID respond) and retry
ctx.state.usage.incr(response.usage)
if ctx.deps.usage_limits: # pragma: no branch
ctx.deps.usage_limits.check_tokens(ctx.state.usage)
ctx.state.message_history.append(response)
ctx.state.increment_retries(ctx.deps.max_result_retries, error=e)
m = _messages.RetryPromptPart(content=e.message)
instructions = await ctx.deps.get_instructions(run_context)
retry_node = ModelRequestNode[DepsT, NodeRunEndT](
_messages.ModelRequest(parts=[m], instructions=instructions)
)
self._result = retry_node
return retry_node

# Update usage
ctx.state.usage.incr(response.usage)
Expand Down
37 changes: 24 additions & 13 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,21 +297,32 @@ async def do_execute(args: dict[str, Any]) -> Any:
if cap is not None:
tool_def = validated.tool.tool_def

# before_tool_execute
args = await cap.before_tool_execute(ctx, call=call, tool_def=tool_def, args=validated.validated_args)

# wrap_tool_execute wraps the execution; on_tool_execute_error on failure
try:
tool_result = await cap.wrap_tool_execute(
ctx, call=call, tool_def=tool_def, args=args, handler=do_execute
# before_tool_execute
args = await cap.before_tool_execute(ctx, call=call, tool_def=tool_def, args=validated.validated_args)

# wrap_tool_execute wraps the execution; on_tool_execute_error on failure
try:
tool_result = await cap.wrap_tool_execute(
ctx, call=call, tool_def=tool_def, args=args, handler=do_execute
)
except (SkipToolExecution, CallDeferred, ApprovalRequired, ToolRetryError):
raise # Control flow, not errors
except ModelRetry:
raise # Propagate to outer handler
except Exception as e:
tool_result = await cap.on_tool_execute_error(ctx, call=call, tool_def=tool_def, args=args, error=e)
Comment on lines +309 to +314
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚩 Behavioral change: ModelRetry from wrap_tool_execute no longer reaches on_tool_execute_error

Previously, if wrap_tool_execute raised ModelRetry, it would be caught by except Exception as e: and passed to on_tool_execute_error. With the new except ModelRetry: raise clause at pydantic_ai_slim/pydantic_ai/_tool_manager.py:311-312, ModelRetry now bypasses on_tool_execute_error entirely and is converted directly to ToolRetryError by the outer handler. This is documented as intentional in the updated docstring for on_tool_execute_error at pydantic_ai_slim/pydantic_ai/capabilities/abstract.py:491-492, but it is a semantic change to the hook contract that could affect existing capability implementations relying on on_tool_execute_error seeing ModelRetry exceptions.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


# after_tool_execute
tool_result = await cap.after_tool_execute(
ctx, call=call, tool_def=tool_def, args=args, result=tool_result
)
except (SkipToolExecution, CallDeferred, ApprovalRequired, ToolRetryError):
raise # Control flow, not errors
except Exception as e:
tool_result = await cap.on_tool_execute_error(ctx, call=call, tool_def=tool_def, args=args, error=e)

# after_tool_execute
tool_result = await cap.after_tool_execute(ctx, call=call, tool_def=tool_def, args=args, result=tool_result)
except ModelRetry as e:
# Hook raised ModelRetry — convert to ToolRetryError for retry handling
name = call.tool_name
self._check_max_retries(name, validated.tool.max_retries, e)
self.failed_tools.add(name)
raise self._wrap_error_as_retry(name, call, e) from e
else:
tool_result = await do_execute(validated.validated_args)

Expand Down
44 changes: 37 additions & 7 deletions pydantic_ai_slim/pydantic_ai/capabilities/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,12 @@ async def after_model_request(
request_context: ModelRequestContext,
response: ModelResponse,
) -> ModelResponse:
"""Called after each model response. Can modify the response before further processing."""
"""Called after each model response. Can modify the response before further processing.

Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the response and
ask the model to try again. The original response is still appended to message history
so the model can see what it said. Retries count against `max_result_retries`.
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new ModelRetry support from capability hooks is a meaningful feature addition, but there are no documentation updates. The capabilities docs should describe:

  • Which hooks support ModelRetry and the semantics for each (model hooks vs tool hooks, which retry pool is used)
  • The interaction between ModelRetry and on_model_request_error / on_tool_execute_error (i.e. ModelRetry bypasses the error hooks)
  • Examples showing common use cases (e.g. post-processing validation in after_model_request, guardrails in wrap_model_request)

The docstring updates here are a good start, but users need documentation to discover and understand this feature.

return response

async def wrap_model_request(
Expand All @@ -347,7 +352,11 @@ async def wrap_model_request(
request_context: ModelRequestContext,
handler: WrapModelRequestHandler,
) -> ModelResponse:
"""Wraps the model request. handler() calls the model."""
"""Wraps the model request. handler() calls the model.

Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip `on_model_request_error`
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says "Raise ModelRetry to skip on_model_request_error and directly retry the model request with a retry prompt" — but looking at the implementation, ModelRetry raised from wrap_model_request after the handler has been called means the model's response is silently discarded (not appended to history). This is unlike after_model_request where the response is preserved so the model can see what it said.

This is a meaningful behavioral difference that should be documented explicitly. Users might expect the response to be visible to the model on retry, especially if they're raising ModelRetry based on inspecting the response inside wrap_model_request.

and directly retry the model request with a retry prompt.
"""
return await handler(request_context)

async def on_model_request_error(
Expand All @@ -365,8 +374,11 @@ async def on_model_request_error(
**Raise** the original `error` (or a different exception) to propagate it.
**Return** a [`ModelResponse`][pydantic_ai.messages.ModelResponse] to suppress
the error and use the response as if the model call succeeded.
**Raise** [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to retry the model request
with a retry prompt instead of recovering or propagating.

Not called for [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest].
Not called for [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest]
or [`ModelRetry`][pydantic_ai.exceptions.ModelRetry].
"""
raise error

Expand All @@ -380,7 +392,11 @@ async def before_tool_validate(
tool_def: ToolDefinition,
args: RawToolArgs,
) -> RawToolArgs:
"""Modify raw args before validation."""
"""Modify raw args before validation.

Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip validation and
ask the model to redo the tool call.
"""
return args

async def after_tool_validate(
Expand All @@ -391,7 +407,11 @@ async def after_tool_validate(
tool_def: ToolDefinition,
args: ValidatedToolArgs,
) -> ValidatedToolArgs:
"""Modify validated args. Called only on successful validation."""
"""Modify validated args. Called only on successful validation.

Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the validated args
and ask the model to redo the tool call.
"""
return args

async def wrap_tool_validate(
Expand Down Expand Up @@ -439,7 +459,11 @@ async def before_tool_execute(
tool_def: ToolDefinition,
args: ValidatedToolArgs,
) -> ValidatedToolArgs:
"""Modify validated args before execution."""
"""Modify validated args before execution.

Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip execution and
ask the model to redo the tool call.
"""
return args

async def after_tool_execute(
Expand All @@ -451,7 +475,11 @@ async def after_tool_execute(
args: ValidatedToolArgs,
result: Any,
) -> Any:
"""Modify result after execution."""
"""Modify result after execution.

Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the tool result
and ask the model to redo the tool call.
"""
return result

async def wrap_tool_execute(
Expand Down Expand Up @@ -482,6 +510,8 @@ async def on_tool_execute_error(

**Raise** the original `error` (or a different exception) to propagate it.
**Return** any value to suppress the error and use it as the tool result.
**Raise** [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to
redo the tool call instead of recovering or propagating.

Not called for control flow exceptions
([`SkipToolExecution`][pydantic_ai.exceptions.SkipToolExecution],
Expand Down
6 changes: 4 additions & 2 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@


class ModelRetry(Exception):
"""Exception to raise when a tool function should be retried.
"""Exception to raise to request a model retry.

The agent will return the message to the model and ask it to try calling the function/tool again.
Can be raised from tool functions, output validators, and capability hooks
(such as `after_model_request`, `after_tool_execute`, etc.) to send
a retry prompt back to the model asking it to try again.
"""

message: str
Expand Down
Loading
Loading