Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0c6aaf6
(feat): aggregate via `dask`
ilan-gold Jul 3, 2025
5ba33b0
(feat): var operation
ilan-gold Jul 4, 2025
693aa09
Merge branch 'main' into ig/agg_dask
ilan-gold Jul 4, 2025
0e37ea7
(fix): relative import
ilan-gold Jul 4, 2025
de4ef27
Merge branch 'ig/agg_dask' of github.com:scverse/scanpy into ig/agg_dask
ilan-gold Jul 4, 2025
900ef94
(refactor): remove duplicated code
ilan-gold Jul 4, 2025
1feb7ab
refactor: do mean/var together if possible
ilan-gold Jul 4, 2025
166bf79
todo message
ilan-gold Jul 4, 2025
52b00b6
refactor: revert sum_sq thing
ilan-gold Jul 4, 2025
28d9899
chore: comment
ilan-gold Jul 4, 2025
7f80282
(fix): reomve unnecessary todo
ilan-gold Jul 7, 2025
b6b4978
(fix): clarify potential failure case
ilan-gold Jul 7, 2025
a3a9365
fix: compute one of the means :/
ilan-gold Jul 10, 2025
2b36963
chore: relnote
ilan-gold Jul 10, 2025
aee1921
Merge branch 'main' into ig/agg_dask
ilan-gold Jul 14, 2025
644283a
Merge branch 'main' into ig/agg_dask
ilan-gold Jul 16, 2025
92ddcc5
fix: refactor 1d chunking
ilan-gold Jul 18, 2025
9022651
fix: relnote
ilan-gold Jul 18, 2025
650e909
chore: add failure-case tests
ilan-gold Jul 18, 2025
7ff6c72
fix: use `if` statement for `compute`
ilan-gold Jul 21, 2025
1b2787c
chore: dtypes
ilan-gold Jul 21, 2025
f9fa0e5
Update 3700.feature.md
ilan-gold Jul 23, 2025
13edf64
Merge branch 'main' into ig/agg_dask
flying-sheep Jul 23, 2025
2faef67
Apply suggestions from code review
ilan-gold Jul 23, 2025
bf9eccd
fix: `xfail_dask_median` handling
ilan-gold Jul 23, 2025
403d5e7
Merge branch 'main' into ig/agg_dask
ilan-gold Jul 24, 2025
56efa79
Update src/scanpy/get/_aggregated.py
ilan-gold Jul 24, 2025
a1e9f6a
fix: centralize check
ilan-gold Jul 24, 2025
ec2f1f2
Merge branch 'ig/agg_dask' of github.com:scverse/scanpy into ig/agg_dask
ilan-gold Jul 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

from functools import singledispatch
from functools import partial, singledispatch
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd
from anndata import AnnData, utils
from fast_array_utils.stats._power import power as fau_power # TODO: upstream
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

from scanpy._compat import CSBase
from scanpy._compat import CSBase, DaskArray

from .._utils import _resolve_axis, get_literal_vals
from .get import _check_mask
Expand Down Expand Up @@ -100,6 +101,9 @@ def mean(self) -> Array:
/ np.bincount(self.groupby.codes)[:, None]
)

def sum_sq(self) -> Array:
return utils.asarray(self.indicator_matrix @ _power(self.data, 2))

