Skip to content

Commit e8058de

Browse files
committed
Fix: Route params (query or post) with future annotations gives "not fully defined" error #246
1 parent f786a5d commit e8058de

File tree

5 files changed

+361
-233
lines changed

5 files changed

+361
-233
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ Please follow [the Keep a Changelog standard](https://keepachangelog.com/en/1.0.
55

66
## [Unreleased]
77

8+
### Added
9+
10+
* Added support for more field attributes in `schema.had()` and `schema.didnt_have()`: `field_title_generator`, `fail_fast`, `coerce_numbers_to_str`, `union_mode`, `allow_mutation`, `pattern`, `discriminator`
11+
* Added support for forwardrefs in body fields (for example, when you use `from __future__ import annotations` in the file with your routes)
12+
813
## [4.5.0]
914

1015
### Added

cadwyn/schema_generation.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
RootValidatorDecoratorInfo,
4141
ValidatorDecoratorInfo,
4242
)
43+
from pydantic._internal._typing_extra import try_eval_type as pydantic_try_eval_type
4344
from pydantic.fields import ComputedFieldInfo, FieldInfo
4445
from typing_extensions import Doc, Self, _AnnotatedAlias, assert_never
4546

@@ -418,6 +419,7 @@ def __init__(self, generator: "SchemaGenerator") -> None:
418419
# because such copies could produce weird behaviors at runtime, especially if you/fastapi do any comparisons.
419420
# It's defined here and not on the method because of this: https://youtu.be/sVjtp6tGo0g
420421
self.generator = generator
422+
# TODO: Rewrite this to memoize
421423
self.change_versions_of_a_non_container_annotation = functools.cache(
422424
self._change_version_of_a_non_container_annotation
423425
)
@@ -492,6 +494,7 @@ def _change_version_of_a_non_container_annotation(self, annotation: Any) -> Any:
492494
) or isinstance(annotation, fastapi.security.base.SecurityBase):
493495
return annotation
494496

497+
# If we do not use modifier, we will get an unhashable module error
495498
def modifier(annotation: Any):
496499
return self.change_version_of_annotation(annotation)
497500

