Skip to content

Commit 50d9de0

Browse files
committed
adapt tests
1 parent c963b3c commit 50d9de0

File tree

2 files changed

+30
-26
lines changed

2 files changed

+30
-26
lines changed

src/fast_array_utils/_plugins/numba_sparse.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,24 +62,18 @@ def instance_class(
6262
) -> CSBase:
6363
return cls.cls((data, indices, indptr), shape, copy=False)
6464

65-
def __init__(
66-
self,
67-
ndim: int,
68-
*,
69-
dtype_data: nbtypes.Type,
70-
dtype_indices: nbtypes.Type,
71-
dtype_indptr: nbtypes.Type,
72-
) -> None:
73-
self.dtype = nbtypes.DType(dtype_data)
74-
self.data = nbtypes.Array(dtype_data, 1, "A")
75-
self.indices = nbtypes.Array(dtype_indices, 1, "A")
76-
self.indptr = nbtypes.Array(dtype_indptr, 1, "A")
65+
def __init__(self, ndim: int, *, dtype: nbtypes.Type, dtype_ind: nbtypes.Type) -> None:
66+
self.dtype = nbtypes.DType(dtype)
67+
self.dtype_ind = nbtypes.DType(dtype_ind)
68+
self.data = nbtypes.Array(dtype, 1, "A")
69+
self.indices = nbtypes.Array(dtype_ind, 1, "A")
70+
self.indptr = nbtypes.Array(dtype_ind, 1, "A")
7771
self.shape = nbtypes.UniTuple(nbtypes.intp, ndim)
7872
super().__init__(self.name)
7973

8074
@property
8175
def key(self) -> tuple[str | nbtypes.Type, ...]:
82-
return (self.name, self.dtype, self.indices.dtype, self.indptr.dtype)
76+
return (self.name, self.dtype, self.dtype_ind)
8377

8478

8579
# make data model attributes available in numba functions
@@ -88,13 +82,15 @@ def key(self) -> tuple[str | nbtypes.Type, ...]:
8882

8983

9084
def make_typeof_fn(typ: type[CSType]) -> Callable[[CSBase, _TypeofContext], CSType]:
85+
"""Create a `typeof` function that maps a scipy matrix/array type to a numba `Type`."""
86+
9187
def typeof(val: CSBase, c: _TypeofContext) -> CSType:
88+
if val.indptr.dtype != val.indices.dtype:
89+
msg = "indptr and indices must have the same dtype"
90+
raise TypeError(msg)
9291
data = cast("nbtypes.Array", typeof_impl(val.data, c))
93-
indices = cast("nbtypes.Array", typeof_impl(val.indices, c))
9492
indptr = cast("nbtypes.Array", typeof_impl(val.indptr, c))
95-
return typ(
96-
val.ndim, dtype_data=data.dtype, dtype_indices=indices.dtype, dtype_indptr=indptr.dtype
97-
)
93+
return typ(val.ndim, dtype=data.dtype, dtype_ind=indptr.dtype)
9894

9995
return typeof
10096

@@ -106,6 +102,11 @@ def typeof(val: CSBase, c: _TypeofContext) -> CSType:
106102

107103

108104
class CSModel(_Base):
105+
"""Numba data model for compressed sparse matrices.
106+
107+
This is the class that is used by numba to lower the array types.
108+
"""
109+
109110
def __init__(self, dmm: DataModelManager, fe_type: CSType) -> None:
110111
members = [
111112
("data", fe_type.data),
@@ -116,6 +117,7 @@ def __init__(self, dmm: DataModelManager, fe_type: CSType) -> None:
116117
super().__init__(dmm, fe_type, members)
117118

118119

120+
# create all the actual types and data models
119121
CLASSES: Sequence[type[CSBase]] = [
120122
sparse.csr_matrix,
121123
sparse.csc_matrix,

tests/test_numpy_scipy_sparse.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,24 @@ def copy_mat(mat: CSBase) -> CSBase:
4747

4848

4949
@pytest.mark.array_type(select=Flags.Sparse, skip=Flags.Dask | Flags.Disk | Flags.Gpu)
50-
@pytest.mark.parametrize("dtype_indptr", [np.int32, np.int64], ids=["p=32", "p=64"])
51-
@pytest.mark.parametrize("dtype_index", [np.int32, np.int64], ids=["i=32", "i=64"])
50+
@pytest.mark.parametrize("dtype_ind", [np.int32, np.int64], ids=["i=32", "i=64"])
5251
@pytest.mark.parametrize("dtype_data", [np.int64, np.float64], ids=["d=i64", "d=f64"])
5352
def test_copy(
5453
array_type: ArrayType[CSBase, None],
5554
dtype_data: type[np.int64 | np.float64],
56-
dtype_index: type[np.int32 | np.int64],
57-
dtype_indptr: type[np.int32 | np.int64],
55+
dtype_ind: type[np.int32 | np.int64],
5856
) -> None:
5957
mat = array_type.random((10, 10), density=0.1, dtype=dtype_data)
60-
mat.indices = mat.indices.astype(dtype_index)
61-
mat.indptr = mat.indptr.astype(dtype_indptr)
58+
mat.indptr = mat.indptr.astype(dtype_ind)
59+
mat.indices = mat.indices.astype(dtype_ind)
60+
6261
copied = copy_mat(mat)
63-
assert mat.data is not copied.data
64-
assert mat.indices is not copied.indices
65-
assert mat.indptr is not copied.indptr
62+
63+
# check that the copied arrays point to different memory locations
64+
assert mat.data.ctypes.data != copied.data.ctypes.data
65+
assert mat.indices.ctypes.data != copied.indices.ctypes.data
66+
assert mat.indptr.ctypes.data != copied.indptr.ctypes.data
67+
# check that the array contents and dtypes are the same
6668
assert mat.shape == copied.shape
6769
np.testing.assert_equal(mat.toarray(), copied.toarray(), strict=True)
6870
np.testing.assert_equal(mat.data, copied.data, strict=True)

0 commit comments

Comments
 (0)