Skip to content

Commit aae79ef

Browse files
authored
perf: fast path for unchunked minor axis CSR dask sparse reindexing when concating (#2395)
1 parent 3aefe2e commit aae79ef

5 files changed

Lines changed: 137 additions & 26 deletions

File tree

.github/workflows/benchmark.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,4 @@ jobs:
6464
working-directory: ${{ env.ASV_DIR }}
6565
run: |
6666
asv machine --yes
67-
asv run --quick --show-stderr --verbose
67+
asv run --dry-run --quick --show-stderr --verbose HEAD^!

benchmarks/benchmarks/sparse_dataset.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
from types import MappingProxyType
4+
from typing import TYPE_CHECKING
45

56
import numpy as np
7+
import pandas as pd
68
import zarr
79
from dask.array.core import Array as DaskArray
810
from scipy import sparse
@@ -12,6 +14,9 @@
1214
from anndata._io.specs import write_elem
1315
from anndata.experimental import read_elem_lazy
1416

17+
if TYPE_CHECKING:
18+
from typing import Literal
19+
1520

1621
def make_alternating_mask(n):
1722
mask_alternating = np.ones(10_000, dtype=bool)
@@ -79,9 +84,12 @@ def peakmem_getitem_adata(self, *_):
7984
res.compute()
8085

8186

82-
class SparseCSRDask:
87+
class SparseCSRDaskConcat:
8388
filepath = "data.zarr"
8489

90+
params = (["inner", "outer"], [0, -1])
91+
param_names = ("join", "fill_value")
92+
8593
def setup_cache(self):
8694
X = sparse.random(
8795
10_000,
@@ -93,18 +101,59 @@ def setup_cache(self):
93101
g = zarr.group(self.filepath)
94102
write_elem(g, "X", X)
95103

96-
def setup(self):
104+
def setup(self, *_):
97105
self.group = zarr.group(self.filepath)
98-
self.adata = AnnData(X=read_elem_lazy(self.group["X"]))
106+
self.adatas = [
107+
AnnData(
108+
var=pd.DataFrame(
109+
index=[
110+
f"gene_{j}{f'_{i}' if (j % 100 == 0) else ''}"
111+
for j in range(10_000)
112+
]
113+
),
114+
X=read_elem_lazy(self.group["X"]),
115+
)
116+
for i in range(10)
117+
]
118+
119+
def time_concat(self, join: Literal["inner", "outer"], fill_value: Literal[0, -1]):
120+
concat(self.adatas, join=join, fill_value=fill_value)
121+
122+
def peakmem_concat(
123+
self, join: Literal["inner", "outer"], fill_value: Literal[0, -1]
124+
):
125+
concat(self.adatas, join=join, fill_value=fill_value)
126+
127+
def time_concat_with_mem(
128+
self, join: Literal["inner", "outer"], fill_value: Literal[0, -1]
129+
):
130+
concat(self.adatas, join=join, fill_value=fill_value).to_memory()
131+
132+
def peakmem_concat_with_mem(
133+
self, join: Literal["inner", "outer"], fill_value: Literal[0, -1]
134+
):
135+
concat(self.adatas, join=join, fill_value=fill_value).to_memory()
99136

100-
def time_concat(self):
101-
concat([self.adata for i in range(100)])
102137

103-
def peakmem_concat(self):
104-
concat([self.adata for i in range(100)])
138+
class SparseCSRDask:
139+
filepath = "data.zarr"
140+
141+
def setup_cache(self):
142+
X = sparse.random(
143+
10_000,
144+
10_000,
145+
density=0.01,
146+
format="csr",
147+
random_state=np.random.default_rng(42),
148+
)
149+
g = zarr.group(self.filepath)
150+
write_elem(g, "X", X)
151+
152+
def setup(self, *_):
153+
self.group = zarr.group(self.filepath)
105154

106-
def time_read(self):
155+
def time_read(self, *_):
107156
AnnData(X=read_elem_lazy(self.group["X"]))
108157

109-
def peakmem_read(self):
158+
def peakmem_read(self, *_):
110159
AnnData(X=read_elem_lazy(self.group["X"]))

docs/release-notes/2395.perf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Accelerate outer joins on dask-sparse matrices with unchunked minor axes in {func}`anndata.concat` {user}`ilan-gold`

src/anndata/_core/merge.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,29 @@ def _apply_to_df_like(self, el: pd.DataFrame | Dataset2D, *, axis, fill_value=No
583583
def _apply_to_dask_array(self, el: DaskArray, *, axis, fill_value=None):
584584
import dask.array as da
585585

586+
indexer = self.idx
587+
is_outer = any(indexer == -1)
588+
# Fast path for the majority of sparse matrices whose minor-axis is unchunked and is being reindexed.
589+
# This prevents 0's from being stored explicitly in the sparse matrices when outer joining, for example (see below).
590+
if (
591+
is_sparse_sub := isinstance(el._meta, CSArray | CSMatrix)
592+
and el.chunksize[minor_axis := int(el._meta.format == "csr")]
593+
== el.shape[minor_axis]
594+
and axis == minor_axis
595+
and is_outer
596+
):
597+
return el.map_blocks(
598+
partial(
599+
self._apply_to_sparse,
600+
axis=axis,
601+
fill_value=fill_value,
602+
keep_format=True,
603+
),
604+
chunks=(el.chunks[0], len(self.new_idx))
605+
if minor_axis == 1
606+
else (len(self.new_idx), el.chunks[1]),
607+
meta=el._meta,
608+
)
586609
if fill_value is None:
587610
fill_value = default_fill_value([el])
588611
shape = list(el.shape)
@@ -591,12 +614,11 @@ def _apply_to_dask_array(self, el: DaskArray, *, axis, fill_value=None):
591614
shape[axis] = len(self.new_idx)
592615
return da.broadcast_to(fill_value, tuple(shape))
593616

594-
indexer = self.idx
595617
sub_el = _subset(el, make_slice(indexer, axis, len(shape)))
596618

597-
if any(indexer == -1):
619+
if is_outer:
598620
# TODO: Remove this condition once https://github.com/dask/dask/pull/12078 is released
599-
if isinstance(sub_el._meta, CSArray | CSMatrix) and np.isscalar(fill_value):
621+
if is_sparse_sub and np.isscalar(fill_value):
600622
fill_value = np.array([[fill_value]])
601623
sub_el[make_slice(indexer == -1, axis, len(shape))] = fill_value
602624

@@ -658,7 +680,7 @@ def _apply_to_array_api(
658680
return xp.where(mask, fv, taken)
659681

660682
def _apply_to_sparse( # noqa: PLR0912
661-
self, el: CSMatrix | CSArray, *, axis, fill_value=None
683+
self, el: CSMatrix | CSArray, *, axis, fill_value=None, keep_format: bool = True
662684
) -> CSMatrix:
663685
if isinstance(el, CupySparseMatrix):
664686
from cupyx.scipy import sparse
@@ -730,7 +752,8 @@ def _apply_to_sparse( # noqa: PLR0912
730752

731753
if fill_idxer is not None:
732754
out[fill_idxer] = fill_value
733-
755+
if keep_format:
756+
out = out.tocsr() if el.format == "csr" else out.tocsc()
734757
return out
735758

736759
def _apply_to_awkward(self, el: AwkArray, *, axis, fill_value=None):

tests/test_concatenate.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,22 +1814,60 @@ def test_concat_on_var_outer_join(array_type):
18141814
_ = concat([a, b], join="outer", axis=1)
18151815

18161816

1817-
def test_concat_dask_sparse_matches_memory(join_type, merge_strategy):
1817+
@pytest.mark.parametrize("format", ["csr", "csc"])
1818+
@pytest.mark.parametrize(
1819+
"unchunked_minor_axis", [True, False], ids=["unchunked_minor", "chunked_minor"]
1820+
)
1821+
@pytest.mark.parametrize("fill_value", [0, -1])
1822+
def test_concat_dask_sparse_matches_memory(
1823+
join_type,
1824+
merge_strategy,
1825+
format: Literal["csr", "csc"],
1826+
axis_name: Literal["obs", "var"],
1827+
fill_value: Literal[-1, 0],
1828+
*,
1829+
unchunked_minor_axis: bool,
1830+
):
18181831
import dask.array as da
18191832

1820-
X = sparse.random(50, 20, density=0.5, format="csr")
1821-
X_dask = da.from_array(X, chunks=(5, 20))
1822-
var_names_1 = [f"gene_{i}" for i in range(20)]
1823-
var_names_2 = [f"gene_{i}{'_foo' if (i % 2) else ''}" for i in range(20)]
1833+
X = sparse.random(50, 20, density=0.5, format=format)
1834+
X_dask = da.from_array(
1835+
X,
1836+
chunks=(
1837+
X.shape[0] if format == "csc" else 10,
1838+
X.shape[1] if format == "csr" else 5,
1839+
)
1840+
if unchunked_minor_axis
1841+
else (5, 10),
1842+
)
1843+
off_axis_idx = int(axis_name == "obs")
1844+
concat_axis_idx = int(axis_name == "var")
1845+
off_axis = "var" if axis_name == "obs" else "obs"
1846+
axis_names_1 = [f"off_axis_{i}" for i in range(X.shape[off_axis_idx])]
1847+
axis_names_2 = [
1848+
f"off_axis_{i}{'_foo' if (i % 2) else ''}" for i in range(X.shape[off_axis_idx])
1849+
]
18241850

1825-
ad1 = AnnData(X=X, var=pd.DataFrame(index=var_names_1))
1826-
ad2 = AnnData(X=X, var=pd.DataFrame(index=var_names_2))
1851+
ad1 = AnnData(X=X, **{off_axis: pd.DataFrame(index=axis_names_1)})
1852+
ad2 = AnnData(X=X, **{off_axis: pd.DataFrame(index=axis_names_2)})
18271853

1828-
ad1_dask = AnnData(X=X_dask, var=pd.DataFrame(index=var_names_1))
1829-
ad2_dask = AnnData(X=X_dask, var=pd.DataFrame(index=var_names_2))
1854+
ad1_dask = AnnData(X=X_dask, **{off_axis: pd.DataFrame(index=axis_names_1)})
1855+
ad2_dask = AnnData(X=X_dask, **{off_axis: pd.DataFrame(index=axis_names_2)})
18301856

1831-
res_in_memory = concat([ad1, ad2], join=join_type, merge=merge_strategy)
1832-
res_dask = concat([ad1_dask, ad2_dask], join=join_type, merge=merge_strategy)
1857+
res_in_memory = concat(
1858+
[ad1, ad2],
1859+
join=join_type,
1860+
merge=merge_strategy,
1861+
axis=concat_axis_idx,
1862+
fill_value=fill_value,
1863+
)
1864+
res_dask = concat(
1865+
[ad1_dask, ad2_dask],
1866+
join=join_type,
1867+
merge=merge_strategy,
1868+
axis=concat_axis_idx,
1869+
fill_value=fill_value,
1870+
)
18331871
assert_equal(res_in_memory, res_dask)
18341872

18351873

0 commit comments

Comments
 (0)