Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions litestar/_signature/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import re
from functools import partial
from pathlib import Path, PurePath
from typing import (
TYPE_CHECKING,
Annotated,
Expand All @@ -15,25 +14,22 @@
Union,
cast,
)
from uuid import UUID

from msgspec import NODEFAULT, Meta, Struct, ValidationError, convert, defstruct
from msgspec.structs import asdict

from litestar._signature.types import ExtendedMsgSpecValidationError
from litestar._signature.utils import (
_get_decoder_for_type,
_normalize_annotation,
_validate_signature_dependencies,
)
from litestar.datastructures.state import ImmutableState
from litestar.datastructures.url import URL
from litestar.dto import AbstractDTO, DTOData
from litestar.enums import ParamType, ScopeType
from litestar.exceptions import InternalServerException, ValidationException
from litestar.params import KwargDefinition, ParameterKwarg
from litestar.typing import FieldDefinition # noqa
from litestar.utils import get_origin_or_inner_type, is_class_and_subclass
from litestar.utils import get_origin_or_inner_type
from litestar.utils.dataclass import simple_asdict

if TYPE_CHECKING:
Expand Down Expand Up @@ -73,10 +69,6 @@ class ErrorMessage(TypedDict):

ERR_RE = re.compile(r"`\$\.(.+)`$")

DEFAULT_TYPE_DECODERS = [
(lambda x: is_class_and_subclass(x, (Path, PurePath, ImmutableState, UUID)), lambda t, v: t(v)),
]


