-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fix capability schema generation to include full parameter types #4867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
9f4b902
257a750
995e23f
762df2f
3094803
1f3fa6d
f474a64
1312568
86c59d7
5b221d3
e734712
6b638e7
60156b3
983ddfd
46e25b1
e8ec85b
ff2d812
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -224,17 +224,61 @@ def load_from_registry( | |
| raise ValueError(f'Failed to instantiate {label} {spec.name!r}{detail}: {e}') from e | ||
|
|
||
|
|
||
| def filter_serializable_type(tp: Any) -> Any | None: | ||
| """Filter a type to only include members that can be represented in JSON schema. | ||
|
|
||
| For Union types, removes non-serializable members (TypeVars, Callables). | ||
| Returns None if the type is entirely non-serializable. | ||
| """ | ||
| return _filter_serializable_type(tp) | ||
|
|
||
|
|
||
| def _filter_serializable_type(tp: object) -> object | None: | ||
| import types | ||
| import typing | ||
|
||
|
|
||
| # TypeVar is not serializable | ||
| if isinstance(tp, TypeVar): | ||
| return None | ||
|
|
||
| origin = typing.get_origin(tp) | ||
|
|
||
| # Callable is not serializable | ||
| if origin is Callable: | ||
| return None | ||
|
|
||
| # Union: filter members | ||
| if origin is typing.Union or isinstance(tp, types.UnionType): | ||
| args = typing.get_args(tp) | ||
| filtered = [fa for a in args if (fa := _filter_serializable_type(a)) is not None] | ||
| if not filtered: # pragma: no cover — requires union of only non-serializable types | ||
|
||
| return None | ||
| if len(filtered) == 1: | ||
| return filtered[0] | ||
| return typing.Union[tuple(filtered)] # noqa: UP007 | ||
|
|
||
| # Other generics (list[X], dict[X, Y]): all args must be serializable | ||
| args = typing.get_args(tp) | ||
| if args and not all(_filter_serializable_type(a) is not None for a in args): | ||
| return None | ||
|
|
||
| return tp | ||
|
|
||
|
|
||
| def build_schema_types( | ||
| registry: Mapping[str, type[Any]], | ||
| *, | ||
| get_schema_target: Callable[[type[Any]], Any] | None = None, | ||
| filter_type_hint: Callable[[Any], Any | None] | None = None, | ||
|
||
| ) -> list[Any]: | ||
| """Build a list of schema types from a registry for JSON schema generation. | ||
|
|
||
| Args: | ||
| registry: Mapping from names to classes. | ||
| get_schema_target: Optional callback to get the schema target (e.g. `from_spec` method) | ||
| from a class. Default: use the class itself. | ||
| filter_type_hint: Optional callback to filter type hints. Called on each resolved type hint; | ||
| return the (possibly modified) type, or None to exclude the parameter. | ||
|
|
||
| Returns: | ||
| A list of types suitable for use in a Union for JSON schema generation. | ||
|
|
@@ -244,13 +288,24 @@ def build_schema_types( | |
| target = get_schema_target(cls) if get_schema_target is not None else cls | ||
| type_hints = get_function_type_hints(target) | ||
| type_hints.pop('return', None) | ||
|
|
||
| # Apply type filtering if provided | ||
| if filter_type_hint is not None: | ||
| type_hints = {k: fv for k, v in type_hints.items() if (fv := filter_type_hint(v)) is not None} | ||
|
|
||
| required_type_hints: dict[str, Any] = {} | ||
|
|
||
| for p in inspect.signature(target).parameters.values(): | ||
| # Skip *args and **kwargs — they can't be represented as typed dict fields | ||
| # Skip self/cls (unbound instance/class methods) and *args/**kwargs | ||
| if p.name in ('self', 'cls') and p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD): | ||
| type_hints.pop(p.name, None) | ||
| continue | ||
| if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD): | ||
| type_hints.pop(p.name, None) | ||
| continue | ||
| # When filtering, skip params whose type was entirely filtered out | ||
| if filter_type_hint is not None and p.name not in type_hints: | ||
| continue | ||
| type_hints.setdefault(p.name, Any) | ||
| if p.default is not p.empty: | ||
| type_hints[p.name] = NotRequired[type_hints[p.name]] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -325,7 +325,24 @@ def load_capability_from_nested_spec(spec: dict[str, Any] | str) -> AbstractCapa | |
|
|
||
| def _build_capability_schema_types(registry: Mapping[str, type[Any]]) -> list[Any]: | ||
| """Build a list of schema types for capabilities from a registry.""" | ||
| from pydantic_ai._utils import get_function_type_hints | ||
|
||
|
|
||
| def _get_schema_target(cls: type[Any]) -> Any: | ||
| # When from_spec is not overridden, it delegates to cls(*args, **kwargs). | ||
| # Use __init__ directly so build_schema_types sees the actual parameter types. | ||
| # Fall back to from_spec if __init__ hints can't be resolved (e.g. TYPE_CHECKING imports). | ||
| if 'from_spec' not in cls.__dict__: | ||
| try: | ||
| get_function_type_hints(cls.__init__) | ||
| return cls.__init__ | ||
| except Exception: | ||
|
||
| pass | ||
| return cls.from_spec | ||
|
Comment on lines
+338
to
+348
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Schema precision varies by Python version and MCP installation status The combination of
This is by design but worth documenting, as users generating schemas in different environments could get inconsistent results. Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| import pydantic_ai._spec | ||
|
|
||
| return build_schema_types( | ||
| registry, | ||
| get_schema_target=lambda cls: cls.from_spec, | ||
| get_schema_target=_get_schema_target, | ||
| filter_type_hint=pydantic_ai._spec.filter_serializable_type, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,30 @@ def __init__( | |
| self.description = description | ||
| self.__post_init__() | ||
|
|
||
| @classmethod | ||
| def from_spec( | ||
|
||
| cls, | ||
| url: str, | ||
| *, | ||
| builtin: MCPServerTool | bool = True, | ||
| local: Literal[False] | None = None, | ||
| id: str | None = None, | ||
| authorization_token: str | None = None, | ||
| headers: dict[str, str] | None = None, | ||
| allowed_tools: list[str] | None = None, | ||
| description: str | None = None, | ||
| ) -> MCP[Any]: | ||
| return cls( | ||
| url=url, | ||
| builtin=builtin, | ||
| local=local, | ||
| id=id, | ||
| authorization_token=authorization_token, | ||
| headers=headers, | ||
| allowed_tools=allowed_tools, | ||
| description=description, | ||
| ) | ||
|
|
||
| @cached_property | ||
| def _resolved_id(self) -> str: | ||
| if self.id: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.