Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0c6aaf6
(feat): aggregate via `dask`
ilan-gold Jul 3, 2025
5ba33b0
(feat): var operation
ilan-gold Jul 4, 2025
693aa09
Merge branch 'main' into ig/agg_dask
ilan-gold Jul 4, 2025
0e37ea7
(fix): relative import
ilan-gold Jul 4, 2025
de4ef27
Merge branch 'ig/agg_dask' of github.com:scverse/scanpy into ig/agg_dask
ilan-gold Jul 4, 2025
900ef94
(refactor): remove duplicated code
ilan-gold Jul 4, 2025
1feb7ab
refactor: do mean/var together if possible
ilan-gold Jul 4, 2025
166bf79
todo message
ilan-gold Jul 4, 2025
52b00b6
refactor: revert sum_sq thing
ilan-gold Jul 4, 2025
28d9899
chore: comment
ilan-gold Jul 4, 2025
7f80282
(fix): reomve unnecessary todo
ilan-gold Jul 7, 2025
b6b4978
(fix): clarify potential failure case
ilan-gold Jul 7, 2025
a3a9365
fix: compute one of the means :/
ilan-gold Jul 10, 2025
2b36963
chore: relnote
ilan-gold Jul 10, 2025
aee1921
Merge branch 'main' into ig/agg_dask
ilan-gold Jul 14, 2025
644283a
Merge branch 'main' into ig/agg_dask
ilan-gold Jul 16, 2025
92ddcc5
fix: refactor 1d chunking
ilan-gold Jul 18, 2025
9022651
fix: relnote
ilan-gold Jul 18, 2025
650e909
chore: add failure-case tests
ilan-gold Jul 18, 2025
7ff6c72
fix: use `if` statement for `compute`
ilan-gold Jul 21, 2025
1b2787c
chore: dtypes
ilan-gold Jul 21, 2025
f9fa0e5
Update 3700.feature.md
ilan-gold Jul 23, 2025
13edf64
Merge branch 'main' into ig/agg_dask
flying-sheep Jul 23, 2025
2faef67
Apply suggestions from code review
ilan-gold Jul 23, 2025
bf9eccd
fix: `xfail_dask_median` handling
ilan-gold Jul 23, 2025
403d5e7
Merge branch 'main' into ig/agg_dask
ilan-gold Jul 24, 2025
56efa79
Update src/scanpy/get/_aggregated.py
ilan-gold Jul 24, 2025
a1e9f6a
fix: centralize check
ilan-gold Jul 24, 2025
ec2f1f2
Merge branch 'ig/agg_dask' of github.com:scverse/scanpy into ig/agg_dask
ilan-gold Jul 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3700.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make {func}`scanpy.get.aggregate` `dask` compatible with all aggregations except median. {smaller}`I Gold`
109 changes: 102 additions & 7 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

from functools import singledispatch
from typing import TYPE_CHECKING, Literal
from functools import partial, singledispatch
from typing import TYPE_CHECKING, Literal, TypedDict, get_args

import numpy as np
import pandas as pd
from anndata import AnnData, utils
from fast_array_utils.stats._power import power as fau_power # TODO: upstream
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

from scanpy._compat import CSBase
from scanpy._compat import CSBase, CSCBase, CSRBase, DaskArray

from .._utils import _resolve_axis, get_literal_vals
from .get import _check_mask
Expand All @@ -19,10 +20,11 @@

from numpy.typing import NDArray

Array = np.ndarray | CSBase
Array = np.ndarray | CSBase | DaskArray

# Used with get_literal_vals
AggType = Literal["count_nonzero", "mean", "sum", "var", "median"]
ConstantDtypeAgg = Literal["count_nonzero", "sum", "median"]
AggType = ConstantDtypeAgg | Literal["mean", "var"]


class Aggregate:
Expand Down Expand Up @@ -330,13 +332,106 @@ def _aggregate(
*,
mask: NDArray[np.bool_] | None = None,
dof: int = 1,
):
) -> dict[AggType, np.ndarray | DaskArray]:
msg = f"Data type {type(data)} not supported for aggregation"
raise NotImplementedError(msg)


class MeanVarDict(TypedDict):
mean: DaskArray
var: DaskArray


def aggregate_dask_mean_var(
data: DaskArray,
by: pd.Categorical,
*,
mask: NDArray[np.bool_] | None = None,
dof: int = 1,
) -> MeanVarDict:
mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"]
sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"]
# TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
if isinstance(data._meta, CSRBase):
Comment thread
ilan-gold marked this conversation as resolved.
sq_mean = sq_mean.compute()
elif isinstance(data._meta, CSCBase): # pragma: no-cover
msg = "Cannot handle CSC matrices as dask meta."
raise ValueError(msg)
var = sq_mean - fau_power(mean, 2)
if dof != 0:
group_counts = np.bincount(by.codes)
var *= (group_counts / (group_counts - dof))[:, np.newaxis]
return MeanVarDict(mean=mean, var=var)


@_aggregate.register(DaskArray)
def aggregate_dask(
data: DaskArray,
by: pd.Categorical,
func: AggType | Iterable[AggType],
*,
mask: NDArray[np.bool_] | None = None,
dof: int = 1,
) -> dict[AggType, DaskArray]:
if not isinstance(data._meta, CSRBase | np.ndarray):
msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported."
raise ValueError(msg)
if data.chunksize[1] != data.shape[1]:
msg = "Feature axis must be unchunked"
raise ValueError(msg)

