Skip to content

feat: add custom DTO factory support for mixed model types #4123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions docs/usage/dto/0-basic-use.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,45 @@ Python to JSON (collection) ~5.4x
.. seealso::
If you are interested in technical details, check out
https://github.com/litestar-org/litestar/pull/2388

Mixed model types
~~~~~~~~~~~~~~~~~~~

It is sometimes required to parse data that cannot be handled by a single DTO factory.
For example, a data container that is a ``dataclass`` but its inner data contains a
``SQLAlchemy`` model. Using the ``DataclassDTO`` by itself will raise an error as it does
not know how to handle non-native models (e.g. a ``SQLAlchemy`` model). Therefore, we configure
the ``DataclassDTO`` to use a custom DTO factory for specific models. Here is an example:

.. code-block:: python

from sqlalchemy.orm import DeclarativeBase

from dataclasses import dataclass

from litestar.dto import DTOConfig, DataclassDTO
from litestar.plugins.sqlalchemy import SQLAlchemyDTO

class Base(DeclarativeBase):
pass

class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]

@dataclass
class Foo:
user: User
info: str


class FooDTO(DataclassDTO[Foo]):
config = DTOConfig(
custom_dto_factories={
User: SQLAlchemyDTO[User],
}
)

In the above example, whenever a ``User`` model is to be parsed, the ``SQLAlchemyDTO`` will be used
instead of the outer ``DataclassDTO``.
10 changes: 6 additions & 4 deletions litestar/dto/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,15 @@ def parse_model(
"""
defined_fields = []
generic_field_definitions = list(FieldDefinition.from_annotation(model_type).generic_types or ())
for field_definition in self.dto_factory.generate_field_definitions(model_type):
dto_factory = self.dto_factory.get_dto_factory_for_type(model_type)
for field_definition in dto_factory.generate_field_definitions(model_type):
if field_definition.is_type_var:
base_arg_field = generic_field_definitions.pop()
field_definition = replace(
field_definition, annotation=base_arg_field.annotation, raw=base_arg_field.raw
)

if _should_mark_private(field_definition, self.dto_factory.config.underscore_fields_private):
if _should_mark_private(field_definition, dto_factory.config.underscore_fields_private):
field_definition.dto_field.mark = Mark.PRIVATE

try:
Expand All @@ -165,7 +166,7 @@ def parse_model(
field_definition=field_definition,
serialization_name=rename_fields.get(field_definition.name),
transfer_type=transfer_type,
is_partial=self.dto_factory.config.partial,
is_partial=dto_factory.config.partial,
is_excluded=_should_exclude_field(
field_definition=field_definition,
exclude=exclude,
Expand Down Expand Up @@ -406,7 +407,8 @@ def _create_transfer_type(

transfer_model: NestedFieldInfo | None = None

if self.dto_factory.detect_nested_field(field_definition):
dto_factory = self.dto_factory.get_dto_factory_for_type(field_definition.annotation)
if dto_factory.detect_nested_field(field_definition):
if nested_depth == self.dto_factory.config.max_nested_depth:
raise RecursionError

Expand Down
12 changes: 12 additions & 0 deletions litestar/dto/base_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def get_config_for_model_type(cls, config: DTOConfig, model_type: type[Any]) ->
"""
return config

@classmethod
def get_dto_factory_for_type(cls, model_type: type[Any]) -> type[AbstractDTO]:
"""Get the appropriate DTO factory for a given model type.

Args:
model_type: The model type to get the DTO factory for.

Returns:
The DTO factory to use for the given model type.
"""
return cls.config.custom_dto_factories.get(model_type, cls)

def decode_builtins(self, value: dict[str, Any]) -> Any:
"""Decode a dictionary of Python values into an the DTO's datatype."""

Expand Down
5 changes: 4 additions & 1 deletion litestar/dto/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from litestar.exceptions import ImproperlyConfiguredException

if TYPE_CHECKING:
from typing import AbstractSet

from litestar.dto.base_dto import AbstractDTO
from litestar.dto.types import RenameStrategy

__all__ = ("DTOConfig",)
Expand Down Expand Up @@ -60,6 +61,8 @@ class DTOConfig:
"""Use the experimental codegen backend"""
forbid_unknown_fields: bool = False
"""Raise an exception for fields present in the raw data that are not defined on the model"""
custom_dto_factories: dict[Any, type[AbstractDTO]] = field(default_factory=dict)
"""Use custom dto factories for specific models."""

def __post_init__(self) -> None:
if self.include and self.exclude:
Expand Down
174 changes: 174 additions & 0 deletions tests/unit/test_dto/test_factory/test_custom_dto_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from dataclasses import dataclass
from typing import List

import msgspec
import pytest
from sqlalchemy import ForeignKey
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
mapped_column,
relationship,
)
from typing_extensions import Annotated

