Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
fail-fast: false
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down
6 changes: 0 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,10 @@ check-codestyle:
mypy:
poetry run mypy --config-file pyproject.toml $(filter-out $@,$(MAKECMDGOALS))

.PHONY: check-safety
check-safety:
poetry check
poetry run safety --disable-telemetry check --full-report --disable-audit-and-monitor

.PHONY: lint
lint:
@$(MAKE) -s check-codestyle $(filter-out $@,$(MAKECMDGOALS))
@$(MAKE) -s mypy $(filter-out $@,$(MAKECMDGOALS))
@$(MAKE) -s check-safety

#* Cleaning
.PHONY: clean
Expand Down
173 changes: 59 additions & 114 deletions apiclient_pydantic/serializers.py
Original file line number Diff line number Diff line change
@@ -1,165 +1,110 @@
from __future__ import annotations

import inspect
from functools import partial
from typing import Any, Awaitable, Callable, Optional, Set, Type, TypeVar, Union
from typing import TYPE_CHECKING, Annotated, Any, Callable, TypeVar, cast

from apiclient import APIClient
from pydantic import AfterValidator, BaseModel, ConfigDict
from pydantic._internal import _generate_schema
from pydantic._internal import _generate_schema, _typing_extra, _validate_call
from pydantic._internal._config import ConfigWrapper
from pydantic._internal._validate_call import ValidateCallWrapper as PydanticValidateCallWrapper
from pydantic._internal._generate_schema import GenerateSchema, ValidateCallSupportedTypes
from pydantic._internal._namespace_utils import MappingNamespace, NsResolver, ns_for_function
from pydantic._internal._validate_call import (
ValidateCallWrapper as PydanticValidateCallWrapper,
extract_function_qualname,
)
from pydantic.plugin._schema_validator import create_schema_validator
from typing_extensions import Annotated

try: # pragma: no cover
from pydantic._internal._typing_extra import get_module_ns_of as get_module
except ImportError: # pragma: no cover
from pydantic._internal._typing_extra import add_module_globals as get_module # type: ignore[attr-defined,no-redef]
from pydantic.validate_call_decorator import _check_function_type

if TYPE_CHECKING:
from collections.abc import Awaitable

AnyCallableT = TypeVar('AnyCallableT', bound=Callable[..., Any])
T = TypeVar('T', bound=APIClient)

AnyCallableT = TypeVar('AnyCallableT', bound=Callable[..., Any])
TModel = TypeVar('TModel', bound=BaseModel)
ModelDumped = Annotated[TModel, AfterValidator(lambda v: v.model_dump(exclude_none=True, by_alias=True))]


class ValidateCallWrapper(PydanticValidateCallWrapper):
__slots__ = ('_response',)
__slots__ = ()

def __init__(
self,
function: Callable[..., Any],
config: Optional[ConfigDict],
function: ValidateCallSupportedTypes,
config: ConfigDict | None,
validate_return: bool,
response: Optional[Type[BaseModel]] = None,
parent_namespace: MappingNamespace | None,
response: type[BaseModel] | None = None,
) -> None:
self.raw_function = function
self._config = config
self._validate_return = validate_return
self._response = response
self.__signature__ = inspect.signature(function)
if isinstance(function, partial):
func = function.func
schema_type = func
self.__name__ = f'partial({func.__name__})'
self.__qualname__ = f'partial({func.__qualname__})'
self.__annotations__ = func.__annotations__
self.__module__ = func.__module__
self.__doc__ = func.__doc__
else:
schema_type = function
self.__name__ = function.__name__
self.__qualname__ = function.__qualname__
self.__annotations__ = function.__annotations__
self.__module__ = function.__module__
self.__doc__ = function.__doc__

namespace = get_module(function)
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
self.__pydantic_core_schema__ = schema
core_config = config_wrapper.core_config(self)

self.__pydantic_validator__ = create_schema_validator(
schema,
schema_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)
if self._validate_return or self._response:
return_type: Any
if not (return_type := self._response):
return_type = (
self.__signature__.return_annotation is not self.__signature__.empty
and self.__signature__.return_annotation
or Any
)

gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
self.__return_pydantic_core_schema__ = schema
super().__init__(function, config, validate_return, parent_namespace)
if response:
if isinstance(function, partial): # pragma: no cover
schema_type = function.func
module = function.func.__module__
else:
schema_type = function
module = function.__module__
qualname = extract_function_qualname(function)

