Skip to content

Commit 9cfdaa3

Browse files
committed
fixes #9 and fixes #10
1 parent 4085ca5 commit 9cfdaa3

File tree

4 files changed

+110
-19
lines changed

4 files changed

+110
-19
lines changed

test/test_00_validate.py

+33
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,18 @@ def test_specific_invalid_cases(val: typing.Any, ts: typing.List[typing.Any]) ->
148148
except TypeError:
149149
pass
150150

151+
_union_cases: typing.Tuple[typing.Tuple[typing.Any, typing.List[typing.Any]], ...]
151152
_union_cases = (
152153
(0, [typing.Union[str, int], typing.Union[int, str], typing.Optional[int]]),
153154
("hello", [typing.Union[str, int], typing.Union[int, str], typing.Optional[str]]),
154155
)
155156

157+
if sys.version_info[1] >= 10:
158+
_union_cases += (
159+
(0, [str|int, int|str, int|None]),
160+
("hello", [str|int, int|str, str|None]),
161+
)
162+
156163
@pytest.mark.parametrize("val, ts", _union_cases)
157164
def test_union_cases(val: typing.Any, ts: typing.List[typing.Any]) -> None:
158165
for t in ts:
@@ -303,3 +310,29 @@ def test_invalid_typed_dict_cases(t: typing.Any, vals: typing.List[typing.Any])
303310
assert False, f"For type {repr(t)}, the following value shouldn't have been an instance: {repr(val)}"
304311
except TypeError:
305312
pass
313+
314+
315+
S = typing.TypeVar("S", bound=str)
316+
T = typing.TypeVar("T")
317+
U_co = typing.TypeVar("U_co", covariant=True)
318+
319+
class A:
320+
...
321+
class B(typing.Generic[T]):
322+
def __init__(self, t: T) -> None:
323+
...
324+
325+
class C(typing.Generic[S, T, U_co]):
326+
def __init__(self, s: S, t: T, u: U_co) -> None:
327+
...
328+
329+
_user_class_cases = (
330+
(A(), [A]),
331+
(B(10), [B, B[int]]),
332+
(C("hello", 20, 30), [C, C[str, int, typing.Union[int, str]]]),
333+
)
334+
335+
@pytest.mark.parametrize("val, ts", _user_class_cases)
336+
def test_user_class_cases(val: typing.Any, ts: typing.List[typing.Any]) -> None:
337+
for t in ts:
338+
validate(val, t)

test/test_01_can_validate.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing_validation.inspector import _typing_equiv
99
from typing_validation.validation import _pseudotypes_dict
1010

11-
from .test_00_validate import _test_cases, _union_cases, _literal_cases, _alias_cases,_typed_dict_cases, _validation_aliases
11+
from .test_00_validate import _test_cases, _union_cases, _literal_cases, _alias_cases,_typed_dict_cases, _user_class_cases, _validation_aliases
1212

1313
def assert_recorded_type(t: typing.Any) -> None:
1414
_t = can_validate(t).recorded_type
@@ -74,3 +74,14 @@ def test_typed_dict_cases(t: typing.Any) -> None:
7474
assert can_validate(typing.Optional[t]), f"Should be able to validate {typing.Optional[t]}"
7575
str(can_validate(t))
7676
assert_recorded_type(t)
77+
78+
_user_class_cases_ts = sorted({
79+
t for _, ts in _user_class_cases for t in typing.cast(typing.Any, ts)
80+
}, key=repr)
81+
82+
@pytest.mark.parametrize("t", _user_class_cases_ts)
83+
def test_user_class_cases(t: typing.Any) -> None:
84+
assert can_validate(t), f"Should be able to validate {t}"
85+
assert can_validate(typing.Optional[t]), f"Should be able to validate {typing.Optional[t]}"
86+
str(can_validate(t))
87+
assert_recorded_type(t)

typing_validation/inspector.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
typing.Tuple[Literal["typed-dict"], type],
4040
typing.Tuple[Literal["union"], int],
4141
typing.Tuple[Literal["tuple"], Optional[int]],
42+
typing.Tuple[Literal["user-class"], Optional[int]],
4243
typing.Tuple[Literal["alias"], str],
4344
typing.Tuple[Literal["unsupported"], Any],
4445
]
@@ -180,6 +181,17 @@ def _recorded_type(self, idx: int) -> typing.Tuple[Any, int]:
180181
item_ts = [tuple()]
181182
t = pending_type[tuple(item_ts)] if pending_type is not None else typing.Tuple[tuple(item_ts)]
182183
return t, idx
184+
if tag == "user-class":
185+
assert isinstance(param, int)
186+
assert pending_type is not None
187+
item_ts = []
188+
for _ in range(param):
189+
item_t, idx = self._recorded_type(idx+1)
190+
item_ts.append(item_t)
191+
if not item_ts:
192+
item_ts = [tuple()]
193+
t = pending_type[tuple(item_ts)]
194+
return t, idx
183195
assert False, f"Invalid type constructor tag: {repr(tag)}"
184196

185197
def _append_constructor_args(self, args: TypeConstructorArgs) -> None:
@@ -190,7 +202,7 @@ def _append_constructor_args(self, args: TypeConstructorArgs) -> None:
190202
pending_tag, pending_param = pending_generic_type_constr
191203
args_tag, args_param = args
192204
assert pending_tag == "type" and isinstance(pending_param, type)
193-
assert args_tag in ("tuple", "mapping", "collection"), f"Found unexpected tag '{args_tag}' with type constructor {pending_generic_type_constr} pending."
205+
assert args_tag in ("tuple", "mapping", "collection", "user-class"), f"Found unexpected tag '{args_tag}' with type constructor {pending_generic_type_constr} pending."
194206
if sys.version_info[1] >= 9:
195207
self._recorded_constructors.append(typing.cast(TypeConstructorArgs, ("type", (pending_param, args_tag, args_param))))
196208
else:
@@ -228,6 +240,9 @@ def _record_variadic_tuple(self, item_t: Any) -> None:
228240
def _record_fixed_tuple(self, *item_ts: Any) -> None:
229241
self._append_constructor_args(("tuple", len(item_ts)))
230242

