Skip to content

Commit 58240e5

Browse files
(feat): sc.get.aggregate via dask (#3700)
Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent 77dd43c commit 58240e5

8 files changed

Lines changed: 270 additions & 83 deletions

File tree

docs/release-notes/3700.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make {func}`scanpy.get.aggregate` `dask` compatible with all aggregations except median. {smaller}`I Gold`

src/scanpy/get/_aggregated.py

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from __future__ import annotations
22

3-
from functools import singledispatch
4-
from typing import TYPE_CHECKING, Literal
3+
from functools import partial, singledispatch
4+
from typing import TYPE_CHECKING, Literal, TypedDict, get_args
55

66
import numpy as np
77
import pandas as pd
88
from anndata import AnnData, utils
9+
from fast_array_utils.stats._power import power as fau_power # TODO: upstream
910
from scipy import sparse
1011
from sklearn.utils.sparsefuncs import csc_median_axis_0
1112

12-
from scanpy._compat import CSBase
13+
from scanpy._compat import CSBase, CSCBase, CSRBase, DaskArray
1314

1415
from .._utils import _resolve_axis, get_literal_vals
1516
from .get import _check_mask
@@ -19,10 +20,11 @@
1920

2021
from numpy.typing import NDArray
2122

22-
Array = np.ndarray | CSBase
23+
Array = np.ndarray | CSBase | DaskArray
2324

2425
# Used with get_literal_vals
25-
AggType = Literal["count_nonzero", "mean", "sum", "var", "median"]
26+
ConstantDtypeAgg = Literal["count_nonzero", "sum", "median"]
27+
AggType = ConstantDtypeAgg | Literal["mean", "var"]
2628

2729

2830
class Aggregate:
@@ -330,13 +332,106 @@ def _aggregate(
330332
*,
331333
mask: NDArray[np.bool_] | None = None,
332334
dof: int = 1,
333-
):
335+
) -> dict[AggType, np.ndarray | DaskArray]:
334336
msg = f"Data type {type(data)} not supported for aggregation"
335337
raise NotImplementedError(msg)
336338

337339

340+
class MeanVarDict(TypedDict):
341+
mean: DaskArray
342+
var: DaskArray
343+
344+
345+
def aggregate_dask_mean_var(
346+
data: DaskArray,
347+
by: pd.Categorical,
348+
*,
349+
mask: NDArray[np.bool_] | None = None,
350+
dof: int = 1,
351+
) -> MeanVarDict:
352+
mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"]
353+
sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"]
354+
# TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
355+
if isinstance(data._meta, CSRBase):
356+
sq_mean = sq_mean.compute()
357+
elif isinstance(data._meta, CSCBase): # pragma: no-cover
358+
msg = "Cannot handle CSC matrices as dask meta."
359+
raise ValueError(msg)
360+
var = sq_mean - fau_power(mean, 2)
361+
if dof != 0:
362+
group_counts = np.bincount(by.codes)
363+
var *= (group_counts / (group_counts - dof))[:, np.newaxis]
364+
return MeanVarDict(mean=mean, var=var)
365+
366+
367+
@_aggregate.register(DaskArray)
368+
def aggregate_dask(
369+
data: DaskArray,
370+
by: pd.Categorical,
371+
func: AggType | Iterable[AggType],
372+
*,
373+
mask: NDArray[np.bool_] | None = None,
374+
dof: int = 1,
375+
) -> dict[AggType, DaskArray]:
376+
if not isinstance(data._meta, CSRBase | np.ndarray):
377+
msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported."
378+
raise ValueError(msg)
379+
if data.chunksize[1] != data.shape[1]:
380+
msg = "Feature axis must be unchunked"
381+
raise ValueError(msg)
382+
383+
def aggregate_chunk_sum_or_count_nonzero(
384+
chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None
385+
):
386+
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
387+
# for what is contained in `block_info`.
388+
subset = slice(*block_info[0]["array-location"][0])
389+
by_subsetted = by[subset]
390+
mask_subsetted = mask[subset] if mask is not None else mask
391+
res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func]
392+
return res[None, :]
393+
394+
funcs = set([func] if isinstance(func, str) else func)
395+
if "median" in funcs:
396+
msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
397+
raise NotImplementedError(msg)
398+
has_mean, has_var = (v in funcs for v in ["mean", "var"])
399+
funcs_no_var_or_mean = funcs - {"var", "mean"}
400+
# aggregate each row chunk individually,
401+
# producing a #chunks × #categories × #features array,
402+
# then aggregate the per-chunk results.
403+
aggregated = {
404+
f: data.map_blocks(
405+
partial(aggregate_chunk_sum_or_count_nonzero, func=func),
406+
new_axis=(1,),
407+
chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)),
408+
meta=np.array(
409+
[],
410+
dtype=np.float64
411+
if func not in get_args(ConstantDtypeAgg)
412+
else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original
413+
),
414+
).sum(axis=0)
415+
for f in funcs_no_var_or_mean
416+
}
417+
if has_var:
418+
aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof)
419+
aggregated["var"] = aggredated_mean_var["var"]
420+
if has_mean:
421+
aggregated["mean"] = aggredated_mean_var["mean"]
422+
# division must come after, not before, the summation for numerical precision
423+
# i.e., we can't just call map blocks over the mean function.
424+
elif has_mean:
425+
group_counts = np.bincount(by.codes)
426+
aggregated["mean"] = (
427+
aggregate_dask(data, by, "sum", mask=mask, dof=dof)["sum"]
428+
/ group_counts[:, None]
429+
)
430+
return aggregated
431+
432+
338433
@_aggregate.register(pd.DataFrame)
339-
def aggregate_df(data, by, func, *, mask=None, dof=1):
434+
def aggregate_df(data, by, func, *, mask=None, dof=1) -> dict[AggType, np.ndarray]:
340435
return _aggregate(data.values, by, func, mask=mask, dof=dof)
341436

