Skip to content

Commit 9f9dbff

Browse files
committed
feat: add custom DTO factory support for mixed model types
This allows for nesting mixed models, e.g. a SQLAlchemy model in a dataclass or any other mixture of models.
1 parent 1c83630 commit 9f9dbff

File tree

5 files changed

+238
-5
lines changed

5 files changed

+238
-5
lines changed

docs/usage/dto/0-basic-use.rst

+42
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,45 @@ Python to JSON (collection) ~5.4x
188188
.. seealso::
189189
If you are interested in technical details, check out
190190
https://github.com/litestar-org/litestar/pull/2388
191+
192+
Mixed model types
193+
~~~~~~~~~~~~~~~~~~~
194+
195+
It is sometimes required to parse data that cannot be handled by a single DTO factory.
196+
For example, a data container that is a ``dataclass`` but its inner data contains a
197+
``SQLAlchemy`` model. Using the ``DataclassDTO`` by itself will raise an error as it does
198+
not know how to handle non-native models (e.g. a ``SQLAlchemy`` model). Therefore, we configure
199+
the ``DataclassDTO`` to use a custom DTO factory for specific models. Here is an example:
200+
201+
.. code-block:: python
202+
203+
from sqlalchemy.orm import DeclarativeBase
204+
205+
from dataclasses import dataclass
206+
207+
from litestar.dto import DTOConfig, DataclassDTO
208+
from litestar.plugins.sqlalchemy import SQLAlchemyDTO
209+
210+
class Base(DeclarativeBase):
211+
pass
212+
213+
class User(Base):
214+
__tablename__ = "users"
215+
id: Mapped[int] = mapped_column(primary_key=True)
216+
name: Mapped[str]
217+
218+
@dataclass
219+
class Foo:
220+
user: User
221+
info: str
222+
223+
224+
class FooDTO(DataclassDTO[Foo]):
225+
config = DTOConfig(
226+
custom_dto_factories={
227+
User: SQLAlchemyDTO[User],
228+
}
229+
)
230+
231+
In the above example, whenever a ``User`` model is to be parsed, the ``SQLAlchemyDTO`` will be used
232+
instead of the outer ``DataclassDTO``.

