Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ jobs:
working-directory: ${{ env.ASV_DIR }}
run: |
asv machine --yes
asv run --quick --show-stderr --verbose
asv run --dry-run --quick --show-stderr --verbose HEAD^!
67 changes: 58 additions & 9 deletions benchmarks/benchmarks/sparse_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from types import MappingProxyType
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import zarr
from dask.array.core import Array as DaskArray
from scipy import sparse
Expand All @@ -12,6 +14,9 @@
from anndata._io.specs import write_elem
from anndata.experimental import read_elem_lazy

if TYPE_CHECKING:
from typing import Literal


def make_alternating_mask(n):
mask_alternating = np.ones(10_000, dtype=bool)
Expand Down Expand Up @@ -79,9 +84,12 @@ def peakmem_getitem_adata(self, *_):
res.compute()


class SparseCSRDask:
class SparseCSRDaskConcat:
filepath = "data.zarr"

params = (["inner", "outer"], [0, -1])
param_names = ("join", "fill_value")

def setup_cache(self):
X = sparse.random(
10_000,
Expand All @@ -93,18 +101,59 @@ def setup_cache(self):
g = zarr.group(self.filepath)
write_elem(g, "X", X)

def setup(self):
def setup(self, *_):
self.group = zarr.group(self.filepath)
self.adata = AnnData(X=read_elem_lazy(self.group["X"]))
self.adatas = [
AnnData(
var=pd.DataFrame(
index=[
f"gene_{j}{f'_{i}' if (j % 100 == 0) else ''}"
for j in range(10_000)
]
),
X=read_elem_lazy(self.group["X"]),
)
for i in range(10)
]

def time_concat(self, join: Literal["inner", "outer"], fill_value: Literal[0, -1]):
concat(self.adatas, join=join, fill_value=fill_value)

def peakmem_concat(
self, join: Literal["inner", "outer"], fill_value: Literal[0, -1]
):
concat(self.adatas, join=join, fill_value=fill_value)

def time_concat_with_mem(
self, join: Literal["inner", "outer"], fill_value: Literal[0, -1]
):
concat(self.adatas, join=join, fill_value=fill_value).to_memory()

def peakmem_concat_with_mem(
self, join: Literal["inner", "outer"], fill_value: Literal[0, -1]
):
concat(self.adatas, join=join, fill_value=fill_value).to_memory()

def time_concat(self):
concat([self.adata for i in range(100)])

def peakmem_concat(self):
concat([self.adata for i in range(100)])
class SparseCSRDask:
filepath = "data.zarr"

def setup_cache(self):
X = sparse.random(
10_000,
10_000,
density=0.01,
format="csr",
random_state=np.random.default_rng(42),
)
g = zarr.group(self.filepath)
write_elem(g, "X", X)

def setup(self, *_):
self.group = zarr.group(self.filepath)

def time_read(self):
def time_read(self, *_):
AnnData(X=read_elem_lazy(self.group["X"]))

def peakmem_read(self):
def peakmem_read(self, *_):
AnnData(X=read_elem_lazy(self.group["X"]))
1 change: 1 addition & 0 deletions docs/release-notes/2395.perf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Accelerate outer joins on dask-sparse matrices with unchunked minor axes in {func}`anndata.concat` {user}`ilan-gold`
33 changes: 28 additions & 5 deletions src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,29 @@ def _apply_to_df_like(self, el: pd.DataFrame | Dataset2D, *, axis, fill_value=No
def _apply_to_dask_array(self, el: DaskArray, *, axis, fill_value=None):
import dask.array as da

indexer = self.idx
is_outer = any(indexer == -1)
# Fast path for the majority of sparse matrices whose minor-axis is unchunked and is being reindexed.
# This prevents 0's from being stored explicitly in the sparse matrices when outer joining, for example (see below).
if (
is_sparse_sub := isinstance(el._meta, CSArray | CSMatrix)
and el.chunksize[minor_axis := int(el._meta.format == "csr")]
== el.shape[minor_axis]
and axis == minor_axis
and is_outer
):
return el.map_blocks(
partial(
self._apply_to_sparse,
axis=axis,
fill_value=fill_value,
keep_format=True,
),
chunks=(el.chunks[0], len(self.new_idx))
if minor_axis == 1
else (len(self.new_idx), el.chunks[1]),
meta=el._meta,
)
if fill_value is None:
fill_value = default_fill_value([el])
shape = list(el.shape)
Expand All @@ -591,12 +614,11 @@ def _apply_to_dask_array(self, el: DaskArray, *, axis, fill_value=None):
shape[axis] = len(self.new_idx)
return da.broadcast_to(fill_value, tuple(shape))

indexer = self.idx
sub_el = _subset(el, make_slice(indexer, axis, len(shape)))

if any(indexer == -1):
if is_outer:
# TODO: Remove this condition once https://github.com/dask/dask/pull/12078 is released
if isinstance(sub_el._meta, CSArray | CSMatrix) and np.isscalar(fill_value):
if is_sparse_sub and np.isscalar(fill_value):
fill_value = np.array([[fill_value]])
sub_el[make_slice(indexer == -1, axis, len(shape))] = fill_value

Expand Down Expand Up @@ -658,7 +680,7 @@ def _apply_to_array_api(
return xp.where(mask, fv, taken)

def _apply_to_sparse( # noqa: PLR0912
self, el: CSMatrix | CSArray, *, axis, fill_value=None
self, el: CSMatrix | CSArray, *, axis, fill_value=None, keep_format: bool = True
) -> CSMatrix:
if isinstance(el, CupySparseMatrix):
from cupyx.scipy import sparse
Expand Down Expand Up @@ -730,7 +752,8 @@ def _apply_to_sparse( # noqa: PLR0912

if fill_idxer is not None:
out[fill_idxer] = fill_value

if keep_format:
out = out.tocsr() if el.format == "csr" else out.tocsc()
return out

def _apply_to_awkward(self, el: AwkArray, *, axis, fill_value=None):
Expand Down
60 changes: 49 additions & 11 deletions tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,22 +1814,60 @@ def test_concat_on_var_outer_join(array_type):
_ = concat([a, b], join="outer", axis=1)


def test_concat_dask_sparse_matches_memory(join_type, merge_strategy):
@pytest.mark.parametrize("format", ["csr", "csc"])
@pytest.mark.parametrize(
"unchunked_minor_axis", [True, False], ids=["unchunked_minor", "chunked_minor"]
)
@pytest.mark.parametrize("fill_value", [0, -1])
def test_concat_dask_sparse_matches_memory(
join_type,
merge_strategy,
format: Literal["csr", "csc"],
axis_name: Literal["obs", "var"],
fill_value: Literal[-1, 0],
*,
unchunked_minor_axis: bool,
):
import dask.array as da

X = sparse.random(50, 20, density=0.5, format="csr")
X_dask = da.from_array(X, chunks=(5, 20))
var_names_1 = [f"gene_{i}" for i in range(20)]
var_names_2 = [f"gene_{i}{'_foo' if (i % 2) else ''}" for i in range(20)]
X = sparse.random(50, 20, density=0.5, format=format)
X_dask = da.from_array(
X,
chunks=(
X.shape[0] if format == "csc" else 10,
X.shape[1] if format == "csr" else 5,
)
if unchunked_minor_axis
else (5, 10),
)
off_axis_idx = int(axis_name == "obs")
concat_axis_idx = int(axis_name == "var")
off_axis = "var" if axis_name == "obs" else "obs"
axis_names_1 = [f"off_axis_{i}" for i in range(X.shape[off_axis_idx])]
axis_names_2 = [
f"off_axis_{i}{'_foo' if (i % 2) else ''}" for i in range(X.shape[off_axis_idx])
]

ad1 = AnnData(X=X, var=pd.DataFrame(index=var_names_1))
ad2 = AnnData(X=X, var=pd.DataFrame(index=var_names_2))
ad1 = AnnData(X=X, **{off_axis: pd.DataFrame(index=axis_names_1)})
ad2 = AnnData(X=X, **{off_axis: pd.DataFrame(index=axis_names_2)})

ad1_dask = AnnData(X=X_dask, var=pd.DataFrame(index=var_names_1))
ad2_dask = AnnData(X=X_dask, var=pd.DataFrame(index=var_names_2))
ad1_dask = AnnData(X=X_dask, **{off_axis: pd.DataFrame(index=axis_names_1)})
ad2_dask = AnnData(X=X_dask, **{off_axis: pd.DataFrame(index=axis_names_2)})

res_in_memory = concat([ad1, ad2], join=join_type, merge=merge_strategy)
res_dask = concat([ad1_dask, ad2_dask], join=join_type, merge=merge_strategy)
res_in_memory = concat(
[ad1, ad2],
join=join_type,
merge=merge_strategy,
axis=concat_axis_idx,
fill_value=fill_value,
)
res_dask = concat(
[ad1_dask, ad2_dask],
join=join_type,
merge=merge_strategy,
axis=concat_axis_idx,
fill_value=fill_value,
)
assert_equal(res_in_memory, res_dask)


Expand Down
Loading