342437

src/scanpy/preprocessing/_qc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,9 @@ def _(mtx: DaskArray, ns: Collection[int]) -> DaskArray:
416416
if not isinstance(mtx._meta, CSRBase | np.ndarray):
417417
msg = f"DaskArray must have csr matrix or ndarray meta, got {mtx._meta}."
418418
raise ValueError(msg)
419+
if mtx.chunksize[1] != mtx.shape[1]:
420+
msg = f"{mtx} must not be chunked along the feature axis"
421+
raise ValueError(msg)
419422
return mtx.map_blocks(
420423
lambda x: top_segment_proportions(x, ns), meta=np.array([])
421424
).compute()

src/testing/scanpy/_pytest/params.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from functools import wraps
56
from importlib.metadata import version
67
from typing import TYPE_CHECKING
78

@@ -10,18 +11,18 @@
1011
from packaging.version import Version
1112
from scipy import sparse
1213

13-
from .._helpers import (
14-
as_dense_dask_array,
15-
as_sparse_dask_array,
16-
)
14+
from .._helpers import as_dense_dask_array, as_sparse_dask_array
1715
from .._pytest.marks import needs
1816

1917
if TYPE_CHECKING:
2018
from collections.abc import Callable, Iterable
2119
from typing import Any, Literal
2220

21+
import numpy as np
2322
from _pytest.mark.structures import ParameterSet
2423

24+
from ....scanpy._compat import DaskArray
25+
2526

2627
skipif_no_sparray = pytest.mark.skipif(
2728
Version(version("anndata")) < Version("0.11"),
@@ -41,6 +42,18 @@ def param_with(
4142
)
4243

4344

45+
def _chunked_1d(
46+
f: Callable[[np.ndarray], DaskArray],
47+
) -> Callable[[np.ndarray], DaskArray]:
48+
@wraps(f)
49+
def wrapper(a: np.ndarray) -> DaskArray:
50+
da = f(a)
51+
return da.rechunk((da.chunksize[0], -1))
52+
53+
wrapper.__name__ = f"{wrapper.__name__}-1d_chunked"
54+
return wrapper
55+
56+
4457
MAP_ARRAY_TYPES: dict[
4558
tuple[Literal["mem", "dask"], Literal["dense", "sparse"]],
4659
tuple[ParameterSet, ...],
@@ -51,20 +64,21 @@ def param_with(
5164
pytest.param(sparse.csc_matrix, id="scipy_csc_mat"), # noqa: TID251
5265
pytest.param(sparse.csr_array, id="scipy_csr_arr", marks=[skipif_no_sparray]), # noqa: TID251
5366
),
54-
("dask", "dense"): (
67+
("dask", "dense"): tuple(
5568
pytest.param(
56-
as_dense_dask_array,
69+
wrapper(as_dense_dask_array),
5770
marks=[needs.dask, pytest.mark.anndata_dask_support],
58-
id="dask_array_dense",
59-
),
71+
id=f"dask_array_dense{suffix}",
72+
)
73+
for wrapper, suffix in [(lambda x: x, ""), (_chunked_1d, "-1d_chunked")]
6074
),
61-
("dask", "sparse"): (
75+
("dask", "sparse"): tuple(
6276
pytest.param(
63-
as_sparse_dask_array,
77+
wrapper(as_sparse_dask_array),
6478
marks=[needs.dask, pytest.mark.anndata_dask_support],
65-
id="dask_array_sparse",
66-
),
67-
# probably not necessary to also do csc
79+
id=f"dask_array_sparse{suffix}",
80+
)
81+
for wrapper, suffix in [(lambda x: x, ""), (_chunked_1d, "-1d_chunked")]
6882
),
6983
}
7084

0 commit comments

Comments
 (0)