Skip to content

Commit 543d2b1

Browse files
DouweMclaude
andcommitted
refactor: address review feedback — processor validate/call, types, imports
Major changes: - Add validate() and call() methods to BaseOutputProcessor with identity defaults. Override in ObjectOutputProcessor (existing), UnionOutputProcessor (new: clean decomposition of union envelope + inner validation/execution), TextFunctionOutputProcessor (new: function call in call()). - Eliminate isinstance branching in _build_output_handlers — now uses polymorphic processor.validate() and processor.call(). - Move OutputContext to public pydantic_ai.output module. - Fix output_type for functions: now the function input type (what the model produces), not the return type. - Fix execute hook parameter types: Any instead of RawOutput, since validated output varies by processor type. - Top-level imports in _tool_manager.py, keyword-only args in helpers. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4f31ecb commit 543d2b1

File tree

9 files changed

+218
-160
lines changed

9 files changed

+218
-160
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 151 additions & 112 deletions
Large diffs are not rendered by default.

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from . import messages as _messages
1616
from ._instrumentation import InstrumentationNames
17+
from ._output import OutputToolset, run_output_execute_hooks, run_output_validate_hooks
1718
from ._run_context import AgentDepsT, RunContext
1819
from .exceptions import (
1920
ApprovalRequired,
@@ -506,12 +507,6 @@ async def _raw_execute_output_tool(
506507
validated: ValidatedToolCall[AgentDepsT],
507508
) -> Any:
508509
"""Execute an output tool call with output hooks."""
509-
from ._output import (
510-
OutputToolset,
511-
run_output_execute_hooks,
512-
run_output_validate_hooks,
513-
)
514-
515510
assert validated.tool is not None
516511
assert validated.validated_args is not None
517512
assert self.root_capability is not None
@@ -535,18 +530,18 @@ async def do_validate(data: str | dict[str, Any]) -> str | dict[str, Any]:
535530

536531
validated_output = await run_output_validate_hooks(
537532
cap,
538-
ctx,
539-
output_context,
540-
validated.validated_args,
541-
do_validate,
533+
run_context=ctx,
534+
output_context=output_context,
535+
raw_output=validated.validated_args,
536+
do_validate=do_validate,
542537
allow_partial=False,
543538
wrap_validation_errors=False,
544539
)
545540

546541
# --- Output execute phase (wraps processor.call + output validators) ---
547-
async def do_execute(output: str | dict[str, Any]) -> Any:
542+
async def do_execute(output: Any) -> Any:
548543
try:
549-
result = await processor.call(output, ctx, wrap_validation_errors=False) # type: ignore[arg-type]
544+
result = await processor.call(output, ctx, wrap_validation_errors=False)
550545
for validator in toolset.output_validators:
551546
result = await validator.validate(result, ctx, wrap_validation_errors=False)
552547
return result
@@ -555,7 +550,13 @@ async def do_execute(output: str | dict[str, Any]) -> Any:
555550
self.failed_tools.add(name)
556551
raise self._wrap_error_as_retry(name, validated.call, e) from e
557552

558-
return await run_output_execute_hooks(cap, ctx, output_context, validated_output, do_execute)
553+
return await run_output_execute_hooks(
554+
cap,
555+
run_context=ctx,
556+
output_context=output_context,
557+
validated=validated_output,
558+
do_execute=do_execute,
559+
)
559560

560561
async def _execute_function_tool_call(
561562
self,

pydantic_ai_slim/pydantic_ai/capabilities/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any
22

3-
from pydantic_ai._output import OutputContext
3+
from pydantic_ai.output import OutputContext
44

55
from .abstract import (
66
AbstractCapability,

pydantic_ai_slim/pydantic_ai/capabilities/abstract.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
if TYPE_CHECKING:
1717
from pydantic_ai import _agent_graph
18-
from pydantic_ai._output import OutputContext
1918
from pydantic_ai.agent.abstract import AgentModelSettings
2019
from pydantic_ai.capabilities.prefix_tools import PrefixTools
2120
from pydantic_ai.models import ModelRequestContext
21+
from pydantic_ai.output import OutputContext
2222
from pydantic_ai.result import FinalResult
2323
from pydantic_ai.run import AgentRunResult
2424
from pydantic_graph import End
@@ -59,7 +59,7 @@
5959
WrapOutputValidateHandler: TypeAlias = 'Callable[[str | dict[str, Any]], Awaitable[str | dict[str, Any]]]'
6060
"""Handler type for wrap_output_validate."""
6161

62-
WrapOutputExecuteHandler: TypeAlias = 'Callable[[str | dict[str, Any]], Awaitable[Any]]'
62+
WrapOutputExecuteHandler: TypeAlias = 'Callable[[Any], Awaitable[Any]]'
6363
"""Handler type for wrap_output_execute."""
6464

6565

@@ -570,17 +570,17 @@ async def before_output_execute(
570570
self,
571571
ctx: RunContext[AgentDepsT],
572572
*,
573-
output: RawOutput,
573+
output: Any,
574574
output_context: OutputContext,
575-
) -> RawOutput:
575+
) -> Any:
576576
"""Modify validated output before execution (extraction + function call)."""
577577
return output
578578

579579
async def after_output_execute(
580580
self,
581581
ctx: RunContext[AgentDepsT],
582582
*,
583-
validated_output: RawOutput,
583+
validated_output: Any,
584584
output: Any,
585585
output_context: OutputContext,
586586
) -> Any:
@@ -591,7 +591,7 @@ async def wrap_output_execute(
591591
self,
592592
ctx: RunContext[AgentDepsT],
593593
*,
594-
output: RawOutput,
594+
output: Any,
595595
output_context: OutputContext,
596596
handler: WrapOutputExecuteHandler,
597597
) -> Any:
@@ -602,7 +602,7 @@ async def on_output_execute_error(
602602
self,
603603
ctx: RunContext[AgentDepsT],
604604
*,
605-
output: RawOutput,
605+
output: Any,
606606
output_context: OutputContext,
607607
error: Exception,
608608
) -> Any:

pydantic_ai_slim/pydantic_ai/capabilities/combined.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
if TYPE_CHECKING:
2222
from pydantic_ai import _agent_graph
23-
from pydantic_ai._output import OutputContext
2423
from pydantic_ai.models import ModelRequestContext
24+
from pydantic_ai.output import OutputContext
2525
from pydantic_ai.result import FinalResult
2626
from pydantic_ai.run import AgentRunResult
2727
from pydantic_graph import End
@@ -462,9 +462,9 @@ async def before_output_execute(
462462
self,
463463
ctx: RunContext[AgentDepsT],
464464
*,
465-
output: RawOutput,
465+
output: Any,
466466
output_context: OutputContext,
467-
) -> RawOutput:
467+
) -> Any:
468468
for capability in self.capabilities:
469469
output = await capability.before_output_execute(ctx, output=output, output_context=output_context)
470470
return output
@@ -473,7 +473,7 @@ async def after_output_execute(
473473
self,
474474
ctx: RunContext[AgentDepsT],
475475
*,
476-
validated_output: RawOutput,
476+
validated_output: Any,
477477
output: Any,
478478
output_context: OutputContext,
479479
) -> Any:
@@ -487,7 +487,7 @@ async def wrap_output_execute(
487487
self,
488488
ctx: RunContext[AgentDepsT],
489489
*,
490-
output: RawOutput,
490+
output: Any,
491491
output_context: OutputContext,
492492
handler: WrapOutputExecuteHandler,
493493
) -> Any:
@@ -500,7 +500,7 @@ async def on_output_execute_error(
500500
self,
501501
ctx: RunContext[AgentDepsT],
502502
*,
503-
output: RawOutput,
503+
output: Any,
504504
output_context: OutputContext,
505505
error: Exception,
506506
) -> Any:
@@ -606,9 +606,9 @@ def _make_output_execute_wrap(
606606
cap: AbstractCapability[AgentDepsT],
607607
ctx: RunContext[AgentDepsT],
608608
output_context: OutputContext,
609-
inner: Callable[[RawOutput], Awaitable[Any]],
610-
) -> Callable[[RawOutput], Awaitable[Any]]:
611-
async def wrapped(output: RawOutput) -> Any:
609+
inner: Callable[[Any], Awaitable[Any]],
610+
) -> Callable[[Any], Awaitable[Any]]:
611+
async def wrapped(output: Any) -> Any:
612612
return await cap.wrap_output_execute(ctx, output=output, output_context=output_context, handler=inner)
613613

