Skip to content

Commit 753f058

Browse files
authored
chore: add sparse index dtype checks to assert_equal (#2362)
1 parent ef466f2 commit 753f058

2 files changed

Lines changed: 20 additions & 0 deletions

File tree

src/anndata/tests/helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,11 @@ def assert_equal_sparse(
688688
exact: bool = False,
689689
elem_name: str | None = None,
690690
):
691+
if exact and sparse.issparse(b) and hasattr(a, "indptr") and hasattr(b, "indptr"):
692+
assert a.indptr.dtype == b.indptr.dtype, f"{elem_name}: indptr dtype mismatch"
693+
assert a.indices.dtype == b.indices.dtype, (
694+
f"{elem_name}: indices dtype mismatch"
695+
)
691696
a = asarray(a)
692697
assert_equal(b, a, exact=exact, elem_name=elem_name)
693698

tests/test_helpers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,21 @@ def test_assert_equal_dask_arrays():
248248
assert_equal(c, d)
249249

250250

251+
@pytest.mark.parametrize("attr", ["indices", "indptr"])
252+
def test_assert_equal_sparse_index_dtype(attr):
253+
"""assert_equal(exact=True) should detect indptr/indices dtype mismatches."""
254+
a = sparse.csr_matrix(np.eye(3))
255+
b = sparse.csr_matrix(np.eye(3))
256+
setattr(b, attr, getattr(b, attr).astype(np.int64))
257+
258+
# Non-exact comparison should pass (values are identical)
259+
assert_equal(a, b, exact=False)
260+
261+
# Exact comparison should catch the dtype mismatch
262+
with pytest.raises(AssertionError, match=attr):
263+
assert_equal(a, b, exact=True)
264+
265+
251266
def test_assert_equal_dask_sparse_arrays():
252267
import dask.array as da
253268
from scipy import sparse

0 commit comments

Comments
 (0)