243+
def _record_user_class(self, *item_ts: Any) -> None:
244+
self._append_constructor_args(("user-class", len(item_ts)))
245+
231246
def _record_literal(self, *literals: Any) -> None:
232247
self._append_constructor_args(("literal", literals))
233248

@@ -338,4 +353,16 @@ def _repr(self, idx: int = 0, level: int = 0) -> typing.Tuple[typing.List[str],
338353
lines.append("tuple()")
339354
lines.append(indent+"]")
340355
return lines, idx
356+
if tag == "user-class":
357+
assert isinstance(param, int)
358+
assert pending_type is not None
359+
lines = [indent+f"{pending_type.__name__}["]
360+
for _ in range(param):
361+
item_lines, idx = self._repr(idx+1, level+1)
362+
item_lines[-1] += ","
363+
lines.extend(item_lines)
364+
if len(lines) == 1:
365+
lines.append("tuple()")
366+
lines.append(indent+"]")
367+
return lines, idx
341368
assert False, f"Invalid type constructor tag: {repr(tag)}"

typing_validation/validation.py

+37-17
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ def issoftkeyword(s: str) -> bool:
2929
return s == "_"
3030

3131
if sys.version_info[1] >= 10:
32-
from types import NoneType
32+
from types import NoneType, UnionType
3333
else:
3434
NoneType = type(None)
35+
UnionType = None
3536

3637
_validation_aliases: typing.Dict[str, Any] = {}
3738
r"""
@@ -292,7 +293,7 @@ def _validate_tuple(val: Any, t: Any) -> None:
292293
except TypeError as e:
293294
raise _idx_type_error(val, t, e, idx=idx, ordered=True) from None
294295

295-
def _validate_union(val: Any, t: Any) -> None:
296+
def _validate_union(val: Any, t: Any, *, union_type: bool = False) -> None:
296297
"""
297298
Union type validation. Each type ``u`` listed in the union type ``t`` is checked:
298299
@@ -387,6 +388,18 @@ def _validate_typed_dict(val: Any, t: type) -> None:
387388
except TypeError as e:
388389
raise _key_type_error(val, t, e, key=k) from None
389390

391+
def _validate_user_class(val: Any, t: Any) -> None:
392+
assert hasattr(t, "__args__"), _missing_args_msg(t)
393+
assert isinstance(t.__args__, tuple), f"For type {repr(t)}, expected '__args__' to be a tuple."
394+
if isinstance(val, TypeInspector):
395+
val._record_pending_type_generic(t.__origin__)
396+
val._record_user_class(*t.__args__)
397+
for arg in t.__args__:
398+
validate(val, arg)
399+
return
400+
_validate_type(val, t.__origin__)
401+
# Generic type arguments cannot be validated
402+
390403
# def _validate_callable(val: Any, t: Any) -> None:
391404
# """
392405
# Callable validation
@@ -471,6 +484,9 @@ def validate(val: Any, t: Any) -> None:
471484
val._record_any()
472485
return
473486
return
487+
if UnionType is not None and isinstance(t, UnionType):
488+
_validate_union(val, t, union_type=True)
489+
return
474490
if hasattr(t, "__origin__"): # parametric types
475491
if t.__origin__ is Union:
476492
_validate_union(val, t)
@@ -483,22 +499,26 @@ def validate(val: Any, t: Any) -> None:
483499
val._record_pending_type_generic(t.__origin__)
484500
else:
485501
_validate_type(val, t.__origin__)
486-
if t.__origin__ in _collection_origins:
487-
ordered = t.__origin__ in _ordered_collection_origins
488-
_validate_collection(val, t, ordered)
489-
return
490-
if t.__origin__ in _mapping_origins:
491-
_validate_mapping(val, t)
492-
return
493-
if t.__origin__ == tuple:
494-
_validate_tuple(val, t)
495-
return
496-
if t.__origin__ in _iterator_origins:
497-
if isinstance(val, TypeInspector):
502+
if t.__origin__ in _collection_origins:
503+
ordered = t.__origin__ in _ordered_collection_origins
504+
_validate_collection(val, t, ordered)
505+
return
506+
if t.__origin__ in _mapping_origins:
507+
_validate_mapping(val, t)
508+
return
509+
if t.__origin__ == tuple:
510+
_validate_tuple(val, t)
511+
return
512+
if t.__origin__ in _iterator_origins:
513+
if isinstance(val, TypeInspector):
514+
_validate_collection(val, t, ordered=False)
515+
# Item type cannot be validated for iterators (use validated_iter)
516+
return
517+
if t.__origin__ in _maybe_collection_origins and isinstance(val, typing.Collection):
498518
_validate_collection(val, t, ordered=False)
499-
return
500-
if t.__origin__ in _maybe_collection_origins and isinstance(val, typing.Collection):
501-
_validate_collection(val, t, ordered=False)
519+
return
520+
elif isinstance(t.__origin__, type):
521+
_validate_user_class(val, t)
502522
return
503523
elif isinstance(t, type):
504524
# The `isinstance(t, type)` case goes after the `hasattr(t, "__origin__")` case:

0 commit comments

Comments
 (0)