11from __future__ import annotations
22
33from functools import partial , singledispatch
4- from typing import TYPE_CHECKING , Literal
4+ from typing import TYPE_CHECKING , Literal , TypedDict
55
66import numpy as np
77import pandas as pd
@@ -331,25 +331,30 @@ def _aggregate(
331331 * ,
332332 mask : NDArray [np .bool_ ] | None = None ,
333333 dof : int = 1 ,
334- ):
334+ ) -> dict [ AggType , np . ndarray ] :
335335 msg = f"Data type { type (data )} not supported for aggregation"
336336 raise NotImplementedError (msg )
337337
338338
339- def aggregate_dask_var (
339+ class MeanVarDict (TypedDict ):
340+ mean : DaskArray
341+ var : DaskArray
342+
343+
344+ def aggregate_dask_mean_var (
340345 data : DaskArray ,
341346 by : pd .Categorical ,
342347 * ,
343348 mask : NDArray [np .bool_ ] | None = None ,
344349 dof : int = 1 ,
345- ):
350+ ) -> MeanVarDict :
346351 mean = aggregate_dask (data , by , "mean" , mask = mask , dof = dof )["mean" ]
347352 sq_mean = aggregate_dask (fau_power (data , 2 ), by , "mean" , mask = mask , dof = dof )["mean" ]
348353 var = sq_mean - (mean ** 2 )
349354 if dof != 0 :
350355 group_counts = np .bincount (by .codes )
351356 var *= (group_counts / (group_counts - dof ))[:, np .newaxis ]
352- return var
357+ return { " var" : var , "mean" : mean }
353358
354359
355360@_aggregate .register (DaskArray )
@@ -360,9 +365,8 @@ def aggregate_dask(
360365 * ,
361366 mask : NDArray [np .bool_ ] | None = None ,
362367 dof : int = 1 ,
363- ):
368+ ) -> dict [ AggType , np . ndarray ] :
364369 def aggregate_chunk_no_var (chunk : Array , block_info = None , * , func : AggType = func ):
365- func = "sum" if func == "mean" else func
366370 subset = slice (* block_info [0 ]["array-location" ][0 ])
367371 by_subsetted = by [subset ]
368372 mask_subsetted = mask [subset ] if mask is not None else mask
@@ -374,27 +378,34 @@ def aggregate_chunk_no_var(chunk: Array, block_info=None, *, func: AggType = fun
374378 msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
375379 raise NotImplementedError (msg )
376380 has_var = "var" in funcs
377- funcs_no_var = funcs - {"var" }
381+ has_mean = "mean" in funcs
382+ funcs_no_var_or_mean = funcs - {"var" , "mean" }
378383 aggregated = {
379384 f : data .map_blocks (
380385 partial (aggregate_chunk_no_var , func = func ),
381386 new_axis = (1 ,),
382387 chunks = ((1 ,) * data .blocks .size , (len (by .categories ),), (data .shape [1 ],)),
383388 meta = np .array ([], dtype = np .float64 ), # TODO: figure out dtype
384389 ).sum (axis = 0 )
385- for f in funcs_no_var
390+ for f in funcs_no_var_or_mean
386391 }
387392 if has_var :
388- aggregated ["var" ] = aggregate_dask_var (data , by , mask = mask , dof = dof )
393+ aggredated_mean_var = aggregate_dask_mean_var (data , by , mask = mask , dof = dof )
394+ aggregated ["var" ] = aggredated_mean_var ["var" ]
395+ if has_mean :
396+ aggregated ["mean" ] = aggredated_mean_var ["mean" ]
389397 # division must come after, not before, the summation for numerical precision.
390- if "mean" in aggregated :
398+ elif has_mean :
391399 group_counts = np .bincount (by .codes )
392- aggregated ["mean" ] /= group_counts [:, None ]
400+ aggregated ["mean" ] = (
401+ aggregate_dask (data , by , "sum" , mask = mask , dof = dof )["sum" ]
402+ / group_counts [:, None ]
403+ )
393404 return aggregated
394405
395406
396407@_aggregate .register (pd .DataFrame )
397- def aggregate_df (data , by , func , * , mask = None , dof = 1 ):
408+ def aggregate_df (data , by , func , * , mask = None , dof = 1 ) -> dict [ AggType , np . ndarray ] :
398409 return _aggregate (data .values , by , func , mask = mask , dof = dof )
399410
400411
0 commit comments