litestar/dto/_backend.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,15 @@ def parse_model(
138138
"""
139139
defined_fields = []
140140
generic_field_definitions = list(FieldDefinition.from_annotation(model_type).generic_types or ())
141-
for field_definition in self.dto_factory.generate_field_definitions(model_type):
141+
dto_factory = self.dto_factory.get_dto_factory_for_type(model_type)
142+
for field_definition in dto_factory.generate_field_definitions(model_type):
142143
if field_definition.is_type_var:
143144
base_arg_field = generic_field_definitions.pop()
144145
field_definition = replace(
145146
field_definition, annotation=base_arg_field.annotation, raw=base_arg_field.raw
146147
)
147148

148-
if _should_mark_private(field_definition, self.dto_factory.config.underscore_fields_private):
149+
if _should_mark_private(field_definition, dto_factory.config.underscore_fields_private):
149150
field_definition.dto_field.mark = Mark.PRIVATE
150151

151152
try:
@@ -165,7 +166,7 @@ def parse_model(
165166
field_definition=field_definition,
166167
serialization_name=rename_fields.get(field_definition.name),
167168
transfer_type=transfer_type,
168-
is_partial=self.dto_factory.config.partial,
169+
is_partial=dto_factory.config.partial,
169170
is_excluded=_should_exclude_field(
170171
field_definition=field_definition,
171172
exclude=exclude,
@@ -406,7 +407,8 @@ def _create_transfer_type(
406407

407408
transfer_model: NestedFieldInfo | None = None
408409

409-
if self.dto_factory.detect_nested_field(field_definition):
410+
dto_factory = self.dto_factory.get_dto_factory_for_type(field_definition.annotation)
411+
if dto_factory.detect_nested_field(field_definition):
410412
if nested_depth == self.dto_factory.config.max_nested_depth:
411413
raise RecursionError
412414

litestar/dto/base_dto.py

+12
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,18 @@ def get_config_for_model_type(cls, config: DTOConfig, model_type: type[Any]) ->
104104
"""
105105
return config
106106

107+
@classmethod
108+
def get_dto_factory_for_type(cls, model_type: type[Any]) -> type[AbstractDTO]:
109+
"""Get the appropriate DTO factory for a given model type.
110+
111+
Args:
112+
model_type: The model type to get the DTO factory for.
113+
114+
Returns:
115+
The DTO factory to use for the given model type.
116+
"""
117+
return cls.config.custom_dto_factories.get(model_type, cls)
118+
107119
def decode_builtins(self, value: dict[str, Any]) -> Any:
108120
"""Decode a dictionary of Python values into an the DTO's datatype."""
109121

litestar/dto/config.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any
55

66
from litestar.exceptions import ImproperlyConfiguredException
77

88
if TYPE_CHECKING:
99
from typing import AbstractSet
1010

11+
from litestar.dto.base_dto import AbstractDTO
1112
from litestar.dto.types import RenameStrategy
1213

1314
__all__ = ("DTOConfig",)
@@ -60,6 +61,8 @@ class DTOConfig:
6061
"""Use the experimental codegen backend"""
6162
forbid_unknown_fields: bool = False
6263
"""Raise an exception for fields present in the raw data that are not defined on the model"""
64+
custom_dto_factories: dict[Any, type[AbstractDTO]] = field(default_factory=dict)
65+
"""Use custom dto factories for specific models."""
6366

6467
def __post_init__(self) -> None:
6568
if self.include and self.exclude:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from dataclasses import dataclass
2+
from typing import List
3+
4+
import msgspec
5+
import pytest
6+
from sqlalchemy import ForeignKey
7+
from sqlalchemy.orm import (
8+
DeclarativeBase,
9+
Mapped,
10+
mapped_column,
11+
relationship,
12+
)
13+
from typing_extensions import Annotated
14+
15+
from litestar import post
16+
from litestar.dto import DataclassDTO, DTOConfig
17+
from litestar.testing import create_test_client
18+
19+
try:
20+
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
21+
except ImportError:
22+
from litestar.plugins.sqlalchemy import SQLAlchemyDTO
23+
24+
25+
class Base(DeclarativeBase):
26+
pass
27+
28+
29+
class Role(Base):
30+
__tablename__ = "roles" # type: ignore[assignment]
31+
id: Mapped[int] = mapped_column(primary_key=True)
32+
name: Mapped[str]
33+
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
34+
35+
36+
class User(Base):
37+
__tablename__ = "users" # type: ignore[assignment]
38+
id: Mapped[int] = mapped_column(primary_key=True)
39+
name: Mapped[str]
40+
email: Mapped[str]
41+
roles: Mapped[List[Role]] = relationship()
42+
43+
44+
@dataclass
45+
class SingleUserContainer:
46+
user: User
47+
additional_info: str
48+
49+
50+
@dataclass
51+
class MultiUsersContainer:
52+
users: List[User]
53+
additional_info: str
54+
55+
56+
@pytest.fixture
57+
def user_data() -> User:
58+
return User(
59+
id=1,
60+
name="Test User",
61+
62+
roles=[
63+
Role(id=1, name="Admin"),
64+
Role(id=2, name="SuperAdmin"),
65+
],
66+
)
67+
68+
69+
@pytest.fixture
70+
def single_user_container_data(user_data: User) -> SingleUserContainer:
71+
return SingleUserContainer(user=user_data, additional_info="Additional information")
72+
73+
74+
@pytest.fixture
75+
def multi_users_container_data(user_data: User) -> MultiUsersContainer:
76+
return MultiUsersContainer(users=[user_data, user_data], additional_info="Additional information")
77+
78+
79+
def test_dto_with_custom_dto_factories(
80+
single_user_container_data: SingleUserContainer,
81+
multi_users_container_data: MultiUsersContainer,
82+
use_experimental_dto_backend: bool,
83+
) -> None:
84+
custom_dto_factories = {
85+
User: SQLAlchemyDTO[
86+
Annotated[
87+
User,
88+
DTOConfig(experimental_codegen_backend=use_experimental_dto_backend),
89+
]
90+
],
91+
Role: SQLAlchemyDTO[
92+
Annotated[
93+
Role,
94+
DTOConfig(experimental_codegen_backend=use_experimental_dto_backend),
95+
]
96+
],
97+
}
98+
99+
@post(
100+
path="/single-user",
101+
dto=DataclassDTO[
102+
Annotated[
103+
SingleUserContainer,
104+
DTOConfig(
105+
max_nested_depth=2,
106+
experimental_codegen_backend=use_experimental_dto_backend,
107+
custom_dto_factories=custom_dto_factories,
108+
),
109+
]
110+
],
111+
)
112+
def handler1(data: SingleUserContainer) -> SingleUserContainer:
113+
return data
114+
115+
@post(
116+
path="/multi-users",
117+
dto=DataclassDTO[
118+
Annotated[
119+
MultiUsersContainer,
120+
DTOConfig(
121+
max_nested_depth=2,
122+
experimental_codegen_backend=use_experimental_dto_backend,
123+
custom_dto_factories=custom_dto_factories,
124+
),
125+
]
126+
],
127+
)
128+
def handler2(data: MultiUsersContainer) -> MultiUsersContainer:
129+
return data
130+
131+
with create_test_client(
132+
[
133+
handler1,
134+
handler2,
135+
]
136+
) as client:
137+
user_dict = {
138+
"additional_info": single_user_container_data.additional_info,
139+
"user": {
140+
"id": single_user_container_data.user.id,
141+
"name": single_user_container_data.user.name,
142+
"email": single_user_container_data.user.email,
143+
"roles": [
144+
{"id": role.id, "name": role.name, "user_id": single_user_container_data.user.id}
145+
for role in single_user_container_data.user.roles
146+
],
147+
},
148+
}
149+
received = client.post(
150+
"/single-user",
151+
headers={"Content-Type": "application/json; charset=utf-8"},
152+
content=msgspec.json.encode(user_dict),
153+
)
154+
assert received.json() == user_dict
155+
156+
multi_users_dict = {
157+
"additional_info": multi_users_container_data.additional_info,
158+
"users": [
159+
{
160+
"id": user.id,
161+
"name": user.name,
162+
"email": user.email,
163+
"roles": [{"id": role.id, "name": role.name, "user_id": user.id} for role in user.roles],
164+
}
165+
for user in multi_users_container_data.users
166+
],
167+
}
168+
received = client.post(
169+
"/multi-users",
170+
headers={"Content-Type": "application/json; charset=utf-8"},
171+
content=msgspec.json.encode(multi_users_dict),
172+
)
173+
174+
assert received.json() == multi_users_dict

0 commit comments

Comments
 (0)