Skip to content

Commit 74007de

Browse files
committed
Inspector now distinguishes betwen UnionType and typing.Union
1 parent d2ff39d commit 74007de

File tree

4 files changed

+76
-14
lines changed

4 files changed

+76
-14
lines changed

test/test_00_validate.py

+18
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
else:
2222
from typing_extensions import TypedDict
2323

24+
if sys.version_info[1] >= 10:
25+
from types import UnionType
26+
else:
27+
UnionType = None
28+
29+
2430
_basic_types = [
2531
bool, int, float, complex, str, bytes, bytearray,
2632
list, tuple, set, frozenset, dict, type(None)
@@ -402,3 +408,15 @@ def test_subtype() -> None:
402408
validate(10, typing.Type[int])
403409
with pytest.raises(TypeError):
404410
validate(10, typing.Type[typing.Union[str, float]])
411+
412+
@pytest.mark.parametrize("val, ts", _union_cases)
413+
def test_union_type_cases(val: typing.Any, ts: typing.List[typing.Any]) -> None:
414+
if UnionType is not None:
415+
for t in ts:
416+
members = t.__args__
417+
if not members:
418+
continue
419+
u = members[0]
420+
for t in members[1:]:
421+
u |= t
422+
validate(val, u)

test/test_01_can_validate.py

+21
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import typing
55
import pytest
66

7+
if sys.version_info[1] >= 10:
8+
from types import UnionType
9+
else:
10+
UnionType = None
11+
712
from typing_validation import can_validate, validation_aliases
813
from typing_validation.inspector import _typing_equiv
914
from typing_validation.validation import _pseudotypes_dict
@@ -119,3 +124,19 @@ def test_subtype() -> None:
119124
assert can_validate(typing.Type[typing.Union[int,str]])
120125
assert can_validate(typing.Type[typing.Any])
121126
assert can_validate(typing.Type[typing.Union[typing.Any, str, int]])
127+
128+
_union_cases_ts = sorted({
129+
t for _, ts in _union_cases for t in ts
130+
}, key=repr)
131+
132+
133+
@pytest.mark.parametrize("t", _union_cases_ts)
134+
def test_union_type_cases(t: typing.Any) -> None:
135+
if UnionType is not None:
136+
members = t.__args__
137+
if not members:
138+
return
139+
u = members[0]
140+
for t in members[1:]:
141+
u |= t
142+
assert can_validate(u)

typing_validation/inspector.py

+34-11
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
typing.Tuple[Literal["mapping"], None],
3636
typing.Tuple[Literal["typed-dict"], type],
3737
typing.Tuple[Literal["typevar"], TypeVar],
38-
typing.Tuple[Literal["union"], int],
38+
typing.Tuple[Literal["union"], tuple[int, bool]],
3939
typing.Tuple[Literal["tuple"], Optional[int]],
4040
typing.Tuple[Literal["user-class"], Optional[int]],
4141
typing.Tuple[Literal["alias"], str],
@@ -44,6 +44,11 @@
4444
else:
4545
TypeConstructorArgs = typing.Tuple[str, Any]
4646

47+
if sys.version_info[1] >= 10:
48+
from types import UnionType
49+
else:
50+
UnionType = None
51+
4752
if sys.version_info[1] >= 11:
4853
from typing import Self
4954
else:
@@ -186,12 +191,19 @@ def _recorded_type(self, idx: int) -> typing.Tuple[Any, int]:
186191
idx,
187192
) # pylint: disable = unnecessary-dunder-call
188193
if tag == "union":
189-
assert isinstance(param, int)
194+
assert isinstance(param, tuple)
195+
num_members, use_UnionType = param
196+
assert isinstance(num_members, int)
190197
member_ts: typing.List[Any] = []
191-
for _ in range(param):
198+
for _ in range(num_members):
192199
member_t, idx = self._recorded_type(idx + 1)
193200
member_ts.append(member_t)
194-
return typing.Union.__getitem__(tuple(member_ts)), idx
201+
if not use_UnionType:
202+
return typing.Union.__getitem__(tuple(member_ts)), idx
203+
union_type = member_ts[0]
204+
for t in member_ts[1:]:
205+
union_type |= t
206+
return union_type, idx
195207
if tag == "typed-dict":
196208
for _ in get_type_hints(param):
197209
_, idx = self._recorded_type(idx + 1)
@@ -302,8 +314,11 @@ def _record_collection(self, item_t: Any) -> None:
302314
def _record_mapping(self, key_t: Any, value_t: Any) -> None:
303315
self._append_constructor_args(("mapping", None))
304316

305-
def _record_union(self, *member_ts: Any) -> None:
306-
self._append_constructor_args(("union", len(member_ts)))
317+
def _record_union(self, *member_ts: Any, use_UnionType: bool = False) -> None:
318+
if use_UnionType:
319+
assert member_ts, "Cannot use UnionType with empty members."
320+
assert UnionType is not None, "Cannot use UnionType, version <= 3.9"
321+
self._append_constructor_args(("union", (len(member_ts), use_UnionType)))
307322

308323
def _record_variadic_tuple(self, item_t: Any) -> None:
309324
self._append_constructor_args(("tuple", None))
@@ -385,14 +400,22 @@ def _repr(
385400
]
386401
return lines, idx
387402
if tag == "union":
388-
assert isinstance(param, int)
389-
lines = [indent + "Union["]
390-
for _ in range(param):
403+
assert isinstance(param, tuple)
404+
num_members, use_UnionType = param
405+
assert isinstance(num_members, int)
406+
lines = []
407+
if not use_UnionType:
408+
lines.append(indent + "Union[")
409+
for _ in range(num_members):
391410
member_lines, idx = self._repr(idx + 1, level + 1)
392-
member_lines[-1] += ","
411+
if use_UnionType:
412+
member_lines[-1] += "|"
413+
else:
414+
member_lines[-1] += ","
393415
lines.extend(member_lines)
394416
assert len(lines) > 1, "Cannot take a union of no types."
395-
lines.append(indent + "]")
417+
if not use_UnionType:
418+
lines.append(indent + "]")
396419
return lines, idx
397420
if tag == "typed-dict":
398421
t = param

typing_validation/validation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def _validate_tuple(val: Any, t: Any) -> None:
430430
) from None
431431

432432

433-
def _validate_union(val: Any, t: Any, *, union_type: bool = False) -> None:
433+
def _validate_union(val: Any, t: Any, *, use_UnionType: bool = False) -> None:
434434
"""
435435
Union type validation. Each type ``u`` listed in the union type ``t`` is checked:
436436
@@ -444,7 +444,7 @@ def _validate_union(val: Any, t: Any, *, union_type: bool = False) -> None:
444444
t.__args__, tuple
445445
), f"For type {repr(t)}, expected '__args__' to be a tuple."
446446
if isinstance(val, TypeInspector):
447-
val._record_union(*t.__args__)
447+
val._record_union(*t.__args__, use_UnionType=use_UnionType)
448448
for member_t in t.__args__:
449449
validate(val, member_t)
450450
return
@@ -815,7 +815,7 @@ def validate(val: Any, t: Any) -> Literal[True]:
815815
_validate_typevar(val, t)
816816
return True
817817
if UnionType is not None and isinstance(t, UnionType):
818-
_validate_union(val, t, union_type=True)
818+
_validate_union(val, t, use_UnionType=True)
819819
return True
820820
if hasattr(t, "__origin__"): # parametric types
821821
if t.__origin__ is Union:

0 commit comments

Comments
 (0)