@@ -531,6 +534,9 @@ def _modify_callable_annotations( # pragma: no branch # because of lambdas
531534
annotation_modifying_wrapper = annotation_modifying_wrapper_factory(call)
532535
old_params = inspect.signature(call).parameters
533536
callable_annotations = annotation_modifying_wrapper.__annotations__
537+
callable_annotations = {
538+
k: v if type(v) is not str else _try_eval_type(v, call.__globals__) for k, v in callable_annotations.items()
539+
}
534540
annotation_modifying_wrapper.__annotations__ = modify_annotations(callable_annotations)
535541
annotation_modifying_wrapper.__defaults__ = modify_defaults(
536542
tuple(p.default for p in old_params.values() if p.default is not inspect.Signature.empty),
@@ -977,3 +983,11 @@ def _get_initialization_namespace_for_enum(enum_cls: type[Enum]):
977983
and k not in _DummyEnum.__dict__
978984
and (k not in mro_dict or mro_dict[k] is not v)
979985
}
986+
987+
988+
def _try_eval_type(value: Any, globals: dict[str, Any]) -> Any:
989+
new_value, success = pydantic_try_eval_type(value, globals)
990+
if success:
991+
return new_value
992+
else:
993+
return value

cadwyn/structure/schemas.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from typing import TYPE_CHECKING, Any, Literal, cast
44

55
from issubclass import issubclass as lenient_issubclass
6-
from pydantic import BaseModel, Field
6+
from pydantic import AliasChoices, AliasPath, BaseModel, Field
77
from pydantic._internal._decorators import PydanticDescriptorProxy, unwrap_wrapped_function
8+
from pydantic.config import JsonDict
89
from pydantic.fields import FieldInfo
910

1011
from cadwyn._utils import Sentinel, fully_unwrap_decorator
@@ -19,57 +20,80 @@
1920
PossibleFieldAttributes = Literal[
2021
"default",
2122
"default_factory",
22-
"alias",
23+
"alias_priority",
24+
"validation_alias",
25+
"serialization_alias",
2326
"title",
27+
"field_title_generator",
2428
"description",
29+
"examples",
2530
"exclude",
2631
"const",
32+
"deprecated",
33+
"frozen",
34+
"validate_default",
35+
"repr",
36+
"init",
37+
"init_var",
38+
"kw_only",
39+
"fail_fast",
2740
"gt",
2841
"ge",
2942
"lt",
3043
"le",
31-
"deprecated",
32-
"fail_fast",
3344
"strict",
45+
"coerce_numbers_to_str",
3446
"multiple_of",
3547
"allow_inf_nan",
3648
"max_digits",
3749
"decimal_places",
3850
"min_length",
3951
"max_length",
52+
"union_mode",
4053
"allow_mutation",
4154
"pattern",
4255
"discriminator",
43-
"repr",
4456
]
4557

4658

59+
# TODO: Add json_schema_extra as a breaking change in a major version
4760
@dataclass(slots=True)
4861
class FieldChanges:
4962
default: Any
5063
default_factory: Any
51-
alias: str
52-
title: str
64+
alias_priority: int | None
65+
validation_alias: str | AliasPath | AliasChoices | None
66+
serialization_alias: str | None
67+
title: str | None
68+
field_title_generator: Callable[[str, FieldInfo], str] | None
5369
description: str
70+
examples: list[Any] | None
5471
exclude: "AbstractSetIntStr | MappingIntStrAny | Any"
5572
const: bool
5673
deprecated: bool
74+
frozen: bool | None
75+
validate_default: bool | None
76+
repr: bool
77+
init: bool | None
78+
init_var: bool | None
79+
kw_only: bool | None
5780
fail_fast: bool
5881
gt: float
5982
ge: float
6083
lt: float
6184
le: float
6285
strict: bool
86+
coerce_numbers_to_str: bool | None
6387
multiple_of: float
6488
allow_inf_nan: bool
6589
max_digits: int
6690
decimal_places: int
6791
min_length: int
6892
max_length: int
93+
union_mode: Literal["smart", "left_to_right"]
6994
allow_mutation: bool
7095
pattern: str
7196
discriminator: str
72-
repr: bool
7397

7498

7599
@dataclass(slots=True)
@@ -114,28 +138,39 @@ def had(
114138
type: Any = Sentinel,
115139
default: Any = Sentinel,
116140
default_factory: Callable = Sentinel,
117-
alias: str = Sentinel,
141+
alias_priority: int = Sentinel,
142+
validation_alias: str = Sentinel,
143+
serialization_alias: str = Sentinel,
118144
title: str = Sentinel,
145+
field_title_generator: Callable[[str, FieldInfo], str] = Sentinel,
119146
description: str = Sentinel,
147+
examples: list[Any] = Sentinel,
120148
exclude: "AbstractSetIntStr | MappingIntStrAny | Any" = Sentinel,
121149
const: bool = Sentinel,
150+
deprecated: bool = Sentinel,
151+
frozen: bool = Sentinel,
152+
validate_default: bool = Sentinel,
153+
repr: bool = Sentinel,
154+
init: bool = Sentinel,
155+
init_var: bool = Sentinel,
156+
kw_only: bool = Sentinel,
157+
fail_fast: bool = Sentinel,
122158
gt: float = Sentinel,
123159
ge: float = Sentinel,
124160
lt: float = Sentinel,
125161
le: float = Sentinel,
126162
strict: bool = Sentinel,
127-
deprecated: bool = Sentinel,
163+
coerce_numbers_to_str: bool = Sentinel,
128164
multiple_of: float = Sentinel,
129165
allow_inf_nan: bool = Sentinel,
130166
max_digits: int = Sentinel,
131167
decimal_places: int = Sentinel,
132168
min_length: int = Sentinel,
133169
max_length: int = Sentinel,
170+
union_mode: Literal["smart", "left_to_right"] = Sentinel,
134171
allow_mutation: bool = Sentinel,
135172
pattern: str = Sentinel,
136173
discriminator: str = Sentinel,
137-
repr: bool = Sentinel,
138-
fail_fast: bool = Sentinel,
139174
) -> FieldHadInstruction:
140175
return FieldHadInstruction(
141176
schema=self.schema,
@@ -145,28 +180,39 @@ def had(
145180
field_changes=FieldChanges(
146181
default=default,
147182
default_factory=default_factory,
148-
alias=alias,
183+
alias_priority=alias_priority,
184+
validation_alias=validation_alias,
185+
serialization_alias=serialization_alias,
149186
title=title,
187+
field_title_generator=field_title_generator,
150188
description=description,
189+
examples=examples,
151190
exclude=exclude,
152191
const=const,
192+
deprecated=deprecated,
193+
frozen=frozen,
194+
validate_default=validate_default,
195+
repr=repr,
196+
init=init,
197+
init_var=init_var,
198+
kw_only=kw_only,
199+
fail_fast=fail_fast,
153200
gt=gt,
154201
ge=ge,
155202
lt=lt,
156203
le=le,
157-
deprecated=deprecated,
158204
strict=strict,
205+
coerce_numbers_to_str=coerce_numbers_to_str,
159206
multiple_of=multiple_of,
160207
allow_inf_nan=allow_inf_nan,
161208
max_digits=max_digits,
162209
decimal_places=decimal_places,
163210
min_length=min_length,
164211
max_length=max_length,
212+
union_mode=union_mode,
165213
allow_mutation=allow_mutation,
166214
pattern=pattern,
167215
discriminator=discriminator,
168-
repr=repr,
169-
fail_fast=fail_fast,
170216
),
171217
)
172218

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from __future__ import annotations
2+
3+
from fastapi.testclient import TestClient
4+
from pydantic import BaseModel, Field
5+
6+
from cadwyn.applications import Cadwyn
7+
from cadwyn.route_generation import VersionedAPIRouter
8+
from cadwyn.structure.schemas import schema
9+
from cadwyn.structure.versions import Version, VersionBundle, VersionChange
10+
11+
12+
class OuterSchema(BaseModel):
13+
bar: MySchema
14+
15+
16+
class MySchema(BaseModel):
17+
foo: str = Field(coerce_numbers_to_str=True)
18+
19+
20+
class MyVersionChange(VersionChange):
21+
description = "Hello"
22+
instructions_to_migrate_to_previous_version = (
23+
schema(MySchema).field("foo").had(type=int),
24+
schema(MySchema).field("foo").didnt_have("coerce_numbers_to_str"),
25+
)
26+
27+
28+
app = Cadwyn(versions=VersionBundle(Version("2001-01-01", MyVersionChange), Version("2000-01-01")))
29+
router = VersionedAPIRouter()
30+
31+
32+
@router.post("/test")
33+
async def test_with_inner_schema_forwardref(dep: MySchema) -> MySchema:
34+
return dep
35+
36+
37+
@router.post("/test2")
38+
async def test_with_outer_schema_forwardref(dep: OuterSchema) -> OuterSchema:
39+
return dep
40+
41+
42+
app.generate_and_include_versioned_routers(router)
43+
44+
45+
def test__router_generation__using_forwardref_inner_global_schema_in_body():
46+
unversioned_client = TestClient(app)
47+
client_2000 = TestClient(app, headers={app.router.api_version_header_name: "2000-01-01"})
48+
client_2001 = TestClient(app, headers={app.router.api_version_header_name: "2001-01-01"})
49+
assert client_2000.post("/test", json={"foo": 1}).json() == {"foo": 1}
50+
assert client_2001.post("/test", json={"foo": 1}).json() == {"foo": "1"}
51+
assert unversioned_client.get("/openapi.json?version=2000-01-01").status_code == 200
52+
assert unversioned_client.get("/openapi.json?version=2001-01-01").status_code == 200
53+
54+
55+
def test__router_generation__using_forwardref_outer_global_schema_in_body():
56+
unversioned_client = TestClient(app)
57+
client_2000 = TestClient(app, headers={app.router.api_version_header_name: "2000-01-01"})
58+
client_2001 = TestClient(app, headers={app.router.api_version_header_name: "2001-01-01"})
59+
assert client_2000.post("/test", json={"bar": {"foo": 1}}).json() == {"bar": {"foo": 1}}
60+
assert client_2001.post("/test", json={"bar": {"foo": 1}}).json() == {"bar": {"foo": "1"}}
61+
assert unversioned_client.get("/openapi.json?version=2000-01-01").status_code == 200
62+
assert unversioned_client.get("/openapi.json?version=2001-01-01").status_code == 200

0 commit comments

Comments
 (0)