614614
return wrapped

pydantic_ai_slim/pydantic_ai/capabilities/hooks.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ async def log_request(ctx, request_context):
4848
)
4949

5050
if TYPE_CHECKING:
51-
from pydantic_ai._output import OutputContext
5251
from pydantic_ai.models import ModelRequestContext
52+
from pydantic_ai.output import OutputContext
5353
from pydantic_ai.run import AgentRunResult
5454

5555
_FuncT = TypeVar('_FuncT', bound=Callable[..., Any])
@@ -202,19 +202,19 @@ def __call__(self, ctx: RunContext[Any], /, *, raw_output: RawOutput, output_con
202202

203203
class BeforeOutputExecuteHookFunc(Protocol):
204204
"""Protocol for :meth:`~AbstractCapability.before_output_execute` hook functions."""
205-
def __call__(self, ctx: RunContext[Any], /, *, output: RawOutput, output_context: OutputContext) -> RawOutput | Awaitable[RawOutput]: ...
205+
def __call__(self, ctx: RunContext[Any], /, *, output: Any, output_context: OutputContext) -> Any | Awaitable[Any]: ...
206206

207207
class AfterOutputExecuteHookFunc(Protocol):
208208
"""Protocol for :meth:`~AbstractCapability.after_output_execute` hook functions."""
209-
def __call__(self, ctx: RunContext[Any], /, *, validated_output: RawOutput, output: Any, output_context: OutputContext) -> Any | Awaitable[Any]: ...
209+
def __call__(self, ctx: RunContext[Any], /, *, validated_output: Any, output: Any, output_context: OutputContext) -> Any | Awaitable[Any]: ...
210210

211211
class WrapOutputExecuteHookFunc(Protocol):
212212
"""Protocol for :meth:`~AbstractCapability.wrap_output_execute` hook functions."""
213-
def __call__(self, ctx: RunContext[Any], /, *, output: RawOutput, output_context: OutputContext, handler: WrapOutputExecuteHandler) -> Any | Awaitable[Any]: ...
213+
def __call__(self, ctx: RunContext[Any], /, *, output: Any, output_context: OutputContext, handler: WrapOutputExecuteHandler) -> Any | Awaitable[Any]: ...
214214

215215
class OnOutputExecuteErrorHookFunc(Protocol):
216216
"""Protocol for :meth:`~AbstractCapability.on_output_execute_error` hook functions."""
217-
def __call__(self, ctx: RunContext[Any], /, *, output: RawOutput, output_context: OutputContext, error: Exception) -> Any | Awaitable[Any]: ...
217+
def __call__(self, ctx: RunContext[Any], /, *, output: Any, output_context: OutputContext, error: Exception) -> Any | Awaitable[Any]: ...
218218
# fmt: on
219219

220220

@@ -1114,16 +1114,16 @@ async def on_output_validate_error(
11141114
raise error
11151115

11161116
async def before_output_execute(
1117-
self, ctx: RunContext[AgentDepsT], *, output: RawOutput, output_context: OutputContext
1118-
) -> RawOutput:
1117+
self, ctx: RunContext[AgentDepsT], *, output: Any, output_context: OutputContext
1118+
) -> Any:
11191119
for entry in self._get('before_output_execute'):
11201120
output = await _call_entry(
11211121
entry, 'before_output_execute', ctx, output=output, output_context=output_context
11221122
)
11231123
return output
11241124

11251125
async def after_output_execute(
1126-
self, ctx: RunContext[AgentDepsT], *, validated_output: RawOutput, output: Any, output_context: OutputContext
1126+
self, ctx: RunContext[AgentDepsT], *, validated_output: Any, output: Any, output_context: OutputContext
11271127
) -> Any:
11281128
for entry in self._get('after_output_execute'):
11291129
output = await _call_entry(
@@ -1140,7 +1140,7 @@ async def wrap_output_execute(
11401140
self,
11411141
ctx: RunContext[AgentDepsT],
11421142
*,
1143-
output: RawOutput,
1143+
output: Any,
11441144
output_context: OutputContext,
11451145
handler: WrapOutputExecuteHandler,
11461146
) -> Any:
@@ -1158,7 +1158,7 @@ async def on_output_execute_error(
11581158
self,
11591159
ctx: RunContext[AgentDepsT],
11601160
*,
1161-
output: RawOutput,
1161+
output: Any,
11621162
output_context: OutputContext,
11631163
error: Exception,
11641164
) -> Any:

pydantic_ai_slim/pydantic_ai/capabilities/wrapper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
)
3030

3131
if TYPE_CHECKING:
32-
from pydantic_ai._output import OutputContext
3332
from pydantic_ai.agent.abstract import AgentModelSettings
3433
from pydantic_ai.models import ModelRequestContext
34+
from pydantic_ai.output import OutputContext
3535
from pydantic_ai.run import AgentRunResult
3636

3737

@@ -338,16 +338,16 @@ async def before_output_execute(
338338
self,
339339
ctx: RunContext[AgentDepsT],
340340
*,
341-
output: RawOutput,
341+
output: Any,
342342
output_context: OutputContext,
343-
) -> RawOutput:
343+
) -> Any:
344344
return await self.wrapped.before_output_execute(ctx, output=output, output_context=output_context)
345345

