-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Allow hooks to raise ModelRetry for retry control flow
#4858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
61d172f
3516f8d
559169c
894242b
b14a821
e3c1a59
44e040d
0b65a08
054c774
71556db
4d434cd
89175ed
758f7ee
8291df5
4820998
a753eef
5c04a67
8508c9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -476,13 +476,15 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]): | |
| request: _messages.ModelRequest | ||
| is_resuming_without_prompt: bool = False | ||
|
|
||
| _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None) | ||
| _result: CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT] | None = field( | ||
| repr=False, init=False, default=None | ||
| ) | ||
| _did_stream: bool = field(repr=False, init=False, default=False) | ||
| last_request_context: ModelRequestContext | None = field(repr=False, init=False, default=None) | ||
|
|
||
| async def run( | ||
| self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT]: | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]: | ||
| if self._result is not None: | ||
| return self._result | ||
|
|
||
|
|
@@ -564,13 +566,24 @@ async def _streaming_handler( | |
| if wrap_task.done() and not stream_ready.is_set(): | ||
| # wrap_model_request completed without calling handler — short-circuited or raised SkipModelRequest | ||
| try: | ||
| model_response = wrap_task.result() | ||
| except exceptions.SkipModelRequest as e: | ||
| model_response = e.response | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=wrap_request_context, error=e | ||
| ) | ||
| try: | ||
| model_response = wrap_task.result() | ||
| except exceptions.SkipModelRequest as e: | ||
| model_response = e.response | ||
| except exceptions.ModelRetry: | ||
| raise # Propagate to outer handler | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=wrap_request_context, error=e | ||
| ) | ||
| except exceptions.ModelRetry as e: | ||
| self._did_stream = True | ||
| ctx.state.usage.requests += 1 | ||
| ctx.state.increment_retries(ctx.deps.max_result_retries, error=e) | ||
| m = _messages.RetryPromptPart(content=e.message) | ||
| instructions = await ctx.deps.get_instructions(run_context) | ||
| self._result = ModelRequestNode(_messages.ModelRequest(parts=[m], instructions=instructions)) | ||
| return | ||
devin-ai-integration[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self._did_stream = True | ||
| ctx.state.usage.requests += 1 | ||
| skip_sr = _SkipStreamedResponse(model_request_parameters=model_request_parameters, _response=model_response) | ||
|
|
@@ -601,11 +614,21 @@ async def _streaming_handler( | |
| pass | ||
| else: | ||
| try: | ||
| model_response = await wrap_task | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=wrap_request_context, error=e | ||
| ) | ||
| try: | ||
| model_response = await wrap_task | ||
| except exceptions.ModelRetry: | ||
| raise # Propagate to outer handler | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=wrap_request_context, error=e | ||
| ) | ||
| except exceptions.ModelRetry as e: | ||
| ctx.state.usage.requests += 1 | ||
devin-ai-integration[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ctx.state.increment_retries(ctx.deps.max_result_retries, error=e) | ||
| m = _messages.RetryPromptPart(content=e.message) | ||
| instructions = await ctx.deps.get_instructions(run_context) | ||
| self._result = ModelRequestNode(_messages.ModelRequest(parts=[m], instructions=instructions)) | ||
| return | ||
| await self._finish_handling(ctx, model_response) | ||
| assert self._result is not None | ||
|
|
||
|
|
@@ -629,7 +652,7 @@ def _build_agent_stream( | |
|
|
||
| async def _make_request( | ||
| self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT]: | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]: | ||
| if self._result is not None: | ||
| return self._result # pragma: no cover | ||
|
|
||
|
|
@@ -655,17 +678,32 @@ async def model_handler(req_ctx: ModelRequestContext) -> _messages.ModelResponse | |
| model_request_parameters=model_request_parameters, | ||
| ) | ||
| try: | ||
| model_response = await ctx.deps.root_capability.wrap_model_request( | ||
| run_context, | ||
| request_context=request_context, | ||
| handler=model_handler, | ||
| ) | ||
| except exceptions.SkipModelRequest as e: | ||
| model_response = e.response | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=request_context, error=e | ||
| try: | ||
| model_response = await ctx.deps.root_capability.wrap_model_request( | ||
| run_context, | ||
| request_context=request_context, | ||
| handler=model_handler, | ||
| ) | ||
| except exceptions.SkipModelRequest as e: | ||
| model_response = e.response | ||
| except exceptions.ModelRetry: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| raise # Propagate to outer handler | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=request_context, error=e | ||
| ) | ||
| except exceptions.ModelRetry as e: | ||
| # ModelRetry from wrap_model_request or on_model_request_error — retry the model request. | ||
| # No model response to append (handler may not have been called). | ||
| ctx.state.usage.requests += 1 | ||
|
||
| ctx.state.increment_retries(ctx.deps.max_result_retries, error=e) | ||
| m = _messages.RetryPromptPart(content=e.message) | ||
| instructions = await ctx.deps.get_instructions(run_context) | ||
| retry_node = ModelRequestNode[DepsT, NodeRunEndT]( | ||
| _messages.ModelRequest(parts=[m], instructions=instructions) | ||
| ) | ||
| self._result = retry_node | ||
| return retry_node | ||
|
||
| ctx.state.usage.requests += 1 | ||
|
|
||
| return await self._finish_handling(ctx, model_response) | ||
|
|
@@ -748,16 +786,31 @@ async def _finish_handling( | |
| self, | ||
| ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], | ||
| response: _messages.ModelResponse, | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT]: | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]: | ||
| response.run_id = response.run_id or ctx.state.run_id | ||
|
|
||
| run_context = build_run_context(ctx) | ||
| assert self.last_request_context is not None, 'last_request_context must be set before _finish_handling' | ||
| request_context = self.last_request_context | ||
| run_context.model_settings = request_context.model_settings | ||
| response = await ctx.deps.root_capability.after_model_request( | ||
| run_context, request_context=request_context, response=response | ||
| ) | ||
| try: | ||
| response = await ctx.deps.root_capability.after_model_request( | ||
| run_context, request_context=request_context, response=response | ||
| ) | ||
| except exceptions.ModelRetry as e: | ||
| # Hook rejected the response — append it to history (model DID respond) and retry | ||
| ctx.state.usage.incr(response.usage) | ||
| if ctx.deps.usage_limits: # pragma: no branch | ||
| ctx.deps.usage_limits.check_tokens(ctx.state.usage) | ||
| ctx.state.message_history.append(response) | ||
| ctx.state.increment_retries(ctx.deps.max_result_retries, error=e) | ||
| m = _messages.RetryPromptPart(content=e.message) | ||
| instructions = await ctx.deps.get_instructions(run_context) | ||
| retry_node = ModelRequestNode[DepsT, NodeRunEndT]( | ||
| _messages.ModelRequest(parts=[m], instructions=instructions) | ||
| ) | ||
| self._result = retry_node | ||
| return retry_node | ||
|
|
||
| # Update usage | ||
| ctx.state.usage.incr(response.usage) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -297,21 +297,32 @@ async def do_execute(args: dict[str, Any]) -> Any: | |
| if cap is not None: | ||
| tool_def = validated.tool.tool_def | ||
|
|
||
| # before_tool_execute | ||
| args = await cap.before_tool_execute(ctx, call=call, tool_def=tool_def, args=validated.validated_args) | ||
|
|
||
| # wrap_tool_execute wraps the execution; on_tool_execute_error on failure | ||
| try: | ||
| tool_result = await cap.wrap_tool_execute( | ||
| ctx, call=call, tool_def=tool_def, args=args, handler=do_execute | ||
| # before_tool_execute | ||
| args = await cap.before_tool_execute(ctx, call=call, tool_def=tool_def, args=validated.validated_args) | ||
|
|
||
| # wrap_tool_execute wraps the execution; on_tool_execute_error on failure | ||
| try: | ||
| tool_result = await cap.wrap_tool_execute( | ||
| ctx, call=call, tool_def=tool_def, args=args, handler=do_execute | ||
| ) | ||
| except (SkipToolExecution, CallDeferred, ApprovalRequired, ToolRetryError): | ||
| raise # Control flow, not errors | ||
| except ModelRetry: | ||
| raise # Propagate to outer handler | ||
| except Exception as e: | ||
| tool_result = await cap.on_tool_execute_error(ctx, call=call, tool_def=tool_def, args=args, error=e) | ||
|
Comment on lines
+309
to
+314
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Behavioral change: Previously, if Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| # after_tool_execute | ||
| tool_result = await cap.after_tool_execute( | ||
| ctx, call=call, tool_def=tool_def, args=args, result=tool_result | ||
| ) | ||
| except (SkipToolExecution, CallDeferred, ApprovalRequired, ToolRetryError): | ||
| raise # Control flow, not errors | ||
| except Exception as e: | ||
| tool_result = await cap.on_tool_execute_error(ctx, call=call, tool_def=tool_def, args=args, error=e) | ||
|
|
||
| # after_tool_execute | ||
| tool_result = await cap.after_tool_execute(ctx, call=call, tool_def=tool_def, args=args, result=tool_result) | ||
| except ModelRetry as e: | ||
| # Hook raised ModelRetry — convert to ToolRetryError for retry handling | ||
| name = call.tool_name | ||
| self._check_max_retries(name, validated.tool.max_retries, e) | ||
| self.failed_tools.add(name) | ||
| raise self._wrap_error_as_retry(name, call, e) from e | ||
| else: | ||
| tool_result = await do_execute(validated.validated_args) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -337,7 +337,12 @@ async def after_model_request( | |
| request_context: ModelRequestContext, | ||
| response: ModelResponse, | ||
| ) -> ModelResponse: | ||
| """Called after each model response. Can modify the response before further processing.""" | ||
| """Called after each model response. Can modify the response before further processing. | ||
|
|
||
| Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the response and | ||
| ask the model to try again. The original response is still appended to message history | ||
| so the model can see what it said. Retries count against `max_result_retries`. | ||
| """ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new
The docstring updates here are a good start, but users need documentation to discover and understand this feature. |
||
| return response | ||
|
|
||
| async def wrap_model_request( | ||
|
|
@@ -347,7 +352,11 @@ async def wrap_model_request( | |
| request_context: ModelRequestContext, | ||
| handler: WrapModelRequestHandler, | ||
| ) -> ModelResponse: | ||
| """Wraps the model request. handler() calls the model.""" | ||
| """Wraps the model request. handler() calls the model. | ||
|
|
||
| Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip `on_model_request_error` | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring says "Raise This is a meaningful behavioral difference that should be documented explicitly. Users might expect the response to be visible to the model on retry, especially if they're raising |
||
| and directly retry the model request with a retry prompt. | ||
| """ | ||
| return await handler(request_context) | ||
|
|
||
| async def on_model_request_error( | ||
|
|
@@ -365,8 +374,11 @@ async def on_model_request_error( | |
| **Raise** the original `error` (or a different exception) to propagate it. | ||
| **Return** a [`ModelResponse`][pydantic_ai.messages.ModelResponse] to suppress | ||
| the error and use the response as if the model call succeeded. | ||
| **Raise** [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to retry the model request | ||
| with a retry prompt instead of recovering or propagating. | ||
|
|
||
| Not called for [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest]. | ||
| Not called for [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest] | ||
| or [`ModelRetry`][pydantic_ai.exceptions.ModelRetry]. | ||
| """ | ||
| raise error | ||
|
|
||
|
|
@@ -380,7 +392,11 @@ async def before_tool_validate( | |
| tool_def: ToolDefinition, | ||
| args: RawToolArgs, | ||
| ) -> RawToolArgs: | ||
| """Modify raw args before validation.""" | ||
| """Modify raw args before validation. | ||
|
|
||
| Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip validation and | ||
| ask the model to redo the tool call. | ||
| """ | ||
| return args | ||
|
|
||
| async def after_tool_validate( | ||
|
|
@@ -391,7 +407,11 @@ async def after_tool_validate( | |
| tool_def: ToolDefinition, | ||
| args: ValidatedToolArgs, | ||
| ) -> ValidatedToolArgs: | ||
| """Modify validated args. Called only on successful validation.""" | ||
| """Modify validated args. Called only on successful validation. | ||
|
|
||
| Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the validated args | ||
| and ask the model to redo the tool call. | ||
| """ | ||
| return args | ||
|
|
||
| async def wrap_tool_validate( | ||
|
|
@@ -439,7 +459,11 @@ async def before_tool_execute( | |
| tool_def: ToolDefinition, | ||
| args: ValidatedToolArgs, | ||
| ) -> ValidatedToolArgs: | ||
| """Modify validated args before execution.""" | ||
| """Modify validated args before execution. | ||
|
|
||
| Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip execution and | ||
| ask the model to redo the tool call. | ||
| """ | ||
| return args | ||
|
|
||
| async def after_tool_execute( | ||
|
|
@@ -451,7 +475,11 @@ async def after_tool_execute( | |
| args: ValidatedToolArgs, | ||
| result: Any, | ||
| ) -> Any: | ||
| """Modify result after execution.""" | ||
| """Modify result after execution. | ||
|
|
||
| Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the tool result | ||
| and ask the model to redo the tool call. | ||
| """ | ||
| return result | ||
|
|
||
| async def wrap_tool_execute( | ||
|
|
@@ -482,6 +510,8 @@ async def on_tool_execute_error( | |
|
|
||
| **Raise** the original `error` (or a different exception) to propagate it. | ||
| **Return** any value to suppress the error and use it as the tool result. | ||
| **Raise** [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to | ||
| redo the tool call instead of recovering or propagating. | ||
|
|
||
| Not called for control flow exceptions | ||
| ([`SkipToolExecution`][pydantic_ai.exceptions.SkipToolExecution], | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.