Skip to content

Commit e6fbb70

Browse files
authored
feat: Support sparrays in dask (#2087)
1 parent 6feb6d3 commit e6fbb70

8 files changed

Lines changed: 82 additions & 64 deletions

File tree

docs/release-notes/2087.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add support for Dask {class}`dask.array.Array`s containing {class}`~scipy.sparse.sparray`s {user}`flying-sheep`

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ test-min = [
101101
"pyarrow<21", # https://github.com/scikit-hep/awkward/issues/3579
102102
"anndata[dask]",
103103
]
104-
test = [ "anndata[test-min,lazy]" ]
104+
test = [ "anndata[test-min,lazy]", "fast-array-utils>=1.2.3" ]
105105
gpu = [ "cupy" ]
106106
cu12 = [ "cupy-cuda12x" ]
107107
cu11 = [ "cupy-cuda11x" ]

src/anndata/_core/merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def equal_dask_array(a, b) -> bool:
146146
return False
147147
if isinstance(b, DaskArray) and tokenize(a) == tokenize(b):
148148
return True
149-
if isinstance(a._meta, CSMatrix):
149+
if isinstance(a._meta, CSMatrix | CSArray):
150150
# TODO: Maybe also do this in the other case?
151151
return da.map_blocks(equal, a, b, drop_axis=(0, 1)).all()
152152
else:

src/anndata/_io/specs/methods.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,8 @@ def chunk_slice(start: int, stop: int) -> tuple[slice | None, slice | None]:
876876
for scipy_sparse_type, spec in [
877877
(sparse.csr_matrix, IOSpec("csr_matrix", "0.1.0")),
878878
(sparse.csc_matrix, IOSpec("csc_matrix", "0.1.0")),
879+
(sparse.csr_array, IOSpec("csr_matrix", "0.1.0")),
880+
(sparse.csc_array, IOSpec("csc_matrix", "0.1.0")),
879881
]:
880882
_REGISTRY.register_write(group_type, (array_type, scipy_sparse_type), spec)(
881883
write_dask_sparse

src/anndata/compat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import h5py
1212
import numpy as np
1313
import pandas as pd
14-
import scipy
14+
import scipy.sparse
1515
from numpy.typing import NDArray
1616
from packaging.version import Version
1717
from zarr import Array as ZarrArray # noqa: F401

src/anndata/tests/helpers.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import Counter, defaultdict
77
from collections.abc import Mapping
88
from functools import partial, singledispatch, wraps
9+
from importlib.metadata import version
910
from importlib.util import find_spec
1011
from string import ascii_letters
1112
from typing import TYPE_CHECKING
@@ -14,6 +15,7 @@
1415
import numpy as np
1516
import pandas as pd
1617
import pytest
18+
from packaging.version import Version
1719
from pandas.api.types import is_numeric_dtype
1820
from scipy import sparse
1921

@@ -61,6 +63,14 @@
6163
*(pd.UInt8Dtype, pd.UInt16Dtype, pd.UInt32Dtype, pd.UInt64Dtype),
6264
)
6365

66+
try:
67+
import fast_array_utils as _
68+
except ImportError:
69+
# dask natively supports sparray since https://github.com/dask/dask/pull/11750
70+
DASK_CAN_SPARRAY = Version(version("dask")) >= Version("2025.3.0")
71+
else: # fast-array-utils monkeypatches dask to support sparrays
72+
DASK_CAN_SPARRAY = True
73+
6474

6575
DEFAULT_KEY_TYPES = (
6676
sparse.csr_matrix,
@@ -628,8 +638,9 @@ def assert_equal_arrayview(
628638

629639
@assert_equal.register(BaseCompressedSparseDataset)
630640
@assert_equal.register(sparse.spmatrix)
641+
@assert_equal.register(CSArray)
631642
def assert_equal_sparse(
632-
a: BaseCompressedSparseDataset | sparse.spmatrix,
643+
a: BaseCompressedSparseDataset | sparse.spmatrix | CSArray,
633644
b: object,
634645
*,
635646
exact: bool = False,
@@ -639,13 +650,6 @@ def assert_equal_sparse(
639650
assert_equal(b, a, exact=exact, elem_name=elem_name)
640651

641652

642-
@assert_equal.register(CSArray)
643-
def assert_equal_sparse_array(
644-
a: CSArray, b: object, *, exact: bool = False, elem_name: str | None = None
645-
):
646-
return assert_equal_sparse(a, b, exact=exact, elem_name=elem_name)
647-
648-
649653
@assert_equal.register(CupySparseMatrix)
650654
def assert_equal_cupy_sparse(
651655
a: CupySparseMatrix, b: object, *, exact: bool = False, elem_name: str | None = None
@@ -878,29 +882,37 @@ def _(a):
878882

879883

880884
@singledispatch
881-
def as_sparse_dask_array(a) -> DaskArray:
882-
import dask.array as da
883-
884-
return da.from_array(sparse.csr_matrix(a), chunks=_half_chunk_size(a.shape))
885+
def _as_sparse_dask(
886+
a: NDArray | CSArray | CSMatrix | DaskArray, *, typ: type[CSArray | CSMatrix]
887+
) -> DaskArray:
888+
"""Convert a to a sparse dask array, preserving sparse format and container (`cs{rc}_{array,matrix}`)."""
889+
raise NotImplementedError
885890

886891

887-
@as_sparse_dask_array.register(CSMatrix)
888-
def _(a):
892+
@_as_sparse_dask.register(CSArray | CSMatrix | np.ndarray)
893+
def _(a: CSArray | CSMatrix | NDArray, *, typ: type[CSArray | CSMatrix]) -> DaskArray:
889894
import dask.array as da
890895

891-
return da.from_array(a, _half_chunk_size(a.shape))
896+
return da.from_array(_as_sparse_dask_inner(a, typ=typ), _half_chunk_size(a.shape))
892897

893898

894-
@as_sparse_dask_array.register(CSArray)
895-
def _(a):
896-
import dask.array as da
899+
@_as_sparse_dask.register(DaskArray)
900+
def _(a: DaskArray, *, typ: type[CSArray | CSMatrix]) -> DaskArray:
901+
return a.map_blocks(_as_sparse_dask_inner, typ=typ, dtype=a.dtype, meta=typ((2, 2)))
897902

898-
return da.from_array(sparse.csr_matrix(a), _half_chunk_size(a.shape))
899903

904+
def _as_sparse_dask_inner(
905+
a: NDArray | CSArray | CSMatrix, *, typ: type[CSArray | CSMatrix]
906+
) -> CSArray | CSMatrix:
907+
"""Convert into a a sparse container that dask supports (or complain)."""
908+
if issubclass(typ, CSArray) and not DASK_CAN_SPARRAY: # convert sparray to spmatrix
909+
msg = "Dask <2025.3 without fast-array-utils doesn’t support sparse arrays"
910+
raise TypeError(msg)
911+
return typ(a)
900912

901-
@as_sparse_dask_array.register(DaskArray)
902-
def _(a):
903-
return a.map_blocks(sparse.csr_matrix)
913+
914+
as_sparse_dask_array = partial(_as_sparse_dask, typ=sparse.csr_array)
915+
as_sparse_dask_matrix = partial(_as_sparse_dask, typ=sparse.csr_matrix)
904916

905917

906918
@singledispatch
@@ -949,11 +961,8 @@ def _(a):
949961
# We should try and fix this upstream in dask/ cupy
950962
@singledispatch
951963
def as_cupy_sparse_dask_array(a, format="csr"):
952-
memory_class = format_to_memory_class[format]
953-
cpu_da = as_sparse_dask_array(a)
954-
return cpu_da.rechunk((cpu_da.chunks[0], -1)).map_blocks(
955-
memory_class, dtype=a.dtype, meta=memory_class(cpu_da._meta)
956-
)
964+
da = _as_sparse_dask(a, typ=format_to_memory_class[format])
965+
return da.rechunk((da.chunks[0], -1))
957966

958967

959968
@as_cupy_sparse_dask_array.register(CupyArray)
@@ -1003,7 +1012,7 @@ def as_cupy(val, typ=None):
10031012
if issubclass(typ, CupyArray):
10041013
import cupy as cp
10051014

1006-
if isinstance(val, CSMatrix):
1015+
if isinstance(val, CSMatrix | CSArray):
10071016
val = val.toarray()
10081017
return cp.array(val)
10091018
elif issubclass(typ, CupyCSRMatrix):
@@ -1059,7 +1068,14 @@ def shares_memory_sparse(x, y):
10591068

10601069
DASK_MATRIX_PARAMS = [
10611070
pytest.param(as_dense_dask_array, id="dense_dask_array"),
1062-
pytest.param(as_sparse_dask_array, id="sparse_dask_array"),
1071+
pytest.param(as_sparse_dask_matrix, id="sparse_dask_matrix"),
1072+
pytest.param(
1073+
as_sparse_dask_array,
1074+
marks=pytest.mark.skipif(
1075+
not DASK_CAN_SPARRAY, reason="Dask does not support sparrays"
1076+
),
1077+
id="sparse_dask_array",
1078+
),
10631079
]
10641080

10651081
CUPY_MATRIX_PARAMS = [

tests/test_dask.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,37 @@
44

55
from __future__ import annotations
66

7+
from pathlib import Path
78
from typing import TYPE_CHECKING
89

910
import numpy as np
1011
import pandas as pd
1112
import pytest
12-
from scipy import sparse
1313

1414
import anndata as ad
1515
from anndata._core.anndata import AnnData
16-
from anndata.compat import CupyArray, DaskArray
16+
from anndata.compat import CSArray, CSMatrix, CupyArray, DaskArray
1717
from anndata.experimental.merge import as_group
1818
from anndata.tests.helpers import (
19+
BASE_MATRIX_PARAMS,
20+
DASK_CAN_SPARRAY,
21+
DASK_MATRIX_PARAMS,
1922
GEN_ADATA_DASK_ARGS,
20-
as_cupy_sparse_dask_array,
2123
as_dense_cupy_dask_array,
2224
as_dense_dask_array,
2325
as_sparse_dask_array,
26+
as_sparse_dask_matrix,
2427
assert_equal,
2528
gen_adata,
2629
)
2730

2831
if TYPE_CHECKING:
32+
from collections.abc import Callable
2933
from pathlib import Path
3034
from typing import Literal
3135

36+
from numpy.typing import NDArray
37+
3238

3339
pytest.importorskip("dask.array")
3440

@@ -276,13 +282,18 @@ def test_assign_X(adata):
276282
@pytest.mark.parametrize(
277283
("array_func", "mem_type"),
278284
[
279-
pytest.param(as_dense_dask_array, np.ndarray, id="dense_dask_array"),
280-
pytest.param(as_sparse_dask_array, sparse.csr_matrix, id="sparse_dask_array"),
285+
pytest.param(as_dense_dask_array, np.ndarray, id="dense"),
286+
pytest.param(as_sparse_dask_matrix, CSMatrix, id="sparse_matrix"),
287+
pytest.param(
288+
as_sparse_dask_array,
289+
CSArray,
290+
marks=pytest.mark.skipif(
291+
not DASK_CAN_SPARRAY, reason="Dask does not support sparrays"
292+
),
293+
id="sparse_array",
294+
),
281295
pytest.param(
282-
as_dense_cupy_dask_array,
283-
CupyArray,
284-
id="cupy_dense_dask_array",
285-
marks=pytest.mark.gpu,
296+
as_dense_cupy_dask_array, CupyArray, id="cupy_dense", marks=pytest.mark.gpu
286297
),
287298
],
288299
)
@@ -311,29 +322,17 @@ def test_dask_to_memory_unbacked(array_func, mem_type):
311322
assert isinstance(orig.uns["da"]["da"], DaskArray)
312323

313324

314-
@pytest.mark.parametrize(
315-
"array_func",
316-
[
317-
pytest.param(as_dense_dask_array, id="dense_dask_array"),
318-
pytest.param(as_sparse_dask_array, id="sparse_dask_array"),
319-
pytest.param(
320-
as_dense_cupy_dask_array,
321-
id="cupy_dense_dask_array",
322-
marks=pytest.mark.gpu,
323-
),
324-
pytest.param(
325-
as_cupy_sparse_dask_array,
326-
id="cupy_sparse_dask_array",
327-
marks=pytest.mark.gpu,
328-
),
329-
],
330-
)
331-
def test_dask_to_disk_view(array_func, diskfmt, tmp_path):
325+
@pytest.mark.parametrize("to_dask", [*BASE_MATRIX_PARAMS, *DASK_MATRIX_PARAMS])
326+
def test_dask_to_disk_view(
327+
to_dask: Callable[[NDArray], DaskArray],
328+
diskfmt: Literal["h5ad", "zarr"],
329+
tmp_path: Path,
330+
) -> None:
332331
random_state = np.random.default_rng()
333-
orig = ad.AnnData(
334-
# need to change type for cupy
335-
array_func(random_state.binomial(100, 0.005, (20, 15)).astype("float32"))
336-
)
332+
arr = random_state.binomial(100, 0.005, (20, 15)).astype("float32")
333+
334+
# TODO: need to change type for cupy
335+
orig = ad.AnnData(to_dask(arr))
337336
orig = orig[orig.shape[0] // 2]
338337
path = tmp_path / f"test.{diskfmt}"
339338
getattr(orig, f"write_{diskfmt}")(path)

tests/test_io_elementwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def test_io_spec_compressed_scalars(store: G, value: np.ndarray, encoding_type:
252252
@pytest.mark.parametrize("as_dask", [False, True])
253253
def test_io_spec_cupy(store, value, encoding_type, as_dask):
254254
if as_dask:
255-
if isinstance(value, CSMatrix):
255+
if isinstance(value, CSMatrix | CSArray):
256256
value = as_cupy_sparse_dask_array(value, format=encoding_type[:3])
257257
else:
258258
value = as_dense_cupy_dask_array(value)

0 commit comments

Comments
 (0)