Skip to content

Commit a1643f9

Browse files
authored
prepare tests for ndim (#85)
1 parent 75506e9 commit a1643f9

File tree

1 file changed

+43
-18
lines changed

1 file changed

+43
-18
lines changed

tests/test_stats.py

+43-18
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,18 @@
2323

2424
Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray
2525

26-
DTypeIn = type[np.float32 | np.float64 | np.int32 | np.bool]
27-
DTypeOut = type[np.float32 | np.float64 | np.int64]
26+
DTypeIn = np.float32 | np.float64 | np.int32 | np.bool
27+
DTypeOut = np.float32 | np.float64 | np.int64
28+
29+
NdAndAx: TypeAlias = tuple[Literal[2], Literal[0, 1, None]]
2830

2931
class BenchFun(Protocol): # noqa: D101
3032
def __call__( # noqa: D102
3133
self,
3234
arr: CpuArray,
3335
*,
3436
axis: Literal[0, 1, None] = None,
35-
dtype: DTypeOut | None = None,
37+
dtype: type[DTypeOut] | None = None,
3638
) -> NDArray[Any] | np.number[Any] | types.DaskArray: ...
3739

3840

@@ -44,29 +46,53 @@ def __call__( # noqa: D102
4446
ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str(at)}
4547

4648

47-
@pytest.fixture(scope="session", params=[0, 1, None])
48-
def axis(request: pytest.FixtureRequest) -> Literal[0, 1, None]:
49-
return cast("Literal[0, 1, None]", request.param)
49+
@pytest.fixture(
50+
scope="session",
51+
params=[
52+
pytest.param((2, None), id="2d-all"),
53+
pytest.param((2, 0), id="2d-ax0"),
54+
pytest.param((2, 1), id="2d-ax1"),
55+
],
56+
)
57+
def ndim_and_axis(request: pytest.FixtureRequest) -> NdAndAx:
58+
return cast("NdAndAx", request.param)
59+
60+
61+
@pytest.fixture
62+
def ndim(ndim_and_axis: NdAndAx) -> Literal[2]:
63+
return ndim_and_axis[0]
64+
65+
66+
@pytest.fixture(scope="session")
67+
def axis(ndim_and_axis: NdAndAx) -> Literal[0, 1, None]:
68+
return ndim_and_axis[1]
5069

5170

5271
@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool])
53-
def dtype_in(request: pytest.FixtureRequest) -> DTypeIn:
54-
return cast("DTypeIn", request.param)
72+
def dtype_in(request: pytest.FixtureRequest) -> type[DTypeIn]:
73+
return cast("type[DTypeIn]", request.param)
5574

5675

5776
@pytest.fixture(scope="session", params=[np.float32, np.float64, None])
58-
def dtype_arg(request: pytest.FixtureRequest) -> DTypeOut | None:
59-
return cast("DTypeOut | None", request.param)
77+
def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None:
78+
return cast("type[DTypeOut] | None", request.param)
79+
80+
81+
@pytest.fixture
82+
def np_arr(dtype_in: type[DTypeIn]) -> NDArray[DTypeIn]:
83+
np_arr = cast("NDArray[DTypeIn]", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in))
84+
np_arr.flags.writeable = False
85+
return np_arr
6086

6187

6288
@pytest.mark.array_type(skip=ATS_SPARSE_DS)
6389
def test_sum(
6490
array_type: ArrayType[Array],
65-
dtype_in: DTypeIn,
66-
dtype_arg: DTypeOut | None,
91+
dtype_in: type[DTypeIn],
92+
dtype_arg: type[DTypeOut] | None,
6793
axis: Literal[0, 1, None],
94+
np_arr: NDArray[DTypeIn],
6895
) -> None:
69-
np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in)
7096
if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f":
7197
pytest.skip("CuPy sparse matrices only support floats")
7298
arr = array_type(np_arr.copy())
@@ -104,21 +130,20 @@ def test_sum(
104130

105131

106132
@pytest.mark.array_type(skip=ATS_SPARSE_DS)
107-
@pytest.mark.parametrize(("axis", "expected"), [(None, 3.5), (0, [2.5, 3.5, 4.5]), (1, [2.0, 5.0])])
108133
def test_mean(
109-
array_type: ArrayType[Array], axis: Literal[0, 1, None], expected: float | list[float]
134+
array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn]
110135
) -> None:
111-
np_arr = np.array([[1, 2, 3], [4, 5, 6]])
112136
if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f":
113137
pytest.skip("CuPy sparse matrices only support floats")
114-
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), expected)
115-
116138
arr = array_type(np_arr)
139+
117140
result = stats.mean(arr, axis=axis) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777
118141
if isinstance(result, types.DaskArray):
119142
result = result.compute()
120143
if isinstance(result, types.CupyArray | types.CupyCSMatrix):
121144
result = result.get()
145+
146+
expected = np.mean(np_arr, axis=axis) # type: ignore[arg-type]
122147
np.testing.assert_array_equal(result, expected)
123148

124149

0 commit comments

Comments
 (0)