Skip to content

Commit 1feb7ab

Browse files
committed
refactor: do mean/var together if possible
1 parent 900ef94 commit 1feb7ab

1 file changed

Lines changed: 24 additions & 14 deletions

File tree

src/scanpy/get/_aggregated.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from functools import partial, singledispatch
4-
from typing import TYPE_CHECKING, Literal
4+
from typing import TYPE_CHECKING, Literal, TypedDict
55

66
import numpy as np
77
import pandas as pd
@@ -331,25 +331,30 @@ def _aggregate(
331331
*,
332332
mask: NDArray[np.bool_] | None = None,
333333
dof: int = 1,
334-
):
334+
) -> dict[AggType, np.ndarray]:
335335
msg = f"Data type {type(data)} not supported for aggregation"
336336
raise NotImplementedError(msg)
337337

338338

339-
def aggregate_dask_var(
339+
class MeanVarDict(TypedDict):
340+
mean: DaskArray
341+
var: DaskArray
342+
343+
344+
def aggregate_dask_mean_var(
340345
data: DaskArray,
341346
by: pd.Categorical,
342347
*,
343348
mask: NDArray[np.bool_] | None = None,
344349
dof: int = 1,
345-
):
350+
) -> MeanVarDict:
346351
mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"]
347352
sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"]
348353
var = sq_mean - (mean**2)
349354
if dof != 0:
350355
group_counts = np.bincount(by.codes)
351356
var *= (group_counts / (group_counts - dof))[:, np.newaxis]
352-
return var
357+
return {"var": var, "mean": mean}
353358

354359

355360
@_aggregate.register(DaskArray)
@@ -360,9 +365,8 @@ def aggregate_dask(
360365
*,
361366
mask: NDArray[np.bool_] | None = None,
362367
dof: int = 1,
363-
):
368+
) -> dict[AggType, np.ndarray]:
364369
def aggregate_chunk_no_var(chunk: Array, block_info=None, *, func: AggType = func):
365-
func = "sum" if func == "mean" else func
366370
subset = slice(*block_info[0]["array-location"][0])
367371
by_subsetted = by[subset]
368372
mask_subsetted = mask[subset] if mask is not None else mask
@@ -373,28 +377,34 @@ def aggregate_chunk_no_var(chunk: Array, block_info=None, *, func: AggType = fun
373377
if "median" in funcs:
374378
msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
375379
raise NotImplementedError(msg)
376-
has_var = "var" in funcs
377-
funcs_no_var = funcs - {"var"}
380+
has_mean, has_var = (v in funcs for v in ["mean", "var"])
381+
funcs_no_var_or_mean = funcs - {"var", "mean"}
378382
aggregated = {
379383
f: data.map_blocks(
380384
partial(aggregate_chunk_no_var, func=func),
381385
new_axis=(1,),
382386
chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)),
383387
meta=np.array([], dtype=np.float64), # TODO: figure out dtype
384388
).sum(axis=0)
385-
for f in funcs_no_var
389+
for f in funcs_no_var_or_mean
386390
}
387391
if has_var:
388-
aggregated["var"] = aggregate_dask_var(data, by, mask=mask, dof=dof)
392+
aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof)
393+
aggregated["var"] = aggredated_mean_var["var"]
394+
if has_mean:
395+
aggregated["mean"] = aggredated_mean_var["mean"]
389396
# division must come after, not before, the summation for numerical precision.
390-
if "mean" in aggregated:
397+
elif has_mean:
391398
group_counts = np.bincount(by.codes)
392-
aggregated["mean"] /= group_counts[:, None]
399+
aggregated["mean"] = (
400+
aggregate_dask(data, by, "sum", mask=mask, dof=dof)["sum"]
401+
/ group_counts[:, None]
402+
)
393403
return aggregated
394404

395405

396406
@_aggregate.register(pd.DataFrame)
397-
def aggregate_df(data, by, func, *, mask=None, dof=1):
407+
def aggregate_df(data, by, func, *, mask=None, dof=1) -> dict[AggType, np.ndarray]:
398408
return _aggregate(data.values, by, func, mask=mask, dof=dof)
399409

400410

0 commit comments

Comments
 (0)