diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index eaa096bc41..671c172354 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -30,6 +30,7 @@ from .ui import SSE_CONTENT_TYPE, OnCompleteFunc, StateDeps, StateHandler from .ui.ag_ui import AGUIAdapter + from .ui.ag_ui._event_stream import DEFAULT_AG_UI_VERSION, AGUIVersion from .ui.ag_ui.app import AGUIApp except ImportError as e: # pragma: no cover raise ImportError( @@ -53,6 +54,8 @@ async def handle_ag_ui_request( agent: AbstractAgent[AgentDepsT, Any], request: Request, *, + ag_ui_version: AGUIVersion = DEFAULT_AG_UI_VERSION, + preserve_file_data: bool = False, output_type: OutputSpec[Any] | None = None, message_history: Sequence[ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, @@ -71,6 +74,9 @@ async def handle_ag_ui_request( Args: agent: The agent to run. request: The Starlette request (e.g. from FastAPI) containing the AG-UI run input. + ag_ui_version: AG-UI protocol version controlling thinking/reasoning event format. + preserve_file_data: Whether to preserve agent-generated files and uploaded files + in AG-UI message conversion. See [`AGUIAdapter.preserve_file_data`][pydantic_ai.ui.ag_ui.AGUIAdapter.preserve_file_data]. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. @@ -94,6 +100,8 @@ async def handle_ag_ui_request( return await AGUIAdapter[AgentDepsT].dispatch_request( request, agent=agent, + ag_ui_version=ag_ui_version, + preserve_file_data=preserve_file_data, deps=deps, output_type=output_type, message_history=message_history, @@ -114,6 +122,8 @@ def run_ag_ui( run_input: RunAgentInput, accept: str = SSE_CONTENT_TYPE, *, + ag_ui_version: AGUIVersion = DEFAULT_AG_UI_VERSION, + preserve_file_data: bool = False, output_type: OutputSpec[Any] | None = None, message_history: Sequence[ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, @@ -133,6 +143,9 @@ def run_ag_ui( agent: The agent to run. run_input: The AG-UI run input containing thread_id, run_id, messages, etc. accept: The accept header value for the run. + ag_ui_version: AG-UI protocol version controlling thinking/reasoning event format. + preserve_file_data: Whether to preserve agent-generated files and uploaded files + in AG-UI message conversion. See [`AGUIAdapter.preserve_file_data`][pydantic_ai.ui.ag_ui.AGUIAdapter.preserve_file_data]. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's output type. @@ -153,7 +166,13 @@ def run_ag_ui( Yields: Streaming event chunks encoded as strings according to the accept header value. """ - adapter = AGUIAdapter(agent=agent, run_input=run_input, accept=accept) + adapter = AGUIAdapter( + agent=agent, + run_input=run_input, + accept=accept, + ag_ui_version=ag_ui_version, + preserve_file_data=preserve_file_data, + ) return adapter.encode_stream( adapter.run_stream( output_type=output_type, diff --git a/pydantic_ai_slim/pydantic_ai/ui/AGENTS.md b/pydantic_ai_slim/pydantic_ai/ui/AGENTS.md new file mode 100644 index 0000000000..5dcffe6b9c --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ui/AGENTS.md @@ -0,0 +1,6 @@ +## Backwards compatibility in UI adapters (specially AG-UI) + +Since [3971](https://github.com/pydantic/pydantic-ai/pull/3971#discussion_r3011028336) we decided to introduce the policy of sticking to the lower (existing) version requirement. In short, this means: +- version requirement bumps are disallowed +- new functionality should be gated behind version checks (including imports) +- older versions don't error out when they encounter new functionality, but instead skip it diff --git a/pydantic_ai_slim/pydantic_ai/ui/CLAUDE.md b/pydantic_ai_slim/pydantic_ai/ui/CLAUDE.md new file mode 120000 index 0000000000..47dc3e3d86 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ui/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/__init__.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/__init__.py index 6228771869..ec358d1f92 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/__init__.py @@ -1,9 +1,11 @@ """AG-UI protocol integration for Pydantic AI agents.""" from ._adapter import AGUIAdapter -from ._event_stream import AGUIEventStream +from ._event_stream import DEFAULT_AG_UI_VERSION, AGUIEventStream, AGUIVersion __all__ = [ 'AGUIAdapter', 'AGUIEventStream', + 'AGUIVersion', + 'DEFAULT_AG_UI_VERSION', ] diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py index 1fe191c210..4b2fc3f696 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py @@ -3,8 +3,10 @@ from __future__ import annotations import json +import uuid from base64 import b64decode from collections.abc import Mapping, Sequence +from dataclasses import KW_ONLY, dataclass from functools import cached_property from typing import ( TYPE_CHECKING, @@ -12,19 +14,29 @@ cast, ) +from typing_extensions import assert_never + from ... import ExternalToolset, ToolDefinition from ...messages import ( AudioUrl, BinaryContent, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentUrl, + FilePart, ImageUrl, ModelMessage, + ModelRequest, + ModelResponse, + RetryPromptPart, SystemPromptPart, + TextContent, TextPart, + ThinkingPart, ToolCallPart, ToolReturnPart, + UploadedFile, UserPromptPart, VideoUrl, ) @@ -39,17 +51,27 @@ BaseEvent, BinaryInputContent, DeveloperMessage, + FunctionCall, Message, RunAgentInput, SystemMessage, TextInputContent, Tool as AGUITool, + ToolCall, ToolMessage, UserMessage, ) from .. import MessagesBuilder, UIAdapter, UIEventStream - from ._event_stream import BUILTIN_TOOL_CALL_ID_PREFIX, AGUIEventStream + from ._event_stream import ( + BUILTIN_TOOL_CALL_ID_PREFIX, + DEFAULT_AG_UI_VERSION, + REASONING_VERSION, + AGUIEventStream, + AGUIVersion, + parse_ag_ui_version, + thinking_encrypted_metadata, + ) except ImportError as e: # pragma: no cover raise ImportError( 'Please install the `ag-ui-protocol` package to use AG-UI integration, ' @@ -57,7 +79,18 @@ ) from e if TYPE_CHECKING: - pass + from ag_ui.core import ReasoningMessage + from starlette.requests import Request + + from ...agent import AbstractAgent +else: + try: + from ag_ui.core import ReasoningMessage + except ImportError: # pragma: no cover + + class ReasoningMessage: + """Stub for ag-ui-protocol < 0.1.13 — no instances exist, so pattern matching is a no-op.""" + __all__ = ['AGUIAdapter'] @@ -91,9 +124,64 @@ def label(self) -> str: return 'the AG-UI frontend tools' # pragma: no cover +def _new_message_id() -> str: + """Generate a new unique message ID.""" + return str(uuid.uuid4()) + + +def _user_content_to_input( + item: str | TextContent | ImageUrl | VideoUrl | AudioUrl | DocumentUrl | BinaryContent | UploadedFile | CachePoint, +) -> TextInputContent | BinaryInputContent | None: + """Convert a user content item to AG-UI input content.""" + if isinstance(item, str): + return TextInputContent(type='text', text=item) + elif isinstance(item, TextContent): + return TextInputContent(type='text', text=item.content) + elif isinstance(item, (ImageUrl, VideoUrl, AudioUrl, DocumentUrl)): + return BinaryInputContent(type='binary', url=item.url, mime_type=item.media_type or '') + elif isinstance(item, BinaryContent): + return BinaryInputContent(type='binary', data=item.base64, mime_type=item.media_type) + elif isinstance(item, UploadedFile): + # UploadedFile holds an opaque provider file_id (e.g. 'file-abc123'), not a URL or + # binary data, so it can't be mapped to AG-UI's BinaryInputContent. Skipped like CachePoint. + return None + elif isinstance(item, CachePoint): + return None + else: + assert_never(item) + + +@dataclass class AGUIAdapter(UIAdapter[RunAgentInput, Message, BaseEvent, AgentDepsT, OutputDataT]): """UI adapter for the Agent-User Interaction (AG-UI) protocol.""" + _: KW_ONLY + ag_ui_version: AGUIVersion = DEFAULT_AG_UI_VERSION + """AG-UI protocol version controlling thinking/reasoning event format. + + Defaults to the version detected from the installed `ag-ui-protocol` package. + + - `'0.1.10'`: emits `THINKING_*` events during streaming, drops `ThinkingPart` + from `dump_messages` output. Compatible with AG-UI frontends that don't support reasoning events. + - `'0.1.13'`: emits `REASONING_*` events with encrypted metadata during streaming, and + includes `ThinkingPart` as `ReasoningMessage` in `dump_messages` output for full round-trip + fidelity of thinking signatures and provider metadata. + + `load_messages` always accepts `ReasoningMessage` regardless of this setting. + """ + + preserve_file_data: bool = False + """Whether to preserve agent-generated files and uploaded files in AG-UI message conversion. + + When `True`, agent-generated files and uploaded files are stored as + [activity messages](https://docs.ag-ui.com/concepts/activities) during `dump_messages` + and restored during `load_messages`, enabling full round-trip fidelity. + When `False` (default), they are silently dropped. + + If your AG-UI frontend uses activities, be aware that `pydantic_ai_*` activity types + are reserved for internal round-trip use and should be ignored by frontend activity handlers. + """ + @classmethod def build_run_input(cls, body: bytes) -> RunAgentInput: """Build an AG-UI run input object from the request body.""" @@ -101,12 +189,27 @@ def build_run_input(cls, body: bytes) -> RunAgentInput: def build_event_stream(self) -> UIEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]: """Build an AG-UI event stream transformer.""" - return AGUIEventStream(self.run_input, accept=self.accept) + return AGUIEventStream(self.run_input, accept=self.accept, ag_ui_version=self.ag_ui_version) + + @classmethod + async def from_request( + cls, + request: Request, + *, + agent: AbstractAgent[AgentDepsT, OutputDataT], + ag_ui_version: AGUIVersion = DEFAULT_AG_UI_VERSION, + preserve_file_data: bool = False, + **kwargs: Any, + ) -> AGUIAdapter[AgentDepsT, OutputDataT]: + """Extends [`from_request`][pydantic_ai.ui.UIAdapter.from_request] with AG-UI-specific parameters.""" + return await super().from_request( + request, agent=agent, ag_ui_version=ag_ui_version, preserve_file_data=preserve_file_data, **kwargs + ) @cached_property def messages(self) -> list[ModelMessage]: """Pydantic AI messages from the AG-UI run input.""" - return self.load_messages(self.run_input.messages) + return self.load_messages(self.run_input.messages, preserve_file_data=self.preserve_file_data) @cached_property def toolset(self) -> AbstractToolset[AgentDepsT] | None: @@ -128,7 +231,7 @@ def state(self) -> dict[str, Any] | None: return cast('dict[str, Any]', state) @classmethod - def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: # noqa: C901 + def load_messages(cls, messages: Sequence[Message], *, preserve_file_data: bool = False) -> list[ModelMessage]: # noqa: C901 """Transform AG-UI messages into Pydantic AI messages.""" builder = MessagesBuilder() tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping. @@ -163,8 +266,8 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: # no else: # pragma: no cover raise ValueError('BinaryInputContent must have either a `url` or `data` field.') user_prompt_content.append(binary_part) - case _: # pragma: no cover - raise ValueError(f'Unsupported user message part type: {type(part)}') + case _: + assert_never(part) if user_prompt_content: content_to_add = ( @@ -216,7 +319,7 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: # no if isinstance(content, str): try: content = json.loads(content) - except (json.JSONDecodeError, ValueError): + except json.JSONDecodeError: pass builder.add( BuiltinToolReturnPart( @@ -235,7 +338,288 @@ def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: # no ) ) - case ActivityMessage(): - pass + case ReasoningMessage() as reasoning_msg: + try: + metadata: dict[str, Any] = ( + json.loads(reasoning_msg.encrypted_value) if reasoning_msg.encrypted_value else {} + ) + if not isinstance(metadata, dict): + metadata = {} + except json.JSONDecodeError: + metadata = {} + builder.add( + ThinkingPart( + content=reasoning_msg.content, + id=metadata.get('id'), + signature=metadata.get('signature'), + provider_name=metadata.get('provider_name'), + provider_details=metadata.get('provider_details'), + ) + ) + + case ActivityMessage() as activity_msg: + if activity_msg.activity_type == 'pydantic_ai_file' and preserve_file_data: + activity_content = activity_msg.content + url = activity_content.get('url', '') + if not url: + raise ValueError( + 'ActivityMessage with activity_type=pydantic_ai_file must have a non-empty url.' + ) + builder.add( + FilePart( + content=BinaryContent.from_data_uri(url), + id=activity_content.get('id'), + provider_name=activity_content.get('provider_name'), + provider_details=activity_content.get('provider_details'), + ) + ) + elif activity_msg.activity_type == 'pydantic_ai_uploaded_file' and preserve_file_data: + activity_content = activity_msg.content + file_id = activity_content.get('file_id', '') + provider_name = activity_content.get('provider_name', '') + if not file_id or not provider_name: + raise ValueError( + 'ActivityMessage with activity_type=pydantic_ai_uploaded_file ' + 'must have non-empty file_id and provider_name.' + ) + builder.add( + UserPromptPart( + content=[ + UploadedFile( + file_id=file_id, + provider_name=provider_name, + vendor_metadata=activity_content.get('vendor_metadata'), + media_type=activity_content.get('media_type'), + identifier=activity_content.get('identifier'), + ) + ] + ) + ) + + case _: + # this might crash if a user is using the latest AG-UI protocol with new message types + # in that case we can easily push a patch handling the new message type as a placeholder while we plan the actual implementation + assert_never(msg) return builder.messages + + @staticmethod + def _dump_request_parts(msg: ModelRequest, *, preserve_file_data: bool = False) -> list[Message]: + """Convert a `ModelRequest` into AG-UI messages.""" + result: list[Message] = [] + system_content: list[str] = [] + user_content: list[TextInputContent | BinaryInputContent] = [] + + for part in msg.parts: + if isinstance(part, SystemPromptPart): + system_content.append(part.content) + elif isinstance(part, UserPromptPart): + if isinstance(part.content, str): + user_content.append(TextInputContent(type='text', text=part.content)) + else: + for item in part.content: + if isinstance(item, UploadedFile) and preserve_file_data: + uploaded_content: dict[str, Any] = { + 'file_id': item.file_id, + 'provider_name': item.provider_name, + 'media_type': item.media_type, + 'identifier': item.identifier, + } + if item.vendor_metadata is not None: + uploaded_content['vendor_metadata'] = item.vendor_metadata + result.append( + ActivityMessage( + id=_new_message_id(), + activity_type='pydantic_ai_uploaded_file', + content=uploaded_content, + ) + ) + else: + converted = _user_content_to_input(item) + if converted is not None: + user_content.append(converted) + elif isinstance(part, ToolReturnPart): + result.append( + ToolMessage( + id=_new_message_id(), + content=part.model_response_str(), + tool_call_id=part.tool_call_id, + ) + ) + elif isinstance(part, RetryPromptPart): + if part.tool_name: + result.append( + ToolMessage( + id=_new_message_id(), + content=part.model_response(), + tool_call_id=part.tool_call_id, + error=part.model_response(), + ) + ) + else: + user_content.append(TextInputContent(type='text', text=part.model_response())) + else: + assert_never(part) + + messages: list[Message] = [] + if system_content: + messages.append(SystemMessage(id=_new_message_id(), content='\n'.join(system_content))) + if user_content: + # Simplify to plain string if only single text item + if len(user_content) == 1 and isinstance(user_content[0], TextInputContent): + messages.append(UserMessage(id=_new_message_id(), content=user_content[0].text)) + else: + messages.append(UserMessage(id=_new_message_id(), content=user_content)) + messages.extend(result) + return messages + + @staticmethod + def _dump_response_parts( # noqa: C901 + msg: ModelResponse, *, ag_ui_version: AGUIVersion = DEFAULT_AG_UI_VERSION, preserve_file_data: bool = False + ) -> list[Message]: + """Convert a `ModelResponse` into AG-UI messages. + + Uses a flush pattern to preserve part ordering: text that appears after tool calls + gets its own AssistantMessage, and ThinkingPart/FilePart boundaries trigger a flush + so content on either side doesn't get merged. + """ + result: list[Message] = [] + text_content: list[str] = [] + tool_calls_list: list[ToolCall] = [] + tool_messages: list[ToolMessage] = [] + + builtin_returns = {part.tool_call_id: part for part in msg.parts if isinstance(part, BuiltinToolReturnPart)} + + def flush() -> None: + nonlocal text_content, tool_calls_list, tool_messages + if not text_content and not tool_calls_list: + return + result.append( + AssistantMessage( + id=_new_message_id(), + content='\n'.join(text_content) if text_content else None, + tool_calls=tool_calls_list if tool_calls_list else None, + ) + ) + result.extend(tool_messages) + text_content = [] + tool_calls_list = [] + tool_messages = [] + + for part in msg.parts: + if isinstance(part, TextPart): + if tool_calls_list: + flush() + text_content.append(part.content) + elif isinstance(part, ThinkingPart): + if parse_ag_ui_version(ag_ui_version) >= REASONING_VERSION: + from ag_ui.core import ReasoningMessage + + flush() + encrypted = thinking_encrypted_metadata(part) + result.append( + ReasoningMessage( + id=_new_message_id(), + content=part.content, + encrypted_value=json.dumps(encrypted) if encrypted else None, + ) + ) + elif isinstance(part, ToolCallPart): + tool_calls_list.append( + ToolCall( + id=part.tool_call_id, + function=FunctionCall(name=part.tool_name, arguments=part.args_as_json_str()), + ) + ) + elif isinstance(part, BuiltinToolCallPart): + prefixed_id = '|'.join([BUILTIN_TOOL_CALL_ID_PREFIX, part.provider_name or '', part.tool_call_id]) + tool_calls_list.append( + ToolCall( + id=prefixed_id, + function=FunctionCall(name=part.tool_name, arguments=part.args_as_json_str()), + ) + ) + if builtin_return := builtin_returns.get(part.tool_call_id): + tool_messages.append( + ToolMessage( + id=_new_message_id(), + content=builtin_return.model_response_str(), + tool_call_id=prefixed_id, + ) + ) + elif isinstance(part, BuiltinToolReturnPart): + # Emitted when matching BuiltinToolCallPart is processed above. + pass + elif isinstance(part, FilePart): + if preserve_file_data: + flush() + file_content: dict[str, Any] = { + 'url': part.content.data_uri, + 'media_type': part.content.media_type, + } + if part.id is not None: + file_content['id'] = part.id + if part.provider_name is not None: + file_content['provider_name'] = part.provider_name + if part.provider_details is not None: + file_content['provider_details'] = part.provider_details + result.append( + ActivityMessage( + id=_new_message_id(), + activity_type='pydantic_ai_file', + content=file_content, + ) + ) + else: + assert_never(part) + + flush() + return result + + @classmethod + def dump_messages( + cls, + messages: Sequence[ModelMessage], + *, + ag_ui_version: AGUIVersion = DEFAULT_AG_UI_VERSION, + preserve_file_data: bool = False, + ) -> list[Message]: + """Transform Pydantic AI messages into AG-UI messages. + + Note: The round-trip `dump_messages` -> `load_messages` is not fully lossless: + + - `TextPart.id`, `.provider_name`, `.provider_details` are lost. + - `ToolCallPart.id`, `.provider_name`, `.provider_details` are lost. + - `BuiltinToolCallPart.id`, `.provider_details` are lost (only `.provider_name` survives + via the prefixed tool call ID). + - `BuiltinToolReturnPart.provider_details` is lost. + - `RetryPromptPart` becomes `ToolReturnPart` (or `UserPromptPart`) on reload. + - `CachePoint` and `UploadedFile` content items are dropped (unless `preserve_file_data=True`). + - `ThinkingPart` is dropped when `ag_ui_version='0.1.10'`. + - `FilePart` is silently dropped unless `preserve_file_data=True`. + - `UploadedFile` in a multi-item `UserPromptPart` is split into a separate activity message + when `preserve_file_data=True`, which reloads as a separate `UserPromptPart`. + - Part ordering within a `ModelResponse` may change when text follows tool calls. + + Args: + messages: A sequence of ModelMessage objects to convert. + ag_ui_version: AG-UI protocol version controlling `ThinkingPart` emission. + preserve_file_data: Whether to include `FilePart` and `UploadedFile` as `ActivityMessage`. + + Returns: + A list of AG-UI Message objects. + """ + result: list[Message] = [] + + for msg in messages: + if isinstance(msg, ModelRequest): + request_messages = cls._dump_request_parts(msg, preserve_file_data=preserve_file_data) + result.extend(request_messages) + elif isinstance(msg, ModelResponse): + result.extend( + cls._dump_response_parts(msg, ag_ui_version=ag_ui_version, preserve_file_data=preserve_file_data) + ) + else: + assert_never(msg) + + return result diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py index 5ae697b1e2..35a2c9a8cb 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py @@ -6,10 +6,11 @@ from __future__ import annotations +import importlib.metadata import json from collections.abc import AsyncIterator, Iterable from dataclasses import dataclass, field -from typing import Final +from typing import Any, Final, Literal from uuid import uuid4 from ..._utils import now_utc @@ -42,11 +43,6 @@ TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent, - ThinkingEndEvent, - ThinkingStartEvent, - ThinkingTextMessageContentEvent, - ThinkingTextMessageEndEvent, - ThinkingTextMessageStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent, @@ -60,8 +56,52 @@ 'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`' ) from e +AGUIVersion = Literal['0.1.10', '0.1.13'] +"""Supported AG-UI protocol versions. + +- `'0.1.10'`: emits `THINKING_*` events, drops `ThinkingPart` from `dump_messages`. +- `'0.1.13'`: emits `REASONING_*` events with encrypted metadata, preserves `ThinkingPart` + as `ReasoningMessage` in `dump_messages` for full round-trip fidelity. +""" + +REASONING_VERSION = (0, 1, 13) +"""AG-UI version that introduced REASONING_* events (replacing THINKING_*).""" + + +def parse_ag_ui_version(version: str) -> tuple[int, ...]: + """Parse an AG-UI version string (e.g. `'0.1.13'`) into a comparable tuple. + + Pre-release suffixes like `a1`, `b2`, `rc1`, `.dev0` are stripped before parsing. + """ + import re + + from ...exceptions import UserError + + match = re.match(r'(\d+(?:\.\d+)*)', version) + if not match: + raise UserError(f"Invalid AG-UI version {version!r}: expected a dotted numeric version like '0.1.13'") + return tuple(int(x) for x in match.group(1).split('.')) + + +def _detect_ag_ui_version() -> AGUIVersion: + """Detect installed ag-ui-protocol version and map to the nearest supported `AGUIVersion`.""" + try: + installed = importlib.metadata.version('ag-ui-protocol') + if parse_ag_ui_version(installed) >= REASONING_VERSION: + return '0.1.13' + except Exception: + pass + return '0.1.10' + + +DEFAULT_AG_UI_VERSION: AGUIVersion = _detect_ag_ui_version() +"""The default AG-UI version, auto-detected from the installed `ag-ui-protocol` package.""" + + __all__ = [ 'AGUIEventStream', + 'AGUIVersion', + 'DEFAULT_AG_UI_VERSION', 'RunAgentInput', 'RunStartedEvent', 'RunFinishedEvent', @@ -70,14 +110,37 @@ BUILTIN_TOOL_CALL_ID_PREFIX: Final[str] = 'pyd_ai_builtin' +def thinking_encrypted_metadata(part: ThinkingPart) -> dict[str, Any]: + """Collect non-None metadata fields from a ThinkingPart for AG-UI encrypted_value.""" + encrypted: dict[str, Any] = {} + if part.id is not None: + encrypted['id'] = part.id + if part.signature is not None: + encrypted['signature'] = part.signature + if part.provider_name is not None: + encrypted['provider_name'] = part.provider_name + if part.provider_details is not None: + encrypted['provider_details'] = part.provider_details + return encrypted + + @dataclass class AGUIEventStream(UIEventStream[RunAgentInput, BaseEvent, AgentDepsT, OutputDataT]): """UI event stream transformer for the Agent-User Interaction (AG-UI) protocol.""" - _thinking_text: bool = False + ag_ui_version: AGUIVersion = DEFAULT_AG_UI_VERSION + + _use_reasoning: bool = field(default=False, init=False) + _reasoning_message_id: str | None = None + _reasoning_started: bool = False + _reasoning_text: bool = False _builtin_tool_call_ids: dict[str, str] = field(default_factory=dict[str, str]) + _ended_tool_call_ids: set[str] = field(default_factory=set[str]) _error: bool = False + def __post_init__(self) -> None: + self._use_reasoning = parse_ag_ui_version(self.ag_ui_version) >= REASONING_VERSION + @property def _event_encoder(self) -> EventEncoder: return EventEncoder(accept=self.accept or SSE_CONTENT_TYPE) @@ -147,33 +210,42 @@ async def handle_text_end(self, part: TextPart, followed_by_text: bool = False) async def handle_thinking_start( self, part: ThinkingPart, follows_thinking: bool = False ) -> AsyncIterator[BaseEvent]: - if not follows_thinking: - yield ThinkingStartEvent(type=EventType.THINKING_START) + self._reasoning_message_id = str(uuid4()) + self._reasoning_started = False - if part.content: - yield ThinkingTextMessageStartEvent(type=EventType.THINKING_TEXT_MESSAGE_START) - yield ThinkingTextMessageContentEvent(type=EventType.THINKING_TEXT_MESSAGE_CONTENT, delta=part.content) - self._thinking_text = True + if self._use_reasoning: + from ._thinking_0_13 import handle_thinking_start as _impl + else: + from ._thinking_0_10 import handle_thinking_start as _impl + async for event in _impl(self, part): + yield event async def handle_thinking_delta(self, delta: ThinkingPartDelta) -> AsyncIterator[BaseEvent]: if not delta.content_delta: return # pragma: no cover - if not self._thinking_text: - yield ThinkingTextMessageStartEvent(type=EventType.THINKING_TEXT_MESSAGE_START) - self._thinking_text = True + assert self._reasoning_message_id is not None, ( + 'handle_thinking_start must be called before handle_thinking_delta' + ) - yield ThinkingTextMessageContentEvent(type=EventType.THINKING_TEXT_MESSAGE_CONTENT, delta=delta.content_delta) + if self._use_reasoning: + from ._thinking_0_13 import handle_thinking_delta as _impl + else: + from ._thinking_0_10 import handle_thinking_delta as _impl + async for event in _impl(self, delta): + yield event async def handle_thinking_end( self, part: ThinkingPart, followed_by_thinking: bool = False ) -> AsyncIterator[BaseEvent]: - if self._thinking_text: - yield ThinkingTextMessageEndEvent(type=EventType.THINKING_TEXT_MESSAGE_END) - self._thinking_text = False + assert self._reasoning_message_id is not None, 'handle_thinking_start must be called before handle_thinking_end' - if not followed_by_thinking: - yield ThinkingEndEvent(type=EventType.THINKING_END) + if self._use_reasoning: + from ._thinking_0_13 import handle_thinking_end as _impl + else: + from ._thinking_0_10 import handle_thinking_end as _impl + async for event in _impl(self, part): + yield event def handle_tool_call_start(self, part: ToolCallPart | BuiltinToolCallPart) -> AsyncIterator[BaseEvent]: return self._handle_tool_call_start(part) @@ -203,16 +275,21 @@ async def handle_tool_call_delta(self, delta: ToolCallPartDelta) -> AsyncIterato assert tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set' if tool_call_id in self._builtin_tool_call_ids: tool_call_id = self._builtin_tool_call_ids[tool_call_id] + if tool_call_id in self._ended_tool_call_ids: + return yield ToolCallArgsEvent( tool_call_id=tool_call_id, delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta), ) async def handle_tool_call_end(self, part: ToolCallPart) -> AsyncIterator[BaseEvent]: + self._ended_tool_call_ids.add(part.tool_call_id) yield ToolCallEndEvent(tool_call_id=part.tool_call_id) async def handle_builtin_tool_call_end(self, part: BuiltinToolCallPart) -> AsyncIterator[BaseEvent]: - yield ToolCallEndEvent(tool_call_id=self._builtin_tool_call_ids[part.tool_call_id]) + builtin_id = self._builtin_tool_call_ids[part.tool_call_id] + self._ended_tool_call_ids.add(builtin_id) + yield ToolCallEndEvent(tool_call_id=builtin_id) async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> AsyncIterator[BaseEvent]: tool_call_id = self._builtin_tool_call_ids[part.tool_call_id] diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_thinking_0_10.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_thinking_0_10.py new file mode 100644 index 0000000000..914cfe7044 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_thinking_0_10.py @@ -0,0 +1,72 @@ +# pyright: reportPrivateUsage=false +"""Thinking event handlers for AG-UI protocol < 0.1.13 (THINKING_* events). + +These are extracted class methods of `AGUIEventStream` — the `self` parameter +is the event stream instance, and access to its private fields is intentional. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +from ag_ui.core import ( + BaseEvent, + ThinkingEndEvent, + ThinkingStartEvent, + ThinkingTextMessageContentEvent, + ThinkingTextMessageEndEvent, + ThinkingTextMessageStartEvent, +) + +from ...messages import ThinkingPart, ThinkingPartDelta + +if TYPE_CHECKING: + from ...output import OutputDataT + from ...tools import AgentDepsT + from ._event_stream import AGUIEventStream + + +async def handle_thinking_start( + self: AGUIEventStream[AgentDepsT, OutputDataT], part: ThinkingPart +) -> AsyncIterator[BaseEvent]: + if part.content: + yield ThinkingStartEvent() + self._reasoning_started = True + yield ThinkingTextMessageStartEvent() + yield ThinkingTextMessageContentEvent(delta=part.content) + self._reasoning_text = True + + +async def handle_thinking_delta( + self: AGUIEventStream[AgentDepsT, OutputDataT], delta: ThinkingPartDelta +) -> AsyncIterator[BaseEvent]: + assert delta.content_delta is not None + + if not self._reasoning_started: + yield ThinkingStartEvent() + self._reasoning_started = True + + if not self._reasoning_text: + yield ThinkingTextMessageStartEvent() + self._reasoning_text = True + + yield ThinkingTextMessageContentEvent(delta=delta.content_delta) + + +async def handle_thinking_end( + self: AGUIEventStream[AgentDepsT, OutputDataT], part: ThinkingPart +) -> AsyncIterator[BaseEvent]: + if not self._reasoning_started and not part.content: + self._reasoning_message_id = None + return + + if not self._reasoning_started: + yield ThinkingStartEvent() + + if self._reasoning_text: + yield ThinkingTextMessageEndEvent() + self._reasoning_text = False + + yield ThinkingEndEvent() + self._reasoning_message_id = None diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_thinking_0_13.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_thinking_0_13.py new file mode 100644 index 0000000000..166254510c --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_thinking_0_13.py @@ -0,0 +1,89 @@ +# pyright: reportPrivateUsage=false +"""Reasoning event handlers for AG-UI protocol >= 0.1.13 (REASONING_* events). + +These are extracted class methods of `AGUIEventStream` — the `self` parameter +is the event stream instance, and access to its private fields is intentional. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +from ag_ui.core import ( + BaseEvent, + ReasoningEncryptedValueEvent, + ReasoningEndEvent, + ReasoningMessageContentEvent, + ReasoningMessageEndEvent, + ReasoningMessageStartEvent, + ReasoningStartEvent, +) + +from ...messages import ThinkingPart, ThinkingPartDelta +from ._event_stream import thinking_encrypted_metadata + +if TYPE_CHECKING: + from ...output import OutputDataT + from ...tools import AgentDepsT + from ._event_stream import AGUIEventStream + + +async def handle_thinking_start( + self: AGUIEventStream[AgentDepsT, OutputDataT], part: ThinkingPart +) -> AsyncIterator[BaseEvent]: + assert self._reasoning_message_id is not None + if part.content: + yield ReasoningStartEvent(message_id=self._reasoning_message_id) + self._reasoning_started = True + yield ReasoningMessageStartEvent(message_id=self._reasoning_message_id, role='assistant') + yield ReasoningMessageContentEvent(message_id=self._reasoning_message_id, delta=part.content) + self._reasoning_text = True + + +async def handle_thinking_delta( + self: AGUIEventStream[AgentDepsT, OutputDataT], delta: ThinkingPartDelta +) -> AsyncIterator[BaseEvent]: + assert self._reasoning_message_id is not None + assert delta.content_delta is not None + message_id = self._reasoning_message_id + + if not self._reasoning_started: + yield ReasoningStartEvent(message_id=message_id) + self._reasoning_started = True + + if not self._reasoning_text: + yield ReasoningMessageStartEvent(message_id=message_id, role='assistant') + self._reasoning_text = True + + yield ReasoningMessageContentEvent(message_id=message_id, delta=delta.content_delta) + + +async def handle_thinking_end( + self: AGUIEventStream[AgentDepsT, OutputDataT], part: ThinkingPart +) -> AsyncIterator[BaseEvent]: + assert self._reasoning_message_id is not None + message_id = self._reasoning_message_id + encrypted = thinking_encrypted_metadata(part) + + if not self._reasoning_started and not encrypted: + self._reasoning_message_id = None + return + + if not self._reasoning_started: + yield ReasoningStartEvent(message_id=message_id) + + if self._reasoning_text: + yield ReasoningMessageEndEvent(message_id=message_id) + self._reasoning_text = False + + if encrypted: + yield ReasoningEncryptedValueEvent( + subtype='message', + entity_id=message_id, + encrypted_value=json.dumps(encrypted), + ) + + yield ReasoningEndEvent(message_id=message_id) + self._reasoning_message_id = None diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py index 1f0fbe5262..4dfaf155c6 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/app.py @@ -21,6 +21,7 @@ from .. import OnCompleteFunc, StateHandler from ._adapter import AGUIAdapter +from ._event_stream import DEFAULT_AG_UI_VERSION, AGUIVersion try: from starlette.applications import Starlette @@ -44,6 +45,8 @@ def __init__( agent: AbstractAgent[AgentDepsT, OutputDataT], *, # AGUIAdapter.dispatch_request parameters + ag_ui_version: AGUIVersion = DEFAULT_AG_UI_VERSION, + preserve_file_data: bool = False, output_type: OutputSpec[Any] | None = None, message_history: Sequence[ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, @@ -75,6 +78,9 @@ def __init__( Args: agent: The agent to run. + ag_ui_version: AG-UI protocol version controlling thinking/reasoning event format. + preserve_file_data: Whether to preserve agent-generated files and uploaded files + in AG-UI message conversion. See [`AGUIAdapter.preserve_file_data`][pydantic_ai.ui.ag_ui.AGUIAdapter.preserve_file_data]. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's @@ -131,6 +137,8 @@ async def run_agent(request: Request) -> Response: return await AGUIAdapter[AgentDepsT, OutputDataT].dispatch_request( request, agent=agent, + ag_ui_version=ag_ui_version, + preserve_file_data=preserve_file_data, output_type=output_type, message_history=message_history, deferred_tool_results=deferred_tool_results, diff --git a/scripts/check_cassettes.py b/scripts/check_cassettes.py index 39ad943723..c56793991f 100644 --- a/scripts/check_cassettes.py +++ b/scripts/check_cassettes.py @@ -11,52 +11,73 @@ from __future__ import annotations +import ast import sys from collections import defaultdict from pathlib import Path -import pytest +_FORBIDDEN_CHARS = r"""<>?%*:|"'/\\""" -class _CollectVcrTests: - """Pytest plugin that collects cassette names referenced by VCR-marked tests. +def _sanitize_cassette_name(name: str) -> str: + """Replicate pytest-recording's cassette name sanitization.""" + for ch in _FORBIDDEN_CHARS: + name = name.replace(ch, '-') + return name - This is a class (not functions) because pytest's plugin system requires objects - with hook methods, and we need to accumulate state across all test items. - """ - def __init__(self) -> None: - self.tests: dict[str, set[str]] = defaultdict(set) +def _has_vcr_marker(decorator_list: list[ast.expr]) -> bool: + """Check if a decorator list contains pytest.mark.vcr (with or without parens).""" + for dec in decorator_list: + # @pytest.mark.vcr or @pytest.mark.vcr() + if isinstance(dec, ast.Attribute) and dec.attr == 'vcr': + return True + if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'vcr': + return True + return False - @staticmethod - def _remove_yaml_ext(s: str) -> str: - if s.endswith('.yaml'): - return s[:-5] - return s - def pytest_collection_modifyitems( - self, session: pytest.Session, config: pytest.Config, items: list[pytest.Item] - ) -> None: - # prevents pytest.PytestAssertRewriteWarning: Module already imported so cannot be rewritten; pytest_recording - from pytest_recording.plugin import get_default_cassette_name - - for item in items: - if not any(item.iter_markers('vcr')): +def _has_module_vcr_marker(tree: ast.Module) -> bool: + """Check if the module has pytestmark = [..., pytest.mark.vcr, ...].""" + for node in ast.iter_child_nodes(tree): + if not isinstance(node, ast.Assign): + continue + for target in node.targets: + if not (isinstance(target, ast.Name) and target.id == 'pytestmark'): continue + return 'vcr' in ast.dump(node.value) + return False - test_file_stem = Path(item.location[0]).stem - m = item.get_closest_marker('default_cassette') - if m and m.args: - self.tests[test_file_stem].add(self._remove_yaml_ext(m.args[0])) - else: - self.tests[test_file_stem].add( - self._remove_yaml_ext(get_default_cassette_name(getattr(item, 'cls', None), item.name)) - ) +def _collect_vcr_tests_from_file(path: Path) -> set[str]: + """Parse a Python test file and return cassette names for VCR-marked tests.""" + try: + tree = ast.parse(path.read_text()) + except SyntaxError: + return set() - for vm in item.iter_markers('vcr'): - for arg in vm.args: - self.tests[test_file_stem].add(self._remove_yaml_ext(arg)) + module_has_vcr = _has_module_vcr_marker(tree) + cassette_names: set[str] = set() + + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef): + if not node.name.startswith('test_'): + continue + if module_has_vcr or _has_vcr_marker(node.decorator_list): + # Parametrized tests get [] suffixes but cassettes use the base name + cassette_names.add(_sanitize_cassette_name(node.name)) + + elif isinstance(node, ast.ClassDef): + class_has_vcr = _has_vcr_marker(node.decorator_list) + for method in ast.iter_child_nodes(node): + if not isinstance(method, ast.FunctionDef | ast.AsyncFunctionDef): + continue + if not method.name.startswith('test_'): + continue + if module_has_vcr or class_has_vcr or _has_vcr_marker(method.decorator_list): + cassette_names.add(_sanitize_cassette_name(f'{node.name}.{method.name}')) + + return cassette_names def get_all_cassettes() -> dict[str, set[str]]: @@ -77,12 +98,15 @@ def get_all_cassettes() -> dict[str, set[str]]: def get_all_tests() -> dict[str, set[str]]: - """Use pytest collection to get all VCR-marked tests and their cassette names.""" - collector = _CollectVcrTests() - rc = pytest.main(['--collect-only', '-q', 'tests/'], plugins=[collector]) - if rc not in (pytest.ExitCode.OK, pytest.ExitCode.NO_TESTS_COLLECTED): - raise SystemExit(rc) - return dict(collector.tests) + """Use AST parsing to find all VCR-marked tests and their cassette names.""" + tests: dict[str, set[str]] = defaultdict(set) + + for test_file in Path('tests').rglob('test_*.py'): + cassette_names = _collect_vcr_tests_from_file(test_file) + if cassette_names: + tests[test_file.stem].update(cassette_names) + + return dict(tests) def main() -> int: @@ -93,7 +117,7 @@ def main() -> int: total_cassettes = sum(len(c) for c in cassettes.values()) print(f'Found {total_cassettes} cassettes in {len(cassettes)} test modules') - print('Collecting VCR-marked tests (this may take a moment)...') + print('Collecting VCR-marked tests...') tests = get_all_tests() total_tests = sum(len(t) for t in tests.values()) print(f'Found {total_tests} tests in {len(tests)} test modules') @@ -108,7 +132,10 @@ def main() -> int: print(f'Warning: No tests found for module {test_file}') for cassette in sorted(cassette_names): - if cassette in expected_cassettes: + # Parametrized tests produce cassettes like test_foo[param].yaml + # Strip the [param] suffix to match the base test name + base_name = cassette.split('[')[0] + if cassette in expected_cassettes or base_name in expected_cassettes: matched += 1 if verbose: print(f' OK: {test_file}/{cassette}.yaml') diff --git a/tests/cassettes/test_ag_ui/test_thinking_roundtrip_anthropic.yaml b/tests/cassettes/test_ag_ui/test_thinking_roundtrip_anthropic.yaml new file mode 100644 index 0000000000..aec2835ded --- /dev/null +++ b/tests/cassettes/test_ag_ui/test_thinking_roundtrip_anthropic.yaml @@ -0,0 +1,72 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, br, zstd + connection: + - keep-alive + content-length: + - '229' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 4096 + messages: + - content: + - text: What is 1+1? Reply in one word. + type: text + role: user + model: claude-sonnet-4-5 + stream: false + thinking: + budget_tokens: 1024 + type: enabled + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '992' + content-security-policy: + - default-src 'none'; frame-ancestors 'none' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + vary: + - Accept-Encoding + parsed_body: + content: + - signature: EooCCkYICxgCKkDYW6Ka+Mo73ZE34HVijmFbdV6QH/iRdv+3WuisH3pR8D5aSFASMBsF1F1bZRQFQXuM0+G4H83czthKvHqdqWriEgwB0eJaWoXZWU18NKoaDMH4nN8ZwJ6W9DnYLyIwrdTWmfc5QTqDr8gye3/yrPpV2YPeZnUBoHBLOGl8MUaC6SuGmxcm8rGqf2s+P+ZtKnJPJJzQiTrvPcEkF3ij22w3bXC9yoyZCyJVPcibR2ZZpLYF/UOoZ+BRBs0FCdm/QFXUUe8W1tcQ/ZQgBaW44LTcdzwOSP5hJb25UrPiGWuTytGMxIr7QyG7INpVbmm8JRBIIEzj3gs2zlxdbl17yZ/yZXcYAQ== + thinking: The user is asking what 1+1 equals and wants a one-word reply. The answer is 2, which is one word. + type: thinking + - text: Two + type: text + id: msg_01VWoy28sUMKXEpbwuUkrpmP + model: claude-sonnet-4-5-20250929 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation: + ephemeral_1h_input_tokens: 0 + ephemeral_5m_input_tokens: 0 + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + inference_geo: not_available + input_tokens: 48 + output_tokens: 42 + service_tier: standard + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 5a1e6c11c9..377b7714f7 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -2,12 +2,13 @@ from __future__ import annotations +import importlib.metadata import json import uuid from collections.abc import AsyncIterator, MutableMapping from dataclasses import dataclass from http import HTTPStatus -from typing import Any +from typing import Any, Literal import httpx import pytest @@ -17,32 +18,43 @@ from pydantic_ai import ( AudioUrl, BinaryContent, + BinaryImage, BuiltinToolCallPart, BuiltinToolReturnPart, + CachePoint, DocumentUrl, + FilePart, FunctionToolCallEvent, FunctionToolResultEvent, ImageUrl, ModelMessage, ModelRequest, + ModelRequestPart, ModelResponse, + ModelResponsePart, PartDeltaEvent, PartEndEvent, PartStartEvent, RequestUsage, + RetryPromptPart, SystemPromptPart, + TextContent, TextPart, TextPartDelta, + ThinkingPart, + ThinkingPartDelta, ToolCallPart, ToolCallPartDelta, ToolReturn, ToolReturnPart, + UploadedFile, UserPromptPart, VideoUrl, ) from pydantic_ai._run_context import RunContext from pydantic_ai.agent import Agent, AgentRunResult from pydantic_ai.builtin_tools import WebSearchTool +from pydantic_ai.exceptions import UserError from pydantic_ai.models.function import ( AgentInfo, BuiltinToolCallsReturns, @@ -70,6 +82,7 @@ EventType, FunctionCall, Message, + ReasoningMessage, RunAgentInput, StateSnapshotEvent, SystemMessage, @@ -92,6 +105,14 @@ run_ag_ui, ) from pydantic_ai.ui.ag_ui import AGUIEventStream + from pydantic_ai.ui.ag_ui._event_stream import ( + _detect_ag_ui_version, # pyright: ignore[reportPrivateUsage] + parse_ag_ui_version, + ) + +with try_import() as anthropic_imports_successful: + from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings + from pydantic_ai.providers.anthropic import AnthropicProvider pytestmark = [ @@ -144,10 +165,11 @@ async def run_and_collect_events( *run_inputs: RunAgentInput, deps: AgentDepsT = None, on_complete: OnCompleteFunc[BaseEvent] | None = None, + ag_ui_version: Literal['0.1.10', '0.1.13'] = '0.1.10', ) -> list[dict[str, Any]]: events = list[dict[str, Any]]() for run_input in run_inputs: - async for event in run_ag_ui(agent, run_input, deps=deps, on_complete=on_complete): + async for event in run_ag_ui(agent, run_input, ag_ui_version=ag_ui_version, deps=deps, on_complete=on_complete): events.append(json.loads(event.removeprefix('data: '))) return events @@ -1058,8 +1080,7 @@ async def stream_function( 'threadId': (thread_id := IsSameStr()), 'runId': (run_id := IsSameStr()), }, - {'type': 'THINKING_START', 'timestamp': IsInt()}, - {'type': 'THINKING_END', 'timestamp': IsInt()}, + # Part 0: empty thinking — skipped (no content, no metadata) { 'type': 'TEXT_MESSAGE_START', 'timestamp': IsInt(), @@ -1079,11 +1100,16 @@ async def stream_function( 'delta': ' and some more', }, {'type': 'TEXT_MESSAGE_END', 'timestamp': IsInt(), 'messageId': message_id}, + # Part 1: "Thinking about the weather" {'type': 'THINKING_START', 'timestamp': IsInt()}, {'type': 'THINKING_TEXT_MESSAGE_START', 'timestamp': IsInt()}, {'type': 'THINKING_TEXT_MESSAGE_CONTENT', 'timestamp': IsInt(), 'delta': 'Thinking '}, {'type': 'THINKING_TEXT_MESSAGE_CONTENT', 'timestamp': IsInt(), 'delta': 'about the weather'}, {'type': 'THINKING_TEXT_MESSAGE_END', 'timestamp': IsInt()}, + {'type': 'THINKING_END', 'timestamp': IsInt()}, + # Part 2: empty thinking — skipped (no content, no metadata) + # Part 3: "Thinking about the meaning of life" + {'type': 'THINKING_START', 'timestamp': IsInt()}, {'type': 'THINKING_TEXT_MESSAGE_START', 'timestamp': IsInt()}, { 'type': 'THINKING_TEXT_MESSAGE_CONTENT', @@ -1091,12 +1117,11 @@ async def stream_function( 'delta': 'Thinking about the meaning of life', }, {'type': 'THINKING_TEXT_MESSAGE_END', 'timestamp': IsInt()}, + {'type': 'THINKING_END', 'timestamp': IsInt()}, + # Part 4: "Thinking about the universe" + {'type': 'THINKING_START', 'timestamp': IsInt()}, {'type': 'THINKING_TEXT_MESSAGE_START', 'timestamp': IsInt()}, - { - 'type': 'THINKING_TEXT_MESSAGE_CONTENT', - 'timestamp': IsInt(), - 'delta': 'Thinking about the universe', - }, + {'type': 'THINKING_TEXT_MESSAGE_CONTENT', 'timestamp': IsInt(), 'delta': 'Thinking about the universe'}, {'type': 'THINKING_TEXT_MESSAGE_END', 'timestamp': IsInt()}, {'type': 'THINKING_END', 'timestamp': IsInt()}, { @@ -1109,6 +1134,800 @@ async def stream_function( ) +async def test_thinking_with_signature() -> None: + """Test that ReasoningEncryptedValueEvent is emitted with thinking metadata.""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaThinkingCalls | str]: + yield {0: DeltaThinkingPart(content='Thinking deeply', signature='sig_abc123')} + yield 'Here is my response' + + agent = Agent(model=FunctionModel(stream_function=stream_function)) + + run_input = create_input( + UserMessage(id='msg_1', content='Think about something'), + ) + + events = await run_and_collect_events(agent, run_input, ag_ui_version='0.1.13') + + assert events == snapshot( + [ + { + 'type': 'RUN_STARTED', + 'timestamp': IsInt(), + 'threadId': (thread_id := IsSameStr()), + 'runId': (run_id := IsSameStr()), + }, + {'type': 'REASONING_START', 'timestamp': IsInt(), 'messageId': (reasoning_id := IsSameStr())}, + { + 'type': 'REASONING_MESSAGE_START', + 'timestamp': IsInt(), + 'messageId': reasoning_id, + 'role': 'assistant', + }, + { + 'type': 'REASONING_MESSAGE_CONTENT', + 'timestamp': IsInt(), + 'messageId': reasoning_id, + 'delta': 'Thinking deeply', + }, + {'type': 'REASONING_MESSAGE_END', 'timestamp': IsInt(), 'messageId': reasoning_id}, + { + 'type': 'REASONING_ENCRYPTED_VALUE', + 'timestamp': IsInt(), + 'subtype': 'message', + 'entityId': reasoning_id, + 'encryptedValue': IsStr(), + }, + {'type': 'REASONING_END', 'timestamp': IsInt(), 'messageId': reasoning_id}, + { + 'type': 'TEXT_MESSAGE_START', + 'timestamp': IsInt(), + 'messageId': (message_id := IsSameStr()), + 'role': 'assistant', + }, + { + 'type': 'TEXT_MESSAGE_CONTENT', + 'timestamp': IsInt(), + 'messageId': message_id, + 'delta': 'Here is my response', + }, + {'type': 'TEXT_MESSAGE_END', 'timestamp': IsInt(), 'messageId': message_id}, + {'type': 'RUN_FINISHED', 'timestamp': IsInt(), 'threadId': thread_id, 'runId': run_id}, + ] + ) + + +async def test_thinking_consecutive_signatures() -> None: + """Test that consecutive ThinkingParts each preserve their own metadata via separate REASONING blocks.""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaThinkingCalls | str]: + yield {0: DeltaThinkingPart(content='First thought', signature='sig_aaa')} + yield {1: DeltaThinkingPart(content='Second thought', signature='sig_bbb')} + yield {2: DeltaThinkingPart(content='Third thought', signature='sig_ccc')} + yield 'Final answer' + + agent = Agent(model=FunctionModel(stream_function=stream_function)) + + run_input = create_input( + UserMessage(id='msg_1', content='Think deeply'), + ) + + events = await run_and_collect_events(agent, run_input, ag_ui_version='0.1.13') + + assert events == snapshot( + [ + { + 'type': 'RUN_STARTED', + 'timestamp': IsInt(), + 'threadId': (thread_id := IsSameStr()), + 'runId': (run_id := IsSameStr()), + }, + # Part 0: signature=sig_aaa + {'type': 'REASONING_START', 'timestamp': IsInt(), 'messageId': (r0 := IsSameStr())}, + { + 'type': 'REASONING_MESSAGE_START', + 'timestamp': IsInt(), + 'messageId': r0, + 'role': 'assistant', + }, + {'type': 'REASONING_MESSAGE_CONTENT', 'timestamp': IsInt(), 'messageId': r0, 'delta': 'First thought'}, + {'type': 'REASONING_MESSAGE_END', 'timestamp': IsInt(), 'messageId': r0}, + { + 'type': 'REASONING_ENCRYPTED_VALUE', + 'timestamp': IsInt(), + 'subtype': 'message', + 'entityId': r0, + 'encryptedValue': IsStr(), + }, + {'type': 'REASONING_END', 'timestamp': IsInt(), 'messageId': r0}, + # Part 1: signature=sig_bbb + {'type': 'REASONING_START', 'timestamp': IsInt(), 'messageId': (r1 := IsSameStr())}, + { + 'type': 'REASONING_MESSAGE_START', + 'timestamp': IsInt(), + 'messageId': r1, + 'role': 'assistant', + }, + {'type': 'REASONING_MESSAGE_CONTENT', 'timestamp': IsInt(), 'messageId': r1, 'delta': 'Second thought'}, + {'type': 'REASONING_MESSAGE_END', 'timestamp': IsInt(), 'messageId': r1}, + { + 'type': 'REASONING_ENCRYPTED_VALUE', + 'timestamp': IsInt(), + 'subtype': 'message', + 'entityId': r1, + 'encryptedValue': IsStr(), + }, + {'type': 'REASONING_END', 'timestamp': IsInt(), 'messageId': r1}, + # Part 2: signature=sig_ccc + {'type': 'REASONING_START', 'timestamp': IsInt(), 'messageId': (r2 := IsSameStr())}, + { + 'type': 'REASONING_MESSAGE_START', + 'timestamp': IsInt(), + 'messageId': r2, + 'role': 'assistant', + }, + {'type': 'REASONING_MESSAGE_CONTENT', 'timestamp': IsInt(), 'messageId': r2, 'delta': 'Third thought'}, + {'type': 'REASONING_MESSAGE_END', 'timestamp': IsInt(), 'messageId': r2}, + { + 'type': 'REASONING_ENCRYPTED_VALUE', + 'timestamp': IsInt(), + 'subtype': 'message', + 'entityId': r2, + 'encryptedValue': IsStr(), + }, + {'type': 'REASONING_END', 'timestamp': IsInt(), 'messageId': r2}, + # Text response + { + 'type': 'TEXT_MESSAGE_START', + 'timestamp': IsInt(), + 'messageId': (message_id := IsSameStr()), + 'role': 'assistant', + }, + { + 'type': 'TEXT_MESSAGE_CONTENT', + 'timestamp': IsInt(), + 'messageId': message_id, + 'delta': 'Final answer', + }, + {'type': 'TEXT_MESSAGE_END', 'timestamp': IsInt(), 'messageId': message_id}, + {'type': 'RUN_FINISHED', 'timestamp': IsInt(), 'threadId': thread_id, 'runId': run_id}, + ] + ) + + +def test_reasoning_message_thinking_roundtrip() -> None: + """Test that ReasoningMessage converts to ThinkingPart with metadata from encrypted_value.""" + messages = AGUIAdapter.load_messages( + [ + ReasoningMessage( + id='reasoning-1', + content='Let me think about this...', + encrypted_value=json.dumps( + { + 'id': 'thinking-1', + 'signature': 'sig_abc123', + 'provider_name': 'anthropic', + 'provider_details': {'some': 'details'}, + } + ), + ), + AssistantMessage(id='msg-1', content='Here is my response'), + ] + ) + + assert messages == snapshot( + [ + ModelResponse( + parts=[ + ThinkingPart( + content='Let me think about this...', + id='thinking-1', + signature='sig_abc123', + provider_name='anthropic', + provider_details={'some': 'details'}, + ), + TextPart(content='Here is my response'), + ], + timestamp=IsDatetime(), + ) + ] + ) + + +async def test_reasoning_events_with_all_metadata() -> None: + """Test that REASONING_* events emit encryptedValue with all metadata fields.""" + run_input = create_input(UserMessage(id='msg_1', content='test')) + event_stream = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.13') + + part = ThinkingPart( + content='Thinking content', + id='thinking-123', + signature='sig_xyz', + provider_name='anthropic', + provider_details={'model': 'claude-sonnet-4-5'}, + ) + + events: list[BaseEvent] = [] + async for e in event_stream.handle_thinking_start(part): + events.append(e) + async for e in event_stream.handle_thinking_end(part): + events.append(e) + + assert [e.model_dump(exclude_none=True) for e in events] == snapshot( + [ + {'type': 'REASONING_START', 'message_id': IsStr()}, + {'type': 'REASONING_MESSAGE_START', 'message_id': IsStr(), 'role': 'assistant'}, + {'type': 'REASONING_MESSAGE_CONTENT', 'message_id': IsStr(), 'delta': 'Thinking content'}, + {'type': 'REASONING_MESSAGE_END', 'message_id': IsStr()}, + { + 'type': 'REASONING_ENCRYPTED_VALUE', + 'subtype': 'message', + 'entity_id': IsStr(), + 'encrypted_value': '{"id": "thinking-123", "signature": "sig_xyz", "provider_name": "anthropic", "provider_details": {"model": "claude-sonnet-4-5"}}', + }, + {'type': 'REASONING_END', 'message_id': IsStr()}, + ] + ) + + +def test_activity_message_other_types_ignored() -> None: + """Test that ActivityMessage with other activity types are ignored.""" + messages = AGUIAdapter.load_messages( + [ + ActivityMessage( + id='activity-1', + activity_type='some_other_activity', + content={'foo': 'bar'}, + ), + AssistantMessage(id='msg-1', content='Response'), + ] + ) + + assert messages == snapshot([ModelResponse(parts=[TextPart(content='Response')], timestamp=IsDatetime())]) + + +@pytest.mark.parametrize( + 'encrypted_value', + [ + pytest.param('not valid json{{{', id='invalid-json'), + pytest.param('"just a string"', id='non-dict-string'), + pytest.param('[1, 2, 3]', id='non-dict-list'), + pytest.param('42', id='non-dict-number'), + ], +) +def test_reasoning_message_malformed_encrypted_value(encrypted_value: str) -> None: + """Test that malformed or non-dict encrypted_value is handled gracefully.""" + messages = AGUIAdapter.load_messages( + [ + ReasoningMessage(id='r-1', content='Thinking...', encrypted_value=encrypted_value), + AssistantMessage(id='msg-1', content='Done'), + ] + ) + + assert messages == snapshot( + [ + ModelResponse( + parts=[ThinkingPart(content='Thinking...'), TextPart(content='Done')], + timestamp=IsDatetime(), + ) + ] + ) + + +def test_activity_message_file_part_missing_url() -> None: + """Test that ActivityMessage(pydantic_ai_file) with empty url raises ValueError.""" + with pytest.raises(ValueError, match='must have a non-empty url'): + AGUIAdapter.load_messages( + [ + ActivityMessage( + id='activity-1', + activity_type='pydantic_ai_file', + content={'url': '', 'media_type': 'image/png'}, + ), + ], + preserve_file_data=True, + ) + + +_TIMESTAMPED_PARTS = (UserPromptPart, RetryPromptPart, ToolReturnPart, BuiltinToolReturnPart, SystemPromptPart) + + +def _sync_part_timestamps( + original_part: ModelRequestPart | ModelResponsePart, + new_part: ModelRequestPart | ModelResponsePart, +) -> None: + """Sync timestamp attribute if both parts are request parts (which carry timestamps).""" + if isinstance(new_part, _TIMESTAMPED_PARTS) and isinstance(original_part, _TIMESTAMPED_PARTS): + object.__setattr__(new_part, 'timestamp', original_part.timestamp) + + +def _sync_timestamps(original: list[ModelMessage], reloaded: list[ModelMessage]) -> None: + """Sync timestamps between original and reloaded messages for comparison.""" + for o, n in zip(original, reloaded): + if isinstance(n, ModelResponse) and isinstance(o, ModelResponse): + n.timestamp = o.timestamp + for op, np in zip(o.parts, n.parts): + _sync_part_timestamps(op, np) + elif isinstance(n, ModelRequest) and isinstance(o, ModelRequest): # pragma: no branch + for op, np in zip(o.parts, n.parts): + _sync_part_timestamps(op, np) + + +def test_dump_load_roundtrip_basic() -> None: + """Test that load_messages(dump_messages(msgs)) preserves basic messages.""" + original: list[ModelMessage] = [ + ModelRequest(parts=[SystemPromptPart(content='You are helpful'), UserPromptPart(content='Hello')]), + ModelResponse(parts=[TextPart(content='Hi!')]), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(original, reloaded) + + assert reloaded == original + + +def test_dump_load_roundtrip_thinking() -> None: + """Test full round-trip for thinking parts with all metadata.""" + original: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Think about this')]), + ModelResponse( + parts=[ + ThinkingPart( + content='Deep thoughts...', + id='think-001', + signature='sig_xyz', + provider_name='anthropic', + provider_details={'model': 'claude-sonnet-4-5'}, + ), + TextPart(content='Conclusion'), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original, ag_ui_version='0.1.13') + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(original, reloaded) + + assert reloaded == original + + +def test_dump_load_roundtrip_tools() -> None: + """Test full round-trip for tool calls and returns.""" + original: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Call tool')]), + ModelResponse(parts=[ToolCallPart(tool_name='my_tool', tool_call_id='call_abc', args='{"x": 1}')]), + ModelRequest(parts=[ToolReturnPart(tool_name='my_tool', tool_call_id='call_abc', content='result')]), + ModelResponse(parts=[TextPart(content='Done')]), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(original, reloaded) + + assert reloaded == original + + +def test_dump_load_roundtrip_multiple_thinking_parts() -> None: + """Test round-trip preserves multiple ThinkingParts with their metadata.""" + original: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Think hard')]), + ModelResponse( + parts=[ + ThinkingPart(content='First thought', id='think-1', signature='sig_1'), + ThinkingPart(content='Second thought', id='think-2', signature='sig_2'), + TextPart(content='Final answer'), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original, ag_ui_version='0.1.13') + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(original, reloaded) + + assert reloaded == original + + +def test_dump_load_roundtrip_binary_content() -> None: + """Test round-trip for binary content in user prompts (images, documents, etc.).""" + original: list[ModelMessage] = [ + ModelRequest( + parts=[ + UserPromptPart( + content=[ + 'Describe this image', + ImageUrl(url='https://example.com/image.png', media_type='image/png'), + BinaryContent(data=b'raw image data', media_type='image/jpeg'), + ] + ), + ] + ), + ModelResponse(parts=[TextPart(content='I see an image.')]), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(original, reloaded) + + assert reloaded == original + + +def test_dump_load_roundtrip_file_part() -> None: + """Test round-trip for FilePart in model responses. + + Note: BinaryImage is used because from_data_uri() returns BinaryImage for image/* media types. + """ + file_data = b'generated file content' + original: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Generate an image')]), + ModelResponse( + parts=[ + FilePart( + content=BinaryImage(data=file_data, media_type='image/png'), + id='file-001', + provider_name='openai', + provider_details={'model': 'gpt-image'}, + ), + TextPart(content='Here is your generated image.'), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original, preserve_file_data=True) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs, preserve_file_data=True) + _sync_timestamps(original, reloaded) + + assert reloaded == original + + +def test_dump_load_roundtrip_builtin_tool_return() -> None: + """Test round-trip for builtin tool calls with their return values. + + Note: The round-trip reorders parts within ModelResponse because AG-UI's AssistantMessage + has separate content and tool_calls fields. TextPart comes first (from content), then + BuiltinToolCallPart (from tool_calls), then BuiltinToolReturnPart (from subsequent ToolMessage). + """ + original: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Search for info')]), + ModelResponse( + parts=[ + TextPart(content='Based on the search...'), + BuiltinToolCallPart( + tool_name='web_search', + tool_call_id='call_123', + args='{"query": "test"}', + provider_name='anthropic', + ), + BuiltinToolReturnPart( + tool_name='web_search', + tool_call_id='call_123', + content='Search results here', + provider_name='anthropic', + ), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(original, reloaded) + + assert reloaded == original + + +def test_dump_builtin_tool_call_without_return() -> None: + """Test that BuiltinToolCallPart without a matching BuiltinToolReturnPart still dumps correctly.""" + messages: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Search for info')]), + ModelResponse( + parts=[ + BuiltinToolCallPart( + tool_name='web_search', + tool_call_id='call_orphan', + args='{"query": "test"}', + provider_name='anthropic', + ), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(messages) + + assert len(ag_ui_msgs) == 2 + assistant_msg = ag_ui_msgs[1] + assert isinstance(assistant_msg, AssistantMessage) + assert assistant_msg.tool_calls is not None + assert len(assistant_msg.tool_calls) == 1 + assert assistant_msg.tool_calls[0].id == 'pyd_ai_builtin|anthropic|call_orphan' + + +def test_dump_load_roundtrip_cache_point() -> None: + """Test that CachePoint is filtered out during round-trip (it's metadata only).""" + original: list[ModelMessage] = [ + ModelRequest( + parts=[ + UserPromptPart(content=['Hello', CachePoint(), 'world']), + ] + ), + ModelResponse(parts=[TextPart(content='Hi!')]), + ] + expected: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content=['Hello', 'world'])]), + ModelResponse(parts=[TextPart(content='Hi!')]), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(expected, reloaded) + + assert reloaded == expected + + +def test_dump_load_roundtrip_uploaded_file() -> None: + """Test that UploadedFile is filtered out during round-trip (opaque provider file_id).""" + original: list[ModelMessage] = [ + ModelRequest( + parts=[ + UserPromptPart( + content=['Hello', UploadedFile(file_id='file-abc123', provider_name='anthropic'), 'world'] + ), + ] + ), + ModelResponse(parts=[TextPart(content='Hi!')]), + ] + expected: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content=['Hello', 'world'])]), + ModelResponse(parts=[TextPart(content='Hi!')]), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(expected, reloaded) + + assert reloaded == expected + + +def test_dump_load_roundtrip_retry_prompt_with_tool() -> None: + """Test round-trip for RetryPromptPart with tool_name (converted to ToolMessage with error).""" + original: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Call tool')]), + ModelResponse(parts=[ToolCallPart(tool_name='my_tool', tool_call_id='call_1', args='{}')]), + ModelRequest( + parts=[ + RetryPromptPart( + tool_name='my_tool', + tool_call_id='call_1', + content='Invalid args', + ) + ] + ), + ModelResponse(parts=[TextPart(content='OK')]), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(original, reloaded) + + # RetryPromptPart becomes ToolReturnPart on reload (same tool_call_id mapping) + assert len(reloaded) == 4 + assert isinstance(reloaded[2], ModelRequest) + retry_part = reloaded[2].parts[0] + assert isinstance(retry_part, ToolReturnPart) + assert retry_part.tool_name == 'my_tool' + assert retry_part.tool_call_id == 'call_1' + + +def test_dump_load_roundtrip_retry_prompt_without_tool() -> None: + """Test round-trip for RetryPromptPart without tool_name (converted to UserMessage).""" + original: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Do something')]), + ModelResponse(parts=[TextPart(content='Done')]), + ModelRequest(parts=[RetryPromptPart(content='Please try again')]), + ModelResponse(parts=[TextPart(content='OK')]), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(original, reloaded) + + # RetryPromptPart without tool becomes UserPromptPart on reload + # Content is formatted by RetryPromptPart.model_response() + assert len(reloaded) == 4 + assert isinstance(reloaded[2], ModelRequest) + retry_part = reloaded[2].parts[0] + assert isinstance(retry_part, UserPromptPart) + assert 'Please try again' in str(retry_part.content) + + +def test_dump_load_roundtrip_file_part_minimal() -> None: + """Test round-trip for FilePart without optional attributes (id, provider_name, provider_details).""" + file_data = b'minimal file' + original: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Generate')]), + ModelResponse( + parts=[ + FilePart(content=BinaryImage(data=file_data, media_type='image/png')), + TextPart(content='Done'), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original, preserve_file_data=True) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs, preserve_file_data=True) + _sync_timestamps(original, reloaded) + + assert reloaded == original + + +def test_dump_load_roundtrip_file_part_only() -> None: + """Test round-trip for response with only FilePart (no text, no tool calls).""" + file_data = b'only file' + original: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Generate image only')]), + ModelResponse(parts=[FilePart(content=BinaryImage(data=file_data, media_type='image/png'))]), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original, preserve_file_data=True) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs, preserve_file_data=True) + _sync_timestamps(original, reloaded) + + assert reloaded == original + + +def test_file_part_dropped_by_default() -> None: + """Test that FilePart is silently dropped when include_file_parts=False (default). + + dump_messages drops FilePart from output, and load_messages ignores + ActivityMessage(pydantic_ai_file) — both without raising errors. + """ + messages_with_file: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Generate an image')]), + ModelResponse( + parts=[ + FilePart(content=BinaryImage(data=b'image data', media_type='image/png')), + TextPart(content='Here is your image.'), + ] + ), + ] + + # dump_messages drops FilePart by default + ag_ui_msgs = AGUIAdapter.dump_messages(messages_with_file) + assert not any(isinstance(m, ActivityMessage) and m.activity_type == 'pydantic_ai_file' for m in ag_ui_msgs) + + # load_messages ignores ActivityMessage(pydantic_ai_file) by default + ag_ui_msgs_with_activity = AGUIAdapter.dump_messages(messages_with_file, preserve_file_data=True) + reloaded = AGUIAdapter.load_messages(ag_ui_msgs_with_activity) + assert not any(isinstance(part, FilePart) for msg in reloaded for part in msg.parts) + + +def test_dump_load_roundtrip_interleaved_text_and_tools() -> None: + """Test round-trip for response with text interleaved around tool calls. + + When text appears after tool calls, the flush pattern splits them into + separate AssistantMessages to preserve ordering on round-trip. + """ + original: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Do things')]), + ModelResponse( + parts=[ + TextPart(content='Before tools'), + ToolCallPart(tool_name='search', args='{"q": "test"}', tool_call_id='call_1'), + TextPart(content='After tools'), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original) + + # Text before tools shares an AssistantMessage with the tool call; + # text after tools gets its own AssistantMessage. + assert [m.model_dump(exclude={'id'}, exclude_none=True) for m in ag_ui_msgs] == snapshot( + [ + {'role': 'user', 'content': 'Do things'}, + { + 'role': 'assistant', + 'content': 'Before tools', + 'tool_calls': [ + { + 'id': 'call_1', + 'type': 'function', + 'function': {'name': 'search', 'arguments': '{"q": "test"}'}, + }, + ], + }, + {'role': 'assistant', 'content': 'After tools'}, + ] + ) + + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(original, reloaded) + + # Round-trip splits into two ModelResponses due to the two AssistantMessages + assert reloaded == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='Do things', timestamp=IsDatetime())]), + ModelResponse( + parts=[ + TextPart(content='Before tools'), + ToolCallPart(tool_name='search', args='{"q": "test"}', tool_call_id='call_1'), + TextPart(content='After tools'), + ], + timestamp=IsDatetime(), + ), + ] + ) + + +async def test_reasoning_events_empty_content_with_metadata() -> None: + """Test REASONING_* events for ThinkingPart with no content but with metadata. + + This exercises the path in handle_thinking_end where _reasoning_started is False + (no content was streamed) but encrypted metadata is present — e.g. redacted thinking. + """ + run_input = create_input(UserMessage(id='msg_1', content='test')) + event_stream = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.13') + + part = ThinkingPart( + content='', + id='think_redacted', + signature='sig_redacted', + ) + + events: list[BaseEvent] = [e async for e in event_stream.handle_thinking_start(part)] + async for e in event_stream.handle_thinking_end(part): + events.append(e) + + assert [e.model_dump(exclude_none=True) for e in events] == snapshot( + [ + {'type': 'REASONING_START', 'message_id': IsStr()}, + { + 'type': 'REASONING_ENCRYPTED_VALUE', + 'subtype': 'message', + 'entity_id': IsStr(), + 'encrypted_value': '{"id": "think_redacted", "signature": "sig_redacted"}', + }, + {'type': 'REASONING_END', 'message_id': IsStr()}, + ] + ) + + +@pytest.mark.vcr() +@pytest.mark.skipif(not anthropic_imports_successful(), reason='anthropic not installed') +async def test_thinking_roundtrip_anthropic(allow_model_requests: None, anthropic_api_key: str) -> None: + """Test that pydantic -> AG-UI -> pydantic round-trip preserves thinking metadata with real Anthropic responses.""" + m = AnthropicModel('claude-sonnet-4-5', provider=AnthropicProvider(api_key=anthropic_api_key)) + settings: AnthropicModelSettings = {'anthropic_thinking': {'type': 'enabled', 'budget_tokens': 1024}} + agent: Agent[None, str] = Agent(m, model_settings=settings) + + result = await agent.run('What is 1+1? Reply in one word.') + original = result.all_messages() + + ag_ui_msgs = AGUIAdapter.dump_messages(original, ag_ui_version='0.1.13') + reloaded = AGUIAdapter.load_messages(ag_ui_msgs) + _sync_timestamps(original, reloaded) + + assert reloaded == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='What is 1+1? Reply in one word.', timestamp=IsDatetime())]), + ModelResponse( + parts=[ + ThinkingPart( + content='The user is asking what 1+1 equals and wants a one-word reply. The answer is 2, which is one word.', + signature='EooCCkYICxgCKkDYW6Ka+Mo73ZE34HVijmFbdV6QH/iRdv+3WuisH3pR8D5aSFASMBsF1F1bZRQFQXuM0+G4H83czthKvHqdqWriEgwB0eJaWoXZWU18NKoaDMH4nN8ZwJ6W9DnYLyIwrdTWmfc5QTqDr8gye3/yrPpV2YPeZnUBoHBLOGl8MUaC6SuGmxcm8rGqf2s+P+ZtKnJPJJzQiTrvPcEkF3ij22w3bXC9yoyZCyJVPcibR2ZZpLYF/UOoZ+BRBs0FCdm/QFXUUe8W1tcQ/ZQgBaW44LTcdzwOSP5hJb25UrPiGWuTytGMxIr7QyG7INpVbmm8JRBIIEzj3gs2zlxdbl17yZ/yZXcYAQ==', + provider_name='anthropic', + ), + TextPart(content='Two'), + ], + timestamp=IsDatetime(), + ), + ] + ) + + async def test_tool_local_then_ag_ui() -> None: """Test mixed local and AG-UI tool calls.""" @@ -1759,37 +2578,19 @@ async def test_messages(image_content: BinaryContent, document_content: BinaryCo timestamp=IsDatetime(), ), UserPromptPart( - content=[ - ImageUrl( - url='http://example.com/image.png', _media_type='image/png', media_type='image/png' - ) - ], + content=[ImageUrl(url='http://example.com/image.png', _media_type='image/png')], timestamp=IsDatetime(), ), UserPromptPart( - content=[ - VideoUrl( - url='http://example.com/video.mp4', _media_type='video/mp4', media_type='video/mp4' - ) - ], + content=[VideoUrl(url='http://example.com/video.mp4', _media_type='video/mp4')], timestamp=IsDatetime(), ), UserPromptPart( - content=[ - AudioUrl( - url='http://example.com/audio.mp3', _media_type='audio/mpeg', media_type='audio/mpeg' - ) - ], + content=[AudioUrl(url='http://example.com/audio.mp3', _media_type='audio/mpeg')], timestamp=IsDatetime(), ), UserPromptPart( - content=[ - DocumentUrl( - url='http://example.com/doc.pdf', - _media_type='application/pdf', - media_type='application/pdf', - ) - ], + content=[DocumentUrl(url='http://example.com/doc.pdf', _media_type='application/pdf')], timestamp=IsDatetime(), ), UserPromptPart( @@ -2575,6 +3376,126 @@ async def send(data: MutableMapping[str, Any]) -> None: ) +async def test_stray_tool_call_delta_after_end() -> None: + """Test that TOOL_CALL_ARGS events are suppressed after TOOL_CALL_END for the same tool call.""" + run_input = create_input(UserMessage(id='msg_1', content='test')) + event_stream = AGUIEventStream(run_input=run_input) + + part = BuiltinToolCallPart( + tool_name='web_search', + tool_call_id='call_123', + args='{"query": "test"}', + provider_name='anthropic', + ) + + events: list[BaseEvent] = [] + async for e in event_stream.handle_builtin_tool_call_start(part): + events.append(e) + async for e in event_stream.handle_builtin_tool_call_end(part): + events.append(e) + + stray_delta = ToolCallPartDelta(tool_call_id='call_123', args_delta='{"extra": true}') + async for e in event_stream.handle_tool_call_delta(stray_delta): + events.append(e) # pragma: no cover + + event_types = [e.type.value for e in events] + assert 'TOOL_CALL_START' in event_types + assert 'TOOL_CALL_END' in event_types + # No TOOL_CALL_ARGS after TOOL_CALL_END + end_idx = event_types.index('TOOL_CALL_END') + assert 'TOOL_CALL_ARGS' not in event_types[end_idx + 1 :] + + +def test_dump_load_roundtrip_uploaded_file_preserved() -> None: + """Test UploadedFile round-trips via ActivityMessage when preserve_file_data=True.""" + original: list[ModelMessage] = [ + ModelRequest( + parts=[ + UserPromptPart( + content=[ + 'Describe this file', + UploadedFile( + file_id='file-abc123', + provider_name='anthropic', + media_type='application/pdf', + vendor_metadata={'source': 'upload'}, + identifier='my-doc.pdf', + ), + ] + ), + ] + ), + ModelResponse(parts=[TextPart(content='I see a PDF.')]), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(original, preserve_file_data=True) + + # Verify ActivityMessage was emitted + activity_msgs = [m for m in ag_ui_msgs if isinstance(m, ActivityMessage)] + assert len(activity_msgs) == 1 + assert activity_msgs[0].activity_type == 'pydantic_ai_uploaded_file' + assert activity_msgs[0].content['file_id'] == 'file-abc123' + + reloaded = AGUIAdapter.load_messages(ag_ui_msgs, preserve_file_data=True) + + # The text and UploadedFile come back as separate UserPromptParts + request_parts = [p for msg in reloaded if isinstance(msg, ModelRequest) for p in msg.parts] + user_parts = [p for p in request_parts if isinstance(p, UserPromptPart)] + assert len(user_parts) == 2 + + # First UserPromptPart has the text + assert user_parts[0].content == 'Describe this file' + + # Second UserPromptPart has the UploadedFile + assert isinstance(user_parts[1].content, list) + uploaded = user_parts[1].content[0] + assert isinstance(uploaded, UploadedFile) + assert uploaded.file_id == 'file-abc123' + assert uploaded.provider_name == 'anthropic' + assert uploaded.media_type == 'application/pdf' + assert uploaded.vendor_metadata == {'source': 'upload'} + assert uploaded.identifier == 'my-doc.pdf' + + +def test_dump_messages_v010_drops_thinking() -> None: + """Test that dump_messages with ag_ui_version='0.1.10' drops ThinkingPart.""" + messages: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Think about this')]), + ModelResponse( + parts=[ + ThinkingPart(content='Deep thoughts...', signature='sig_xyz'), + TextPart(content='Conclusion'), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(messages, ag_ui_version='0.1.10') + # No ReasoningMessage in output + assert not any(isinstance(m, ReasoningMessage) for m in ag_ui_msgs) + # Text still present + assert any(isinstance(m, AssistantMessage) and m.content == 'Conclusion' for m in ag_ui_msgs) + + +def test_dump_messages_v013_includes_reasoning() -> None: + """Test that dump_messages with ag_ui_version='0.1.13' includes ThinkingPart as ReasoningMessage.""" + messages: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content='Think about this')]), + ModelResponse( + parts=[ + ThinkingPart(content='Deep thoughts...', signature='sig_xyz'), + TextPart(content='Conclusion'), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(messages, ag_ui_version='0.1.13') + reasoning_msgs = [m for m in ag_ui_msgs if isinstance(m, ReasoningMessage)] + assert len(reasoning_msgs) == 1 + assert reasoning_msgs[0].content == 'Deep thoughts...' + assert reasoning_msgs[0].encrypted_value is not None + assert 'sig_xyz' in reasoning_msgs[0].encrypted_value + + async def test_tool_return_with_files(): """Test that tool returns with files include file descriptions in the output.""" @@ -2624,3 +3545,340 @@ async def event_generator(): }, ] ) + + +# region: Coverage — event_stream thinking version branches + + +async def test_thinking_events_v010_with_content() -> None: + """Test v0.1.10 THINKING_* events for ThinkingPart with content.""" + run_input = create_input(UserMessage(id='msg_1', content='test')) + event_stream = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.10') + + part = ThinkingPart(content='Some thoughts', signature='sig_abc') + + events: list[BaseEvent] = [] + async for e in event_stream.handle_thinking_start(part): + events.append(e) + async for e in event_stream.handle_thinking_end(part): + events.append(e) + + assert [e.model_dump(exclude_none=True) for e in events] == snapshot( + [ + {'type': 'THINKING_START'}, + {'type': 'THINKING_TEXT_MESSAGE_START'}, + {'type': 'THINKING_TEXT_MESSAGE_CONTENT', 'delta': 'Some thoughts'}, + {'type': 'THINKING_TEXT_MESSAGE_END'}, + {'type': 'THINKING_END'}, + ] + ) + + +async def test_thinking_events_v010_empty_content() -> None: + """Test v0.1.10 early return when ThinkingPart has no content.""" + run_input = create_input(UserMessage(id='msg_1', content='test')) + event_stream = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.10') + + part = ThinkingPart(content='', signature='sig_abc') + + events = [e async for e in event_stream.handle_thinking_start(part)] + events.extend([e async for e in event_stream.handle_thinking_end(part)]) + + assert events == [] + + +async def test_thinking_delta_v013() -> None: + """Test v0.1.13 REASONING_* events emitted via handle_thinking_delta.""" + run_input = create_input(UserMessage(id='msg_1', content='test')) + event_stream = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.13') + + start_part = ThinkingPart(content='') + events: list[BaseEvent] = [e async for e in event_stream.handle_thinking_start(start_part)] + + delta = ThinkingPartDelta(content_delta='chunk1') + async for e in event_stream.handle_thinking_delta(delta): + events.append(e) + + assert [e.model_dump(exclude_none=True) for e in events] == snapshot( + [ + {'type': 'REASONING_START', 'message_id': IsStr()}, + {'type': 'REASONING_MESSAGE_START', 'message_id': IsStr(), 'role': 'assistant'}, + {'type': 'REASONING_MESSAGE_CONTENT', 'message_id': IsStr(), 'delta': 'chunk1'}, + ] + ) + + +async def test_thinking_end_v013_no_content_no_metadata() -> None: + """Test v0.1.13 early return when ThinkingPart has no content and no encrypted metadata.""" + run_input = create_input(UserMessage(id='msg_1', content='test')) + event_stream = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.13') + + part = ThinkingPart(content='') + + events = [e async for e in event_stream.handle_thinking_start(part)] + events.extend([e async for e in event_stream.handle_thinking_end(part)]) + + assert events == [] + + +async def test_thinking_delta_v013_after_content_start() -> None: + """Test v0.1.13 delta skips START/MESSAGE_START when reasoning already started.""" + run_input = create_input(UserMessage(id='msg_1', content='test')) + event_stream = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.13') + + start_part = ThinkingPart(content='initial') + events = [e async for e in event_stream.handle_thinking_start(start_part)] + + delta = ThinkingPartDelta(content_delta='more') + events.extend([e async for e in event_stream.handle_thinking_delta(delta)]) + + assert [e.model_dump(exclude_none=True) for e in events] == snapshot( + [ + {'type': 'REASONING_START', 'message_id': IsStr()}, + {'type': 'REASONING_MESSAGE_START', 'message_id': IsStr(), 'role': 'assistant'}, + {'type': 'REASONING_MESSAGE_CONTENT', 'message_id': IsStr(), 'delta': 'initial'}, + {'type': 'REASONING_MESSAGE_CONTENT', 'message_id': IsStr(), 'delta': 'more'}, + ] + ) + + +async def test_thinking_end_v010_with_content() -> None: + """Test v0.1.10 end emits TextMessageEnd when content was streamed, and ThinkingStart when not started.""" + run_input = create_input(UserMessage(id='msg_1', content='test')) + + # Case 1: start with content → _reasoning_started=True, _reasoning_text=True + # end should emit TextMessageEnd + ThinkingEnd + event_stream = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.10') + part = ThinkingPart(content='text') + events = [e async for e in event_stream.handle_thinking_start(part)] + events.extend([e async for e in event_stream.handle_thinking_end(part)]) + + assert [e.model_dump(exclude_none=True) for e in events] == snapshot( + [ + {'type': 'THINKING_START'}, + {'type': 'THINKING_TEXT_MESSAGE_START'}, + {'type': 'THINKING_TEXT_MESSAGE_CONTENT', 'delta': 'text'}, + {'type': 'THINKING_TEXT_MESSAGE_END'}, + {'type': 'THINKING_END'}, + ] + ) + + # Case 2: start with empty content → _reasoning_started=False + # end with content → hits ThinkingStartEvent at line 246 + event_stream2 = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.10') + empty_part = ThinkingPart(content='') + events2 = [e async for e in event_stream2.handle_thinking_start(empty_part)] + + full_part = ThinkingPart(content='non-empty') + events2.extend([e async for e in event_stream2.handle_thinking_end(full_part)]) + + assert [e.model_dump(exclude_none=True) for e in events2] == snapshot( + [ + {'type': 'THINKING_START'}, + {'type': 'THINKING_END'}, + ] + ) + + +async def test_thinking_end_v013_no_encrypted_metadata() -> None: + """Test v0.1.13 end skips encrypted_value event when part has no signature or metadata.""" + run_input = create_input(UserMessage(id='msg_1', content='test')) + event_stream = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.13') + + part = ThinkingPart(content='text') + events = [e async for e in event_stream.handle_thinking_start(part)] + events.extend([e async for e in event_stream.handle_thinking_end(part)]) + + assert [e.model_dump(exclude_none=True) for e in events] == snapshot( + [ + {'type': 'REASONING_START', 'message_id': IsStr()}, + {'type': 'REASONING_MESSAGE_START', 'message_id': IsStr(), 'role': 'assistant'}, + {'type': 'REASONING_MESSAGE_CONTENT', 'message_id': IsStr(), 'delta': 'text'}, + {'type': 'REASONING_MESSAGE_END', 'message_id': IsStr()}, + {'type': 'REASONING_END', 'message_id': IsStr()}, + ] + ) + + +# endregion + +# region: Coverage — encrypted_metadata branch gap + + +async def test_thinking_encrypted_metadata_partial_fields() -> None: + """Test thinking_encrypted_metadata with signature but no provider_name.""" + run_input = create_input(UserMessage(id='msg_1', content='test')) + event_stream = AGUIEventStream(run_input, accept=SSE_CONTENT_TYPE, ag_ui_version='0.1.13') + + part = ThinkingPart(content='Thoughts', signature='sig_only') + + events: list[BaseEvent] = [] + async for e in event_stream.handle_thinking_start(part): + events.append(e) + async for e in event_stream.handle_thinking_end(part): + events.append(e) + + assert [e.model_dump(exclude_none=True) for e in events] == snapshot( + [ + {'type': 'REASONING_START', 'message_id': IsStr()}, + {'type': 'REASONING_MESSAGE_START', 'message_id': IsStr(), 'role': 'assistant'}, + {'type': 'REASONING_MESSAGE_CONTENT', 'message_id': IsStr(), 'delta': 'Thoughts'}, + {'type': 'REASONING_MESSAGE_END', 'message_id': IsStr()}, + { + 'type': 'REASONING_ENCRYPTED_VALUE', + 'subtype': 'message', + 'entity_id': IsStr(), + 'encrypted_value': '{"signature": "sig_only"}', + }, + {'type': 'REASONING_END', 'message_id': IsStr()}, + ] + ) + + +# endregion + +# region: Coverage — adapter uploaded file edge cases + + +def test_load_messages_uploaded_file_missing_fields() -> None: + """Test load_messages raises ValueError for malformed pydantic_ai_uploaded_file ActivityMessage.""" + with pytest.raises(ValueError, match='must have non-empty file_id and provider_name'): + AGUIAdapter.load_messages( + [ActivityMessage(id='msg_1', activity_type='pydantic_ai_uploaded_file', content={})], + preserve_file_data=True, + ) + + +def test_dump_messages_uploaded_file_with_vendor_metadata() -> None: + """Test dump_messages includes vendor_metadata in ActivityMessage when present on UploadedFile.""" + messages: list[ModelMessage] = [ + ModelRequest( + parts=[ + UserPromptPart( + content=[ + UploadedFile( + file_id='file-xyz', + provider_name='openai', + media_type='text/plain', + vendor_metadata={'custom': 'data'}, + ), + ] + ), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(messages, preserve_file_data=True) + activity_msgs = [m for m in ag_ui_msgs if isinstance(m, ActivityMessage)] + assert [m.model_dump() for m in activity_msgs] == snapshot( + [ + { + 'id': IsStr(), + 'role': 'activity', + 'activity_type': 'pydantic_ai_uploaded_file', + 'content': { + 'file_id': 'file-xyz', + 'provider_name': 'openai', + 'media_type': 'text/plain', + 'identifier': '6f0bbc', + 'vendor_metadata': {'custom': 'data'}, + }, + } + ] + ) + + +def test_dump_messages_uploaded_file_without_vendor_metadata() -> None: + """Test dump_messages omits vendor_metadata from ActivityMessage when None on UploadedFile.""" + messages: list[ModelMessage] = [ + ModelRequest( + parts=[ + UserPromptPart( + content=[ + UploadedFile( + file_id='file-xyz', + provider_name='openai', + media_type='text/plain', + ), + ] + ), + ] + ), + ] + + ag_ui_msgs = AGUIAdapter.dump_messages(messages, preserve_file_data=True) + activity_msgs = [m for m in ag_ui_msgs if isinstance(m, ActivityMessage)] + assert [m.model_dump() for m in activity_msgs] == snapshot( + [ + { + 'id': IsStr(), + 'role': 'activity', + 'activity_type': 'pydantic_ai_uploaded_file', + 'content': { + 'file_id': 'file-xyz', + 'provider_name': 'openai', + 'media_type': 'text/plain', + 'identifier': '6f0bbc', + }, + } + ] + ) + + +# endregion + + +# region: Coverage — parse_ag_ui_version validation + TextContent + detect fallback + + +def test_parse_ag_ui_version_invalid() -> None: + """Test that parse_ag_ui_version raises UserError for malformed input.""" + with pytest.raises(UserError, match="Invalid AG-UI version 'latest'"): + parse_ag_ui_version('latest') + + with pytest.raises(UserError, match="Invalid AG-UI version ''"): + parse_ag_ui_version('') + + +def test_parse_ag_ui_version_prerelease() -> None: + """Test that parse_ag_ui_version strips pre-release suffixes.""" + assert parse_ag_ui_version('0.1.13a1') == snapshot((0, 1, 13)) + assert parse_ag_ui_version('0.1.13b2') == snapshot((0, 1, 13)) + assert parse_ag_ui_version('0.1.13rc1') == snapshot((0, 1, 13)) + assert parse_ag_ui_version('0.1.13.dev0') == snapshot((0, 1, 13)) + assert parse_ag_ui_version('0.1.x') == snapshot((0, 1)) + + +def test_detect_ag_ui_version_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that _detect_ag_ui_version returns '0.1.10' when package is not found.""" + + def _raise_not_found(_name: str) -> str: + raise importlib.metadata.PackageNotFoundError() + + monkeypatch.setattr('pydantic_ai.ui.ag_ui._event_stream.importlib.metadata.version', _raise_not_found) + assert _detect_ag_ui_version() == snapshot('0.1.10') + + +def test_detect_ag_ui_version_old(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that _detect_ag_ui_version returns '0.1.10' when installed version is below REASONING_VERSION.""" + + def _return_old_version(_name: str) -> str: + return '0.1.10' + + monkeypatch.setattr('pydantic_ai.ui.ag_ui._event_stream.importlib.metadata.version', _return_old_version) + assert _detect_ag_ui_version() == snapshot('0.1.10') + + +def test_dump_messages_text_content() -> None: + """Test that TextContent in UserPromptPart is converted to TextInputContent.""" + messages: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content=[TextContent(content='hello')])]), + ] + + result = AGUIAdapter.dump_messages(messages) + assert [m.model_dump(exclude={'id'}, exclude_none=True) for m in result] == snapshot( + [{'role': 'user', 'content': 'hello'}] + ) + + +# endregion diff --git a/uv.lock b/uv.lock index 56c29ce613..dc679ba373 100644 --- a/uv.lock +++ b/uv.lock @@ -47,14 +47,14 @@ wheels = [ [[package]] name = "ag-ui-protocol" -version = "0.1.10" +version = "0.1.13" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/67/bb/5a5ec893eea5805fb9a3db76a9888c3429710dfb6f24bbb37568f2cf7320/ag_ui_protocol-0.1.10.tar.gz", hash = "sha256:3213991c6b2eb24bb1a8c362ee270c16705a07a4c5962267a083d0959ed894f4", size = 6945, upload-time = "2025-11-06T15:17:17.068Z" } +sdist = { url = "https://files.pythonhosted.org/packages/04/b5/fc0b65b561d00d88811c8a7d98ee735833f81554be244340950e7b65820c/ag_ui_protocol-0.1.13.tar.gz", hash = "sha256:811d7d7dcce4783dec252918f40b717ebfa559399bf6b071c4ba47c0c1e21bcb", size = 5671, upload-time = "2026-02-19T18:40:38.602Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/78/eb55fabaab41abc53f52c0918a9a8c0f747807e5306273f51120fd695957/ag_ui_protocol-0.1.10-py3-none-any.whl", hash = "sha256:c81e6981f30aabdf97a7ee312bfd4df0cd38e718d9fc10019c7d438128b93ab5", size = 7889, upload-time = "2025-11-06T15:17:15.325Z" }, + { url = "https://files.pythonhosted.org/packages/cd/9f/b833c1ab1999da35ebad54841ae85d2c2764c931da9a6f52d8541b6901b2/ag_ui_protocol-0.1.13-py3-none-any.whl", hash = "sha256:1393fa894c1e8416efe184168a50689e760d05b32f4646eebb8ff423dddf8e8f", size = 8053, upload-time = "2026-02-19T18:40:37.27Z" }, ] [[package]]