Skip to content

Commit 75ce93c

Browse files
committed
Add support for forwardrefs in fields
1 parent e8058de commit 75ce93c

File tree

4 files changed

+35
-13
lines changed

4 files changed

+35
-13
lines changed

cadwyn/applications.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ async def openapi_jsons(self, req: Request) -> JSONResponse:
279279
except (ValueError, TypeError):
280280
version = raw_version
281281

282-
if version in self.router.versioned_routers:
282+
if isinstance(version, date) and version in self.router.versioned_routers:
283283
routes = self.router.versioned_routers[version].routes
284284
formatted_version = version.isoformat()
285285
elif version == "unversioned" and self._there_are_public_unversioned_routes():
@@ -296,7 +296,7 @@ async def openapi_jsons(self, req: Request) -> JSONResponse:
296296
self.servers.insert(0, {"url": root_path})
297297

298298
webhook_routes = None
299-
if version in self._versioned_webhook_routers:
299+
if isinstance(version, date) and version in self._versioned_webhook_routers:
300300
webhook_routes = self._versioned_webhook_routers[version].routes
301301

302302
return JSONResponse(

cadwyn/schema_generation.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,12 @@ def _is_dunder(attr_name: str):
232232

233233

234234
def _wrap_pydantic_model(model: type[_T_PYDANTIC_MODEL]) -> "_PydanticModelWrapper[_T_PYDANTIC_MODEL]":
235+
# In case we have a forwardref within one of the fields
236+
# For example, when "from __future__ import annotations" is used in the file with the schema
237+
if model is not BaseModel:
238+
model.model_rebuild(raise_errors=False)
239+
model = cast(type[_T_PYDANTIC_MODEL], model)
240+
235241
decorators = _get_model_decorators(model)
236242
validators = {}
237243
for decorator_wrapper in decorators:
@@ -240,8 +246,20 @@ def _wrap_pydantic_model(model: type[_T_PYDANTIC_MODEL]) -> "_PydanticModelWrapp
240246

241247
wrapped_validator = _wrap_validator(decorator_wrapper.func, decorator_wrapper.shim, decorator_wrapper.info)
242248
validators[decorator_wrapper.cls_var_name] = wrapped_validator
249+
250+
annotations = {
251+
name: value
252+
if not isinstance(value, str)
253+
else model.model_fields[name].annotation or model.__annotations__[name]
254+
for name, value in model.__annotations__.items()
255+
}
256+
243257
fields = {
244-
field_name: PydanticFieldWrapper(model.model_fields[field_name], model.__annotations__[field_name], field_name)
258+
field_name: PydanticFieldWrapper(
259+
model.model_fields[field_name],
260+
annotations[field_name],
261+
field_name,
262+
)
245263
for field_name in model.__annotations__
246264
}
247265

@@ -264,7 +282,7 @@ def _wrap_pydantic_model(model: type[_T_PYDANTIC_MODEL]) -> "_PydanticModelWrapp
264282
fields=fields,
265283
other_attributes=other_attributes,
266284
validators=validators,
267-
annotations=model.__annotations__.copy(),
285+
annotations=annotations,
268286
)
269287

270288

@@ -356,6 +374,7 @@ def generate_model_copy(self, generator: "SchemaGenerator") -> type[_T_PYDANTIC_
356374
if not validator.is_deleted and type(validator) == _ValidatorWrapper # noqa: E721
357375
}
358376
fields = {name: field.generate_field_copy(generator) for name, field in self.fields.items()}
377+
359378
model_copy = type(self.cls)(
360379
self.name,
361380
tuple(generator[cast(type[BaseModel], base)] for base in self.cls.__bases__),
@@ -494,7 +513,6 @@ def _change_version_of_a_non_container_annotation(self, annotation: Any) -> Any:
494513
) or isinstance(annotation, fastapi.security.base.SecurityBase):
495514
return annotation
496515

497-
# If we do not use modifier, we will get an unhashable module error
498516
def modifier(annotation: Any):
499517
return self.change_version_of_annotation(annotation)
500518

@@ -637,8 +655,6 @@ def __getitem__(self, model: type[_T_ANY_MODEL], /) -> type[_T_ANY_MODEL]:
637655