def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
"""Compute the count, as well as mean and variance per feature, per group of observations.

Expand All @@ -124,10 +128,7 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
group_counts = np.bincount(self.groupby.codes)
mean_ = self.mean()
# sparse matrices do not support ** for elementwise power.
mean_sq = (
utils.asarray(self.indicator_matrix @ _power(self.data, 2))
/ group_counts[:, None]
)
mean_sq = self.sum_sq() / group_counts[:, None]
sq_mean = mean_**2
var_ = mean_sq - sq_mean
# TODO: Why these values exactly? Because they are high relative to the datatype?
Expand Down Expand Up @@ -335,6 +336,63 @@ def _aggregate(
raise NotImplementedError(msg)


def aggregate_dask_var(
data: DaskArray,
by: pd.Categorical,
*,
mask: NDArray[np.bool_] | None = None,
dof: int = 1,
):
mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"]
sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"]
var = sq_mean - (mean**2)
if dof != 0:
group_counts = np.bincount(by.codes)
var *= (group_counts / (group_counts - dof))[:, np.newaxis]
return var


@_aggregate.register(DaskArray)
def aggregate_dask(
data: DaskArray,
by: pd.Categorical,
func: AggType | Iterable[AggType],
*,
mask: NDArray[np.bool_] | None = None,
dof: int = 1,
):
def aggregate_chunk_no_var(chunk: Array, block_info=None, *, func: AggType = func):
func = "sum" if func == "mean" else func
subset = slice(*block_info[0]["array-location"][0])
Comment thread
ilan-gold marked this conversation as resolved.
by_subsetted = by[subset]
mask_subsetted = mask[subset] if mask is not None else mask
res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func]
return res[None, :]

funcs = set([func] if isinstance(func, str) else func)
if "median" in funcs:
msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
raise NotImplementedError(msg)
has_var = "var" in funcs
funcs_no_var = funcs - {"var"}
aggregated = {
Comment thread
ilan-gold marked this conversation as resolved.
f: data.map_blocks(
partial(aggregate_chunk_no_var, func=func),
new_axis=(1,),
chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)),
meta=np.array([], dtype=np.float64), # TODO: figure out dtype
).sum(axis=0)
for f in funcs_no_var
}
if has_var:
aggregated["var"] = aggregate_dask_var(data, by, mask=mask, dof=dof)
# division must come after, not before, the summation for numerical precision.
if "mean" in aggregated:
group_counts = np.bincount(by.codes)
aggregated["mean"] /= group_counts[:, None]
return aggregated


@_aggregate.register(pd.DataFrame)
def aggregate_df(data, by, func, *, mask=None, dof=1):
return _aggregate(data.values, by, func, mask=mask, dof=dof)
Expand Down
44 changes: 35 additions & 9 deletions tests/test_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from scipy import sparse

import scanpy as sc
from scanpy._compat import DaskArray
from scanpy._utils import _resolve_axis, get_literal_vals
from scanpy.get._aggregated import AggType
from testing.scanpy._helpers import assert_equal
from testing.scanpy._helpers.data import pbmc3k_processed
from testing.scanpy._pytest.params import ARRAY_TYPES_MEM

from .test_pca import ARRAY_TYPES
Comment thread
ilan-gold marked this conversation as resolved.
Outdated


@pytest.fixture(params=get_literal_vals(AggType))
Expand Down Expand Up @@ -93,16 +95,19 @@ def test_mask(axis):
assert np.all(by_name["0"].layers["sum"] == 0)


@pytest.mark.parametrize("array_type", ARRAY_TYPES_MEM)
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
def test_aggregate_vs_pandas(metric, array_type):
adata = pbmc3k_processed().raw.to_adata()
adata = adata[
adata.obs["louvain"].isin(adata.obs["louvain"].cat.categories[:5]), :1_000
].copy()
adata.X = array_type(adata.X)
xfail_dask_median(adata, metric)
adata.obs["percent_mito_binned"] = pd.cut(adata.obs["percent_mito"], bins=5)
result = sc.get.aggregate(adata, ["louvain", "percent_mito_binned"], metric)

# TODO: upstream
if isinstance(adata.X, DaskArray):
adata.X = adata.X.compute(scheduler="single-threaded")
if metric == "count_nonzero":
expected = (
(adata.to_df() != 0)
Expand All @@ -124,7 +129,10 @@ def test_aggregate_vs_pandas(metric, array_type):
)
expected.index.name = None
expected.columns.name = None

if isinstance(result.layers[metric], DaskArray):
result.layers[metric] = result.layers[metric].compute(
scheduler="single-threaded"
)
result_df = result.to_df(layer=metric)
result_df.index.name = None
result_df.columns.name = None
Expand All @@ -139,16 +147,25 @@ def test_aggregate_vs_pandas(metric, array_type):
pd.testing.assert_frame_equal(result_df, expected, check_dtype=False, atol=1e-5)


@pytest.mark.parametrize("array_type", ARRAY_TYPES_MEM)
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
def test_aggregate_axis(array_type, metric):
adata = pbmc3k_processed().raw.to_adata()
adata = adata[
adata.obs["louvain"].isin(adata.obs["louvain"].cat.categories[:5]), :1_000
].copy()
# TODO: disallow transposing dask sparse matrices in anndata
# This test actually passes in all cases except with sparse var calculation,
# even though I'm not clear on the behavior of transpose with sparse matrices in dask.
adata_T = adata.T
adata_T.X = array_type(adata_T.X)
xfail_dask_median(adata_T, metric)
adata.X = array_type(adata.X)
expected = sc.get.aggregate(adata, ["louvain"], metric)
actual = sc.get.aggregate(adata.T, ["louvain"], metric, axis=1).T

actual = sc.get.aggregate(adata.T, ["louvain"], metric, axis=1)
if isinstance(adata.X, DaskArray):
for d in [expected, actual]:
d.layers[metric] = d.layers[metric].compute(scheduler="single-threaded")
actual = actual.T
assert_equal(expected, actual)


Expand Down Expand Up @@ -387,15 +404,24 @@ def test_combine_categories(label_cols, cols, expected):
pd.testing.assert_frame_equal(reconstructed_df, result_label_df)


@pytest.mark.parametrize("array_type", ARRAY_TYPES_MEM)
def xfail_dask_median(adata, metric):
if isinstance(adata.X, DaskArray) and metric == "median":
pytest.xfail("Median calculation not implemented for Dask")
Comment thread
ilan-gold marked this conversation as resolved.
Outdated


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
def test_aggregate_arraytype(array_type, metric):
adata = pbmc3k_processed().raw.to_adata()
adata = adata[
adata.obs["louvain"].isin(adata.obs["louvain"].cat.categories[:5]), :1_000
].copy()
adata.X = array_type(adata.X)
xfail_dask_median(adata, metric)
aggregate = sc.get.aggregate(adata, ["louvain"], metric)
assert isinstance(aggregate.layers[metric], np.ndarray)
assert isinstance(
aggregate.layers[metric],
DaskArray if isinstance(adata.X, DaskArray) else np.ndarray,
)


def test_aggregate_obsm_varm():
Expand Down
Loading