Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions litestar/dto/msgspec_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
from dataclasses import replace
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, Literal, TypeVar

import msgspec.inspect
from msgspec import NODEFAULT, Struct, structs
Expand All @@ -12,13 +12,12 @@
from litestar.dto.field import DTO_FIELD_META_KEY, DTOField, extract_dto_field
from litestar.plugins.core._msgspec import kwarg_definition_from_field
from litestar.types.empty import Empty
from litestar.typing import FieldDefinition

if TYPE_CHECKING:
from collections.abc import Collection, Generator
from typing import Any

from litestar.typing import FieldDefinition


__all__ = ("MsgspecDTO",)

Expand All @@ -33,16 +32,36 @@ def _default_or_none(value: Any) -> Any:
return None if value is NODEFAULT else value


def _msgspec_attribute_accessor(obj: object, name: str) -> Any:
"""Like ``getattr``, but also resolves the synthetic tag field on msgspec Structs.

The tag field (e.g. ``"type"``) is not a real instance attribute — msgspec injects it
only during encoding. This accessor falls back to the struct's type-info when the
normal attribute lookup fails so the DTO transfer layer can read the tag value.
"""
try:
return getattr(obj, name)
except AttributeError:
if isinstance(obj, Struct):
type_info = msgspec.inspect.type_info(type(obj)) # type: ignore[arg-type]
if name == type_info.tag_field:
return type_info.tag
raise


class MsgspecDTO(AbstractDTO[T], Generic[T]):
"""Support for domain modelling with Msgspec."""

attribute_accessor = _msgspec_attribute_accessor

@classmethod
def generate_field_definitions(cls, model_type: type[Struct]) -> Generator[DTOFieldDefinition, None, None]:
msgspec_fields = {f.name: f for f in structs.fields(model_type)}

struct_info = msgspec.inspect.type_info(model_type) # type: ignore[arg-type]
inspect_fields: dict[str, msgspec.inspect.Field] = {
field.name: field
for field in msgspec.inspect.type_info(model_type).fields # type: ignore[attr-defined]
for field in struct_info.fields # type: ignore[attr-defined]
}

property_fields = cls.get_property_fields(model_type)
Expand All @@ -68,6 +87,21 @@ def generate_field_definitions(cls, model_type: type[Struct]) -> Generator[DTOFi
name=key,
)

if struct_info.tag is not None: # type: ignore[attr-defined]
tag_value = struct_info.tag # type: ignore[attr-defined]
tag_field_name = struct_info.tag_field # type: ignore[attr-defined]
tag_annotation = Literal[tag_value] # type: ignore[valid-type]
yield replace(
DTOFieldDefinition.from_field_definition(
field_definition=FieldDefinition.from_annotation(tag_annotation, name=tag_field_name),
dto_field=DTOField(mark="read-only"),
model_name=model_type.__name__,
default_factory=None,
),
default=tag_value,
name=tag_field_name,
)

for key, property_field in property_fields.items():
if key.startswith("_"):
continue
Expand Down
58 changes: 56 additions & 2 deletions tests/unit/test_contrib/test_msgspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import pytest
from msgspec import Meta, Struct, field

from litestar import Litestar, post
from litestar.dto import DTOField, Mark, MsgspecDTO, dto_field
from litestar import Litestar, get, post
from litestar.dto import DTOConfig, DTOField, Mark, MsgspecDTO, dto_field
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.testing import create_test_client
from litestar.typing import FieldDefinition

if TYPE_CHECKING:
Expand Down Expand Up @@ -228,3 +229,56 @@ class ModelWithClassVar(Struct):
# Only the regular field should be included, not the ClassVar
assert len(field_defs) == 1
assert field_defs[0].name == "regular_field"


@pytest.mark.parametrize("use_experimental_dto_backend", [False, True])
def test_msgspec_dto_tagged_union_tag_field_serialized(use_experimental_dto_backend: bool) -> None:
"""Tag field must be present in DTO-serialized output for tagged Struct types.

Regression: MsgspecDTO.generate_field_definitions iterates over
msgspec.inspect.type_info(model).fields which does NOT include the synthetic
tag field, so the tag is silently dropped when the DTO builds its transfer model.
"""

class Cat(Struct, tag=True):
name: str

class Dog(Struct, tag=True):
name: str

class CatDTO(MsgspecDTO[Cat]):
config = DTOConfig(experimental_codegen_backend=use_experimental_dto_backend)

@get("/cat", return_dto=CatDTO, signature_types=[Cat])
def handler() -> Cat:
return Cat(name="Whiskers")

with create_test_client([handler]) as client:
response = client.get("/cat")
assert response.status_code == 200
data = response.json()
# The tag field ("type") must be present and equal to the class name
assert data.get("type") == "Cat", f"Expected tag field 'type' = 'Cat' in response, got: {data!r}"
assert data.get("name") == "Whiskers"


@pytest.mark.parametrize("use_experimental_dto_backend", [False, True])
def test_msgspec_dto_tagged_union_custom_tag_field_serialized(use_experimental_dto_backend: bool) -> None:
"""Custom tag_field and tag value must be present in DTO-serialized output."""

class Widget(Struct, tag_field="kind", tag="widget"):
value: int

class WidgetDTO(MsgspecDTO[Widget]):
config = DTOConfig(experimental_codegen_backend=use_experimental_dto_backend)

@get("/widget", return_dto=WidgetDTO, signature_types=[Widget])
def handler() -> Widget:
return Widget(value=42)

with create_test_client([handler]) as client:
response = client.get("/widget")
assert response.status_code == 200
data = response.json()
assert data.get("kind") == "widget", f"Expected tag field 'kind' = 'widget' in response, got: {data!r}"
assert data.get("value") == 42
Loading