@@ -695,6 +695,27 @@ def test_numpy_array_error() -> None:
695
695
with pytest .raises (TypeError ):
696
696
validate (val , npt .NDArray [np .str_ ])
697
697
698
+ def test_numpy_array_shape () -> None :
699
+ # pylint: disable = import-outside-toplevel
700
+ import numpy as np
701
+ val = np .zeros (5 , dtype = np .uint8 )
702
+ validate (val , np .ndarray [typing .Any , np .dtype [np .uint8 ]])
703
+ validate (val , np .ndarray [tuple , np .dtype [np .uint8 ]])
704
+ validate (val , np .ndarray [tuple [typing .Any , ...], np .dtype [np .uint8 ]])
705
+ validate (val , np .ndarray [tuple [typing .Any ], np .dtype [np .uint8 ]])
706
+ validate (val , np .ndarray [tuple [int , ...], np .dtype [np .uint8 ]])
707
+ validate (val , np .ndarray [tuple [int ], np .dtype [np .uint8 ]])
708
+ validate (val , np .ndarray [tuple [Literal [5 ], ...], np .dtype [np .uint8 ]])
709
+ validate (val , np .ndarray [tuple [Literal [5 ]], np .dtype [np .uint8 ]])
710
+ with pytest .raises (TypeError ):
711
+ validate (val , np .ndarray [tuple [int , int ], np .dtype [np .uint8 ]])
712
+ with pytest .raises (TypeError ):
713
+ validate (val , np .ndarray [tuple [typing .Any , typing .Any ], np .dtype [np .uint8 ]])
714
+ with pytest .raises (TypeError ):
715
+ validate (val , np .ndarray [tuple [Literal [5 ], int ], np .dtype [np .uint8 ]])
716
+ with pytest .raises (TypeError ):
717
+ validate (val , np .ndarray [tuple [int , Literal [5 ]], np .dtype [np .uint8 ]])
718
+
698
719
699
720
def test_typevar () -> None :
700
721
T = typing .TypeVar ("T" )
0 commit comments