Skip to content

Commit 2a11dea

Browse files
committed
Closes #17
1 parent 4601fe5 commit 2a11dea

File tree

5 files changed

+170
-20
lines changed

5 files changed

+170
-20
lines changed

test/test_00_validate.py

+14
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,17 @@ def test_typevar() -> None:
388388
validate("Hello", IntStrSeqT)
389389
with pytest.raises(TypeError):
390390
validate(0, IntStrSeqT)
391+
392+
def test_subtype() -> None:
393+
validate(int, type)
394+
validate(int, typing.Type)
395+
validate(int, typing.Type[int])
396+
validate(int, typing.Type[typing.Any])
397+
validate(int, typing.Type[typing.Union[float,str,typing.Any]])
398+
validate(int, typing.Type[typing.Union[int,str]])
399+
with pytest.raises(TypeError):
400+
validate(int, typing.Type[typing.Union[str, float]])
401+
with pytest.raises(TypeError):
402+
validate(10, typing.Type[int])
403+
with pytest.raises(TypeError):
404+
validate(10, typing.Type[typing.Union[str, float]])

test/test_01_can_validate.py

+8
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,11 @@ def test_typevar() -> None:
111111
assert can_validate(IntT)
112112
IntStrSeqT = typing.TypeVar("IntStrSeqT", bound=typing.Sequence[typing.Union[int,str]])
113113
assert can_validate(IntStrSeqT)
114+
115+
def test_subtype() -> None:
116+
assert can_validate(type)
117+
assert can_validate(typing.Type)
118+
assert can_validate(typing.Type[int])
119+
assert can_validate(typing.Type[typing.Union[int,str]])
120+
assert can_validate(typing.Type[typing.Any])
121+
assert can_validate(typing.Type[typing.Union[typing.Any, str, int]])