638656
if model in self.concrete_models:
639657
return self.concrete_models[model]
640-
else:
641-
wrapper = self._get_wrapper_for_model(model)
642658

643659
wrapper = self._get_wrapper_for_model(model)
644660
model_copy = wrapper.generate_model_copy(self)
@@ -661,6 +677,8 @@ def _get_wrapper_for_model(
661677
return self.model_bundle.enums[model]
662678

663679
if lenient_issubclass(model, BaseModel):
680+
# TODO: My god, what if one of its fields is in our concrete schemas and we don't use it? :O
681+
# TODO: Add an argument with our concrete schemas for _wrap_pydantic_model
664682
wrapper = _wrap_pydantic_model(model)
665683
self.model_bundle.schemas[model] = wrapper
666684
elif lenient_issubclass(model, Enum):
@@ -989,5 +1007,5 @@ def _try_eval_type(value: Any, globals: dict[str, Any]) -> Any:
9891007
new_value, success = pydantic_try_eval_type(value, globals)
9901008
if success:
9911009
return new_value
992-
else:
1010+
else: # pragma: no cover # Can't imagine when this would happen
9931011
return value

cadwyn/structure/schemas.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919

2020
PossibleFieldAttributes = Literal[
2121
"default",
22-
"default_factory",
22+
"alias",
2323
"alias_priority",
24+
"default_factory",
2425
"validation_alias",
2526
"serialization_alias",
2627
"title",
@@ -60,6 +61,7 @@
6061
@dataclass(slots=True)
6162
class FieldChanges:
6263
default: Any
64+
alias: str | None
6365
default_factory: Any
6466
alias_priority: int | None
6567
validation_alias: str | AliasPath | AliasChoices | None
@@ -137,6 +139,7 @@ def had(
137139
name: str = Sentinel,
138140
type: Any = Sentinel,
139141
default: Any = Sentinel,
142+
alias: str | None = Sentinel,
140143
default_factory: Callable = Sentinel,
141144
alias_priority: int = Sentinel,
142145
validation_alias: str = Sentinel,
@@ -181,6 +184,7 @@ def had(
181184
default=default,
182185
default_factory=default_factory,
183186
alias_priority=alias_priority,
187+
alias=alias,
184188
validation_alias=validation_alias,
185189
serialization_alias=serialization_alias,
186190
title=title,

tests/test_router_generation_with_from_future_annotations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ class MyVersionChange(VersionChange):
3030

3131

3232
@router.post("/test")
33-
async def test_with_inner_schema_forwardref(dep: MySchema) -> MySchema:
33+
async def route_with_inner_schema_forwardref(dep: MySchema) -> MySchema:
3434
return dep
3535

3636

3737
@router.post("/test2")
38-
async def test_with_outer_schema_forwardref(dep: OuterSchema) -> OuterSchema:
38+
async def route_with_outer_schema_forwardref(dep: OuterSchema) -> OuterSchema:
3939
return dep
4040

4141

@@ -56,7 +56,7 @@ def test__router_generation__using_forwardref_outer_global_schema_in_body():
5656
unversioned_client = TestClient(app)
5757
client_2000 = TestClient(app, headers={app.router.api_version_header_name: "2000-01-01"})
5858
client_2001 = TestClient(app, headers={app.router.api_version_header_name: "2001-01-01"})
59-
assert client_2000.post("/test", json={"bar": {"foo": 1}}).json() == {"bar": {"foo": 1}}
60-
assert client_2001.post("/test", json={"bar": {"foo": 1}}).json() == {"bar": {"foo": "1"}}
59+
assert client_2000.post("/test2", json={"bar": {"foo": 1}}).json() == {"bar": {"foo": 1}}
60+
assert client_2001.post("/test2", json={"bar": {"foo": 1}}).json() == {"bar": {"foo": "1"}}
6161
assert unversioned_client.get("/openapi.json?version=2000-01-01").status_code == 200
6262
assert unversioned_client.get("/openapi.json?version=2001-01-01").status_code == 200

0 commit comments

Comments
 (0)