23
23
24
24
Array : TypeAlias = CpuArray | GpuArray | DiskArray | types .CSDataset | types .DaskArray
25
25
26
- DTypeIn = type [np .float32 | np .float64 | np .int32 | np .bool ]
27
- DTypeOut = type [np .float32 | np .float64 | np .int64 ]
26
+ DTypeIn = np .float32 | np .float64 | np .int32 | np .bool
27
+ DTypeOut = np .float32 | np .float64 | np .int64
28
+
29
+ NdAndAx : TypeAlias = tuple [Literal [2 ], Literal [0 , 1 , None ]]
28
30
29
31
class BenchFun (Protocol ): # noqa: D101
30
32
def __call__ ( # noqa: D102
31
33
self ,
32
34
arr : CpuArray ,
33
35
* ,
34
36
axis : Literal [0 , 1 , None ] = None ,
35
- dtype : DTypeOut | None = None ,
37
+ dtype : type [ DTypeOut ] | None = None ,
36
38
) -> NDArray [Any ] | np .number [Any ] | types .DaskArray : ...
37
39
38
40
@@ -44,29 +46,53 @@ def __call__( # noqa: D102
44
46
ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str (at )}
45
47
46
48
47
- @pytest .fixture (scope = "session" , params = [0 , 1 , None ])
48
- def axis (request : pytest .FixtureRequest ) -> Literal [0 , 1 , None ]:
49
- return cast ("Literal[0, 1, None]" , request .param )
49
+ @pytest .fixture (
50
+ scope = "session" ,
51
+ params = [
52
+ pytest .param ((2 , None ), id = "2d-all" ),
53
+ pytest .param ((2 , 0 ), id = "2d-ax0" ),
54
+ pytest .param ((2 , 1 ), id = "2d-ax1" ),
55
+ ],
56
+ )
57
+ def ndim_and_axis (request : pytest .FixtureRequest ) -> NdAndAx :
58
+ return cast ("NdAndAx" , request .param )
59
+
60
+
61
+ @pytest .fixture
62
+ def ndim (ndim_and_axis : NdAndAx ) -> Literal [2 ]:
63
+ return ndim_and_axis [0 ]
64
+
65
+
66
+ @pytest .fixture (scope = "session" )
67
+ def axis (ndim_and_axis : NdAndAx ) -> Literal [0 , 1 , None ]:
68
+ return ndim_and_axis [1 ]
50
69
51
70
52
71
@pytest .fixture (scope = "session" , params = [np .float32 , np .float64 , np .int32 , np .bool ])
53
- def dtype_in (request : pytest .FixtureRequest ) -> DTypeIn :
54
- return cast ("DTypeIn" , request .param )
72
+ def dtype_in (request : pytest .FixtureRequest ) -> type [ DTypeIn ] :
73
+ return cast ("type[ DTypeIn] " , request .param )
55
74
56
75
57
76
@pytest .fixture (scope = "session" , params = [np .float32 , np .float64 , None ])
58
- def dtype_arg (request : pytest .FixtureRequest ) -> DTypeOut | None :
59
- return cast ("DTypeOut | None" , request .param )
77
+ def dtype_arg (request : pytest .FixtureRequest ) -> type [DTypeOut ] | None :
78
+ return cast ("type[DTypeOut] | None" , request .param )
79
+
80
+
81
+ @pytest .fixture
82
+ def np_arr (dtype_in : type [DTypeIn ]) -> NDArray [DTypeIn ]:
83
+ np_arr = cast ("NDArray[DTypeIn]" , np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype_in ))
84
+ np_arr .flags .writeable = False
85
+ return np_arr
60
86
61
87
62
88
@pytest .mark .array_type (skip = ATS_SPARSE_DS )
63
89
def test_sum (
64
90
array_type : ArrayType [Array ],
65
- dtype_in : DTypeIn ,
66
- dtype_arg : DTypeOut | None ,
91
+ dtype_in : type [ DTypeIn ] ,
92
+ dtype_arg : type [ DTypeOut ] | None ,
67
93
axis : Literal [0 , 1 , None ],
94
+ np_arr : NDArray [DTypeIn ],
68
95
) -> None :
69
- np_arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype_in )
70
96
if array_type in ATS_CUPY_SPARSE and np_arr .dtype .kind != "f" :
71
97
pytest .skip ("CuPy sparse matrices only support floats" )
72
98
arr = array_type (np_arr .copy ())
@@ -104,21 +130,20 @@ def test_sum(
104
130
105
131
106
132
@pytest .mark .array_type (skip = ATS_SPARSE_DS )
107
- @pytest .mark .parametrize (("axis" , "expected" ), [(None , 3.5 ), (0 , [2.5 , 3.5 , 4.5 ]), (1 , [2.0 , 5.0 ])])
108
133
def test_mean (
109
- array_type : ArrayType [Array ], axis : Literal [0 , 1 , None ], expected : float | list [ float ]
134
+ array_type : ArrayType [Array ], axis : Literal [0 , 1 , None ], np_arr : NDArray [ DTypeIn ]
110
135
) -> None :
111
- np_arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
112
136
if array_type in ATS_CUPY_SPARSE and np_arr .dtype .kind != "f" :
113
137
pytest .skip ("CuPy sparse matrices only support floats" )
114
- np .testing .assert_array_equal (np .mean (np_arr , axis = axis ), expected )
115
-
116
138
arr = array_type (np_arr )
139
+
117
140
result = stats .mean (arr , axis = axis ) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777
118
141
if isinstance (result , types .DaskArray ):
119
142
result = result .compute ()
120
143
if isinstance (result , types .CupyArray | types .CupyCSMatrix ):
121
144
result = result .get ()
145
+
146
+ expected = np .mean (np_arr , axis = axis ) # type: ignore[arg-type]
122
147
np .testing .assert_array_equal (result , expected )
123
148
124
149
0 commit comments