Skip to content

Commit 7465806

Browse files
committed
fix(output): respect ToolOutput.max_retries parameter
When is specified, the per-tool retry limit is now used for output tool validation and execution retries, instead of always falling back to the agent-level `max_result_retries`. Fixes #4678
1 parent 20ba061 commit 7465806

File tree

4 files changed

+228
-10
lines changed

4 files changed

+228
-10
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from datetime import datetime
4949

5050
from .models.instrumented import InstrumentationSettings
51+
from .toolsets.abstract import ToolsetTool
5152

5253
__all__ = (
5354
'GraphAgentState',
@@ -1242,6 +1243,14 @@ def _emit_skipped_output_tool(
12421243
yield _messages.FunctionToolResultEvent(part)
12431244

12441245

1246+
def _get_output_tool_max_retries(
1247+
tool: ToolsetTool[DepsT] | None,
1248+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]],
1249+
) -> int:
1250+
"""Get the max_retries for an output tool, falling back to the agent-level max_result_retries."""
1251+
return tool.max_retries if tool else ctx.deps.max_result_retries
1252+
1253+
12451254
async def process_tool_calls( # noqa: C901
12461255
tool_manager: ToolManager[DepsT],
12471256
tool_calls: list[_messages.ToolCallPart],
@@ -1301,7 +1310,8 @@ async def process_tool_calls( # noqa: C901
13011310
):
13021311
yield event
13031312
continue
1304-
ctx.state.increment_retries(ctx.deps.max_result_retries, error=e)
1313+
tool = tool_manager.tools.get(call.tool_name) if tool_manager.tools else None
1314+
ctx.state.increment_retries(_get_output_tool_max_retries(tool, ctx), error=e)
13051315
raise # pragma: lax no cover
13061316

13071317
if not validated.args_valid:
@@ -1313,7 +1323,9 @@ async def process_tool_calls( # noqa: C901
13131323
yield event
13141324
continue
13151325

1316-
ctx.state.increment_retries(ctx.deps.max_result_retries, error=validated.validation_error)
1326+
ctx.state.increment_retries(
1327+
_get_output_tool_max_retries(validated.tool, ctx), error=validated.validation_error
1328+
)
13171329
yield _messages.FunctionToolCallEvent(call, args_valid=False)
13181330
output_parts.append(validated.validation_error.tool_retry)
13191331
yield _messages.FunctionToolResultEvent(validated.validation_error.tool_retry)
@@ -1329,13 +1341,13 @@ async def process_tool_calls( # noqa: C901
13291341
):
13301342
yield event
13311343
continue
1332-
ctx.state.increment_retries(ctx.deps.max_result_retries, error=e)
1344+
ctx.state.increment_retries(_get_output_tool_max_retries(validated.tool, ctx), error=e)
13331345
raise # pragma: lax no cover
13341346
except ToolRetryError as e:
13351347
# If we already have a valid final result, don't increment retries for invalid output tools
13361348
# This allows the run to succeed if at least one output tool returned a valid result
13371349
if not final_result:
1338-
ctx.state.increment_retries(ctx.deps.max_result_retries, error=e)
1350+
ctx.state.increment_retries(_get_output_tool_max_retries(validated.tool, ctx), error=e)
13391351
yield _messages.FunctionToolCallEvent(call, args_valid=True)
13401352
output_parts.append(e.tool_retry)
13411353
yield _messages.FunctionToolResultEvent(e.tool_retry)

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,10 @@ class OutputToolset(AbstractToolset[AgentDepsT]):
863863
"""The tool definitions for the output tools in this toolset."""
864864
processors: dict[str, ObjectOutputProcessor[Any]]
865865
"""The processors for the output tools in this toolset."""
866-
max_retries: int
866+
max_retries: int | None
867+
"""Default max retries for output tools, set by the Agent. Per-tool overrides from `ToolOutput.max_retries` take priority."""
868+
_max_retries_overrides: dict[str, int]
869+
"""Per-tool max_retries overrides from `ToolOutput(max_retries=N)`."""
867870
output_validators: list[OutputValidator[AgentDepsT, Any]]
868871

869872
@classmethod
@@ -884,6 +887,9 @@ def build(
884887
default_description = description
885888
default_strict = strict
886889

890+
max_retries_overrides: dict[str, int] = {}
891+
tool_output_max_retries: int | None = None
892+
887893
multiple = len(outputs) > 1
888894
for output in outputs:
889895
name = None
@@ -894,6 +900,7 @@ def build(
894900
name = output.name
895901
description = output.description
896902
strict = output.strict
903+
tool_output_max_retries = output.max_retries
897904

898905
output = output.output # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
899906

@@ -933,19 +940,24 @@ def build(
933940
)
934941
processors[name] = processor
935942
tool_defs.append(tool_def)
943+
if tool_output_max_retries is not None:
944+
max_retries_overrides[name] = tool_output_max_retries
945+
tool_output_max_retries = None
936946

937-
return cls(processors=processors, tool_defs=tool_defs)
947+
return cls(processors=processors, tool_defs=tool_defs, max_retries_overrides=max_retries_overrides)
938948

939949
def __init__(
940950
self,
941951
tool_defs: list[ToolDefinition],
942952
processors: dict[str, ObjectOutputProcessor[Any]],
943-
max_retries: int = 1,
953+
max_retries: int | None = None,
954+
max_retries_overrides: dict[str, int] | None = None,
944955
output_validators: list[OutputValidator[AgentDepsT, Any]] | None = None,
945956
):
946957
self.processors = processors
947958
self._tool_defs = tool_defs
948959
self.max_retries = max_retries
960+
self._max_retries_overrides = max_retries_overrides or {}
949961
self.output_validators = output_validators or []
950962

951963
@property
@@ -957,11 +969,13 @@ def label(self) -> str:
957969
return "the agent's output tools"
958970

959971
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
972+
# max_retries should always be set by the Agent before get_tools() is called
973+
assert self.max_retries is not None, 'OutputToolset.max_retries must be set before get_tools() is called'
960974
return {
961975
tool_def.name: ToolsetTool(
962976
toolset=self,
963977
tool_def=tool_def,
964-
max_retries=self.max_retries,
978+
max_retries=self._max_retries_overrides.get(tool_def.name, self.max_retries),
965979
args_validator=self.processors[tool_def.name].validator,
966980
)
967981
for tool_def in self._tool_defs

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def __init__(
441441
self._prepare_output_tools = prepare_output_tools
442442

443443
self._output_toolset = self._output_schema.toolset
444-
if self._output_toolset:
444+
if self._output_toolset and self._output_toolset.max_retries is None:
445445
self._output_toolset.max_retries = self._max_result_retries
446446

447447
self._function_toolset = _AgentFunctionToolset(
@@ -1101,7 +1101,8 @@ def _merged_meta(ctx: RunContext[AgentDepsT]) -> dict[str, Any]:
11011101
if output_schema != self._output_schema or output_validators:
11021102
output_toolset = output_schema.toolset
11031103
if output_toolset:
1104-
output_toolset.max_retries = self._max_result_retries
1104+
if output_toolset.max_retries is None:
1105+
output_toolset.max_retries = self._max_result_retries
11051106
output_toolset.output_validators = output_validators
11061107

11071108
# Build the graph

tests/test_agent.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,197 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse
455455
assert max_retries_log == [target_retries] * (target_retries + 1)
456456

457457

458+
def test_tool_output_max_retries_overrides_agent_retries():
459+
"""ToolOutput.max_retries takes priority over Agent retries. Regression test for #4678."""
460+
retries_log: list[int] = []
461+
max_retries_log: list[int] = []
462+
target_retries = 5
463+
464+
def get_weather(ctx: RunContext[None], city: str) -> str:
465+
retries_log.append(ctx.retry)
466+
max_retries_log.append(ctx.max_retries)
467+
if ctx.retry < target_retries:
468+
raise ModelRetry(f'Retry {ctx.retry}')
469+
return f'Weather in {city}'
470+
471+
def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
472+
assert info.output_tools is not None
473+
args_json = '{"city": "Mexico City"}'
474+
return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)])
475+
476+
# Agent retries=2 (lower than ToolOutput), ToolOutput max_retries=5
477+
# The ToolOutput value should take priority, allowing 5 retries
478+
agent = Agent(
479+
FunctionModel(return_model),
480+
output_type=ToolOutput(get_weather, max_retries=target_retries),
481+
retries=2,
482+
)
483+
484+
result = agent.run_sync('Hello')
485+
assert result.output == 'Weather in Mexico City'
486+
assert retries_log == [0, 1, 2, 3, 4, 5]
487+
assert max_retries_log == [target_retries] * (target_retries + 1)
488+
assert result.all_messages() == snapshot(
489+
[
490+
ModelRequest(
491+
parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())],
492+
timestamp=IsDatetime(),
493+
run_id=IsStr(),
494+
),
495+
ModelResponse(
496+
parts=[
497+
ToolCallPart(
498+
tool_name='final_result',
499+
args='{"city": "Mexico City"}',
500+
tool_call_id=IsStr(),
501+
)
502+
],
503+
usage=RequestUsage(input_tokens=51, output_tokens=6),
504+
model_name='function:return_model:',
505+
timestamp=IsDatetime(),
506+
run_id=IsStr(),
507+
),
508+
ModelRequest(
509+
parts=[
510+
RetryPromptPart(
511+
content='Retry 0',
512+
tool_name='final_result',
513+
tool_call_id=IsStr(),
514+
timestamp=IsDatetime(),
515+
)
516+
],
517+
timestamp=IsDatetime(),
518+
run_id=IsStr(),
519+
),
520+
ModelResponse(
521+
parts=[
522+
ToolCallPart(
523+
tool_name='final_result',
524+
args='{"city": "Mexico City"}',
525+
tool_call_id=IsStr(),
526+
)
527+
],
528+
usage=RequestUsage(input_tokens=60, output_tokens=12),
529+
model_name='function:return_model:',
530+
timestamp=IsDatetime(),
531+
run_id=IsStr(),
532+
),
533+
ModelRequest(
534+
parts=[
535+
RetryPromptPart(
536+
content='Retry 1',
537+
tool_name='final_result',
538+
tool_call_id=IsStr(),
539+
timestamp=IsDatetime(),
540+
)
541+
],
542+
timestamp=IsDatetime(),
543+
run_id=IsStr(),
544+
),
545+
ModelResponse(
546+
parts=[
547+
ToolCallPart(
548+
tool_name='final_result',
549+
args='{"city": "Mexico City"}',
550+
tool_call_id=IsStr(),
551+
)
552+
],
553+
usage=RequestUsage(input_tokens=69, output_tokens=18),
554+
model_name='function:return_model:',
555+
timestamp=IsDatetime(),
556+
run_id=IsStr(),
557+
),
558+
ModelRequest(
559+
parts=[
560+
RetryPromptPart(
561+
content='Retry 2',
562+
tool_name='final_result',
563+
tool_call_id=IsStr(),
564+
timestamp=IsDatetime(),
565+
)
566+
],
567+
timestamp=IsDatetime(),
568+
run_id=IsStr(),
569+
),
570+
ModelResponse(
571+
parts=[
572+
ToolCallPart(
573+
tool_name='final_result',
574+
args='{"city": "Mexico City"}',
575+
tool_call_id=IsStr(),
576+
)
577+
],
578+
usage=RequestUsage(input_tokens=78, output_tokens=24),
579+
model_name='function:return_model:',
580+
timestamp=IsDatetime(),
581+
run_id=IsStr(),
582+
),
583+
ModelRequest(
584+
parts=[
585+
RetryPromptPart(
586+
content='Retry 3',
587+
tool_name='final_result',
588+
tool_call_id=IsStr(),
589+
timestamp=IsDatetime(),
590+
)
591+
],
592+
timestamp=IsDatetime(),
593+
run_id=IsStr(),
594+
),
595+
ModelResponse(
596+
parts=[
597+
ToolCallPart(
598+
tool_name='final_result',
599+
args='{"city": "Mexico City"}',
600+
tool_call_id=IsStr(),
601+
)
602+
],
603+
usage=RequestUsage(input_tokens=87, output_tokens=30),
604+
model_name='function:return_model:',
605+
timestamp=IsDatetime(),
606+
run_id=IsStr(),
607+
),
608+
ModelRequest(
609+
parts=[
610+
RetryPromptPart(
611+
content='Retry 4',
612+
tool_name='final_result',
613+
tool_call_id=IsStr(),
614+
timestamp=IsDatetime(),
615+
)
616+
],
617+
timestamp=IsDatetime(),
618+
run_id=IsStr(),
619+
),
620+
ModelResponse(
621+
parts=[
622+
ToolCallPart(
623+
tool_name='final_result',
624+
args='{"city": "Mexico City"}',
625+
tool_call_id=IsStr(),
626+
)
627+
],
628+
usage=RequestUsage(input_tokens=96, output_tokens=36),
629+
model_name='function:return_model:',
630+
timestamp=IsDatetime(),
631+
run_id=IsStr(),
632+
),
633+
ModelRequest(
634+
parts=[
635+
ToolReturnPart(
636+
tool_name='final_result',
637+
content='Final result processed.',
638+
tool_call_id=IsStr(),
639+
timestamp=IsDatetime(),
640+
)
641+
],
642+
timestamp=IsDatetime(),
643+
run_id=IsStr(),
644+
),
645+
]
646+
)
647+
648+
458649
class TestPartialOutput:
459650
"""Tests for `ctx.partial_output` flag in output validators and output functions."""
460651

0 commit comments

Comments
 (0)