Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions litestar/_openapi/schema_generation/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
is_undefined_sentinel,
)
from litestar.utils.typing import (
get_origin_or_inner_type,
make_non_optional_union,
unwrap_and_get_origin,
unwrap_new_type,
)

Expand Down Expand Up @@ -171,7 +171,7 @@ def _iter_flat_literal_args(annotation: Any) -> Iterable[Any]:
The flattened arguments of the Literal.
"""
for arg in get_args(annotation):
if get_origin_or_inner_type(arg) is Literal:
if unwrap_and_get_origin(arg) is Literal: # pragma: no branch
yield from _iter_flat_literal_args(arg)
else:
yield arg.value if isinstance(arg, Enum) else arg
Expand Down
4 changes: 2 additions & 2 deletions litestar/_signature/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from litestar.exceptions import InternalServerException, ValidationException
from litestar.params import KwargDefinition, ParameterKwarg
from litestar.typing import FieldDefinition # noqa
from litestar.utils import get_origin_or_inner_type, is_class_and_subclass
from litestar.utils import is_class_and_subclass, unwrap_and_get_origin
from litestar.utils.dataclass import simple_asdict

if TYPE_CHECKING:
Expand Down Expand Up @@ -86,7 +86,7 @@ def _deserializer(target_type: Any, value: Any, default_deserializer: Callable[[
if isinstance(value, target_type):
return value
except TypeError as exc:
if (origin := get_origin_or_inner_type(target_type)) is not None:
if (origin := unwrap_and_get_origin(target_type)) is not None: # pragma: no branch
if isinstance(value, origin):
return value
else:
Expand Down
4 changes: 2 additions & 2 deletions litestar/plugins/pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from litestar.utils import is_class_and_subclass, is_generic, is_undefined_sentinel
from litestar.utils.typing import (
_substitute_typevars,
get_origin_or_inner_type,
get_safe_generic_origin,
get_type_hints_with_generics_resolved,
normalize_type_annotation,
unwrap_and_get_origin,
)

# isort: off
Expand Down Expand Up @@ -126,7 +126,7 @@ def is_pydantic_constrained_field(annotation: Any) -> bool:

def pydantic_unwrap_and_get_origin(annotation: Any) -> Any | None:
if pydantic_v2 is Empty or (pydantic_v1 is not Empty and is_class_and_subclass(annotation, pydantic_v1.BaseModel)):
return get_origin_or_inner_type(annotation)
return unwrap_and_get_origin(annotation)

origin = annotation.__pydantic_generic_metadata__["origin"]
return normalize_type_annotation(origin)
Expand Down
4 changes: 2 additions & 2 deletions litestar/serialization/msgspec_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from litestar.datastructures.secret_values import SecretBytes, SecretString
from litestar.exceptions import SerializationException
from litestar.types import Empty, EmptyType, Serializer, TypeDecodersSequence
from litestar.utils.typing import get_origin_or_inner_type
from litestar.utils.typing import unwrap_and_get_origin

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -118,7 +118,7 @@ def default_deserializer(
# we might get a TypeError here if target_type is a subscribed generic. For
# performance reasons, we let this happen and only unwrap this when we're
# certain this might be the case
if (origin := get_origin_or_inner_type(target_type)) is not None:
if (origin := unwrap_and_get_origin(target_type)) is not None:
target_type = origin
if isinstance(value, target_type):
return value
Expand Down
3 changes: 2 additions & 1 deletion litestar/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from .sequence import find_index, unique
from .sync import AsyncIteratorWrapper, ensure_async_callable
from .typing import get_origin_or_inner_type, make_non_optional_union
from .typing import get_origin_or_inner_type, make_non_optional_union, unwrap_and_get_origin

__all__ = (
"AsyncIteratorWrapper",
Expand Down Expand Up @@ -54,6 +54,7 @@
"normalize_path",
"unique",
"unique_name_for_scope",
"unwrap_and_get_origin",
"url_quote",
"warn_deprecation",
)
4 changes: 2 additions & 2 deletions litestar/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from urllib.parse import quote

from litestar.exceptions import LitestarException
from litestar.utils.typing import get_origin_or_inner_type
from litestar.utils.typing import unwrap_and_get_origin

if TYPE_CHECKING:
from collections.abc import Container
Expand Down Expand Up @@ -41,7 +41,7 @@ def get_name(value: object) -> str:
return cast("str", name)

# On Python 3.9, Foo[int] does not have the __name__ attribute.
if origin := get_origin_or_inner_type(value):
if origin := unwrap_and_get_origin(value):
return cast("str", origin.__name__)

return type(value).__name__
Expand Down
24 changes: 11 additions & 13 deletions litestar/utils/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from litestar.constants import UNDEFINED_SENTINELS
from litestar.types.builtin_types import NoneType, UnionTypes
from litestar.utils.helpers import unwrap_partial
from litestar.utils.typing import get_origin_or_inner_type
from litestar.utils.typing import unwrap_and_get_origin

if sys.version_info >= (3, 10):
from inspect import iscoroutinefunction
Expand Down Expand Up @@ -114,7 +114,7 @@ def is_dataclass_class(annotation: Any) -> TypeGuard[type[DataclassProtocol]]:
``True`` if instance or type of ``dataclass``.
"""
try:
origin = get_origin_or_inner_type(annotation)
origin = unwrap_and_get_origin(annotation)
annotation = origin or annotation

