Skip to content

Commit 6605dd3

Browse files
ilan-goldflying-sheeppre-commit-ci[bot]
authored
feat: support csc in dask arrays in get.aggregate (#3872)
Co-authored-by: Philipp A. <flying-sheep@web.de> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent dd6e6fd commit 6605dd3

3 files changed

Lines changed: 34 additions & 20 deletions

File tree

docs/release-notes/3872.feat.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add in `csc`-in-{doc}`dask:index` support for {func}`scanpy.get.aggregate` {smaller}`I Gold`

src/scanpy/get/_aggregated.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from scipy import sparse
1111
from sklearn.utils.sparsefuncs import csc_median_axis_0
1212

13-
from scanpy._compat import CSBase, CSCBase, CSRBase, DaskArray
13+
from scanpy._compat import CSBase, CSRBase, DaskArray
1414

1515
from .._utils import _resolve_axis, get_literal_vals
1616
from .get import _check_mask
@@ -357,9 +357,6 @@ def aggregate_dask_mean_var(
357357
# TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
358358
if isinstance(data._meta, CSRBase):
359359
sq_mean = sq_mean.compute()
360-
elif isinstance(data._meta, CSCBase): # pragma: no-cover
361-
msg = "Cannot handle CSC matrices as dask meta."
362-
raise ValueError(msg)
363360
var = sq_mean - fau_power(mean, 2)
364361
if dof != 0:
365362
group_counts = np.bincount(by.codes)
@@ -376,47 +373,66 @@ def aggregate_dask(
376373
mask: NDArray[np.bool_] | None = None,
377374
dof: int = 1,
378375
) -> dict[AggType, DaskArray]:
379-
if not isinstance(data._meta, CSRBase | np.ndarray):
376+
if not isinstance(data._meta, CSBase | np.ndarray):
380377
msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported."
381378
raise ValueError(msg)
382-
if data.chunksize[1] != data.shape[1]:
379+
chunked_axis, unchunked_axis = (
380+
(0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0)
381+
)
382+
if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]:
383383
msg = "Feature axis must be unchunked"
384384
raise ValueError(msg)
385385

386386
def aggregate_chunk_sum_or_count_nonzero(
387387
chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None
388388
):
389-
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
390-
# for what is contained in `block_info`.
391-
subset = slice(*block_info[0]["array-location"][0])
392-
by_subsetted = by[subset]
393-
mask_subsetted = mask[subset] if mask is not None else mask
389+
# only subset the mask and by if we need to i.e.,
390+
# there is chunking along the same axis as by and mask
391+
if chunked_axis == 0:
392+
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
393+
# for what is contained in `block_info`.
394+
subset = slice(*block_info[0]["array-location"][0])
395+
by_subsetted = by[subset]
396+
mask_subsetted = mask[subset] if mask is not None else mask
397+
else:
398+
by_subsetted = by
399+
mask_subsetted = mask
394400
res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func]
395-
return res[None, :]
401+
return res[None, :] if unchunked_axis == 1 else res
396402

397403
funcs = set([func] if isinstance(func, str) else func)
398404
if "median" in funcs:
399405
msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
400406
raise NotImplementedError(msg)
401407
has_mean, has_var = (v in funcs for v in ["mean", "var"])
402408
funcs_no_var_or_mean = funcs - {"var", "mean"}
403-
# aggregate each row chunk individually,
404-
# producing a #chunks × #categories × #features array,
409+
# aggregate each row chunk or column chunk individually,
410+
# producing a #chunks × #categories × #features or a #categories × #chunks array,
405411
# then aggregate the per-chunk results.
412+
chunks = (
413+
((1,) * data.blocks.size, (len(by.categories),), data.shape[1])
414+
if unchunked_axis == 1
415+
else (len(by.categories), data.chunks[1])
416+
)
406417
aggregated = {
407418
f: data.map_blocks(
408419
partial(aggregate_chunk_sum_or_count_nonzero, func=func),
409-
new_axis=(1,),
410-
chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)),
420+
new_axis=(1,) if unchunked_axis == 1 else None,
421+
chunks=chunks,
411422
meta=np.array(
412423
[],
413424
dtype=np.float64
414425
if func not in get_args(ConstantDtypeAgg)
415426
else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original
416427
),
417-
).sum(axis=0)
428+
)
418429
for f in funcs_no_var_or_mean
419430
}
431+
# If we have row chunking, we need to handle the extra axis by summing over all category × feature matrices.
432+
# Otherwise, dask internally concatenates the #categories × #chunks arrays i.e., the column chunks are concatenated together to get a #categories × #features matrix.
433+
if unchunked_axis == 1:
434+
for k, v in aggregated.items():
435+
aggregated[k] = v.sum(axis=chunked_axis)
420436
if has_var:
421437
aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof)
422438
aggregated["var"] = aggredated_mean_var["var"]

tests/test_aggregated.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
not in {
3030
"dask_array_dense",
3131
"dask_array_sparse",
32-
"dask_array_sparse-1d_chunked-csc_array",
33-
"dask_array_sparse-1d_chunked-csc_matrix",
3432
}
3533
]
3634

@@ -246,7 +244,6 @@ def to_csc(x: CSRBase):
246244
@pytest.mark.parametrize(
247245
("func", "error_msg"),
248246
[
249-
pytest.param(to_csc, r"only csr_matrix", id="csc"),
250247
pytest.param(
251248
to_bad_chunking, r"Feature axis must be unchunked", id="bad_chunking"
252249
),

0 commit comments

Comments
 (0)