|
1 | 1 | import inspect |
2 | 2 | from functools import wraps |
3 | | -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, get_type_hints |
| 3 | +from typing import Any, Callable, Dict, ForwardRef, Optional, Type, get_type_hints |
| 4 | + |
| 5 | +from pydantic import BaseModel, create_model, parse_obj_as |
| 6 | +from pydantic.config import BaseConfig as PydanticBaseConfig, Extra |
| 7 | +from pydantic.typing import evaluate_forwardref |
| 8 | + |
| 9 | +from .utils import is_pydantic_model |
| 10 | + |
| 11 | +DictStrAny = Dict[str, Any] |
| 12 | + |
| 13 | + |
| 14 | +class BaseConfig(PydanticBaseConfig): |
| 15 | + orm_mode = True |
| 16 | + |
| 17 | + |
| 18 | +class ParamsObject: |
| 19 | + def __init__(self, **kwargs: DictStrAny) -> None: |
| 20 | + for attr, param in kwargs.items(): |
| 21 | + setattr(self, attr, param) |
| 22 | + |
| 23 | + |
| 24 | +def get_typed_signature(call: Callable) -> inspect.Signature: |
| 25 | + """Finds call signature and resolves all forwardrefs""" |
| 26 | + signature = inspect.signature(call) |
| 27 | + globalns = getattr(call, '__globals__', {}) |
| 28 | + typed_params = [ |
| 29 | + inspect.Parameter( |
| 30 | + name=param.name, |
| 31 | + kind=param.kind, |
| 32 | + default=param.default, |
| 33 | + annotation=get_typed_annotation(param, globalns), |
| 34 | + ) |
| 35 | + for param in signature.parameters.values() |
| 36 | + ] |
| 37 | + typed_signature = inspect.Signature(typed_params) |
| 38 | + return typed_signature |
| 39 | + |
| 40 | + |
| 41 | +def get_typed_annotation(param: inspect.Parameter, globalns: DictStrAny) -> Any: |
| 42 | + annotation = param.annotation |
| 43 | + if isinstance(annotation, str): |
| 44 | + annotation = make_forwardref(annotation, globalns) |
| 45 | + return annotation |
| 46 | + |
| 47 | + |
| 48 | +def make_forwardref(annotation: str, globalns: DictStrAny) -> Any: |
| 49 | + forward_ref = ForwardRef(annotation) |
| 50 | + return evaluate_forwardref(forward_ref, globalns, globalns) |
| 51 | + |
| 52 | + |
| 53 | +class ParamsSerializer: |
| 54 | + __slot__ = ( |
| 55 | + 'signature', |
| 56 | + 'by_alias', |
| 57 | + 'exclude_unset', |
| 58 | + 'exclude_defaults', |
| 59 | + 'exclude_none', |
| 60 | + 'has_kwargs', |
| 61 | + 'model_param', |
| 62 | + ) |
| 63 | + |
| 64 | + def __init__( |
| 65 | + self, |
| 66 | + by_alias: bool = True, |
| 67 | + exclude_unset: bool = False, |
| 68 | + exclude_defaults: bool = False, |
| 69 | + exclude_none: bool = True, |
| 70 | + ): |
| 71 | + self.by_alias = by_alias |
| 72 | + self.exclude_unset = exclude_unset |
| 73 | + self.exclude_defaults = exclude_defaults |
| 74 | + self.exclude_none = exclude_none |
| 75 | + self.has_kwargs = False |
| 76 | + |
| 77 | + def __call__(self, func: Callable) -> Callable: |
| 78 | + attrs, self.signature = {}, get_typed_signature(func) |
| 79 | + new_signature_parameters: DictStrAny = {} |
| 80 | + |
| 81 | + for name, arg in self.signature.parameters.items(): |
| 82 | + if name == 'self': |
| 83 | + new_signature_parameters.setdefault(arg.name, arg) |
| 84 | + continue |
| 85 | + |
| 86 | + if arg.kind == arg.VAR_KEYWORD: |
| 87 | + # Skipping **kwargs |
| 88 | + self.has_kwargs = True |
| 89 | + continue |
| 90 | + |
| 91 | + if arg.kind == arg.VAR_POSITIONAL: |
| 92 | + # Skipping *args |
| 93 | + continue |
| 94 | + |
| 95 | + arg_type = self._get_param_type(arg) |
| 96 | + |
| 97 | + if name not in new_signature_parameters: |
| 98 | + if is_pydantic_model(arg_type): |
| 99 | + new_signature_parameters.update( |
| 100 | + {argument.name: argument for argument in inspect.signature(arg_type).parameters.values()} |
| 101 | + ) |
| 102 | + else: |
| 103 | + new_signature_parameters.setdefault(arg.name, arg) |
| 104 | + |
| 105 | + attrs[name] = (arg_type, ...) |
| 106 | + |
| 107 | + if attrs: |
| 108 | + self.model_param = create_model(f'{func.__name__}Params', __config__=BaseConfig, **attrs) # type: ignore |
4 | 109 |
|
5 | | -from pydantic import BaseModel, parse_obj_as |
| 110 | + @wraps(func) |
| 111 | + def wrap(*args, **kwargs): |
| 112 | + object_params = {} |
| 113 | + for name, fld in self.model_param.__fields__.items(): |
| 114 | + kw = kwargs if name not in kwargs else kwargs[name] |
6 | 115 |
|
| 116 | + object_params[name] = kw |
| 117 | + if is_pydantic_model(fld.type_) and fld.type_.Config.extra == Extra.forbid and isinstance(kw, dict): |
| 118 | + object_params[name] = {k: v for k, v in kw.items() if k in fld.type_.__fields__} |
7 | 119 |
|
8 | | -def serialize_request(schema: Optional[Type[BaseModel]] = None, extra_kwargs: Dict[str, Any] = None) -> Callable: |
9 | | - extra_kw = extra_kwargs or {'by_alias': True, 'exclude_none': True} |
| 120 | + params_object = ParamsObject(**object_params) |
10 | 121 |
|
11 | | - def decorator(func: Callable) -> Callable: |
12 | | - nonlocal schema |
13 | | - map_schemas = {} |
14 | | - parameters = [] |
15 | | - |
16 | | - if schema: |
17 | | - parameters.extend(list(inspect.signature(schema).parameters.values())) |
18 | | - else: |
19 | | - for arg_name, arg_type in get_type_hints(func).items(): |
20 | | - if arg_name == 'return': |
21 | | - continue |
22 | | - map_schemas[arg_name] = arg_type |
23 | | - if inspect.isclass(arg_type) and issubclass(arg_type, BaseModel): |
24 | | - parameters.extend(list(inspect.signature(arg_type).parameters.values())) |
| 122 | + result = self.model_param.from_orm(params_object).dict( |
| 123 | + by_alias=self.by_alias, |
| 124 | + exclude_unset=self.exclude_unset, |
| 125 | + exclude_defaults=self.exclude_defaults, |
| 126 | + exclude_none=self.exclude_none, |
| 127 | + ) |
25 | 128 |
|
26 | | - @wraps(func) |
27 | | - def wrap(*args, **kwargs): |
28 | | - if schema: |
29 | | - instance = data = parse_obj_as(schema, kwargs) |
30 | | - data = instance.dict(**extra_kw) |
31 | | - return func(*args, data) |
32 | | - elif map_schemas: |
33 | | - data, origin_kwargs = {}, {} |
34 | | - for arg_name, arg_type in map_schemas.items(): |
35 | | - if inspect.isclass(arg_type) and issubclass(arg_type, BaseModel): |
36 | | - data[arg_name] = parse_obj_as(arg_type, kwargs).dict(**extra_kw) |
37 | | - else: |
38 | | - val = kwargs.get(arg_name) |
39 | | - if val is not None: |
40 | | - origin_kwargs[arg_name] = val |
41 | | - new_kwargs = {**origin_kwargs, **data} or kwargs |
42 | | - return func(*args, **new_kwargs) |
43 | | - return func(*args, **kwargs) |
| 129 | + return func(*args, **result) |
44 | 130 |
|
45 | 131 | # Override signature |
46 | | - if parameters: |
| 132 | + if new_signature_parameters and attrs: |
47 | 133 | sig = inspect.signature(func) |
48 | | - _self_param = sig.parameters.get('self') |
49 | | - self_param = [_self_param] if _self_param else [] |
50 | | - sig = sig.replace(parameters=tuple(self_param + parameters)) |
| 134 | + sig = sig.replace(parameters=tuple(sorted(new_signature_parameters.values(), key=lambda x: x.kind))) |
51 | 135 | wrap.__signature__ = sig # type: ignore |
52 | | - return wrap |
53 | 136 |
|
54 | | - return decorator |
| 137 | + return wrap if attrs else func |
55 | 138 |
|
| 139 | + def _get_param_type(self, arg: inspect.Parameter) -> Any: |
| 140 | + annotation = arg.annotation |
| 141 | + |
| 142 | + if annotation == self.signature.empty: |
| 143 | + if arg.default == self.signature.empty: |
| 144 | + annotation = str |
| 145 | + else: |
| 146 | + annotation = type(arg.default) |
| 147 | + |
| 148 | + if annotation is type(None) or annotation is type(Ellipsis): # noqa: E721 |
| 149 | + annotation = str |
| 150 | + |
| 151 | + return annotation |
56 | 152 |
|
57 | | -def serialize_response(schema: Optional[Type[BaseModel]] = None) -> Callable: |
58 | | - def decorator(func: Callable) -> Callable: |
59 | | - nonlocal schema |
60 | | - if not schema: # pragma: no cover |
61 | | - schema = get_type_hints(func).get('return') |
| 153 | + |
| 154 | +serialize_request = params_serializer = ParamsSerializer |
| 155 | + |
| 156 | + |
| 157 | +class ResponseSerializer: |
| 158 | + __slot__ = ('response',) |
| 159 | + |
| 160 | + def __init__(self, response: Optional[Type[BaseModel]] = None): |
| 161 | + self.response = response |
| 162 | + |
| 163 | + def __call__(self, func: Callable) -> Callable: |
| 164 | + self.response = self.response or get_type_hints(func).get('return') |
62 | 165 |
|
63 | 166 | @wraps(func) |
64 | | - def wrap(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Union[BaseModel, Any]: |
65 | | - response = func(*args, **kwargs) |
66 | | - if isinstance(response, (list, dict, tuple, set)) and schema: |
67 | | - return parse_obj_as(schema, response) |
68 | | - return response |
| 167 | + def wrap(*args, **kwargs): |
| 168 | + result = func(*args, **kwargs) |
| 169 | + if result is not None: |
| 170 | + return parse_obj_as(self.response, result) |
| 171 | + return result |
69 | 172 |
|
70 | | - return wrap |
| 173 | + return wrap if self.response else func |
71 | 174 |
|
72 | | - return decorator |
| 175 | + |
| 176 | +serialize_response = response_serializer = ResponseSerializer |
73 | 177 |
|
74 | 178 |
|
75 | 179 | def serialize( |
76 | | - schema_request: Optional[Type[BaseModel]] = None, |
77 | | - schema_response: Optional[Type[BaseModel]] = None, |
78 | | - **base_kwargs: Dict[str, Any], |
| 180 | + response: Optional[Type[BaseModel]] = None, |
| 181 | + by_alias: bool = True, |
| 182 | + exclude_unset: bool = False, |
| 183 | + exclude_defaults: bool = False, |
| 184 | + exclude_none: bool = True, |
79 | 185 | ) -> Callable: |
80 | 186 | def decorator(func: Callable) -> Callable: |
81 | | - response = func |
82 | | - response = serialize_request(schema_request, extra_kwargs=base_kwargs)(func) |
83 | | - response = serialize_response(schema_response)(response) |
84 | | - |
85 | | - return response |
| 187 | + result_func = func |
| 188 | + result_func = ParamsSerializer( |
| 189 | + by_alias=by_alias, |
| 190 | + exclude_unset=exclude_unset, |
| 191 | + exclude_defaults=exclude_defaults, |
| 192 | + exclude_none=exclude_none, |
| 193 | + )(func) |
| 194 | + result_func = serialize_response(response=response)(result_func) |
| 195 | + |
| 196 | + return result_func |
86 | 197 |
|
87 | 198 | return decorator |
88 | 199 |
|
|
0 commit comments