return isclass(annotation) and is_dataclass(annotation)
Expand All @@ -134,7 +134,7 @@ def is_class_and_subclass(annotation: Any, type_or_type_tuple: type[T] | tuple[t
Returns:
bool
"""
origin = get_origin_or_inner_type(annotation)
origin = unwrap_and_get_origin(annotation)
if not origin and not isclass(annotation):
return False
try:
Expand Down Expand Up @@ -164,7 +164,7 @@ def is_mapping(annotation: Any) -> TypeGuard[Mapping[Any, Any]]:
Returns:
A typeguard determining whether the type can be cast as :class:`Mapping <typing.Mapping>`.
"""
_type = get_origin_or_inner_type(annotation) or annotation
_type = unwrap_and_get_origin(annotation) or annotation
return isclass(_type) and issubclass(_type, (dict, defaultdict, DefaultDict, Mapping))


Expand All @@ -177,7 +177,7 @@ def is_non_string_iterable(annotation: Any) -> TypeGuard[Iterable[Any]]:
Returns:
A typeguard determining whether the type can be cast as :class:`Iterable <typing.Iterable>` that is not a string.
"""
origin = get_origin_or_inner_type(annotation)
origin = unwrap_and_get_origin(annotation)
if not origin and not isclass(annotation):
return False
try:
Expand All @@ -198,7 +198,7 @@ def is_non_string_sequence(annotation: Any) -> TypeGuard[Sequence[Any]]:
Returns:
A typeguard determining whether the type can be cast as :class:`Sequence <typing.Sequence>` that is not a string.
"""
origin = get_origin_or_inner_type(annotation)
origin = unwrap_and_get_origin(annotation)
if not origin and not isclass(annotation):
return False
try:
Expand Down Expand Up @@ -234,7 +234,7 @@ def is_any(annotation: Any) -> TypeGuard[Any]:
return (
annotation is Any
or getattr(annotation, "_name", "") == "typing.Any"
or (get_origin_or_inner_type(annotation) in UnionTypes and Any in get_args(annotation))
or (unwrap_and_get_origin(annotation) in UnionTypes and Any in get_args(annotation))
)


Expand All @@ -247,7 +247,7 @@ def is_union(annotation: Any) -> bool:
Returns:
A boolean determining whether the type is :data:`Union typing.Union>`.
"""
return get_origin_or_inner_type(annotation) in UnionTypes
return unwrap_and_get_origin(annotation) in UnionTypes


def is_optional_union(annotation: Any) -> TypeGuard[Any | None]:
Expand All @@ -260,10 +260,8 @@ def is_optional_union(annotation: Any) -> TypeGuard[Any | None]:
A typeguard determining whether the type is :data:`Union typing.Union>` with a
None value or :data:`Optional <typing.Optional>` which is equivalent.
"""
origin = get_origin_or_inner_type(annotation)
return origin is Optional or (
get_origin_or_inner_type(annotation) in UnionTypes and NoneType in get_args(annotation)
)
origin = unwrap_and_get_origin(annotation)
return origin is Optional or (unwrap_and_get_origin(annotation) in UnionTypes and NoneType in get_args(annotation))


def is_class_var(annotation: Any) -> bool:
Expand All @@ -275,7 +273,7 @@ def is_class_var(annotation: Any) -> bool:
Returns:
A boolean.
"""
annotation = get_origin_or_inner_type(annotation) or annotation
annotation = unwrap_and_get_origin(annotation) or annotation
return annotation is ClassVar


Expand Down
20 changes: 16 additions & 4 deletions litestar/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)

from litestar.types.builtin_types import NoneType, UnionTypes
from litestar.utils.deprecation import warn_deprecation

__all__ = (
"get_instantiable_origin",
Expand All @@ -56,6 +57,7 @@
"instantiable_type_mapping",
"make_non_optional_union",
"safe_generic_origin_map",
"unwrap_and_get_origin",
"unwrap_annotation",
)

Expand Down Expand Up @@ -187,8 +189,8 @@ def unwrap_new_type(new_type: Any) -> Any:
return inner


def get_origin_or_inner_type(annotation: Any) -> Any:
"""Get origin or unwrap it. Returns None for non-generic types.
def unwrap_and_get_origin(annotation: Any) -> Any:
"""Unwrap annotation and get the instantiable origin type. Returns None for non-generic types.

Args:
annotation: A type annotation.
Expand All @@ -199,12 +201,22 @@ def get_origin_or_inner_type(annotation: Any) -> Any:
origin = get_origin(annotation)
if origin in wrapper_type_set:
inner, _, _ = unwrap_annotation(annotation)
# we need to recursively call here 'get_origin_or_inner_type' because we might be dealing
# we need to recursively call here 'unwrap_and_get_origin' because we might be dealing
# with a generic type alias e.g. Annotated[dict[str, list[int]]
origin = get_origin_or_inner_type(inner)
origin = unwrap_and_get_origin(inner)
return instantiable_type_mapping.get(origin, origin)


def get_origin_or_inner_type(annotation: Any) -> Any:
"""Deprecated alias for :func:`unwrap_and_get_origin`.

.. deprecated:: 3.0
Use :func:`unwrap_and_get_origin` instead.
"""
warn_deprecation("3.0", "get_origin_or_inner_type", "function", alternative="unwrap_and_get_origin")
return unwrap_and_get_origin(annotation)


def get_safe_generic_origin(origin_type: Any, annotation: Any) -> Any:
"""Get a type that is safe to use as a generic type across all supported Python versions.

Expand Down
15 changes: 11 additions & 4 deletions tests/unit/test_utils/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_origin_or_inner_type,
get_type_hints_with_generics_resolved,
make_non_optional_union,
unwrap_and_get_origin,
)
from tests.models import DataclassPerson, DataclassPet # noqa: F401

Expand Down Expand Up @@ -40,10 +41,10 @@ def test_make_non_optional_union(annotation: Any, expected: Any) -> None:
assert make_non_optional_union(annotation) == expected


def test_get_origin_or_inner_type() -> None:
assert get_origin_or_inner_type(List[DataclassPerson]) == list
assert get_origin_or_inner_type(Annotated[List[DataclassPerson], "foo"]) == list
assert get_origin_or_inner_type(Annotated[Dict[str, List[DataclassPerson]], "foo"]) == dict
def test_unwrap_and_get_origin() -> None:
assert unwrap_and_get_origin(List[DataclassPerson]) == list
assert unwrap_and_get_origin(Annotated[List[DataclassPerson], "foo"]) == list
assert unwrap_and_get_origin(Annotated[Dict[str, List[DataclassPerson]], "foo"]) == dict


T = TypeVar("T")
Expand Down Expand Up @@ -158,3 +159,9 @@ def test_expand_type_var_in_type_hints(
type_hint: dict[str, Any], namespace: dict[str, Any] | None, expected: dict[str, Any]
) -> None:
assert expand_type_var_in_type_hint(type_hint, namespace) == expected


def test_get_origin_or_inner_type_deprecated() -> None:
with pytest.warns(DeprecationWarning, match="get_origin_or_inner_type"):
result = get_origin_or_inner_type(List[int])
assert result == list
Loading