def aggregate_chunk_sum_or_count_nonzero(
chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None
):
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
# for what is contained in `block_info`.
subset = slice(*block_info[0]["array-location"][0])
Comment thread
ilan-gold marked this conversation as resolved.
by_subsetted = by[subset]
mask_subsetted = mask[subset] if mask is not None else mask
res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func]
return res[None, :]

funcs = set([func] if isinstance(func, str) else func)
if "median" in funcs:
msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
raise NotImplementedError(msg)
has_mean, has_var = (v in funcs for v in ["mean", "var"])
funcs_no_var_or_mean = funcs - {"var", "mean"}
# aggregate each row chunk individually,
# producing a #chunks × #categories × #features array,
# then aggregate the per-chunk results.
aggregated = {
Comment thread
ilan-gold marked this conversation as resolved.
f: data.map_blocks(
partial(aggregate_chunk_sum_or_count_nonzero, func=func),
new_axis=(1,),
chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)),
meta=np.array(
[],
dtype=np.float64
if func not in get_args(ConstantDtypeAgg)
else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original
),
).sum(axis=0)
for f in funcs_no_var_or_mean
}
if has_var:
aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof)
aggregated["var"] = aggredated_mean_var["var"]
if has_mean:
aggregated["mean"] = aggredated_mean_var["mean"]
# division must come after, not before, the summation for numerical precision
# i.e., we can't just call map blocks over the mean function.
elif has_mean:
group_counts = np.bincount(by.codes)
aggregated["mean"] = (
aggregate_dask(data, by, "sum", mask=mask, dof=dof)["sum"]
/ group_counts[:, None]
)
return aggregated


@_aggregate.register(pd.DataFrame)
def aggregate_df(data, by, func, *, mask=None, dof=1):
def aggregate_df(data, by, func, *, mask=None, dof=1) -> dict[AggType, np.ndarray]:
return _aggregate(data.values, by, func, mask=mask, dof=dof)


Expand Down
3 changes: 3 additions & 0 deletions src/scanpy/preprocessing/_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ def _(mtx: DaskArray, ns: Collection[int]) -> DaskArray:
if not isinstance(mtx._meta, CSRBase | np.ndarray):
msg = f"DaskArray must have csr matrix or ndarray meta, got {mtx._meta}."
raise ValueError(msg)
if mtx.chunksize[1] != mtx.shape[1]:
msg = f"{mtx} must not be chunked along the feature axis"
raise ValueError(msg)
return mtx.map_blocks(
lambda x: top_segment_proportions(x, ns), meta=np.array([])
).compute()
Expand Down
40 changes: 27 additions & 13 deletions src/testing/scanpy/_pytest/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from functools import wraps
from importlib.metadata import version
from typing import TYPE_CHECKING

Expand All @@ -10,18 +11,18 @@
from packaging.version import Version
from scipy import sparse

from .._helpers import (
as_dense_dask_array,
as_sparse_dask_array,
)
from .._helpers import as_dense_dask_array, as_sparse_dask_array
from .._pytest.marks import needs

if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from typing import Any, Literal

import numpy as np
from _pytest.mark.structures import ParameterSet

from ....scanpy._compat import DaskArray


skipif_no_sparray = pytest.mark.skipif(
Version(version("anndata")) < Version("0.11"),
Expand All @@ -41,6 +42,18 @@ def param_with(
)


def _chunked_1d(
f: Callable[[np.ndarray], DaskArray],
) -> Callable[[np.ndarray], DaskArray]:
@wraps(f)
def wrapper(a: np.ndarray) -> DaskArray:
da = f(a)
return da.rechunk((da.chunksize[0], -1))

wrapper.__name__ = f"{wrapper.__name__}-1d_chunked"
return wrapper


MAP_ARRAY_TYPES: dict[
tuple[Literal["mem", "dask"], Literal["dense", "sparse"]],
tuple[ParameterSet, ...],
Expand All @@ -51,20 +64,21 @@ def param_with(
pytest.param(sparse.csc_matrix, id="scipy_csc_mat"), # noqa: TID251
pytest.param(sparse.csr_array, id="scipy_csr_arr", marks=[skipif_no_sparray]), # noqa: TID251
),
("dask", "dense"): (
("dask", "dense"): tuple(
pytest.param(
as_dense_dask_array,
wrapper(as_dense_dask_array),
marks=[needs.dask, pytest.mark.anndata_dask_support],
id="dask_array_dense",
),
id=f"dask_array_dense{suffix}",
)
for wrapper, suffix in [(lambda x: x, ""), (_chunked_1d, "-1d_chunked")]
),
("dask", "sparse"): (
("dask", "sparse"): tuple(
pytest.param(
as_sparse_dask_array,
wrapper(as_sparse_dask_array),
marks=[needs.dask, pytest.mark.anndata_dask_support],
id="dask_array_sparse",
),
# probably not necessary to also do csc
id=f"dask_array_sparse{suffix}",
)
for wrapper, suffix in [(lambda x: x, ""), (_chunked_1d, "-1d_chunked")]
),
}

Expand Down
Loading
Loading