|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from functools import singledispatch |
4 | | -from typing import TYPE_CHECKING, Literal |
| 3 | +from functools import partial, singledispatch |
| 4 | +from typing import TYPE_CHECKING, Literal, TypedDict, get_args |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import pandas as pd |
8 | 8 | from anndata import AnnData, utils |
| 9 | +from fast_array_utils.stats._power import power as fau_power # TODO: upstream |
9 | 10 | from scipy import sparse |
10 | 11 | from sklearn.utils.sparsefuncs import csc_median_axis_0 |
11 | 12 |
|
12 | | -from scanpy._compat import CSBase |
| 13 | +from scanpy._compat import CSBase, CSCBase, CSRBase, DaskArray |
13 | 14 |
|
14 | 15 | from .._utils import _resolve_axis, get_literal_vals |
15 | 16 | from .get import _check_mask |
|
19 | 20 |
|
20 | 21 | from numpy.typing import NDArray |
21 | 22 |
|
22 | | - Array = np.ndarray | CSBase |
| 23 | + Array = np.ndarray | CSBase | DaskArray |
23 | 24 |
|
24 | 25 | # Used with get_literal_vals |
25 | | -AggType = Literal["count_nonzero", "mean", "sum", "var", "median"] |
| 26 | +ConstantDtypeAgg = Literal["count_nonzero", "sum", "median"] |
| 27 | +AggType = ConstantDtypeAgg | Literal["mean", "var"] |
26 | 28 |
|
27 | 29 |
|
28 | 30 | class Aggregate: |
@@ -330,13 +332,106 @@ def _aggregate( |
330 | 332 | *, |
331 | 333 | mask: NDArray[np.bool_] | None = None, |
332 | 334 | dof: int = 1, |
333 | | -): |
| 335 | +) -> dict[AggType, np.ndarray | DaskArray]: |
334 | 336 | msg = f"Data type {type(data)} not supported for aggregation" |
335 | 337 | raise NotImplementedError(msg) |
336 | 338 |
|
337 | 339 |
|
| 340 | +class MeanVarDict(TypedDict): |
| 341 | + mean: DaskArray |
| 342 | + var: DaskArray |
| 343 | + |
| 344 | + |
| 345 | +def aggregate_dask_mean_var( |
| 346 | + data: DaskArray, |
| 347 | + by: pd.Categorical, |
| 348 | + *, |
| 349 | + mask: NDArray[np.bool_] | None = None, |
| 350 | + dof: int = 1, |
| 351 | +) -> MeanVarDict: |
| 352 | + mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"] |
| 353 | + sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"] |
| 354 | + # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. |
| 355 | + if isinstance(data._meta, CSRBase): |
| 356 | + sq_mean = sq_mean.compute() |
| 357 | + elif isinstance(data._meta, CSCBase): # pragma: no-cover |
| 358 | + msg = "Cannot handle CSC matrices as dask meta." |
| 359 | + raise ValueError(msg) |
| 360 | + var = sq_mean - fau_power(mean, 2) |
| 361 | + if dof != 0: |
| 362 | + group_counts = np.bincount(by.codes) |
| 363 | + var *= (group_counts / (group_counts - dof))[:, np.newaxis] |
| 364 | + return MeanVarDict(mean=mean, var=var) |
| 365 | + |
| 366 | + |
| 367 | +@_aggregate.register(DaskArray) |
| 368 | +def aggregate_dask( |
| 369 | + data: DaskArray, |
| 370 | + by: pd.Categorical, |
| 371 | + func: AggType | Iterable[AggType], |
| 372 | + *, |
| 373 | + mask: NDArray[np.bool_] | None = None, |
| 374 | + dof: int = 1, |
| 375 | +) -> dict[AggType, DaskArray]: |
| 376 | + if not isinstance(data._meta, CSRBase | np.ndarray): |
| 377 | + msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported." |
| 378 | + raise ValueError(msg) |
| 379 | + if data.chunksize[1] != data.shape[1]: |
| 380 | + msg = "Feature axis must be unchunked" |
| 381 | + raise ValueError(msg) |
| 382 | + |
| 383 | + def aggregate_chunk_sum_or_count_nonzero( |
| 384 | + chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None |
| 385 | + ): |
| 386 | + # See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html |
| 387 | + # for what is contained in `block_info`. |
| 388 | + subset = slice(*block_info[0]["array-location"][0]) |
| 389 | + by_subsetted = by[subset] |
| 390 | + mask_subsetted = mask[subset] if mask is not None else mask |
| 391 | + res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func] |
| 392 | + return res[None, :] |
| 393 | + |
| 394 | + funcs = set([func] if isinstance(func, str) else func) |
| 395 | + if "median" in funcs: |
| 396 | + msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue." |
| 397 | + raise NotImplementedError(msg) |
| 398 | + has_mean, has_var = (v in funcs for v in ["mean", "var"]) |
| 399 | + funcs_no_var_or_mean = funcs - {"var", "mean"} |
| 400 | + # aggregate each row chunk individually, |
| 401 | + # producing a #chunks × #categories × #features array, |
| 402 | + # then aggregate the per-chunk results. |
| 403 | + aggregated = { |
| 404 | + f: data.map_blocks( |
| 405 | + partial(aggregate_chunk_sum_or_count_nonzero, func=func), |
| 406 | + new_axis=(1,), |
| 407 | + chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)), |
| 408 | + meta=np.array( |
| 409 | + [], |
| 410 | + dtype=np.float64 |
| 411 | + if func not in get_args(ConstantDtypeAgg) |
| 412 | + else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original |
| 413 | + ), |
| 414 | + ).sum(axis=0) |
| 415 | + for f in funcs_no_var_or_mean |
| 416 | + } |
| 417 | + if has_var: |
| 418 | + aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof) |
| 419 | + aggregated["var"] = aggredated_mean_var["var"] |
| 420 | + if has_mean: |
| 421 | + aggregated["mean"] = aggredated_mean_var["mean"] |
| 422 | + # division must come after, not before, the summation for numerical precision |
| 423 | + # i.e., we can't just call map blocks over the mean function. |
| 424 | + elif has_mean: |
| 425 | + group_counts = np.bincount(by.codes) |
| 426 | + aggregated["mean"] = ( |
| 427 | + aggregate_dask(data, by, "sum", mask=mask, dof=dof)["sum"] |
| 428 | + / group_counts[:, None] |
| 429 | + ) |
| 430 | + return aggregated |
| 431 | + |
| 432 | + |
338 | 433 | @_aggregate.register(pd.DataFrame) |
339 | | -def aggregate_df(data, by, func, *, mask=None, dof=1): |
| 434 | +def aggregate_df(data, by, func, *, mask=None, dof=1) -> dict[AggType, np.ndarray]: |
340 | 435 | return _aggregate(data.values, by, func, mask=mask, dof=dof) |
341 | 436 |
|
342 | 437 |
|
|
0 commit comments