Skip to content

Commit 81ba972

Browse files
authored
Merge pull request #20 from FFY00/fix-bad-typing_extensions-import
Fix typing_extensions always being imported
2 parents d87c635 + eec6a1e commit 81ba972

File tree

2 files changed

+51
-12
lines changed

2 files changed

+51
-12
lines changed

test/test_module.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import importlib.machinery
2+
import sys
3+
import types
4+
5+
from collections.abc import Sequence
6+
from typing import List
7+
8+
import pytest
9+
10+
11+
class FailMockFinder:
12+
def __init__(self, modules: List[str]) -> None:
13+
self.modules = modules
14+
15+
def find_spec(
16+
self,
17+
fullname: str,
18+
path: Sequence[str] | None,
19+
target: types.ModuleType | None,
20+
) -> importlib.machinery.ModuleSpec | None:
21+
if fullname in self.modules:
22+
raise ModuleNotFoundError(f"No module named '{fullname}'", name=fullname)
23+
return None
24+
25+
26+
@pytest.mark.skipif(sys.version_info < (3, 11), reason='Python <3.11 needs typing_extensions')
27+
def test_typing_extensions_availability(monkeypatch: pytest.MonkeyPatch) -> None:
28+
finder = FailMockFinder('typing_extensions')
29+
monkeypatch.setattr(sys, 'meta_path', [finder] + sys.meta_path)
30+
monkeypatch.delitem(sys.modules, 'typing_extensions')
31+
32+
import typing_validation

typing_validation/validation.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
Union,
2020
get_type_hints,
2121
)
22-
import typing_extensions
2322

2423
from .validation_failure import (
2524
InvalidNumpyDTypeValidationFailure,
@@ -54,6 +53,15 @@ def issoftkeyword(s: str) -> bool:
5453
NoneType = type(None)
5554
UnionType = None
5655

56+
57+
try:
58+
import typing_extensions
59+
except ModuleNotFoundError:
60+
_typing_modules = [typing]
61+
else:
62+
_typing_modules = [typing, typing_extensions]
63+
64+
5765
_validation_aliases: typing.Dict[str, Any] = {}
5866
r"""
5967
Current context of type aliases, used to resolve forward references to type aliases in :func:`validate`.
@@ -84,6 +92,14 @@ def validation_aliases(**aliases: Any) -> collections.abc.Iterator[None]:
8492
_validation_aliases = outer_validation_aliases
8593

8694

95+
def _get_type_classes(name: str) -> typing.List[typing.Type[Any]]:
96+
"""Get the classes for the specified type from typing and its possible backport modules."""
97+
return [
98+
getattr(module, name)
99+
for module in _typing_modules
100+
if hasattr(module, name)
101+
]
102+
87103
# basic types
88104
_basic_types = frozenset(
89105
{bool, int, float, complex, bytes, bytearray, memoryview, str, range, slice}
@@ -496,16 +512,7 @@ def _is_typed_dict(t: type) -> bool:
496512
"""
497513
Determines whether a type is a subclass of :class:`TypedDict`.
498514
"""
499-
if (
500-
hasattr(typing_extensions, "_TypedDictMeta")
501-
and t.__class__ == typing_extensions._TypedDictMeta
502-
):
503-
return True
504-
if hasattr(typing, "_TypedDictMeta") and t.__class__ == getattr(
505-
typing, "_TypedDictMeta"
506-
):
507-
return True
508-
return False
515+
return t.__class__ in _get_type_classes('_TypedDictMeta')
509516

510517

511518
def _validate_typed_dict(val: Any, t: type) -> None:
@@ -821,7 +828,7 @@ def validate(val: Any, t: Any) -> Literal[True]:
821828
if t.__origin__ is Union:
822829
_validate_union(val, t)
823830
return True
824-
if t.__origin__ is Literal or t.__origin__ is typing_extensions.Literal:
831+
if t.__origin__ in _get_type_classes('Literal'):
825832
_validate_literal(val, t)
826833
return True
827834
if t.__origin__ in _origins:

0 commit comments

Comments
 (0)