from litestar import post
from litestar.dto import DataclassDTO, DTOConfig
from litestar.testing import create_test_client

try:
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
except ImportError:
from litestar.plugins.sqlalchemy import SQLAlchemyDTO


class Base(DeclarativeBase):
pass


class Role(Base):
__tablename__ = "roles"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))


class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
email: Mapped[str]
roles: Mapped[List[Role]] = relationship()


@dataclass
class SingleUserContainer:
user: User
additional_info: str


@dataclass
class MultiUsersContainer:
users: List[User]
additional_info: str


@pytest.fixture
def user_data() -> User:
return User(
id=1,
name="Test User",
email="[email protected]",
roles=[
Role(id=1, name="Admin"),
Role(id=2, name="SuperAdmin"),
],
)


@pytest.fixture
def single_user_container_data(user_data: User) -> SingleUserContainer:
return SingleUserContainer(user=user_data, additional_info="Additional information")


@pytest.fixture
def multi_users_container_data(user_data: User) -> MultiUsersContainer:
return MultiUsersContainer(users=[user_data, user_data], additional_info="Additional information")


def test_dto_with_custom_dto_factories(
single_user_container_data: SingleUserContainer,
multi_users_container_data: MultiUsersContainer,
use_experimental_dto_backend: bool,
) -> None:
custom_dto_factories = {
User: SQLAlchemyDTO[
Annotated[
User,
DTOConfig(experimental_codegen_backend=use_experimental_dto_backend),
]
],
Role: SQLAlchemyDTO[
Annotated[
Role,
DTOConfig(experimental_codegen_backend=use_experimental_dto_backend),
]
],
}

@post(
path="/single-user",
dto=DataclassDTO[
Annotated[
SingleUserContainer,
DTOConfig(
max_nested_depth=2,
experimental_codegen_backend=use_experimental_dto_backend,
custom_dto_factories=custom_dto_factories,
),
]
],
)
def handler1(data: SingleUserContainer) -> SingleUserContainer:
return data

@post(
path="/multi-users",
dto=DataclassDTO[
Annotated[
MultiUsersContainer,
DTOConfig(
max_nested_depth=2,
experimental_codegen_backend=use_experimental_dto_backend,
custom_dto_factories=custom_dto_factories,
),
]
],
)
def handler2(data: MultiUsersContainer) -> MultiUsersContainer:
return data

with create_test_client(
[
handler1,
handler2,
]
) as client:
user_dict = {
"additional_info": single_user_container_data.additional_info,
"user": {
"id": single_user_container_data.user.id,
"name": single_user_container_data.user.name,
"email": single_user_container_data.user.email,
"roles": [
{"id": role.id, "name": role.name, "user_id": single_user_container_data.user.id}
for role in single_user_container_data.user.roles
],
},
}
received = client.post(
"/single-user",
headers={"Content-Type": "application/json; charset=utf-8"},
content=msgspec.json.encode(user_dict),
)
assert received.json() == user_dict

multi_users_dict = {
"additional_info": multi_users_container_data.additional_info,
"users": [
{
"id": user.id,
"name": user.name,
"email": user.email,
"roles": [{"id": role.id, "name": role.name, "user_id": user.id} for role in user.roles],
}
for user in multi_users_container_data.users
],
}
received = client.post(
"/multi-users",
headers={"Content-Type": "application/json; charset=utf-8"},
content=msgspec.json.encode(multi_users_dict),
)

assert received.json() == multi_users_dict
Loading