-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathabstract.py
More file actions
536 lines (436 loc) · 21.1 KB
/
abstract.py
File metadata and controls
536 lines (436 loc) · 21.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
from __future__ import annotations
from abc import ABC
from collections.abc import AsyncIterable, Awaitable, Callable, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, TypeAlias
from pydantic import ValidationError
from pydantic_ai._instructions import AgentInstructions
from pydantic_ai.exceptions import ModelRetry
from pydantic_ai.messages import AgentStreamEvent, ModelResponse, ToolCallPart
from pydantic_ai.tools import AgentBuiltinTool, AgentDepsT, RunContext, ToolDefinition
from pydantic_ai.toolsets import AbstractToolset, AgentToolset
if TYPE_CHECKING:
from pydantic_ai import _agent_graph
from pydantic_ai.agent.abstract import AgentModelSettings
from pydantic_ai.capabilities.prefix_tools import PrefixTools
from pydantic_ai.models import ModelRequestContext
from pydantic_ai.result import FinalResult
from pydantic_ai.run import AgentRunResult
from pydantic_graph import End
# --- Handler type aliases for use in hook method signatures ---
# These make it easier to write correct type annotations when subclassing AbstractCapability.
AgentNode: TypeAlias = '_agent_graph.AgentNode[AgentDepsT, Any]'
"""Type alias for an agent graph node (`UserPromptNode`, `ModelRequestNode`, `CallToolsNode`)."""
NodeResult: TypeAlias = '_agent_graph.AgentNode[AgentDepsT, Any] | End[FinalResult[Any]]'
"""Type alias for the result of executing an agent graph node: either the next node or `End`."""
WrapRunHandler: TypeAlias = 'Callable[[], Awaitable[AgentRunResult[Any]]]'
"""Handler type for [`wrap_run`][pydantic_ai.capabilities.AbstractCapability.wrap_run]."""
WrapNodeRunHandler: TypeAlias = 'Callable[[_agent_graph.AgentNode[AgentDepsT, Any]], Awaitable[_agent_graph.AgentNode[AgentDepsT, Any] | End[FinalResult[Any]]]]'
"""Handler type for [`wrap_node_run`][pydantic_ai.capabilities.AbstractCapability.wrap_node_run]."""
WrapModelRequestHandler: TypeAlias = 'Callable[[ModelRequestContext], Awaitable[ModelResponse]]'
"""Handler type for [`wrap_model_request`][pydantic_ai.capabilities.AbstractCapability.wrap_model_request]."""
RawToolArgs: TypeAlias = 'str | dict[str, Any]'
"""Type alias for raw (pre-validation) tool arguments."""
ValidatedToolArgs: TypeAlias = 'dict[str, Any]'
"""Type alias for validated tool arguments."""
WrapToolValidateHandler: TypeAlias = 'Callable[[str | dict[str, Any]], Awaitable[dict[str, Any]]]'
"""Handler type for [`wrap_tool_validate`][pydantic_ai.capabilities.AbstractCapability.wrap_tool_validate]."""
WrapToolExecuteHandler: TypeAlias = 'Callable[[dict[str, Any]], Awaitable[Any]]'
"""Handler type for [`wrap_tool_execute`][pydantic_ai.capabilities.AbstractCapability.wrap_tool_execute]."""
@dataclass
class AbstractCapability(ABC, Generic[AgentDepsT]):
"""Abstract base class for agent capabilities.
A capability is a reusable, composable unit of agent behavior that can provide
instructions, model settings, tools, and request/response hooks.
Lifecycle: capabilities are passed to an [`Agent`][pydantic_ai.Agent] at construction time, where
most `get_*` methods are called to collect static configuration (instructions, model
settings, toolsets, builtin tools). The exception is
[`get_wrapper_toolset`][pydantic_ai.capabilities.AbstractCapability.get_wrapper_toolset],
which is called per-run during toolset assembly. Then, on each model request during a
run, the [`before_model_request`][pydantic_ai.capabilities.AbstractCapability.before_model_request]
and [`after_model_request`][pydantic_ai.capabilities.AbstractCapability.after_model_request]
hooks are called to allow dynamic adjustments.
See the [capabilities documentation](capabilities.md) for built-in capabilities.
[`get_serialization_name`][pydantic_ai.capabilities.AbstractCapability.get_serialization_name]
and [`from_spec`][pydantic_ai.capabilities.AbstractCapability.from_spec] support
YAML/JSON specs (via [`Agent.from_spec`][pydantic_ai.Agent.from_spec]); they have
sensible defaults and typically don't need to be overridden.
"""
@property
def has_wrap_node_run(self) -> bool:
"""Whether this capability (or any sub-capability) overrides wrap_node_run."""
return type(self).wrap_node_run is not AbstractCapability.wrap_node_run
@classmethod
def get_serialization_name(cls) -> str | None:
"""Return the name used for spec serialization (CamelCase class name by default).
Return None to opt out of spec-based construction.
"""
return cls.__name__
@classmethod
def from_spec(cls, *args: Any, **kwargs: Any) -> AbstractCapability[Any]:
"""Create from spec arguments. Default: `cls(*args, **kwargs)`.
Override when `__init__` takes non-serializable types.
"""
return cls(*args, **kwargs)
async def for_run(self, ctx: RunContext[AgentDepsT]) -> AbstractCapability[AgentDepsT]:
"""Return the capability instance to use for this agent run.
Called once per run, before `get_*()` re-extraction and before any hooks fire.
Override to return a fresh instance for per-run state isolation.
Default: return `self` (shared across runs).
"""
return self
def get_instructions(self) -> AgentInstructions[AgentDepsT] | None:
"""Return instructions to include in the system prompt, or None.
This method is called once at agent construction time. To get dynamic
per-request behavior, return a callable that receives
[`RunContext`][pydantic_ai.tools.RunContext] or a
[`TemplateStr`][pydantic_ai.TemplateStr] — not a dynamic string.
"""
return None
def get_model_settings(self) -> AgentModelSettings[AgentDepsT] | None:
"""Return model settings to merge into the agent's defaults, or None.
This method is called once at agent construction time. Return a static
`ModelSettings` dict when the settings don't change between requests.
Return a callable that receives [`RunContext`][pydantic_ai.tools.RunContext]
when settings need to vary per step (e.g. based on `ctx.run_step` or `ctx.deps`).
When the callable is invoked, `ctx.model_settings` contains the merged
result of all layers resolved before this capability (model defaults and
agent-level settings). The returned dict is merged on top of that.
"""
return None
def get_toolset(self) -> AgentToolset[AgentDepsT] | None:
"""Return a toolset to register with the agent, or None."""
return None
def get_builtin_tools(self) -> Sequence[AgentBuiltinTool[AgentDepsT]]:
"""Return builtin tools to register with the agent."""
return []
def get_wrapper_toolset(self, toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT] | None:
"""Wrap the agent's assembled toolset, or return None to leave it unchanged.
Called per-run with the combined non-output toolset (after agent-level
[`prepare_tools`][pydantic_ai.tools.ToolsPrepareFunc] wrapping).
Output tools are added separately and are not included.
Unlike the other `get_*` methods which are called once at agent construction,
this is called each run (after [`for_run`][pydantic_ai.capabilities.AbstractCapability.for_run]).
When multiple capabilities provide wrappers, each receives the already-wrapped
toolset from earlier capabilities (first capability wraps innermost).
Use this to apply cross-cutting toolset wrappers like
[`PreparedToolset`][pydantic_ai.toolsets.PreparedToolset],
[`FilteredToolset`][pydantic_ai.toolsets.FilteredToolset],
or custom [`WrapperToolset`][pydantic_ai.toolsets.WrapperToolset] subclasses.
"""
return None
# --- Tool preparation hook ---
async def prepare_tools(
self,
ctx: RunContext[AgentDepsT],
tool_defs: list[ToolDefinition],
) -> list[ToolDefinition]:
"""Filter or modify tool definitions visible to the model for this step.
The list contains all tool kinds (function, output, unapproved) distinguished
by [`tool_def.kind`][pydantic_ai.tools.ToolDefinition.kind]. Return a filtered
or modified list. Called after the agent-level
[`prepare_tools`][pydantic_ai.tools.ToolsPrepareFunc] has already run.
"""
return tool_defs
# --- Run lifecycle hooks ---
async def before_run(
self,
ctx: RunContext[AgentDepsT],
) -> None:
"""Called before the agent run starts. Observe-only; use wrap_run for modification."""
async def after_run(
self,
ctx: RunContext[AgentDepsT],
*,
result: AgentRunResult[Any],
) -> AgentRunResult[Any]:
"""Called after the agent run completes. Can modify the result."""
return result
async def wrap_run(
self,
ctx: RunContext[AgentDepsT],
*,
handler: WrapRunHandler,
) -> AgentRunResult[Any]:
"""Wraps the entire agent run. `handler()` executes the run.
If `handler()` raises and this method catches the exception and
returns a result instead, the error is suppressed and the recovery
result is used.
If this method does not call `handler()` (short-circuit), the run
is skipped and the returned result is used directly.
Note: if the caller cancels the run (e.g. by breaking out of an
`iter()` loop), this method receives an :class:`asyncio.CancelledError`.
Implementations that hold resources should handle cleanup accordingly.
"""
return await handler()
async def on_run_error(
self,
ctx: RunContext[AgentDepsT],
*,
error: BaseException,
) -> AgentRunResult[Any]:
"""Called when the agent run fails with an exception.
This is the error counterpart to
[`after_run`][pydantic_ai.capabilities.AbstractCapability.after_run]:
while `after_run` is called on success, `on_run_error` is called on
failure (after [`wrap_run`][pydantic_ai.capabilities.AbstractCapability.wrap_run]
has had its chance to recover).
**Raise** the original `error` (or a different exception) to propagate it.
**Return** an [`AgentRunResult`][pydantic_ai.run.AgentRunResult] to suppress
the error and recover the run.
Not called for `GeneratorExit` or `KeyboardInterrupt`.
"""
raise error
# --- Node run lifecycle hooks ---
async def before_node_run(
self,
ctx: RunContext[AgentDepsT],
*,
node: AgentNode[AgentDepsT],
) -> AgentNode[AgentDepsT]:
"""Called before each graph node executes. Can observe or replace the node."""
return node
async def after_node_run(
self,
ctx: RunContext[AgentDepsT],
*,
node: AgentNode[AgentDepsT],
result: NodeResult[AgentDepsT],
) -> NodeResult[AgentDepsT]:
"""Called after each graph node succeeds. Can modify the result (next node or `End`)."""
return result
async def wrap_node_run(
self,
ctx: RunContext[AgentDepsT],
*,
node: AgentNode[AgentDepsT],
handler: WrapNodeRunHandler[AgentDepsT],
) -> NodeResult[AgentDepsT]:
"""Wraps execution of each agent graph node (run step).
Called for every node in the agent graph (`UserPromptNode`,
`ModelRequestNode`, `CallToolsNode`). `handler(node)` executes
the node and returns the next node (or `End`).
Override to inspect or modify nodes before execution, inspect or modify
the returned next node, call `handler` multiple times (retry), or
return a different node to redirect graph progression.
Note: this hook fires when using [`agent.run()`][pydantic_ai.Agent.run],
[`agent.run_stream()`][pydantic_ai.Agent.run_stream], and when manually driving
an [`agent.iter()`][pydantic_ai.Agent.iter] run with
[`next()`][pydantic_ai.result.AgentRun.next], but it does **not** fire when
iterating over the run with bare `async for` (which yields stream events, not
node results).
When using `agent.run()` with `event_stream_handler`, the handler wraps both
streaming and graph advancement (i.e. the model call happens inside the wrapper).
When using `agent.run_stream()`, the handler wraps only graph advancement — streaming
happens before the wrapper because `run_stream()` must yield the stream to the caller
while the stream context is still open, which cannot happen from inside a callback.
"""
return await handler(node)
async def on_node_run_error(
self,
ctx: RunContext[AgentDepsT],
*,
node: AgentNode[AgentDepsT],
error: Exception,
) -> NodeResult[AgentDepsT]:
"""Called when a graph node fails with an exception.
This is the error counterpart to
[`after_node_run`][pydantic_ai.capabilities.AbstractCapability.after_node_run].
**Raise** the original `error` (or a different exception) to propagate it.
**Return** a next node or `End` to recover and continue the graph.
Useful for recovering from
[`UnexpectedModelBehavior`][pydantic_ai.exceptions.UnexpectedModelBehavior]
by redirecting to a different node (e.g. retry with different model settings).
"""
raise error
# --- Event stream hook ---
async def wrap_run_event_stream(
self,
ctx: RunContext[AgentDepsT],
*,
stream: AsyncIterable[AgentStreamEvent],
) -> AsyncIterable[AgentStreamEvent]:
"""Wraps the event stream for a streamed node. Can observe or transform events."""
async for event in stream:
yield event
# --- Model request lifecycle hooks ---
async def before_model_request(
self,
ctx: RunContext[AgentDepsT],
request_context: ModelRequestContext,
) -> ModelRequestContext:
"""Called before each model request. Can modify messages, settings, and parameters."""
return request_context
async def after_model_request(
self,
ctx: RunContext[AgentDepsT],
*,
request_context: ModelRequestContext,
response: ModelResponse,
) -> ModelResponse:
"""Called after each model response. Can modify the response before further processing.
Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the response and
ask the model to try again. The original response is still appended to message history
so the model can see what it said. Retries count against `max_result_retries`.
"""
return response
async def wrap_model_request(
self,
ctx: RunContext[AgentDepsT],
*,
request_context: ModelRequestContext,
handler: WrapModelRequestHandler,
) -> ModelResponse:
"""Wraps the model request. handler() calls the model.
Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip `on_model_request_error`
and directly retry the model request with a retry prompt.
"""
return await handler(request_context)
async def on_model_request_error(
self,
ctx: RunContext[AgentDepsT],
*,
request_context: ModelRequestContext,
error: Exception,
) -> ModelResponse:
"""Called when a model request fails with an exception.
This is the error counterpart to
[`after_model_request`][pydantic_ai.capabilities.AbstractCapability.after_model_request].
**Raise** the original `error` (or a different exception) to propagate it.
**Return** a [`ModelResponse`][pydantic_ai.messages.ModelResponse] to suppress
the error and use the response as if the model call succeeded.
**Raise** [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to retry the model request
with a retry prompt instead of recovering or propagating.
Not called for [`SkipModelRequest`][pydantic_ai.exceptions.SkipModelRequest]
or [`ModelRetry`][pydantic_ai.exceptions.ModelRetry].
"""
raise error
# --- Tool validate lifecycle hooks ---
async def before_tool_validate(
self,
ctx: RunContext[AgentDepsT],
*,
call: ToolCallPart,
tool_def: ToolDefinition,
args: RawToolArgs,
) -> RawToolArgs:
"""Modify raw args before validation.
Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip validation and
ask the model to redo the tool call.
"""
return args
async def after_tool_validate(
self,
ctx: RunContext[AgentDepsT],
*,
call: ToolCallPart,
tool_def: ToolDefinition,
args: ValidatedToolArgs,
) -> ValidatedToolArgs:
"""Modify validated args. Called only on successful validation.
Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the validated args
and ask the model to redo the tool call.
"""
return args
async def wrap_tool_validate(
self,
ctx: RunContext[AgentDepsT],
*,
call: ToolCallPart,
tool_def: ToolDefinition,
args: RawToolArgs,
handler: WrapToolValidateHandler,
) -> ValidatedToolArgs:
"""Wraps tool argument validation. handler() runs the validation."""
return await handler(args)
async def on_tool_validate_error(
self,
ctx: RunContext[AgentDepsT],
*,
call: ToolCallPart,
tool_def: ToolDefinition,
args: RawToolArgs,
error: ValidationError | ModelRetry,
) -> ValidatedToolArgs:
"""Called when tool argument validation fails.
This is the error counterpart to
[`after_tool_validate`][pydantic_ai.capabilities.AbstractCapability.after_tool_validate].
Fires for [`ValidationError`][pydantic.ValidationError] (schema mismatch) and
[`ModelRetry`][pydantic_ai.exceptions.ModelRetry] (custom validator rejection).
**Raise** the original `error` (or a different exception) to propagate it.
**Return** validated args to suppress the error and continue as if validation passed.
Not called for [`SkipToolValidation`][pydantic_ai.exceptions.SkipToolValidation].
"""
raise error
# --- Tool execute lifecycle hooks ---
async def before_tool_execute(
self,
ctx: RunContext[AgentDepsT],
*,
call: ToolCallPart,
tool_def: ToolDefinition,
args: ValidatedToolArgs,
) -> ValidatedToolArgs:
"""Modify validated args before execution.
Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to skip execution and
ask the model to redo the tool call.
"""
return args
async def after_tool_execute(
self,
ctx: RunContext[AgentDepsT],
*,
call: ToolCallPart,
tool_def: ToolDefinition,
args: ValidatedToolArgs,
result: Any,
) -> Any:
"""Modify result after execution.
Raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to reject the tool result
and ask the model to redo the tool call.
"""
return result
async def wrap_tool_execute(
self,
ctx: RunContext[AgentDepsT],
*,
call: ToolCallPart,
tool_def: ToolDefinition,
args: ValidatedToolArgs,
handler: WrapToolExecuteHandler,
) -> Any:
"""Wraps tool execution. handler() runs the tool."""
return await handler(args)
async def on_tool_execute_error(
self,
ctx: RunContext[AgentDepsT],
*,
call: ToolCallPart,
tool_def: ToolDefinition,
args: ValidatedToolArgs,
error: Exception,
) -> Any:
"""Called when tool execution fails with an exception.
This is the error counterpart to
[`after_tool_execute`][pydantic_ai.capabilities.AbstractCapability.after_tool_execute].
**Raise** the original `error` (or a different exception) to propagate it.
**Return** any value to suppress the error and use it as the tool result.
**Raise** [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to
redo the tool call instead of recovering or propagating.
Not called for control flow exceptions
([`SkipToolExecution`][pydantic_ai.exceptions.SkipToolExecution],
[`CallDeferred`][pydantic_ai.exceptions.CallDeferred],
[`ApprovalRequired`][pydantic_ai.exceptions.ApprovalRequired])
or retry signals ([`ToolRetryError`][pydantic_ai.exceptions.ToolRetryError]
from [`ModelRetry`][pydantic_ai.exceptions.ModelRetry]).
Use [`wrap_tool_execute`][pydantic_ai.capabilities.AbstractCapability.wrap_tool_execute]
to intercept retries.
"""
raise error
# --- Convenience methods ---
def prefix_tools(self, prefix: str) -> PrefixTools[AgentDepsT]:
"""Returns a new capability that wraps this one and prefixes its tool names.
Only this capability's tools are prefixed; other agent tools are unaffected.
"""
from .prefix_tools import PrefixTools
return PrefixTools(wrapped=self, prefix=prefix)