Skip to content

Commit 400bb22

Browse files
authored
Merge branch 'main' into continuation-node-approach
2 parents 0f2cfc5 + a4d6e74 commit 400bb22

File tree

64 files changed

+9804
-3464
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+9804
-3464
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ jobs:
102102
&& 'depot-ubuntu-24.04-4'
103103
|| 'ubuntu-latest'
104104
}}
105-
timeout-minutes: 15
105+
timeout-minutes: 20
106106
strategy:
107107
fail-fast: false
108108
matrix:
@@ -163,7 +163,7 @@ jobs:
163163
&& 'depot-ubuntu-24.04-4'
164164
|| 'ubuntu-latest'
165165
}}
166-
timeout-minutes: 15
166+
timeout-minutes: 20
167167
strategy:
168168
fail-fast: false
169169
matrix:

docs/evals/evaluators/report-evaluators.md

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,11 @@ PrecisionRecallEvaluator(
177177
| `title` | `str` | `'Precision-Recall Curve'` | Title shown in reports |
178178
| `n_thresholds` | `int` | `100` | Number of threshold points on the curve |
179179

180-
**Returns:** [`PrecisionRecall`][pydantic_evals.reporting.analyses.PrecisionRecall]
180+
**Returns:** [`PrecisionRecall`][pydantic_evals.reporting.analyses.PrecisionRecall] + [`ScalarResult`][pydantic_evals.reporting.analyses.ScalarResult] (AUC)
181+
182+
The AUC is computed at full resolution (using every unique score as a threshold) for accuracy,
183+
then the curve points are downsampled to `n_thresholds` for display. The AUC is returned both
184+
on the curve (for chart rendering) and as a separate `ScalarResult` for querying and sorting.
181185

182186
**Score Sources:**
183187

@@ -237,6 +241,77 @@ dataset = Dataset(
237241

238242
---
239243

244+
### ROCAUCEvaluator
245+
246+
Computes an ROC (Receiver Operating Characteristic) curve and AUC from numeric scores
247+
and binary ground-truth labels. The ROC curve plots the True Positive Rate against the
248+
False Positive Rate at various threshold values, with a dashed random-baseline diagonal
249+
for reference.
250+
251+
```python
252+
from pydantic_evals.evaluators import ROCAUCEvaluator
253+
254+
ROCAUCEvaluator(
255+
score_key='confidence',
256+
positive_from='assertions',
257+
positive_key='is_correct',
258+
)
259+
```
260+
261+
**Parameters:**
262+
263+
| Parameter | Type | Default | Description |
264+
|-----------|------|---------|-------------|
265+
| `score_key` | `str` | _(required)_ | Key in scores or metrics dict |
266+
| `positive_from` | `'expected_output' \| 'assertions' \| 'labels'` | _(required)_ | Source for ground-truth binary labels |
267+
| `positive_key` | `str \| None` | `None` | Key in assertions or labels dict |
268+
| `score_from` | `'scores' \| 'metrics'` | `'scores'` | Source for numeric scores |
269+
| `title` | `str` | `'ROC Curve'` | Title shown in reports |
270+
| `n_thresholds` | `int` | `100` | Number of threshold points on the curve |
271+
272+
**Returns:** [`LinePlot`][pydantic_evals.reporting.analyses.LinePlot] + [`ScalarResult`][pydantic_evals.reporting.analyses.ScalarResult] (AUC)
273+
274+
The AUC is computed at full resolution. The chart includes a dashed "Random" baseline
275+
diagonal from (0, 0) to (1, 1) for visual comparison.
276+
277+
**Score and Positive Sources:** Same as [`PrecisionRecallEvaluator`](#precisionrecallevaluator).
278+
279+
---
280+
281+
### KolmogorovSmirnovEvaluator
282+
283+
Computes a Kolmogorov-Smirnov plot and KS statistic from numeric scores and binary
284+
ground-truth labels. The KS plot shows the empirical CDFs (cumulative distribution functions)
285+
of the score distribution for positive and negative cases. The KS statistic is the maximum
286+
vertical distance between the two CDFs — higher values indicate better class separation.
287+
288+
```python
289+
from pydantic_evals.evaluators import KolmogorovSmirnovEvaluator
290+
291+
KolmogorovSmirnovEvaluator(
292+
score_key='confidence',
293+
positive_from='assertions',
294+
positive_key='is_correct',
295+
)
296+
```
297+
298+
**Parameters:**
299+
300+
| Parameter | Type | Default | Description |
301+
|-----------|------|---------|-------------|
302+
| `score_key` | `str` | _(required)_ | Key in scores or metrics dict |
303+
| `positive_from` | `'expected_output' \| 'assertions' \| 'labels'` | _(required)_ | Source for ground-truth binary labels |
304+
| `positive_key` | `str \| None` | `None` | Key in assertions or labels dict |
305+
| `score_from` | `'scores' \| 'metrics'` | `'scores'` | Source for numeric scores |
306+
| `title` | `str` | `'KS Plot'` | Title shown in reports |
307+
| `n_thresholds` | `int` | `100` | Number of threshold points on the curve |
308+
309+
**Returns:** [`LinePlot`][pydantic_evals.reporting.analyses.LinePlot] + [`ScalarResult`][pydantic_evals.reporting.analyses.ScalarResult] (KS Statistic)
310+
311+
**Score and Positive Sources:** Same as [`PrecisionRecallEvaluator`](#precisionrecallevaluator).
312+
313+
---
314+
240315
## Custom Report Evaluators
241316

242317
Write custom report evaluators by inheriting from [`ReportEvaluator`][pydantic_evals.evaluators.ReportEvaluator]
@@ -373,6 +448,54 @@ Precision-recall curve data (typically produced by `PrecisionRecallEvaluator`):
373448
Each `PrecisionRecallCurve` contains a `name`, a list of `PrecisionRecallPoint`s (with `threshold`,
374449
`precision`, `recall`), and an optional `auc` value.
375450

451+
---
452+
453+
#### LinePlot
454+
455+
A generic XY line chart with labeled axes, supporting multiple curves. Use this for ROC curves,
456+
KS plots, calibration curves, or any custom line chart:
457+
458+
```python
459+
from pydantic_evals.reporting.analyses import LinePlot, LinePlotCurve, LinePlotPoint
460+
461+
LinePlot(
462+
title='ROC Curve',
463+
x_label='False Positive Rate',
464+
y_label='True Positive Rate',
465+
x_range=(0, 1),
466+
y_range=(0, 1),
467+
curves=[
468+
LinePlotCurve(
469+
name='Model (AUC: 0.95)',
470+
points=[LinePlotPoint(x=0.0, y=0.0), LinePlotPoint(x=0.1, y=0.8), LinePlotPoint(x=1.0, y=1.0)],
471+
),
472+
LinePlotCurve(
473+
name='Random',
474+
points=[LinePlotPoint(x=0, y=0), LinePlotPoint(x=1, y=1)],
475+
style='dashed',
476+
),
477+
],
478+
)
479+
```
480+
481+
| Field | Type | Description |
482+
|-------|------|-------------|
483+
| `title` | `str` | Display name |
484+
| `x_label` | `str` | Label for the x-axis |
485+
| `y_label` | `str` | Label for the y-axis |
486+
| `x_range` | `tuple[float, float] \| None` | Optional fixed range for x-axis |
487+
| `y_range` | `tuple[float, float] \| None` | Optional fixed range for y-axis |
488+
| `curves` | `list[LinePlotCurve]` | One or more curves to plot |
489+
| `description` | `str \| None` | Optional longer description |
490+
491+
Each `LinePlotCurve` contains a `name`, a list of `LinePlotPoint`s (with `x`, `y`),
492+
an optional `style` (`'solid'` or `'dashed'`), and an optional `step` interpolation
493+
mode (`'start'`, `'middle'`, or `'end'`) for step functions like empirical CDFs.
494+
495+
`LinePlot` is the recommended return type for custom curve-based evaluators — any evaluator
496+
that returns a `LinePlot` will be rendered as a line chart in the Logfire UI without requiring
497+
any frontend changes.
498+
376499
### Returning Multiple Analyses
377500

378501
A single report evaluator can return multiple analyses by returning a list:
@@ -471,8 +594,8 @@ report_evaluators:
471594
positive_key: is_correct
472595
```
473596
474-
Built-in report evaluators (`ConfusionMatrixEvaluator`, `PrecisionRecallEvaluator`) are
475-
recognized automatically. For custom report evaluators, pass them via `custom_report_evaluator_types`:
597+
Built-in report evaluators (`ConfusionMatrixEvaluator`, `PrecisionRecallEvaluator`,
598+
`ROCAUCEvaluator`, `KolmogorovSmirnovEvaluator`) are recognized automatically. For custom report evaluators, pass them via `custom_report_evaluator_types`:
476599

477600
```python {test="skip" lint="skip"}
478601
from pydantic_evals import Dataset
@@ -501,6 +624,7 @@ as interactive visualizations:
501624

502625
- **Confusion matrices** are displayed as heatmaps
503626
- **Precision-recall curves** are rendered as line charts with AUC in the legend
627+
- **Line plots** (ROC curves, KS plots, etc.) are rendered as line charts with configurable axes
504628
- **Scalar results** are shown as labeled values
505629
- **Tables** are rendered as formatted data tables
506630

@@ -520,9 +644,11 @@ from pydantic_evals.evaluators import (
520644
ConfusionMatrixEvaluator,
521645
Evaluator,
522646
EvaluatorContext,
647+
KolmogorovSmirnovEvaluator,
523648
PrecisionRecallEvaluator,
524649
ReportEvaluator,
525650
ReportEvaluatorContext,
651+
ROCAUCEvaluator,
526652
)
527653
from pydantic_evals.reporting.analyses import ScalarResult
528654
@@ -586,6 +712,18 @@ dataset = Dataset(
586712
positive_from='assertions',
587713
positive_key='is_correct',
588714
),
715+
ROCAUCEvaluator(
716+
score_from='scores',
717+
score_key='confidence',
718+
positive_from='assertions',
719+
positive_key='is_correct',
720+
),
721+
KolmogorovSmirnovEvaluator(
722+
score_from='scores',
723+
score_key='confidence',
724+
positive_from='assertions',
725+
positive_key='is_correct',
726+
),
589727
AccuracyEvaluator(),
590728
],
591729
)
@@ -597,6 +735,11 @@ for analysis in report.analyses:
597735
print(f'{analysis.type}: {analysis.title}')
598736
#> confusion_matrix: Animal Classification
599737
#> precision_recall: Precision-Recall Curve
738+
#> scalar: Precision-Recall Curve AUC
739+
#> line_plot: ROC Curve
740+
#> scalar: ROC Curve AUC
741+
#> line_plot: KS Plot
742+
#> scalar: KS Statistic
600743
#> scalar: Accuracy
601744
```
602745

docs/install.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pip/uv-add "pydantic-ai-slim[openai]"
5050
* `mistral` — installs [Mistral Model](models/mistral.md) dependency `mistralai` [PyPI ↗](https://pypi.org/project/mistralai){:target="_blank"}
5151
* `cohere` - installs [Cohere Model](models/cohere.md) dependency `cohere` [PyPI ↗](https://pypi.org/project/cohere){:target="_blank"}
5252
* `bedrock` - installs [Bedrock Model](models/bedrock.md) dependency `boto3` [PyPI ↗](https://pypi.org/project/boto3){:target="_blank"}
53-
* `huggingface` - installs [Hugging Face Model](models/huggingface.md) dependency `huggingface-hub[inference]` [PyPI ↗](https://pypi.org/project/huggingface-hub){:target="_blank"}
53+
* `huggingface` - installs [Hugging Face Model](models/huggingface.md) dependency `huggingface-hub` [PyPI ↗](https://pypi.org/project/huggingface-hub){:target="_blank"}
5454
* `outlines-transformers` - installs [Outlines Model](models/outlines.md) dependency `outlines[transformers]` [PyPI ↗](https://pypi.org/project/outlines){:target="_blank"}
5555
* `outlines-llamacpp` - installs [Outlines Model](models/outlines.md) dependency `outlines[llamacpp]` [PyPI ↗](https://pypi.org/project/outlines){:target="_blank"}
5656
* `outlines-mlxlm` - installs [Outlines Model](models/outlines.md) dependency `outlines[mlxlm]` [PyPI ↗](https://pypi.org/project/outlines){:target="_blank"}

docs/ui/vercel-ai.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Vercel AI Data Stream Protocol
22

3-
Pydantic AI natively supports the [Vercel AI Data Stream Protocol](https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol) to receive agent run input from, and stream events to, a [Vercel AI Elements](https://ai-sdk.dev/elements) frontend.
3+
Pydantic AI natively supports the [Vercel AI Data Stream Protocol](https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol) to receive agent run input from, and stream events to, a frontend using [AI SDK UI](https://ai-sdk.dev/docs/ai-sdk-ui/overview) hooks like [`useChat`](https://ai-sdk.dev/docs/reference/ai-sdk-ui/use-chat). You can optionally use [AI Elements](https://ai-sdk.dev/elements) for pre-built UI components.
44

55
!!! note
66
By default, the adapter targets AI SDK v5 for backwards compatibility. To use features introduced in AI SDK v6, set `sdk_version=6` on the adapter.
@@ -123,3 +123,26 @@ async def search_docs(query: str) -> ToolReturn:
123123

124124
!!! note
125125
Protocol-control chunks such as `StartChunk`, `FinishChunk`, `StartStepChunk`, or `FinishStepChunk` are automatically filtered out — only the four data-carrying chunk types listed above are forwarded to the stream and preserved in `dump_messages`.
126+
127+
## Tool Approval
128+
129+
!!! note
130+
Tool approval requires AI SDK UI v6 or later on the frontend.
131+
132+
Pydantic AI supports human-in-the-loop tool approval workflows with AI SDK UI, allowing users to approve or deny tool executions before they run. See the [deferred tool calls documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for details on setting up tools that require approval.
133+
134+
To enable tool approval streaming, pass `sdk_version=6` to `dispatch_request`:
135+
136+
```py {test="skip" lint="skip"}
137+
@app.post('/chat')
138+
async def chat(request: Request) -> Response:
139+
return await VercelAIAdapter.dispatch_request(request, agent=agent, sdk_version=6)
140+
```
141+
142+
When `sdk_version=6`, the adapter will:
143+
144+
1. Emit `tool-approval-request` chunks when tools with `requires_approval=True` are called
145+
2. Automatically extract approval responses from follow-up requests
146+
3. Emit `tool-output-denied` chunks for rejected tools
147+
148+
On the frontend, AI SDK UI's [`useChat`](https://ai-sdk.dev/docs/reference/ai-sdk-ui/use-chat) hook handles the approval flow. You can use the [`Confirmation`](https://ai-sdk.dev/elements/components/confirmation) component from AI Elements for a pre-built approval UI, or build your own using the hook's `addToolApprovalResponse` function.

mkdocs.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,13 @@ plugins:
360360
- troubleshooting.md
361361
Concepts documentation:
362362
- a2a.md
363-
- ag-ui.md
364-
- Agents: agent.md
363+
- agent.md
365364
- builtin-tools.md
366365
- dependencies.md
367366
- deferred-tools.md
368367
- direct.md
369368
- embeddings.md
369+
- gateway.md
370370
- input.md
371371
- tools.md
372372
- common-tools.md
@@ -378,14 +378,17 @@ plugins:
378378
- third-party-tools.md
379379
- tools-advanced.md
380380
- toolsets.md
381+
- web.md
381382
Models:
382383
- models/*.md
383384
Graphs:
384385
- graph.md
386+
- graph/*.md
385387
API Reference:
386388
- api/*.md
387389
Evals:
388390
- evals.md
391+
- evals/*.md
389392
Durable Execution:
390393
- durable_execution/*.md
391394
MCP:

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -646,13 +646,18 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
646646
# Check for content filter on empty response
647647
if self.model_response.finish_reason == 'content_filter':
648648
details = self.model_response.provider_details or {}
649-
reason = details.get('finish_reason', 'content_filter')
650-
651649
body = _messages.ModelMessagesTypeAdapter.dump_json([self.model_response]).decode()
652650

653-
raise exceptions.ContentFilterError(
654-
f"Content filter triggered. Finish reason: '{reason}'", body=body
655-
)
651+
if reason := details.get('finish_reason'):
652+
message = f"Content filter triggered. Finish reason: '{reason}'"
653+
elif reason := details.get('block_reason'):
654+
message = f"Content filter triggered. Block reason: '{reason}'"
655+
elif refusal := details.get('refusal'):
656+
message = f'Content filter triggered. Refusal: {refusal!r}'
657+
else: # pragma: no cover
658+
message = 'Content filter triggered.'
659+
660+
raise exceptions.ContentFilterError(message, body=body)
656661

657662
# we got an empty response.
658663
# this sometimes happens with anthropic (and perhaps other models)

pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_mcp_server.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22

33
from pydantic_ai import ToolsetTool
44
from pydantic_ai.mcp import MCPServer
5-
from pydantic_ai.tools import AgentDepsT, ToolDefinition
5+
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
66

77
from ._mcp import DBOSMCPToolset
88
from ._utils import StepConfig
99

1010

1111
class DBOSMCPServer(DBOSMCPToolset[AgentDepsT]):
12-
"""A wrapper for MCPServer that integrates with DBOS, turning call_tool and get_tools to DBOS steps."""
12+
"""A wrapper for MCPServer that integrates with DBOS, turning call_tool and get_tools into DBOS steps.
13+
14+
Tool definitions are cached across steps to avoid redundant MCP server round-trips,
15+
respecting the wrapped server's `cache_tools` setting.
16+
"""
1317

1418
def __init__(
1519
self,
@@ -23,7 +27,24 @@ def __init__(
2327
step_name_prefix=step_name_prefix,
2428
step_config=step_config,
2529
)
30+
# Cached across steps to avoid redundant MCP connections per step.
31+
# Not invalidated by `tools/list_changed` notifications — users who need
32+
# dynamic tools during a workflow should set `cache_tools=False`.
33+
self._cached_tool_defs: dict[str, ToolDefinition] | None = None
2634

27-
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
35+
@property
36+
def _server(self) -> MCPServer:
2837
assert isinstance(self.wrapped, MCPServer)
29-
return self.wrapped.tool_for_tool_def(tool_def)
38+
return self.wrapped
39+
40+
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
41+
return self._server.tool_for_tool_def(tool_def)
42+
43+
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
44+
if self._server.cache_tools and self._cached_tool_defs is not None:
45+
return {name: self.tool_for_tool_def(td) for name, td in self._cached_tool_defs.items()}
46+
47+
result = await super().get_tools(ctx)
48+
if self._server.cache_tools:
49+
self._cached_tool_defs = {name: tool.tool_def for name, tool in result.items()}
50+
return result

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_logfire.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _default_setup_logfire() -> Logfire:
1515
import logfire
1616

1717
instance = logfire.configure()
18-
logfire.instrument_pydantic_ai()
18+
instance.instrument_pydantic_ai()
1919
return instance
2020

2121

0 commit comments

Comments
 (0)