def _deserializer(target_type: Any, value: Any, default_deserializer: Callable[[Any, Any], Any]) -> Any:
if isinstance(value, DTOData):
Expand All @@ -92,9 +84,6 @@ def _deserializer(target_type: Any, value: Any, default_deserializer: Callable[[
else:
raise exc

if decoder := getattr(target_type, "_decoder", None):
return decoder(target_type, value)

return default_deserializer(target_type, value)


Expand Down Expand Up @@ -263,7 +252,6 @@ def create(

annotation = cls._create_annotation(
field_definition=field_definition,
type_decoders=[*(type_decoders or []), *DEFAULT_TYPE_DECODERS],
meta_data=meta_data,
data_dto=data_dto,
)
Expand All @@ -289,7 +277,6 @@ def create(
def _create_annotation(
cls,
field_definition: FieldDefinition,
type_decoders: TypeDecodersSequence,
meta_data: Meta | None = None,
data_dto: type[AbstractDTO] | None = None,
) -> Any:
Expand All @@ -306,18 +293,13 @@ def _create_annotation(
types = [
cls._create_annotation(
field_definition=inner_type,
type_decoders=type_decoders,
meta_data=meta_data,
)
for inner_type in field_definition.inner_types
if not inner_type.is_none_type
]
return Optional[Union[tuple(types)]] if field_definition.is_optional else Union[tuple(types)] # pyright: ignore

if decoder := _get_decoder_for_type(annotation, type_decoders=type_decoders):
# FIXME: temporary (hopefully) hack, see: https://github.com/jcrist/msgspec/issues/497
setattr(annotation, "_decoder", decoder)

if meta_data:
annotation = Annotated[annotation, meta_data] # pyright: ignore

Expand Down
13 changes: 3 additions & 10 deletions litestar/_signature/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

from litestar.constants import SKIP_VALIDATION_NAMES
from litestar.exceptions import ImproperlyConfiguredException
from litestar.params import DependencyKwarg
from litestar.types import Empty, TypeDecodersSequence
from litestar.types import Empty

if TYPE_CHECKING:
from litestar.typing import FieldDefinition
from litestar.utils.signature import ParsedSignature


__all__ = ("_get_decoder_for_type", "_normalize_annotation", "_validate_signature_dependencies")
__all__ = ("_normalize_annotation", "_validate_signature_dependencies")


def _validate_signature_dependencies(
Expand Down Expand Up @@ -49,10 +49,3 @@ def _normalize_annotation(field_definition: FieldDefinition) -> Any:
return Any

return field_definition.annotation


def _get_decoder_for_type(target_type: Any, type_decoders: TypeDecodersSequence) -> Callable[[type, Any], Any] | None:
return next(
(decoder for predicate, decoder in type_decoders if predicate(target_type)),
None,
)
133 changes: 132 additions & 1 deletion tests/unit/test_signature/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import msgspec
import pytest

from litestar import get
from litestar import Litestar, get
from litestar._signature import SignatureModel
from litestar.di import Provide
from litestar.dto import DataclassDTO
from litestar.params import Body, Parameter
from litestar.status_codes import HTTP_200_OK, HTTP_204_NO_CONTENT
Expand All @@ -17,6 +18,25 @@
from litestar.utils.signature import ParsedSignature


def _make_prefixed_decoder(
target_type: type[Any], prefix: str
) -> tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]:
def predicate(annotation: Any) -> bool:
return annotation is target_type

def decoder(annotation: type[Any], value: Any) -> Any:
return annotation(f"{prefix}:{value}")

return predicate, decoder


def _assert_unsupported_query_type(response: Any, key: str = "user_id") -> None:
assert response.status_code == 400
payload = response.json()
assert payload["detail"].startswith("Validation failed for GET")
assert payload["extra"] == [{"message": "Unsupported type: <class 'str'>", "key": key, "source": "query"}]


def test_create_function_signature_model_parameter_parsing() -> None:
@get()
def my_fn(a: int, b: str, c: Optional[bytes], d: bytes = b"123", e: Optional[dict] = None) -> None:
Expand Down Expand Up @@ -212,3 +232,114 @@ def fn(data: Test) -> None:
(field,) = msgspec.structs.fields(model)
assert field.name == "data"
assert field.type is Any


def test_same_app_type_decoder_does_not_leak_to_handler_without_decoder() -> None:
class UserId:
def __init__(self, value: str) -> None:
self.value = value

@get("/decoded", type_decoders=[_make_prefixed_decoder(UserId, "decoded")], sync_to_thread=False)
def decoded(user_id: UserId) -> str:
return user_id.value

@get("/plain", sync_to_thread=False)
def plain(user_id: UserId) -> str:
return user_id.value

with create_test_client(route_handlers=[decoded, plain]) as client:
assert client.get("/decoded?user_id=1").text == "decoded:1"
_assert_unsupported_query_type(client.get("/plain?user_id=1"))


def test_conflicting_type_decoders_do_not_overwrite_each_other() -> None:
class UserId:
def __init__(self, value: str) -> None:
self.value = value

@get("/a", type_decoders=[_make_prefixed_decoder(UserId, "handler-a")], sync_to_thread=False)
def handler_a(user_id: UserId) -> str:
return user_id.value

@get("/b", type_decoders=[_make_prefixed_decoder(UserId, "handler-b")], sync_to_thread=False)
def handler_b(user_id: UserId) -> str:
return user_id.value

with create_test_client(route_handlers=[handler_a, handler_b]) as client:
assert client.get("/a?user_id=1").text == "handler-a:1"
assert client.get("/b?user_id=1").text == "handler-b:1"

with create_test_client(route_handlers=[handler_b, handler_a]) as client:
assert client.get("/a?user_id=1").text == "handler-a:1"
assert client.get("/b?user_id=1").text == "handler-b:1"


def test_type_decoder_does_not_leak_across_apps() -> None:
class UserId:
def __init__(self, value: str) -> None:
self.value = value

@get("/", type_decoders=[_make_prefixed_decoder(UserId, "app-a")], sync_to_thread=False)
def app_a_handler(user_id: UserId) -> str:
return user_id.value

@get("/", sync_to_thread=False)
def app_b_handler(user_id: UserId) -> str:
return user_id.value

app_a = Litestar([app_a_handler], openapi_config=None)
app_b = Litestar([app_b_handler], openapi_config=None)

with TestClient(app_a) as client:
assert client.get("/?user_id=1").text == "app-a:1"

with TestClient(app_b) as client:
_assert_unsupported_query_type(client.get("/?user_id=1"))

with TestClient(app_a) as client:
assert client.get("/?user_id=2").text == "app-a:2"


def test_provider_signature_model_decoder_does_not_leak() -> None:
class UserId:
def __init__(self, value: str) -> None:
self.value = value

async def provide_token(user_id: UserId) -> str:
return user_id.value

@get(
"/provided",
dependencies={"token": Provide(provide_token)},
type_decoders=[_make_prefixed_decoder(UserId, "provider")],
sync_to_thread=False,
)
def provided(token: str) -> str:
return token

@get("/plain", sync_to_thread=False)
def plain(user_id: UserId) -> str:
return user_id.value

with create_test_client(route_handlers=[provided, plain]) as client:
assert client.get("/provided?user_id=1").text == "provider:1"
_assert_unsupported_query_type(client.get("/plain?user_id=1"))


def test_signature_model_creation_does_not_mutate_user_type() -> None:
class UserId:
def __init__(self, value: str) -> None:
self.value = value

def fn(user_id: UserId) -> None:
pass

SignatureModel.create(
dependency_name_set=set(),
fn=fn,
data_dto=None,
parsed_signature=ParsedSignature.from_fn(fn, {}),
type_decoders=[_make_prefixed_decoder(UserId, "model")],
)

assert not hasattr(UserId, "_decoder")
Loading