-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fix capability schema generation to include full parameter types #4867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
9f4b902
257a750
995e23f
762df2f
3094803
1f3fa6d
f474a64
1312568
86c59d7
5b221d3
e734712
6b638e7
60156b3
983ddfd
46e25b1
e8ec85b
ff2d812
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
| from __future__ import annotations | ||
|
|
||
| import inspect | ||
| import types # used at runtime in filter_serializable_type | ||
| import typing # used at runtime in filter_serializable_type | ||
|
||
| from collections.abc import Callable, Mapping, Sequence | ||
| from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast | ||
|
|
||
|
|
@@ -224,17 +226,54 @@ def load_from_registry( | |
| raise ValueError(f'Failed to instantiate {label} {spec.name!r}{detail}: {e}') from e | ||
|
|
||
|
|
||
| def filter_serializable_type(tp: Any) -> Any | None: | ||
| """Filter a type to only include members that can be represented in JSON schema. | ||
|
|
||
| For Union types, removes non-serializable members (TypeVars, Callables). | ||
| Returns None if the type is entirely non-serializable. | ||
| """ | ||
| # TypeVar is not serializable | ||
| if isinstance(tp, TypeVar): | ||
| return None | ||
|
|
||
| origin = typing.get_origin(tp) | ||
|
|
||
| # Callable is not serializable | ||
| if origin is Callable: | ||
| return None | ||
|
|
||
| # Union: filter members | ||
| if origin is typing.Union or isinstance(tp, types.UnionType): | ||
| args = typing.get_args(tp) | ||
| filtered = [fa for a in args if (fa := filter_serializable_type(a)) is not None] | ||
| if not filtered: | ||
| return None | ||
| if len(filtered) == 1: | ||
| return filtered[0] | ||
| return typing.Union[tuple(filtered)] # noqa: UP007 | ||
|
|
||
| # Other generics (list[X], dict[X, Y]): all args must be serializable | ||
| args = typing.get_args(tp) | ||
| if args and any(filter_serializable_type(a) is None for a in args): | ||
| return None | ||
|
|
||
| return tp | ||
|
|
||
|
|
||
| def build_schema_types( | ||
| registry: Mapping[str, type[Any]], | ||
| *, | ||
| get_schema_target: Callable[[type[Any]], Any] | None = None, | ||
| filter_type_hint: Callable[[Any], Any | None] | None = None, | ||
|
||
| ) -> list[Any]: | ||
| """Build a list of schema types from a registry for JSON schema generation. | ||
|
|
||
| Args: | ||
| registry: Mapping from names to classes. | ||
| get_schema_target: Optional callback to get the schema target (e.g. `from_spec` method) | ||
| from a class. Default: use the class itself. | ||
| filter_type_hint: Optional callback to filter type hints. Called on each resolved type hint; | ||
| return the (possibly modified) type, or None to exclude the parameter. | ||
|
|
||
| Returns: | ||
| A list of types suitable for use in a Union for JSON schema generation. | ||
|
|
@@ -244,13 +283,24 @@ def build_schema_types( | |
| target = get_schema_target(cls) if get_schema_target is not None else cls | ||
| type_hints = get_function_type_hints(target) | ||
| type_hints.pop('return', None) | ||
|
|
||
| # Apply type filtering if provided | ||
| if filter_type_hint is not None: | ||
| type_hints = {k: fv for k, v in type_hints.items() if (fv := filter_type_hint(v)) is not None} | ||
|
|
||
| required_type_hints: dict[str, Any] = {} | ||
|
|
||
| for p in inspect.signature(target).parameters.values(): | ||
| # Skip *args and **kwargs — they can't be represented as typed dict fields | ||
| # Skip self/cls (unbound instance/class methods) and *args/**kwargs | ||
| if p.name in ('self', 'cls') and p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD): | ||
| type_hints.pop(p.name, None) | ||
| continue | ||
| if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): | ||
| type_hints.pop(p.name, None) | ||
| continue | ||
| # When filtering, skip params whose type was entirely filtered out | ||
| if filter_type_hint is not None and p.name not in type_hints: | ||
| continue | ||
| type_hints.setdefault(p.name, Any) | ||
| if p.default is not p.empty: | ||
| type_hints[p.name] = NotRequired[type_hints[p.name]] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,8 +12,9 @@ | |
| from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapHandler | ||
|
|
||
| from pydantic_ai._agent_graph import EndStrategy | ||
| from pydantic_ai._spec import NamedSpec, build_registry, build_schema_types | ||
| from pydantic_ai._spec import NamedSpec, build_registry, build_schema_types, filter_serializable_type | ||
| from pydantic_ai._template import TemplateStr | ||
| from pydantic_ai._utils import get_function_type_hints | ||
|
|
||
| if TYPE_CHECKING: | ||
| from pydantic_ai.capabilities.abstract import AbstractCapability | ||
|
|
@@ -211,6 +212,7 @@ class _AgentSpecSchema(BaseModel, extra='forbid'): | |
| capabilities: list[Union[tuple(capability_schema_types)]] = [] # pyright: ignore # noqa: UP007 | ||
|
|
||
| json_schema = _AgentSpecSchema.model_json_schema() | ||
| json_schema['title'] = 'AgentSpec' | ||
| json_schema['properties']['$schema'] = {'type': 'string'} | ||
| return json_schema | ||
|
|
||
|
|
@@ -325,7 +327,21 @@ def load_capability_from_nested_spec(spec: dict[str, Any] | str) -> AbstractCapa | |
|
|
||
| def _build_capability_schema_types(registry: Mapping[str, type[Any]]) -> list[Any]: | ||
| """Build a list of schema types for capabilities from a registry.""" | ||
|
|
||
| def _get_schema_target(cls: type[Any]) -> Any: | ||
| # When from_spec is not overridden, it delegates to cls(*args, **kwargs). | ||
| # Use __init__ directly so build_schema_types sees the actual parameter types. | ||
| # Fall back to from_spec if __init__ hints can't be resolved (e.g. TYPE_CHECKING imports). | ||
| if 'from_spec' not in cls.__dict__: | ||
| try: | ||
| get_function_type_hints(cls.__init__) | ||
| return cls.__init__ | ||
| except (NameError, TypeError, AttributeError): | ||
| pass | ||
| return cls.from_spec | ||
|
Comment on lines
+338
to
+348
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Schema precision varies by Python version and MCP installation status The combination of
This is by design but worth documenting, as users generating schemas in different environments could get inconsistent results. Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| return build_schema_types( | ||
| registry, | ||
| get_schema_target=lambda cls: cls.from_spec, | ||
| get_schema_target=_get_schema_target, | ||
| filter_type_hint=filter_serializable_type, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,13 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable | ||
| from collections.abc import Awaitable, Callable | ||
| from dataclasses import dataclass | ||
| from functools import cached_property | ||
| from typing import TYPE_CHECKING, Any, Literal | ||
| from urllib.parse import urlparse | ||
|
|
||
| from pydantic_ai.builtin_tools import MCPServerTool | ||
| from pydantic_ai.tools import AgentBuiltinTool, AgentDepsT, Tool | ||
| from pydantic_ai.tools import AgentDepsT, RunContext, Tool | ||
| from pydantic_ai.toolsets import AbstractToolset | ||
|
|
||
| from .builtin_or_local import BuiltinOrLocalTool | ||
|
|
@@ -47,7 +47,9 @@ def __init__( | |
| self, | ||
| url: str, | ||
| *, | ||
| builtin: MCPServerTool | AgentBuiltinTool[AgentDepsT] | bool = True, | ||
| builtin: MCPServerTool | ||
| | Callable[[RunContext[AgentDepsT]], Awaitable[MCPServerTool | None] | MCPServerTool | None] | ||
| | bool = True, | ||
| local: MCPServer | FastMCPToolset[AgentDepsT] | Callable[..., Any] | Literal[False] | None = None, | ||
| id: str | None = None, | ||
| authorization_token: str | None = None, | ||
|
|
@@ -65,6 +67,30 @@ def __init__( | |
| self.description = description | ||
| self.__post_init__() | ||
|
|
||
| @classmethod | ||
| def from_spec( | ||
|
||
| cls, | ||
| url: str, | ||
| *, | ||
| builtin: MCPServerTool | bool = True, | ||
| local: Literal[False] | None = None, | ||
| id: str | None = None, | ||
| authorization_token: str | None = None, | ||
| headers: dict[str, str] | None = None, | ||
| allowed_tools: list[str] | None = None, | ||
| description: str | None = None, | ||
| ) -> MCP[Any]: | ||
| return cls( | ||
| url=url, | ||
| builtin=builtin, | ||
| local=local, | ||
| id=id, | ||
| authorization_token=authorization_token, | ||
| headers=headers, | ||
| allowed_tools=allowed_tools, | ||
| description=description, | ||
| ) | ||
|
|
||
| @cached_property | ||
| def _resolved_id(self) -> str: | ||
| if self.id: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.