Skip to content

Commit 38f9d22

Browse files
committed
fix tests
1 parent 50d9de0 commit 38f9d22

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

src/fast_array_utils/_plugins/numba_sparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def instance_class(
5858
data: NDArray[np.number[Any]],
5959
indices: NDArray[np.integer[Any]],
6060
indptr: NDArray[np.integer[Any]],
61-
shape: tuple[int, int],
61+
shape: tuple[int, int], # actually tuple[int, ...] for sparray subclasses
6262
) -> CSBase:
6363
return cls.cls((data, indices, indptr), shape, copy=False)
6464

tests/test_numpy_scipy_sparse.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pytest
88

9+
from fast_array_utils import types
910
from testing.fast_array_utils._array_type import Flags
1011

1112

@@ -15,42 +16,41 @@
1516

1617

1718
if TYPE_CHECKING:
18-
from fast_array_utils.types import CSBase
1919
from testing.fast_array_utils._array_type import ArrayType
2020

2121

2222
@numba.njit(cache=True)
23-
def mat_ndim(mat: CSBase) -> int:
23+
def mat_ndim(mat: types.CSBase) -> int:
2424
return mat.ndim
2525

2626

2727
@pytest.mark.array_type(select=Flags.Sparse, skip=Flags.Dask | Flags.Disk | Flags.Gpu)
28-
def test_ndim(array_type: ArrayType[CSBase, None]) -> None:
28+
def test_ndim(array_type: ArrayType[types.CSBase, None]) -> None:
2929
mat = array_type.random((10, 10), density=0.1)
3030
assert mat_ndim(mat) == mat.ndim
3131

3232

3333
@numba.njit(cache=True)
34-
def mat_shape(mat: CSBase) -> tuple[int, ...]:
34+
def mat_shape(mat: types.CSBase) -> tuple[int, ...]:
3535
return np.shape(mat) # type: ignore[arg-type]
3636

3737

3838
@pytest.mark.array_type(select=Flags.Sparse, skip=Flags.Dask | Flags.Disk | Flags.Gpu)
39-
def test_shape(array_type: ArrayType[CSBase, None]) -> None:
39+
def test_shape(array_type: ArrayType[types.CSBase, None]) -> None:
4040
mat = array_type.random((10, 10), density=0.1)
4141
assert mat_shape(mat) == mat.shape
4242

4343

4444
@numba.njit(cache=True)
45-
def copy_mat(mat: CSBase) -> CSBase:
45+
def copy_mat(mat: types.CSBase) -> types.CSBase:
4646
return mat.copy()
4747

4848

4949
@pytest.mark.array_type(select=Flags.Sparse, skip=Flags.Dask | Flags.Disk | Flags.Gpu)
5050
@pytest.mark.parametrize("dtype_ind", [np.int32, np.int64], ids=["i=32", "i=64"])
5151
@pytest.mark.parametrize("dtype_data", [np.int64, np.float64], ids=["d=i64", "d=f64"])
5252
def test_copy(
53-
array_type: ArrayType[CSBase, None],
53+
array_type: ArrayType[types.CSBase, None],
5454
dtype_data: type[np.int64 | np.float64],
5555
dtype_ind: type[np.int32 | np.int64],
5656
) -> None:
@@ -66,7 +66,15 @@ def test_copy(
6666
assert mat.indptr.ctypes.data != copied.indptr.ctypes.data
6767
# check that the array contents and dtypes are the same
6868
assert mat.shape == copied.shape
69-
np.testing.assert_equal(mat.toarray(), copied.toarray(), strict=True)
70-
np.testing.assert_equal(mat.data, copied.data, strict=True)
71-
np.testing.assert_equal(mat.indices, copied.indices, strict=True)
72-
np.testing.assert_equal(mat.indptr, copied.indptr, strict=True)
69+
np.testing.assert_equal(copied.toarray(), mat.toarray(), strict=True)
70+
np.testing.assert_equal(copied.data, mat.data, strict=True)
71+
np.testing.assert_equal(copied.indices, mat.indices, strict=not downcasts_idx(mat))
72+
np.testing.assert_equal(copied.indptr, mat.indptr, strict=not downcasts_idx(mat))
73+
74+
75+
def downcasts_idx(mat: types.CSBase) -> bool:
76+
"""Check if `mat`’s class downcast’s indices to 32-bit.
77+
78+
See https://github.com/scipy/scipy/pull/18509
79+
"""
80+
return isinstance(mat, types.CSMatrix)

0 commit comments

Comments
 (0)