-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Allow hooks to raise ModelRetry for retry control flow
#4858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
61d172f
3516f8d
559169c
894242b
b14a821
e3c1a59
44e040d
0b65a08
054c774
71556db
4d434cd
89175ed
758f7ee
8291df5
4820998
a753eef
5c04a67
8508c9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -476,13 +476,15 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]): | |
| request: _messages.ModelRequest | ||
| is_resuming_without_prompt: bool = False | ||
|
|
||
| _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None) | ||
| _result: CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT] | None = field( | ||
| repr=False, init=False, default=None | ||
| ) | ||
| _did_stream: bool = field(repr=False, init=False, default=False) | ||
| last_request_context: ModelRequestContext | None = field(repr=False, init=False, default=None) | ||
|
|
||
| async def run( | ||
| self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT]: | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]: | ||
| if self._result is not None: | ||
| return self._result | ||
|
|
||
|
|
@@ -530,9 +532,12 @@ async def stream( | |
| stream_done = asyncio.Event() | ||
| agent_stream_holder: list[result.AgentStream[DepsT, T]] = [] | ||
|
|
||
| _handler_response: _messages.ModelResponse | None = None | ||
|
|
||
| async def _streaming_handler( | ||
| req_ctx: ModelRequestContext, | ||
| ) -> _messages.ModelResponse: | ||
| nonlocal _handler_response | ||
| with set_current_run_context(run_context): | ||
| async with req_ctx.model.request_stream( | ||
| req_ctx.messages, req_ctx.model_settings, req_ctx.model_request_parameters, run_context | ||
|
|
@@ -543,7 +548,9 @@ async def _streaming_handler( | |
| agent_stream_holder.append(agent_stream) | ||
| stream_ready.set() | ||
| await stream_done.wait() | ||
| return sr.get() | ||
| response = sr.get() | ||
| _handler_response = response | ||
| return response | ||
|
|
||
| wrap_request_context = ModelRequestContext( | ||
| model=model, | ||
|
|
@@ -567,13 +574,24 @@ async def _streaming_handler( | |
| if wrap_task.done() and not stream_ready.is_set(): | ||
| # wrap_model_request completed without calling handler — short-circuited or raised SkipModelRequest | ||
| try: | ||
| model_response = wrap_task.result() | ||
| except exceptions.SkipModelRequest as e: | ||
| model_response = e.response | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=wrap_request_context, error=e | ||
| result_or_exc: _messages.ModelResponse | Exception | ||
| try: | ||
| result_or_exc = wrap_task.result() | ||
| except Exception as e: | ||
| result_or_exc = e | ||
| model_response = await self._resolve_wrap_result(ctx, run_context, wrap_request_context, result_or_exc) | ||
| except exceptions.ModelRetry as e: | ||
| self._did_stream = True | ||
| # Don't increment usage.requests — handler was never called (short-circuit) | ||
| run_context = build_run_context(ctx) | ||
| await self._build_retry_node(ctx, run_context, e) | ||
| # Must still yield from @asynccontextmanager — yield an empty stream | ||
| dummy_sr = _SkipStreamedResponse( | ||
| model_request_parameters=model_request_parameters, | ||
| _response=_messages.ModelResponse(parts=[]), | ||
| ) | ||
| yield self._build_agent_stream(ctx, dummy_sr, model_request_parameters) | ||
| return | ||
| self._did_stream = True | ||
| ctx.state.usage.requests += 1 | ||
| skip_sr = _SkipStreamedResponse(model_request_parameters=model_request_parameters, _response=model_response) | ||
|
|
@@ -605,14 +623,25 @@ async def _streaming_handler( | |
| pass | ||
| else: | ||
| try: | ||
| model_response = await wrap_task | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=wrap_request_context, error=e | ||
| ) | ||
| self.last_request_context = wrap_request_context | ||
| await self._finish_handling(ctx, model_response) | ||
| assert self._result is not None | ||
| try: | ||
| model_response = await wrap_task | ||
| except exceptions.ModelRetry: | ||
| raise # Propagate to outer handler | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=wrap_request_context, error=e | ||
| ) | ||
| except exceptions.ModelRetry as e: | ||
| # Don't increment usage.requests — _streaming_handler already did | ||
| # In the normal streaming path the handler was always called (that's | ||
| # how the stream was created), so _handler_response is always set. | ||
| assert _handler_response is not None | ||
| self._append_response(ctx, _handler_response) | ||
| await self._build_retry_node(ctx, run_context, e) | ||
| else: | ||
| self.last_request_context = wrap_request_context | ||
| await self._finish_handling(ctx, model_response) | ||
| assert self._result is not None | ||
|
Comment on lines
+641
to
+644
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Streaming When using Trace of the silent-ignore path
Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| @staticmethod | ||
| def _build_agent_stream( | ||
|
|
@@ -634,7 +663,7 @@ def _build_agent_stream( | |
|
|
||
| async def _make_request( | ||
| self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT]: | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]: | ||
| if self._result is not None: | ||
| return self._result # pragma: no cover | ||
|
|
||
|
|
@@ -650,11 +679,16 @@ async def _make_request( | |
| ctx.state.usage.requests += 1 | ||
| return await self._finish_handling(ctx, e.response) | ||
|
|
||
| _handler_response: _messages.ModelResponse | None = None | ||
|
|
||
| async def model_handler(req_ctx: ModelRequestContext) -> _messages.ModelResponse: | ||
| nonlocal _handler_response | ||
| with set_current_run_context(run_context): | ||
| return await req_ctx.model.request( | ||
| response = await req_ctx.model.request( | ||
| req_ctx.messages, req_ctx.model_settings, req_ctx.model_request_parameters | ||
| ) | ||
| _handler_response = response | ||
| return response | ||
|
|
||
| request_context = ModelRequestContext( | ||
| model=model, | ||
|
|
@@ -663,17 +697,27 @@ async def model_handler(req_ctx: ModelRequestContext) -> _messages.ModelResponse | |
| model_request_parameters=model_request_parameters, | ||
| ) | ||
| try: | ||
| model_response = await ctx.deps.root_capability.wrap_model_request( | ||
| run_context, | ||
| request_context=request_context, | ||
| handler=model_handler, | ||
| ) | ||
| except exceptions.SkipModelRequest as e: | ||
| model_response = e.response | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=request_context, error=e | ||
| ) | ||
| try: | ||
| model_response = await ctx.deps.root_capability.wrap_model_request( | ||
| run_context, | ||
| request_context=request_context, | ||
| handler=model_handler, | ||
| ) | ||
| except exceptions.SkipModelRequest as e: | ||
| model_response = e.response | ||
| except exceptions.ModelRetry: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| raise # Propagate to outer handler | ||
| except Exception as e: | ||
| model_response = await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=request_context, error=e | ||
| ) | ||
| except exceptions.ModelRetry as e: | ||
| # ModelRetry from wrap_model_request or on_model_request_error — retry the model request. | ||
| # If the handler was called, preserve the response in history for context. | ||
| if _handler_response is not None: | ||
| ctx.state.usage.requests += 1 | ||
| self._append_response(ctx, _handler_response) | ||
| return await self._build_retry_node(ctx, run_context, e) | ||
| self.last_request_context = request_context | ||
| ctx.state.usage.requests += 1 | ||
|
|
||
|
|
@@ -765,30 +809,82 @@ async def _finish_handling( | |
| self, | ||
| ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], | ||
| response: _messages.ModelResponse, | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT]: | ||
| ) -> CallToolsNode[DepsT, NodeRunEndT] | ModelRequestNode[DepsT, NodeRunEndT]: | ||
| response.run_id = response.run_id or ctx.state.run_id | ||
|
|
||
| run_context = build_run_context(ctx) | ||
| assert self.last_request_context is not None, 'last_request_context must be set before _finish_handling' | ||
| request_context = self.last_request_context | ||
| run_context.model_settings = request_context.model_settings | ||
| response = await ctx.deps.root_capability.after_model_request( | ||
| run_context, request_context=request_context, response=response | ||
| ) | ||
|
|
||
| # Update usage | ||
| ctx.state.usage.incr(response.usage) | ||
| if ctx.deps.usage_limits: # pragma: no branch | ||
| ctx.deps.usage_limits.check_tokens(ctx.state.usage) | ||
| try: | ||
| response = await ctx.deps.root_capability.after_model_request( | ||
| run_context, request_context=request_context, response=response | ||
| ) | ||
| except exceptions.ModelRetry as e: | ||
| # Hook rejected the response — append it to history (model DID respond) and retry | ||
| self._append_response(ctx, response) | ||
| return await self._build_retry_node(ctx, run_context, e) | ||
|
|
||
| # Append the model response to state.message_history | ||
| ctx.state.message_history.append(response) | ||
| self._append_response(ctx, response) | ||
|
|
||
| # Set the `_result` attribute since we can't use `return` in an async iterator | ||
| self._result = CallToolsNode(response) | ||
|
|
||
| return self._result | ||
|
|
||
| async def _resolve_wrap_result( | ||
| self, | ||
| ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], | ||
| run_context: RunContext[DepsT], | ||
| request_context: ModelRequestContext, | ||
| result_or_exc: _messages.ModelResponse | Exception, | ||
| ) -> _messages.ModelResponse: | ||
| """Resolve a wrap_model_request result, handling SkipModelRequest and errors. | ||
|
|
||
| Returns ModelResponse on success. | ||
| Raises ModelRetry if the result or on_model_request_error raises it. | ||
| """ | ||
| if isinstance(result_or_exc, Exception): | ||
| exc = result_or_exc | ||
| if isinstance(exc, exceptions.SkipModelRequest): | ||
| return exc.response | ||
| if isinstance(exc, exceptions.ModelRetry): | ||
| raise exc | ||
| return await ctx.deps.root_capability.on_model_request_error( | ||
| run_context, request_context=request_context, error=exc | ||
| ) | ||
| return result_or_exc | ||
|
|
||
| @staticmethod | ||
| def _append_response( | ||
| ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[Any, Any]], | ||
| response: _messages.ModelResponse, | ||
| ) -> None: | ||
| """Append a model response to history, updating usage tracking.""" | ||
| response.run_id = response.run_id or ctx.state.run_id | ||
| ctx.state.usage.incr(response.usage) | ||
| if ctx.deps.usage_limits: # pragma: no branch | ||
| ctx.deps.usage_limits.check_tokens(ctx.state.usage) | ||
| ctx.state.message_history.append(response) | ||
|
|
||
| async def _build_retry_node( | ||
| self, | ||
| ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], | ||
| run_context: RunContext[DepsT], | ||
| error: exceptions.ModelRetry, | ||
| ) -> ModelRequestNode[DepsT, NodeRunEndT]: | ||
| """Build a retry ModelRequestNode from a ModelRetry exception. | ||
|
|
||
| Increments the retry counter and creates a new request with a RetryPromptPart. | ||
| """ | ||
| ctx.state.increment_retries(ctx.deps.max_result_retries, error=error) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 📝 Info: ModelRetry retries count against
Was this helpful? React with 👍 or 👎 to provide feedback. |
||
| m = _messages.RetryPromptPart(content=error.message) | ||
| instructions = await ctx.deps.get_instructions(run_context) | ||
| retry_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[m], instructions=instructions)) | ||
| self._result = retry_node | ||
| return retry_node | ||
|
|
||
| __repr__ = dataclasses_no_defaults_repr | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -297,21 +297,32 @@ async def do_execute(args: dict[str, Any]) -> Any: | |
| if cap is not None: | ||
| tool_def = validated.tool.tool_def | ||
|
|
||
| # before_tool_execute | ||
| args = await cap.before_tool_execute(ctx, call=call, tool_def=tool_def, args=validated.validated_args) | ||
|
|
||
| # wrap_tool_execute wraps the execution; on_tool_execute_error on failure | ||
| try: | ||
| tool_result = await cap.wrap_tool_execute( | ||
| ctx, call=call, tool_def=tool_def, args=args, handler=do_execute | ||
| # before_tool_execute | ||
| args = await cap.before_tool_execute(ctx, call=call, tool_def=tool_def, args=validated.validated_args) | ||
|
|
||
| # wrap_tool_execute wraps the execution; on_tool_execute_error on failure | ||
| try: | ||
| tool_result = await cap.wrap_tool_execute( | ||
| ctx, call=call, tool_def=tool_def, args=args, handler=do_execute | ||
| ) | ||
| except (SkipToolExecution, CallDeferred, ApprovalRequired, ToolRetryError): | ||
| raise # Control flow, not errors | ||
| except ModelRetry: | ||
| raise # Propagate to outer handler | ||
| except Exception as e: | ||
| tool_result = await cap.on_tool_execute_error(ctx, call=call, tool_def=tool_def, args=args, error=e) | ||
|
Comment on lines
+309
to
+314
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Behavioral change: Previously, if Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| # after_tool_execute | ||
| tool_result = await cap.after_tool_execute( | ||
| ctx, call=call, tool_def=tool_def, args=args, result=tool_result | ||
| ) | ||
| except (SkipToolExecution, CallDeferred, ApprovalRequired, ToolRetryError): | ||
| raise # Control flow, not errors | ||
| except Exception as e: | ||
| tool_result = await cap.on_tool_execute_error(ctx, call=call, tool_def=tool_def, args=args, error=e) | ||
|
|
||
| # after_tool_execute | ||
| tool_result = await cap.after_tool_execute(ctx, call=call, tool_def=tool_def, args=args, result=tool_result) | ||
| except ModelRetry as e: | ||
| # Hook raised ModelRetry — convert to ToolRetryError for retry handling | ||
| name = call.tool_name | ||
| self._check_max_retries(name, validated.tool.max_retries, e) | ||
| self.failed_tools.add(name) | ||
| raise self._wrap_error_as_retry(name, call, e) from e | ||
| else: | ||
| tool_result = await do_execute(validated.validated_args) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.