ns_resolver = NsResolver(namespaces_tuple=ns_for_function(schema_type, parent_namespace=parent_namespace))
config_wrapper = ConfigWrapper(config)
core_config = config_wrapper.core_config(title=qualname)

gen_schema = GenerateSchema(config_wrapper, ns_resolver)
schema = gen_schema.clean_schema(gen_schema.generate_schema(response))
validator = create_schema_validator(
schema,
schema_type,
self.__module__,
self.__qualname__,
module,
qualname,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)
if inspect.iscoroutinefunction(self.raw_function):
if inspect.iscoroutinefunction(function): # pragma: no cover

async def return_val_wrapper(aw: Awaitable[Any]) -> None:
return validator.validate_python(await aw)

self.__return_pydantic_validator__ = return_val_wrapper
else:
self.__return_pydantic_validator__ = validator.validate_python
else:
self.__return_pydantic_core_schema__ = None
self.__return_pydantic_validator__ = None # type: ignore[assignment]
self._name: Optional[str] = None # set by __get__, used to set the instance attribute when decorating methods)

def __get__(
self,
obj: Any, # noqa: ANN401
objtype: Optional[Type[Any]] = None,
) -> 'ValidateCallWrapper': # pragma: no cover
"""Bind the raw function and return another ValidateCallWrapper wrapping that."""
# Copy-paste to pass _response to the class
if obj is None:
try:
# Handle the case where a method is accessed as a class attribute
return objtype.__getattribute__(objtype, self._name) # type: ignore[call-arg, arg-type]
except AttributeError:
# This will happen the first time the attribute is accessed
pass

bound_function = self.raw_function.__get__(obj, objtype)
result = self.__class__(bound_function, self._config, self._validate_return, self._response)

# skip binding to instance when obj or objtype has __slots__ attribute
if hasattr(obj, '__slots__') or hasattr(objtype, '__slots__'):
return result

if self._name is not None:
if obj is not None:
object.__setattr__(obj, self._name, result)
else:
object.__setattr__(objtype, self._name, result)
return result


APICLIENT_METHODS: Set[str] = {i[0] for i in inspect.getmembers(APIClient, predicate=inspect.isfunction)}
APICLIENT_METHODS: set[str] = {i[0] for i in inspect.getmembers(APIClient, predicate=inspect.isfunction)}


def serialize(
__func: Optional[AnyCallableT] = None,
__func: AnyCallableT | None = None,
/,
*,
config: Optional[ConfigDict] = None,
config: ConfigDict | None = None,
validate_return: bool = True,
response: Optional[Type[BaseModel]] = None,
) -> Union[Callable[[AnyCallableT], ValidateCallWrapper], ValidateCallWrapper]:
def validate(function: AnyCallableT) -> ValidateCallWrapper:
if isinstance(function, (classmethod, staticmethod)):
name = type(function).__name__
msg = f'The `@{name}` decorator should be applied after `@serialize` (put `@{name}` on top)'
raise TypeError(msg)
return ValidateCallWrapper(function, config, validate_return, response)
response: type[BaseModel] | None = None,
) -> AnyCallableT | Callable[[AnyCallableT], AnyCallableT]:
parent_namespace = _typing_extra.parent_frame_namespace()

def validate(function: AnyCallableT) -> AnyCallableT:
_check_function_type(function)
validate_call_wrapper = ValidateCallWrapper(
cast(_generate_schema.ValidateCallSupportedTypes, function),
config,
validate_return,
parent_namespace,
response,
)
return _validate_call.update_wrapper_attributes(function, validate_call_wrapper.__call__) # type:ignore[arg-type]

if __func:
return validate(__func)
return validate


def serialize_all_methods(
__cls: Optional[Type[T]] = None, config: Optional[ConfigDict] = None
) -> Union[AnyCallableT, Callable[[AnyCallableT], AnyCallableT], Callable[[Type[T]], Type[T]]]:
def decorate(cls: Type[T]) -> Type[T]:
__cls: type[T] | None = None, /, *, config: ConfigDict | None = None
) -> AnyCallableT | Callable[[AnyCallableT], AnyCallableT] | Callable[[type[T]], type[T]]:
def decorate(cls: type[T]) -> type[T]:
for attr, value in vars(cls).items():
if not attr.startswith('_') and inspect.isfunction(value) and attr not in APICLIENT_METHODS:
setattr(cls, attr, serialize(value, config=config))
Expand Down
Loading
Loading