Skip to content

Commit e68dd0b

Browse files
authored
fix(dto): apply __schema_name__ only to root transfer model (#4606)
1 parent 30ae1fa commit e68dd0b

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

litestar/dto/_backend.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,16 @@ def create_transfer_model_type(
208208
Returns:
209209
A ``BackendT`` class.
210210
"""
211-
struct_name = self.dto_factory.__schema_name__ or self._create_transfer_model_name(model_name)
211+
# Only apply the custom __schema_name__ to the root transfer model.
212+
# Nested models get their own generated name to avoid child $refs
213+
# pointing to the parent's schema.
214+
# model_name only equals model_type.__name__ for the root model.
215+
# Nested models have a different name, so __schema_name__ won't
216+
# override their generated schema name.
217+
if model_name == self.model_type.__name__ and self.dto_factory.__schema_name__:
218+
struct_name = self.dto_factory.__schema_name__
219+
else:
220+
struct_name = self._create_transfer_model_name(model_name)
212221

213222
struct = _create_struct_for_field_definitions(
214223
model_name=struct_name,

tests/unit/test_dto/test_factory/test_backends/test_base_dto.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def generate_field_definitions(
9292

9393
class _Backend(backend_cls): # type: ignore[valid-type,misc]
9494
def create_transfer_model_type(
95-
self, model_name: str, field_definitions: tuple[TransferDTOFieldDefinition, ...]
95+
self,
96+
model_name: str,
97+
field_definitions: tuple[TransferDTOFieldDefinition, ...],
9698
) -> type[Any]:
9799
"""Create a model for data transfer.
98100

tests/unit/test_openapi/test_integration.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111

1212
from litestar import Controller, Litestar, delete, get, patch, post
1313
from litestar._openapi.plugin import OpenAPIPlugin
14+
from litestar.dto import DataclassDTO, DTOConfig
1415
from litestar.enums import MediaType, OpenAPIMediaType, ParamType
1516
from litestar.openapi import OpenAPIConfig
1617
from litestar.openapi.plugins import YamlRenderPlugin
1718
from litestar.openapi.spec import Parameter as OpenAPIParameter
19+
from litestar.openapi.spec import Schema
1820
from litestar.params import Parameter
1921
from litestar.serialization.msgspec_hooks import decode_json, encode_json, get_serializer
2022
from litestar.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND
@@ -360,6 +362,52 @@ def handler_b() -> module_b.Model: # type: ignore[name-defined]
360362
# TODO: expand this test to cover more cases
361363

362364

365+
@dataclass
366+
class _SchemaNameChild:
367+
name: str
368+
369+
370+
@dataclass
371+
class _SchemaNamePerson:
372+
name: str
373+
age: int
374+
child: _SchemaNameChild
375+
376+
377+
def test_dto_schema_name_with_nested_model() -> None:
378+
"""__schema_name__ should only apply to the root DTO, not nested child models."""
379+
380+
class PersonDTO(DataclassDTO[_SchemaNamePerson]):
381+
__schema_name__ = "PersonPublic"
382+
config = DTOConfig(exclude={"age"})
383+
384+
@get("/person", return_dto=PersonDTO, sync_to_thread=False)
385+
def get_person() -> _SchemaNamePerson:
386+
return _SchemaNamePerson(name="test", age=30, child=_SchemaNameChild(name="kid"))
387+
388+
app = Litestar(route_handlers=[get_person])
389+
openapi_plugin = app.plugins.get(OpenAPIPlugin)
390+
schemas = openapi_plugin.provide_openapi().components.schemas
391+
392+
# parent should use the custom name
393+
assert schemas is not None
394+
assert "PersonPublic" in schemas
395+
396+
# child should NOT use the parent's custom name
397+
# it should have its own schema
398+
person_schema = schemas["PersonPublic"]
399+
assert isinstance(person_schema, Schema)
400+
assert person_schema.properties is not None
401+
child_ref = person_schema.properties["child"].ref # type: ignore[union-attr]
402+
assert child_ref != "#/components/schemas/PersonPublic", (
403+
"child $ref should not point to the parent's __schema_name__"
404+
)
405+
406+
# a separate child schema should exist
407+
child_schema_name = child_ref.split("/")[-1]
408+
assert child_schema_name in schemas
409+
410+
363411
def test_multiple_handlers_for_same_route() -> None:
364412
@post("/", sync_to_thread=False)
365413
def post_handler() -> None: ...

0 commit comments

Comments
 (0)