346346
async def after_output_execute(
347347
self,
348348
ctx: RunContext[AgentDepsT],
349349
*,
350-
validated_output: RawOutput,
350+
validated_output: Any,
351351
output: Any,
352352
output_context: OutputContext,
353353
) -> Any:
@@ -359,7 +359,7 @@ async def wrap_output_execute(
359359
self,
360360
ctx: RunContext[AgentDepsT],
361361
*,
362-
output: RawOutput,
362+
output: Any,
363363
output_context: OutputContext,
364364
handler: WrapOutputExecuteHandler,
365365
) -> Any:
@@ -371,7 +371,7 @@ async def on_output_execute_error(
371371
self,
372372
ctx: RunContext[AgentDepsT],
373373
*,
374-
output: RawOutput,
374+
output: Any,
375375
output_context: OutputContext,
376376
error: Exception,
377377
) -> Any:

pydantic_ai_slim/pydantic_ai/output.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
'TextOutput',
2424
'StructuredDict',
2525
'OutputObjectDefinition',
26+
'OutputContext',
2627
# types
2728
'OutputDataT',
2829
'OutputMode',
@@ -268,6 +269,24 @@ class OutputObjectDefinition:
268269
strict: bool | None = None
269270

270271

272+
@dataclass
273+
class OutputContext:
274+
"""Context about the output being processed, passed to output hooks."""
275+
276+
mode: OutputMode
277+
"""The output mode ('text', 'native', 'prompted', 'tool', 'auto')."""
278+
output_type: type[Any] | None
279+
"""The resolved output type (e.g. MyModel, str). For output functions, the function's input type (what the model produces)."""
280+
object_def: OutputObjectDefinition | None
281+
"""The output object definition (schema, name, description), if structured output."""
282+
has_function: bool
283+
"""Whether there's an output function to call in the execute step."""
284+
tool_call: ToolCallPart | None = None
285+
"""The tool call part, for tool-based output. None for text output."""
286+
tool_def: ToolDefinition | None = None
287+
"""The tool definition, for tool-based output. None for text output."""
288+
289+
271290
@dataclass
272291
class TextOutput(Generic[OutputDataT]):
273292
"""Marker class to use text output for an output function taking a string argument.

tests/test_output_hooks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pytest
99
from pydantic import BaseModel, ValidationError
1010

11-
from pydantic_ai._output import OutputContext
1211
from pydantic_ai._run_context import RunContext
1312
from pydantic_ai.agent import Agent
1413
from pydantic_ai.capabilities.abstract import AbstractCapability
@@ -21,7 +20,7 @@
2120
ToolCallPart,
2221
)
2322
from pydantic_ai.models.function import AgentInfo, FunctionModel
24-
from pydantic_ai.output import PromptedOutput, TextOutput
23+
from pydantic_ai.output import OutputContext, PromptedOutput, TextOutput
2524
from pydantic_ai.tools import ToolDefinition
2625

2726
pytestmark = [

0 commit comments

Comments
 (0)