Skip to content

Commit 5ef874a

Browse files
committed
feat: add support for nupy array shape validation
1 parent 2723eb3 commit 5ef874a

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

test/test_00_validate.py

+21
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,27 @@ def test_numpy_array_error() -> None:
695695
with pytest.raises(TypeError):
696696
validate(val, npt.NDArray[np.str_])
697697

698+
def test_numpy_array_shape() -> None:
699+
# pylint: disable = import-outside-toplevel
700+
import numpy as np
701+
val = np.zeros(5, dtype=np.uint8)
702+
validate(val, np.ndarray[typing.Any, np.dtype[np.uint8]])
703+
validate(val, np.ndarray[tuple, np.dtype[np.uint8]])
704+
validate(val, np.ndarray[tuple[typing.Any, ...], np.dtype[np.uint8]])
705+
validate(val, np.ndarray[tuple[typing.Any], np.dtype[np.uint8]])
706+
validate(val, np.ndarray[tuple[int, ...], np.dtype[np.uint8]])
707+
validate(val, np.ndarray[tuple[int], np.dtype[np.uint8]])
708+
validate(val, np.ndarray[tuple[Literal[5], ...], np.dtype[np.uint8]])
709+
validate(val, np.ndarray[tuple[Literal[5]], np.dtype[np.uint8]])
710+
with pytest.raises(TypeError):
711+
validate(val, np.ndarray[tuple[int, int], np.dtype[np.uint8]])
712+
with pytest.raises(TypeError):
713+
validate(val, np.ndarray[tuple[typing.Any, typing.Any], np.dtype[np.uint8]])
714+
with pytest.raises(TypeError):
715+
validate(val, np.ndarray[tuple[Literal[5], int], np.dtype[np.uint8]])
716+
with pytest.raises(TypeError):
717+
validate(val, np.ndarray[tuple[int, Literal[5]], np.dtype[np.uint8]])
718+
698719

699720
def test_typevar() -> None:
700721
T = typing.TypeVar("T")

typing_validation/validation.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _unsupported_type_error(
228228
"""
229229
Error for unsupported types, with optional explanation.
230230
"""
231-
msg = "Unsupported validation for type {t!r}."
231+
msg = f"Unsupported validation for type {t!r}."
232232
if explanation is not None:
233233
msg += " " + explanation
234234
return UnsupportedTypeError(msg)
@@ -658,6 +658,7 @@ def _extract_dtypes(t: Any) -> typing.Sequence[Any]:
658658

659659
def _validate_numpy_array(val: Any, t: Any) -> None:
660660
import numpy as np # pylint: disable = import-outside-toplevel
661+
661662
if not isinstance(val, TypeInspector):
662663
_validate_type(val, np.ndarray)
663664
# assert hasattr(t, "__args__"), _missing_args_msg(t)
@@ -676,7 +677,7 @@ def _validate_numpy_array(val: Any, t: Any) -> None:
676677
val._record_unsupported_type(t)
677678
return
678679
raise _unsupported_type_error(
679-
t, f"Unsupported NumPy dtype {dtype_t!r}"
680+
t, f"Unsupported NumPy dtype {dtype_t!r}."
680681
) from None
681682
if isinstance(val, TypeInspector):
682683
val._record_pending_type_generic(t.__origin__)
@@ -685,9 +686,10 @@ def _validate_numpy_array(val: Any, t: Any) -> None:
685686
validate(val, arg)
686687
return
687688
val_dtype = val.dtype
688-
if any(dtype is Any or np.issubdtype(val_dtype, dtype) for dtype in dtypes):
689-
return
690-
raise _numpy_dtype_error(val, t)
689+
if not any(dtype is Any or np.issubdtype(val_dtype, dtype) for dtype in dtypes):
690+
raise _numpy_dtype_error(val, t)
691+
validate(val.shape, t__args__[0])
692+
return
691693

692694

693695
def _validate_typevar(val: Any, t: TypeVar) -> None:

0 commit comments

Comments
 (0)