typing_validation/inspector.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,10 @@ def _append_constructor_args(self, args: TypeConstructorArgs) -> None:
228228
"mapping",
229229
"collection",
230230
"user-class",
231-
), f"Found unexpected tag '{args_tag}' with type constructor {pending_generic_type_constr} pending."
231+
), (
232+
f"Found unexpected tag '{args_tag}' with "
233+
f"type constructor {pending_generic_type_constr} pending."
234+
)
232235
if sys.version_info[1] >= 9:
233236
self._recorded_constructors.append(
234237
typing.cast(

typing_validation/validation.py

+115-16
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from .validation_failure import (
2525
InvalidNumpyDTypeValidationFailure,
26+
SubtypeValidationFailure,
2627
TypeVarBoundValidationFailure,
2728
ValidationFailureAtIdx,
2829
ValidationFailureAtKey,
@@ -199,6 +200,29 @@ def validation_aliases(**aliases: Any) -> collections.abc.Iterator[None]:
199200
)
200201

201202

203+
class UnsupportedTypeError(ValueError):
204+
"""
205+
Class for errors raised when attempting to validate an unsupported type.
206+
207+
.. warning::
208+
209+
Currently extends :obj:`ValueError` for backwards compatibility.
210+
This will be changed to :obj:`NotImplementedError` in v1.3.0.
211+
"""
212+
213+
214+
def _unsupported_type_error(
215+
t: Any, explanation: Union[str, None] = None
216+
) -> UnsupportedTypeError:
217+
"""
218+
Error for unsupported types, with optional explanation.
219+
"""
220+
msg = "Unsupported validation for type {t!r}."
221+
if explanation is not None:
222+
msg += " " + explanation
223+
return UnsupportedTypeError(msg)
224+
225+
202226
def _type_error(
203227
val: Any, t: Any, *errors: TypeError, is_union: bool = False
204228
) -> TypeError:
@@ -272,6 +296,13 @@ def _missing_keys_type_error(val: Any, t: Any, *missing_keys: Any) -> TypeError:
272296
return error
273297

274298

299+
def _subtype_error(s: Any, t: Any) -> TypeError:
300+
validation_failure = SubtypeValidationFailure(s, t)
301+
error = TypeError(str(validation_failure))
302+
setattr(error, "validation_failure", validation_failure)
303+
return error
304+
305+
275306
def _type_alias_error(t_alias: str, cause: TypeError) -> TypeError:
276307
"""
277308
Repackages a validation error as a type alias error.
@@ -505,21 +536,88 @@ def _validate_typed_dict(val: Any, t: type) -> None:
505536
except TypeError as e:
506537
raise _key_type_error(val, t, e, key=k) from None
507538

508-
509539
def _validate_user_class(val: Any, t: Any) -> None:
510540
assert hasattr(t, "__args__"), _missing_args_msg(t)
511541
assert isinstance(
512542
t.__args__, tuple
513543
), f"For type {repr(t)}, expected '__args__' to be a tuple."
514544
if isinstance(val, TypeInspector):
545+
if t.__origin__ is type:
546+
if len(t.__args__) != 1 or not _can_validate_subtype_of(
547+
t.__args__[0]
548+
):
549+
val._record_unsupported_type(t)
550+
return
515551
val._record_pending_type_generic(t.__origin__)
516552
val._record_user_class(*t.__args__)
517553
for arg in t.__args__:
518554
validate(val, arg)
519555
return
520556
_validate_type(val, t.__origin__)
521-
# Generic type arguments cannot be validated
557+
if t.__origin__ is type:
558+
if len(t.__args__) != 1:
559+
raise _unsupported_type_error(t)
560+
_validate_subtype_of(val, t.__args__[0])
561+
return
562+
# TODO: Generic type arguments cannot be validated in general,
563+
# but in a future release it will be possible for classes to define
564+
# a dunder classmethod which can be used to validate type arguments.
565+
566+
def __extract_member_types(u: Any) -> tuple[Any, ...]|None:
567+
q = collections.deque([u])
568+
member_types: list[Any] = []
569+
while q:
570+
t = q.popleft()
571+
if t is Any:
572+
return None
573+
elif UnionType is not None and isinstance(t, UnionType):
574+
q.extend(t.__args__)
575+
elif hasattr(t, "__origin__") and t.__origin__ is Union:
576+
q.extend(t.__args__)
577+
else:
578+
member_types.append(t)
579+
return tuple(member_types)
580+
581+
def __check_can_validate_subtypes(*subtypes: Any) -> None:
582+
for s in subtypes:
583+
if not isinstance(s, type):
584+
raise ValueError(
585+
"validate(s, Type[t]) is only supported when 's' is "
586+
"an instance of 'type' or a union of instances of 'type'.\n"
587+
f"Found s = {'|'.join(str(s) for s in subtypes)}"
588+
)
589+
590+
def __check_can_validate_supertypes(*supertypes: Any) -> None:
591+
for t in supertypes:
592+
if not isinstance(t, type):
593+
raise ValueError(
594+
"validate(s, Type[t]) is only supported when 't' is "
595+
"an instance of 'type' or a union of instances of 'type'.\n"
596+
f"Found t = {'|'.join(str(t) for t in supertypes)}"
597+
)
598+
599+
def _can_validate_subtype_of(t: Any) -> bool:
600+
try:
601+
# This is the validation part of _validate_subtype:
602+
t_member_types = __extract_member_types(t)
603+
if t_member_types is not None:
604+
__check_can_validate_supertypes(*t_member_types)
605+
return True
606+
except ValueError:
607+
return False
522608

609+
def _validate_subtype_of(s: Any, t: Any) -> None:
610+
# 1. Validation:
611+
__check_can_validate_subtypes(s)
612+
t_member_types = __extract_member_types(t)
613+
if t_member_types is None:
614+
# An Any was found amongst the member types, all good.
615+
return
616+
__check_can_validate_supertypes(*t_member_types)
617+
# 2. Subtype check:
618+
if not issubclass(s, t_member_types):
619+
raise _subtype_error(s, t)
620+
# TODO: improve support for subtype checks.
523621

524622
def _extract_dtypes(t: Any) -> typing.Sequence[Any]:
525623
if t is Any:
@@ -575,8 +673,8 @@ def _validate_numpy_array(val: Any, t: Any) -> None:
575673
if isinstance(val, TypeInspector):
576674
val._record_unsupported_type(t)
577675
return
578-
raise ValueError(
579-
f"Unsupported validation for NumPy dtype {repr(dtype_t)}."
676+
raise _unsupported_type_error(
677+
t, f"Unsupported NumPy dtype {dtype_t!r}"
580678
) from None
581679
if isinstance(val, TypeInspector):
582680
val._record_pending_type_generic(t.__origin__)
@@ -669,21 +767,24 @@ def validate(val: Any, t: Any) -> Literal[True]:
669767
:param t: the type to type-check against
670768
:type t: :obj:`~typing.Any`
671769
:raises TypeError: if ``val`` is not of type ``t``
672-
:raises ValueError: if validation for type ``t`` is not supported
770+
:raises UnsupportedTypeError: if validation for type ``t`` is not supported
673771
:raises AssertionError: if things go unexpectedly wrong with ``__args__`` for parametric types
674772
675773
"""
676774
# pylint: disable = too-many-return-statements, too-many-branches, too-many-statements
677-
unsupported_type_error: Optional[ValueError] = None
775+
unsupported_type_error: Optional[UnsupportedTypeError] = None
678776
if not isinstance(t, Hashable):
679777
if isinstance(val, TypeInspector):
680778
val._record_unsupported_type(t)
681779
return True
682780
if unsupported_type_error is None:
683-
unsupported_type_error = ValueError(
684-
f"Unsupported validation for type {repr(t)}. Type is not hashable."
781+
unsupported_type_error = _unsupported_type_error(
782+
t, "Type is not hashable."
685783
) # pragma: nocover
686784
raise unsupported_type_error
785+
if t is typing.Type:
786+
# Replace non-generic 'Type' with non-generic 'type':
787+
t = type
687788
if t in _basic_types:
688789
# speed things up for the likely most common case
689790
_validate_type(val, typing.cast(type, t))
@@ -765,8 +866,8 @@ def validate(val: Any, t: Any) -> Literal[True]:
765866
if isinstance(val, TypeInspector):
766867
val._record_unsupported_type(t)
767868
return True
768-
unsupported_type_error = ValueError(
769-
f"Unsupported validation for Protocol {repr(t)}, because it is not runtime-checkable."
869+
unsupported_type_error = _unsupported_type_error(
870+
t, "Protocol class is not runtime-checkable."
770871
) # pragma: nocover
771872
elif _is_typed_dict(t):
772873
_validate_typed_dict(val, t)
@@ -788,8 +889,8 @@ def validate(val: Any, t: Any) -> Literal[True]:
788889
hint = f"Perhaps set it with validation_aliases({t_alias}=...)?"
789890
else:
790891
hint = f"Perhaps set it with validation_aliases(**{{'{t_alias}': ...}})?"
791-
unsupported_type_error = ValueError(
792-
f"Type alias '{t_alias}' is not known. {hint}"
892+
unsupported_type_error = _unsupported_type_error(
893+
t_alias, f"Type alias is not known. {hint}"
793894
) # pragma: nocover
794895
else:
795896
_validate_alias(val, t_alias)
@@ -798,15 +899,13 @@ def validate(val: Any, t: Any) -> Literal[True]:
798899
val._record_unsupported_type(t)
799900
return True
800901
if unsupported_type_error is None:
801-
unsupported_type_error = ValueError(
802-
f"Unsupported validation for type {repr(t)}."
803-
) # pragma: nocover
902+
unsupported_type_error = _unsupported_type_error(t) # pragma: nocover
804903
raise unsupported_type_error
805904

806905

807906
def can_validate(t: Any) -> TypeInspector:
808907
r"""
809-
Checks whether validation is supported for the given type ``t``: if not, :func:`validate` will raise :obj:`ValueError`.
908+
Checks whether validation is supported for the given type ``t``: if not, :func:`validate` will raise :obj:`UnsupportedTypeError`.
810909
811910
The returned :class:`TypeInspector` instance can be used wherever a boolean is expected, and will indicate whether the type is supported or not:
812911

typing_validation/validation_failure.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import sys
88
import typing
9-
from typing import Any, Mapping, Optional, TypeVar
9+
from typing import Any, Mapping, Optional, Type, TypeVar
1010

1111
if sys.version_info[1] >= 8:
1212
from typing import Protocol
@@ -452,8 +452,34 @@ def __new__(
452452

453453
def _str_main_msg(self, type_quals: tuple[str, ...] = ()) -> str:
454454
return (
455-
f"For {self._str_type_descr(type_quals)} {repr(self.t)}, "
456-
f"value is not valid for upper bound."
455+
f"For {self._str_type_descr(type_quals)} {self.t!r}, "
456+
f"value is not valid for upper bound: {self.val!r}"
457+
)
458+
459+
460+
class SubtypeValidationFailure(ValidationFailure):
461+
"""
462+
Validation failures arising from ``validate(s, Type[t])`` when ``s`` is not
463+
a subtype of ``t``.
464+
"""
465+
466+
def __new__(
467+
cls,
468+
s: Any,
469+
t: Any,
470+
*,
471+
type_aliases: Optional[Mapping[str, Any]] = None,
472+
) -> Self:
473+
# pylint: disable = too-many-arguments
474+
instance = super().__new__(cls, s, Type[t], type_aliases=type_aliases)
475+
return instance
476+
477+
def _str_main_msg(self, type_quals: tuple[str, ...] = ()) -> str:
478+
t = self.t
479+
bound_t = t.__args__[0]
480+
return (
481+
f"For {self._str_type_descr(type_quals)} {t!r}, "
482+
f"type bound is not a supertype of value: {self.val!r}"
457483
)
458484

459485

0 commit comments

Comments
 (0)