|
19 | 19 | Union,
|
20 | 20 | get_type_hints,
|
21 | 21 | )
|
22 |
| -import typing_extensions |
23 | 22 |
|
24 | 23 | from .validation_failure import (
|
25 | 24 | InvalidNumpyDTypeValidationFailure,
|
@@ -54,6 +53,15 @@ def issoftkeyword(s: str) -> bool:
|
54 | 53 | NoneType = type(None)
|
55 | 54 | UnionType = None
|
56 | 55 |
|
| 56 | + |
| 57 | +try: |
| 58 | + import typing_extensions |
| 59 | +except ModuleNotFoundError: |
| 60 | + _typing_modules = [typing] |
| 61 | +else: |
| 62 | + _typing_modules = [typing, typing_extensions] |
| 63 | + |
| 64 | + |
57 | 65 | _validation_aliases: typing.Dict[str, Any] = {}
|
58 | 66 | r"""
|
59 | 67 | Current context of type aliases, used to resolve forward references to type aliases in :func:`validate`.
|
@@ -84,6 +92,14 @@ def validation_aliases(**aliases: Any) -> collections.abc.Iterator[None]:
|
84 | 92 | _validation_aliases = outer_validation_aliases
|
85 | 93 |
|
86 | 94 |
|
| 95 | +def _get_type_classes(name: str) -> typing.List[typing.Type[Any]]: |
| 96 | + """Get the classes for the specified type from typing and its possible backport modules.""" |
| 97 | + return [ |
| 98 | + getattr(module, name) |
| 99 | + for module in _typing_modules |
| 100 | + if hasattr(module, name) |
| 101 | + ] |
| 102 | + |
87 | 103 | # basic types
|
88 | 104 | _basic_types = frozenset(
|
89 | 105 | {bool, int, float, complex, bytes, bytearray, memoryview, str, range, slice}
|
@@ -496,16 +512,7 @@ def _is_typed_dict(t: type) -> bool:
|
496 | 512 | """
|
497 | 513 | Determines whether a type is a subclass of :class:`TypedDict`.
|
498 | 514 | """
|
499 |
| - if ( |
500 |
| - hasattr(typing_extensions, "_TypedDictMeta") |
501 |
| - and t.__class__ == typing_extensions._TypedDictMeta |
502 |
| - ): |
503 |
| - return True |
504 |
| - if hasattr(typing, "_TypedDictMeta") and t.__class__ == getattr( |
505 |
| - typing, "_TypedDictMeta" |
506 |
| - ): |
507 |
| - return True |
508 |
| - return False |
| 515 | + return t.__class__ in _get_type_classes('_TypedDictMeta') |
509 | 516 |
|
510 | 517 |
|
511 | 518 | def _validate_typed_dict(val: Any, t: type) -> None:
|
@@ -821,7 +828,7 @@ def validate(val: Any, t: Any) -> Literal[True]:
|
821 | 828 | if t.__origin__ is Union:
|
822 | 829 | _validate_union(val, t)
|
823 | 830 | return True
|
824 |
| - if t.__origin__ is Literal or t.__origin__ is typing_extensions.Literal: |
| 831 | + if t.__origin__ in _get_type_classes('Literal'): |
825 | 832 | _validate_literal(val, t)
|
826 | 833 | return True
|
827 | 834 | if t.__origin__ in _origins:
|
|
0 commit comments