-
Notifications
You must be signed in to change notification settings - Fork 742
(feat): sc.get.aggregate via dask
#3700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
0c6aaf6
5ba33b0
693aa09
0e37ea7
de4ef27
900ef94
1feb7ab
166bf79
52b00b6
28d9899
7f80282
b6b4978
a3a9365
2b36963
aee1921
644283a
92ddcc5
9022651
650e909
7ff6c72
1b2787c
f9fa0e5
13edf64
2faef67
bf9eccd
403d5e7
56efa79
a1e9f6a
ec2f1f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Make {func}`scanpy.get.aggregate` `dask` compatible over all aggregations except median. {smaller}`I Gold` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,15 +1,16 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from functools import singledispatch | ||
| from typing import TYPE_CHECKING, Literal | ||
| from functools import partial, singledispatch | ||
| from typing import TYPE_CHECKING, Literal, TypedDict | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| from anndata import AnnData, utils | ||
| from fast_array_utils.stats._power import power as fau_power # TODO: upstream | ||
| from scipy import sparse | ||
| from sklearn.utils.sparsefuncs import csc_median_axis_0 | ||
|
|
||
| from scanpy._compat import CSBase | ||
| from scanpy._compat import CSBase, CSRBase, DaskArray | ||
|
|
||
| from .._utils import _resolve_axis, get_literal_vals | ||
| from .get import _check_mask | ||
|
|
@@ -19,7 +20,7 @@ | |
|
|
||
| from numpy.typing import NDArray | ||
|
|
||
| Array = np.ndarray | CSBase | ||
| Array = np.ndarray | CSBase | DaskArray | ||
|
|
||
| # Used with get_literal_vals | ||
| AggType = Literal["count_nonzero", "mean", "sum", "var", "median"] | ||
|
|
@@ -330,13 +331,98 @@ def _aggregate( | |
| *, | ||
| mask: NDArray[np.bool_] | None = None, | ||
| dof: int = 1, | ||
| ): | ||
| ) -> dict[AggType, np.ndarray | DaskArray]: | ||
| msg = f"Data type {type(data)} not supported for aggregation" | ||
| raise NotImplementedError(msg) | ||
|
|
||
|
|
||
| class MeanVarDict(TypedDict): | ||
| mean: DaskArray | ||
| var: DaskArray | ||
|
|
||
|
|
||
| def aggregate_dask_mean_var( | ||
| data: DaskArray, | ||
| by: pd.Categorical, | ||
| *, | ||
| mask: NDArray[np.bool_] | None = None, | ||
| dof: int = 1, | ||
| ) -> MeanVarDict: | ||
| mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"] | ||
| sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"] | ||
| # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. | ||
| if isinstance(data._meta, CSRBase): | ||
|
ilan-gold marked this conversation as resolved.
|
||
| sq_mean = sq_mean.compute() | ||
| var = sq_mean - fau_power(mean, 2) | ||
| if dof != 0: | ||
| group_counts = np.bincount(by.codes) | ||
| var *= (group_counts / (group_counts - dof))[:, np.newaxis] | ||
| return {"var": var, "mean": mean} | ||
|
ilan-gold marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| @_aggregate.register(DaskArray) | ||
| def aggregate_dask( | ||
| data: DaskArray, | ||
| by: pd.Categorical, | ||
| func: AggType | Iterable[AggType], | ||
| *, | ||
| mask: NDArray[np.bool_] | None = None, | ||
| dof: int = 1, | ||
| ) -> dict[AggType, DaskArray]: | ||
| if not isinstance(data._meta, CSRBase | np.ndarray): | ||
| msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported." | ||
| raise ValueError(msg) | ||
| if data.chunksize[1] != data.shape[1]: | ||
| msg = "Feature axis must be unchunked" | ||
| raise ValueError(msg) | ||
|
|
||
| def aggregate_chunk_sum_or_count_nonzero( | ||
| chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None | ||
| ): | ||
| subset = slice(*block_info[0]["array-location"][0]) | ||
|
ilan-gold marked this conversation as resolved.
|
||
| by_subsetted = by[subset] | ||
| mask_subsetted = mask[subset] if mask is not None else mask | ||
| res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func] | ||
| return res[None, :] | ||
|
|
||
| funcs = set([func] if isinstance(func, str) else func) | ||
| if "median" in funcs: | ||
| msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue." | ||
| raise NotImplementedError(msg) | ||
| has_mean, has_var = (v in funcs for v in ["mean", "var"]) | ||
| funcs_no_var_or_mean = funcs - {"var", "mean"} | ||
| aggregated = { | ||
|
ilan-gold marked this conversation as resolved.
|
||
| f: data.map_blocks( | ||
| partial(aggregate_chunk_sum_or_count_nonzero, func=func), | ||
| new_axis=(1,), | ||
| chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)), | ||
| meta=np.array( | ||
| [], | ||
| dtype=np.float64 | ||
| if func not in ["count_nonzero", "sum"] | ||
| else data.dtype, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this info be centralized? this looks fragile because it’s an inline check based on the currently available functions.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This really should be a TODO. To be correct, we would need to handle the overflow potential of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First, I still stand by my comment. Please don’t bi-partition an extensible set by specifying one partition. This should be a lookup, either central or by defining and accessing an inline dict/switch-case that spells out all options. regarding the rest: hmm, I guess I conceived fau to be more low-level, i.e. if you fear that something can overflow, set does that make sense or do you think it should be more opinionated?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point, the two things are separate. I will add a TODO as well then. Have a look and let me know if this is what you had in mind or if I misunderstood |
||
| ), | ||
| ).sum(axis=0) | ||
| for f in funcs_no_var_or_mean | ||
| } | ||
| if has_var: | ||
| aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof) | ||
| aggregated["var"] = aggredated_mean_var["var"] | ||
| if has_mean: | ||
| aggregated["mean"] = aggredated_mean_var["mean"] | ||
| # division must come after, not before, the summation for numerical precision | ||
| # i.e., we can't just call map blocks over the mean function. | ||
| elif has_mean: | ||
| group_counts = np.bincount(by.codes) | ||
| aggregated["mean"] = ( | ||
| aggregate_dask(data, by, "sum", mask=mask, dof=dof)["sum"] | ||
| / group_counts[:, None] | ||
| ) | ||
| return aggregated | ||
|
|
||
|
|
||
| @_aggregate.register(pd.DataFrame) | ||
| def aggregate_df(data, by, func, *, mask=None, dof=1): | ||
| def aggregate_df(data, by, func, *, mask=None, dof=1) -> dict[AggType, np.ndarray]: | ||
| return _aggregate(data.values, by, func, mask=mask, dof=dof) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.