diff --git a/docs/usage/dto/0-basic-use.rst b/docs/usage/dto/0-basic-use.rst index 621e655df2..cc38db6968 100644 --- a/docs/usage/dto/0-basic-use.rst +++ b/docs/usage/dto/0-basic-use.rst @@ -188,3 +188,45 @@ Python to JSON (collection) ~5.4x .. seealso:: If you are interested in technical details, check out https://github.com/litestar-org/litestar/pull/2388 + +Mixed model types +~~~~~~~~~~~~~~~~~~~ + +It is sometimes required to parse data that cannot be handled by a single DTO factory. +For example, a data container that is a ``dataclass`` but its inner data contains a +``SQLAlchemy`` model. Using the ``DataclassDTO`` by itself will raise an error as it does +not know how to handle non-native models (e.g. a ``SQLAlchemy`` model). Therefore, we configure +the ``DataclassDTO`` to use a custom DTO factory for specific models. Here is an example: + + .. code-block:: python + + from sqlalchemy.orm import DeclarativeBase + + from dataclasses import dataclass + + from litestar.dto import DTOConfig, DataclassDTO + from litestar.plugins.sqlalchemy import SQLAlchemyDTO + + class Base(DeclarativeBase): + pass + + class User(Base): + __tablename__ = "users" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + @dataclass + class Foo: + user: User + info: str + + + class FooDTO(DataclassDTO[Foo]): + config = DTOConfig( + custom_dto_factories={ + User: SQLAlchemyDTO[User], + } + ) + +In the above example, whenever a ``User`` model is to be parsed, the ``SQLAlchemyDTO`` will be used +instead of the outer ``DataclassDTO``. diff --git a/litestar/dto/_backend.py b/litestar/dto/_backend.py index ba121e0f5c..e27bb666bd 100644 --- a/litestar/dto/_backend.py +++ b/litestar/dto/_backend.py @@ -138,14 +138,15 @@ def parse_model( """ defined_fields = [] generic_field_definitions = list(FieldDefinition.from_annotation(model_type).generic_types or ()) - for field_definition in self.dto_factory.generate_field_definitions(model_type): + dto_factory = self.dto_factory.get_dto_factory_for_type(model_type) + for field_definition in dto_factory.generate_field_definitions(model_type): if field_definition.is_type_var: base_arg_field = generic_field_definitions.pop() field_definition = replace( field_definition, annotation=base_arg_field.annotation, raw=base_arg_field.raw ) - if _should_mark_private(field_definition, self.dto_factory.config.underscore_fields_private): + if _should_mark_private(field_definition, dto_factory.config.underscore_fields_private): field_definition.dto_field.mark = Mark.PRIVATE try: @@ -165,7 +166,7 @@ def parse_model( field_definition=field_definition, serialization_name=rename_fields.get(field_definition.name), transfer_type=transfer_type, - is_partial=self.dto_factory.config.partial, + is_partial=dto_factory.config.partial, is_excluded=_should_exclude_field( field_definition=field_definition, exclude=exclude, @@ -406,7 +407,8 @@ def _create_transfer_type( transfer_model: NestedFieldInfo | None = None - if self.dto_factory.detect_nested_field(field_definition): + dto_factory = self.dto_factory.get_dto_factory_for_type(field_definition.annotation) + if dto_factory.detect_nested_field(field_definition): if nested_depth == self.dto_factory.config.max_nested_depth: raise RecursionError diff --git a/litestar/dto/base_dto.py b/litestar/dto/base_dto.py index 93498f3c99..beb6450bc9 100644 --- a/litestar/dto/base_dto.py +++ b/litestar/dto/base_dto.py @@ -104,6 +104,18 @@ def get_config_for_model_type(cls, config: DTOConfig, model_type: type[Any]) -> """ return config + @classmethod + def get_dto_factory_for_type(cls, model_type: type[Any]) -> type[AbstractDTO]: + """Get the appropriate DTO factory for a given model type. + + Args: + model_type: The model type to get the DTO factory for. + + Returns: + The DTO factory to use for the given model type. + """ + return cls.config.custom_dto_factories.get(model_type, cls) + def decode_builtins(self, value: dict[str, Any]) -> Any: """Decode a dictionary of Python values into an the DTO's datatype.""" diff --git a/litestar/dto/config.py b/litestar/dto/config.py index eae722f65f..87acfa2ee4 100644 --- a/litestar/dto/config.py +++ b/litestar/dto/config.py @@ -1,13 +1,14 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from litestar.exceptions import ImproperlyConfiguredException if TYPE_CHECKING: from typing import AbstractSet + from litestar.dto.base_dto import AbstractDTO from litestar.dto.types import RenameStrategy __all__ = ("DTOConfig",) @@ -60,6 +61,8 @@ class DTOConfig: """Use the experimental codegen backend""" forbid_unknown_fields: bool = False """Raise an exception for fields present in the raw data that are not defined on the model""" + custom_dto_factories: dict[Any, type[AbstractDTO]] = field(default_factory=dict) + """Use custom dto factories for specific models.""" def __post_init__(self) -> None: if self.include and self.exclude: diff --git a/tests/unit/test_dto/test_factory/test_custom_dto_factories.py b/tests/unit/test_dto/test_factory/test_custom_dto_factories.py new file mode 100644 index 0000000000..78dd25a55e --- /dev/null +++ b/tests/unit/test_dto/test_factory/test_custom_dto_factories.py @@ -0,0 +1,174 @@ +from dataclasses import dataclass +from typing import List + +import msgspec +import pytest +from sqlalchemy import ForeignKey +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, + relationship, +) +from typing_extensions import Annotated + +from litestar import post +from litestar.dto import DataclassDTO, DTOConfig +from litestar.testing import create_test_client + +try: + from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO +except ImportError: + from litestar.plugins.sqlalchemy import SQLAlchemyDTO + + +class Base(DeclarativeBase): + pass + + +class Role(Base): + __tablename__ = "roles" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + user_id: Mapped[int] = mapped_column(ForeignKey("users.id")) + + +class User(Base): + __tablename__ = "users" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + email: Mapped[str] + roles: Mapped[List[Role]] = relationship() + + +@dataclass +class SingleUserContainer: + user: User + additional_info: str + + +@dataclass +class MultiUsersContainer: + users: List[User] + additional_info: str + + +@pytest.fixture +def user_data() -> User: + return User( + id=1, + name="Test User", + email="test@example.com", + roles=[ + Role(id=1, name="Admin"), + Role(id=2, name="SuperAdmin"), + ], + ) + + +@pytest.fixture +def single_user_container_data(user_data: User) -> SingleUserContainer: + return SingleUserContainer(user=user_data, additional_info="Additional information") + + +@pytest.fixture +def multi_users_container_data(user_data: User) -> MultiUsersContainer: + return MultiUsersContainer(users=[user_data, user_data], additional_info="Additional information") + + +def test_dto_with_custom_dto_factories( + single_user_container_data: SingleUserContainer, + multi_users_container_data: MultiUsersContainer, + use_experimental_dto_backend: bool, +) -> None: + custom_dto_factories = { + User: SQLAlchemyDTO[ + Annotated[ + User, + DTOConfig(experimental_codegen_backend=use_experimental_dto_backend), + ] + ], + Role: SQLAlchemyDTO[ + Annotated[ + Role, + DTOConfig(experimental_codegen_backend=use_experimental_dto_backend), + ] + ], + } + + @post( + path="/single-user", + dto=DataclassDTO[ + Annotated[ + SingleUserContainer, + DTOConfig( + max_nested_depth=2, + experimental_codegen_backend=use_experimental_dto_backend, + custom_dto_factories=custom_dto_factories, + ), + ] + ], + ) + def handler1(data: SingleUserContainer) -> SingleUserContainer: + return data + + @post( + path="/multi-users", + dto=DataclassDTO[ + Annotated[ + MultiUsersContainer, + DTOConfig( + max_nested_depth=2, + experimental_codegen_backend=use_experimental_dto_backend, + custom_dto_factories=custom_dto_factories, + ), + ] + ], + ) + def handler2(data: MultiUsersContainer) -> MultiUsersContainer: + return data + + with create_test_client( + [ + handler1, + handler2, + ] + ) as client: + user_dict = { + "additional_info": single_user_container_data.additional_info, + "user": { + "id": single_user_container_data.user.id, + "name": single_user_container_data.user.name, + "email": single_user_container_data.user.email, + "roles": [ + {"id": role.id, "name": role.name, "user_id": single_user_container_data.user.id} + for role in single_user_container_data.user.roles + ], + }, + } + received = client.post( + "/single-user", + headers={"Content-Type": "application/json; charset=utf-8"}, + content=msgspec.json.encode(user_dict), + ) + assert received.json() == user_dict + + multi_users_dict = { + "additional_info": multi_users_container_data.additional_info, + "users": [ + { + "id": user.id, + "name": user.name, + "email": user.email, + "roles": [{"id": role.id, "name": role.name, "user_id": user.id} for role in user.roles], + } + for user in multi_users_container_data.users + ], + } + received = client.post( + "/multi-users", + headers={"Content-Type": "application/json; charset=utf-8"}, + content=msgspec.json.encode(multi_users_dict), + ) + + assert received.json() == multi_users_dict