1010from scipy import sparse
1111from sklearn .utils .sparsefuncs import csc_median_axis_0
1212
13- from scanpy ._compat import CSBase , CSCBase , CSRBase , DaskArray
13+ from scanpy ._compat import CSBase , CSRBase , DaskArray
1414
1515from .._utils import _resolve_axis , get_literal_vals
1616from .get import _check_mask
@@ -357,9 +357,6 @@ def aggregate_dask_mean_var(
357357 # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
358358 if isinstance (data ._meta , CSRBase ):
359359 sq_mean = sq_mean .compute ()
360- elif isinstance (data ._meta , CSCBase ): # pragma: no-cover
361- msg = "Cannot handle CSC matrices as dask meta."
362- raise ValueError (msg )
363360 var = sq_mean - fau_power (mean , 2 )
364361 if dof != 0 :
365362 group_counts = np .bincount (by .codes )
@@ -376,47 +373,66 @@ def aggregate_dask(
376373 mask : NDArray [np .bool_ ] | None = None ,
377374 dof : int = 1 ,
378375) -> dict [AggType , DaskArray ]:
379- if not isinstance (data ._meta , CSRBase | np .ndarray ):
376+ if not isinstance (data ._meta , CSBase | np .ndarray ):
380377 msg = f"Got { type (data ._meta )} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported."
381378 raise ValueError (msg )
382- if data .chunksize [1 ] != data .shape [1 ]:
379+ chunked_axis , unchunked_axis = (
380+ (0 , 1 ) if isinstance (data ._meta , CSRBase | np .ndarray ) else (1 , 0 )
381+ )
382+ if data .chunksize [unchunked_axis ] != data .shape [unchunked_axis ]:
383383 msg = "Feature axis must be unchunked"
384384 raise ValueError (msg )
385385
386386 def aggregate_chunk_sum_or_count_nonzero (
387387 chunk : Array , * , func : Literal ["count_nonzero" , "sum" ], block_info = None
388388 ):
389- # See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
390- # for what is contained in `block_info`.
391- subset = slice (* block_info [0 ]["array-location" ][0 ])
392- by_subsetted = by [subset ]
393- mask_subsetted = mask [subset ] if mask is not None else mask
389+ # only subset the mask and by if we need to i.e.,
390+ # there is chunking along the same axis as by and mask
391+ if chunked_axis == 0 :
392+ # See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
393+ # for what is contained in `block_info`.
394+ subset = slice (* block_info [0 ]["array-location" ][0 ])
395+ by_subsetted = by [subset ]
396+ mask_subsetted = mask [subset ] if mask is not None else mask
397+ else :
398+ by_subsetted = by
399+ mask_subsetted = mask
394400 res = _aggregate (chunk , by_subsetted , func , mask = mask_subsetted , dof = dof )[func ]
395- return res [None , :]
401+ return res [None , :] if unchunked_axis == 1 else res
396402
397403 funcs = set ([func ] if isinstance (func , str ) else func )
398404 if "median" in funcs :
399405 msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
400406 raise NotImplementedError (msg )
401407 has_mean , has_var = (v in funcs for v in ["mean" , "var" ])
402408 funcs_no_var_or_mean = funcs - {"var" , "mean" }
403- # aggregate each row chunk individually,
404- # producing a #chunks × #categories × #features array,
409+ # aggregate each row chunk or column chunk individually,
410+ # producing a #chunks × #categories × #features or a #categories × #chunks array,
405411 # then aggregate the per-chunk results.
412+ chunks = (
413+ ((1 ,) * data .blocks .size , (len (by .categories ),), data .shape [1 ])
414+ if unchunked_axis == 1
415+ else (len (by .categories ), data .chunks [1 ])
416+ )
406417 aggregated = {
407418 f : data .map_blocks (
408419 partial (aggregate_chunk_sum_or_count_nonzero , func = func ),
409- new_axis = (1 ,),
410- chunks = (( 1 ,) * data . blocks . size , ( len ( by . categories ),), ( data . shape [ 1 ],)) ,
420+ new_axis = (1 ,) if unchunked_axis == 1 else None ,
421+ chunks = chunks ,
411422 meta = np .array (
412423 [],
413424 dtype = np .float64
414425 if func not in get_args (ConstantDtypeAgg )
415426 else data .dtype , # TODO: figure out best dtype for aggs like sum where dtype can change from original
416427 ),
417- ). sum ( axis = 0 )
428+ )
418429 for f in funcs_no_var_or_mean
419430 }
431+ # If we have row chunking, we need to handle the extra axis by summing over all category × feature matrices.
432+ # Otherwise, dask internally concatenates the #categories × #chunks arrays i.e., the column chunks are concatenated together to get a #categories × #features matrix.
433+ if unchunked_axis == 1 :
434+ for k , v in aggregated .items ():
435+ aggregated [k ] = v .sum (axis = chunked_axis )
420436 if has_var :
421437 aggredated_mean_var = aggregate_dask_mean_var (data , by , mask = mask , dof = dof )
422438 aggregated ["var" ] = aggredated_mean_var ["var" ]
0 commit comments