Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

from scanpy._compat import CSBase, CSCBase, CSRBase, DaskArray
from scanpy._compat import CSBase, CSRBase, DaskArray

from .._utils import _resolve_axis, get_literal_vals
from .get import _check_mask
Expand Down Expand Up @@ -357,9 +357,6 @@ def aggregate_dask_mean_var(
# TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
if isinstance(data._meta, CSRBase):
sq_mean = sq_mean.compute()
elif isinstance(data._meta, CSCBase): # pragma: no-cover
msg = "Cannot handle CSC matrices as dask meta."
raise ValueError(msg)
var = sq_mean - fau_power(mean, 2)
if dof != 0:
group_counts = np.bincount(by.codes)
Expand All @@ -376,47 +373,66 @@ def aggregate_dask(
mask: NDArray[np.bool_] | None = None,
dof: int = 1,
) -> dict[AggType, DaskArray]:
if not isinstance(data._meta, CSRBase | np.ndarray):
if not isinstance(data._meta, CSBase | np.ndarray):
msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported."
raise ValueError(msg)
if data.chunksize[1] != data.shape[1]:
chunked_axis, unchunked_axis = (
(0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0)
)
if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]:
msg = "Feature axis must be unchunked"
raise ValueError(msg)

def aggregate_chunk_sum_or_count_nonzero(
chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None
):
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
# for what is contained in `block_info`.
subset = slice(*block_info[0]["array-location"][0])
by_subsetted = by[subset]
mask_subsetted = mask[subset] if mask is not None else mask
# only subset the mask and by if we need to i.e.,
# there is chunking along the same axis as by and mask
if chunked_axis == 0:
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
# for what is contained in `block_info`.
subset = slice(*block_info[0]["array-location"][0])
by_subsetted = by[subset]
mask_subsetted = mask[subset] if mask is not None else mask
else:
by_subsetted = by
mask_subsetted = mask
res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func]
return res[None, :]
return res[None, :] if unchunked_axis == 1 else res

funcs = set([func] if isinstance(func, str) else func)
if "median" in funcs:
msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
raise NotImplementedError(msg)
has_mean, has_var = (v in funcs for v in ["mean", "var"])
funcs_no_var_or_mean = funcs - {"var", "mean"}
# aggregate each row chunk individually,
# producing a #chunks × #categories × #features array,
# aggregate each row chunk or column chunk individually,
# producing a #chunks × #categories × #features or a #categories × #chunks array,
# then aggregate the per-chunk results.
chunks = (
((1,) * data.blocks.size, (len(by.categories),), data.shape[1])
if unchunked_axis == 1
else (len(by.categories), data.chunks[1])
)
aggregated = {
f: data.map_blocks(
partial(aggregate_chunk_sum_or_count_nonzero, func=func),
new_axis=(1,),
chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)),
new_axis=(1,) if unchunked_axis == 1 else None,
chunks=chunks,
meta=np.array(
[],
dtype=np.float64
if func not in get_args(ConstantDtypeAgg)
else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original
),
).sum(axis=0)
)
for f in funcs_no_var_or_mean
}
# If we have row chunking, we need to handle the extra axis by summing over all category × feature matrices.
# Otherwise, dask internally concatenates the #categories × #chunks arrays i.e., the column chunks are concatenated together to get a #categories × #features matrix.
if unchunked_axis == 1:
for k, v in aggregated.items():
aggregated[k] = v.sum(axis=chunked_axis)
if has_var:
aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof)
aggregated["var"] = aggredated_mean_var["var"]
Expand Down
51 changes: 47 additions & 4 deletions src/testing/scanpy/_pytest/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from functools import wraps
from functools import partial, wraps
from importlib.metadata import version
from typing import TYPE_CHECKING

Expand All @@ -29,6 +29,24 @@
reason="scipy cs{rc}_array not supported in anndata<0.11",
)

anndata_test_utils_supports_typ_kwarg = Version(version("anndata")) >= Version("0.12.6")


def gen_csr_csc_params_wrapper(
func: Callable,
format: Literal["csr", "csc"],
matrix_or_array: Literal["matrix", "array"],
):
def wrapper(arr):
if anndata_test_utils_supports_typ_kwarg:
return _chunked_1d(
partial(func, typ=getattr(sparse, f"{format}_{matrix_or_array}"))
)(arr)
return _chunked_1d(func)(arr)

wrapper.__name__ = f"{func.__name__}-1d_chunked-{format}_{matrix_or_array}"
return wrapper


def param_with(
at: ParameterSet,
Expand All @@ -48,7 +66,11 @@ def _chunked_1d(
@wraps(f)
def wrapper(a: np.ndarray) -> DaskArray:
da = f(a)
return da.rechunk((da.chunksize[0], -1))
return da.rechunk(
(da.chunksize[0], -1)
if not hasattr(da._meta, "format") or da._meta.format == "csr"
else (-1, da.chunksize[1])
)

wrapper.__name__ = f"{wrapper.__name__}-1d_chunked"
return wrapper
Expand All @@ -75,10 +97,31 @@ def wrapper(a: np.ndarray) -> DaskArray:
("dask", "sparse"): tuple(
pytest.param(
wrapper(as_sparse_dask_matrix),
marks=[needs.dask],
marks=[needs.dask, skip_csc_mark]
if skip_csc_mark is not None
else [needs.dask],
id=f"dask_array_sparse{suffix}",
)
for wrapper, suffix in [(lambda x: x, ""), (_chunked_1d, "-1d_chunked")]
for wrapper, suffix, skip_csc_mark in [
(lambda x: x, "", None),
*(
(
partial(
gen_csr_csc_params_wrapper,
format=format,
matrix_or_array=matrix_or_array,
),
f"-1d_chunked-{format}_{matrix_or_array}",
pytest.mark.skipif(
not anndata_test_utils_supports_typ_kwarg and format == "csc",
reason="anndata < 0.12.6 lacked the required kwargs to enable csc matrix test utils.",
),
)
for format in ["csr", "csc"]
# TODO: use `array` as well once anndata 0.13 drops
for matrix_or_array in ["matrix"]
),
]
),
}

Expand Down
15 changes: 9 additions & 6 deletions tests/test_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@

from scanpy._compat import CSRBase

ARRAY_TYPES = [
VALID_ARRAY_TYPES = [
at
for at in ARRAY_TYPES_ALL
if at.id not in {"dask_array_dense", "dask_array_sparse"}
if at.id
not in {
"dask_array_dense",
"dask_array_sparse",
}
]


Expand Down Expand Up @@ -118,7 +122,7 @@ def test_mask(axis):
assert np.all(by_name["0"].layers["sum"] == 0)


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES)
def test_aggregate_vs_pandas(
metric: AggType, array_type, request: pytest.FixtureRequest
):
Expand Down Expand Up @@ -160,7 +164,7 @@ def test_aggregate_vs_pandas(
pd.testing.assert_frame_equal(result_df, expected, check_dtype=False, atol=1e-5)


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES)
def test_aggregate_axis(array_type, metric, request: pytest.FixtureRequest):
adata = pbmc3k_processed().raw.to_adata()
adata = adata[
Expand Down Expand Up @@ -240,7 +244,6 @@ def to_csc(x: CSRBase):
@pytest.mark.parametrize(
("func", "error_msg"),
[
pytest.param(to_csc, r"only csr_matrix", id="csc"),
pytest.param(
to_bad_chunking, r"Feature axis must be unchunked", id="bad_chunking"
),
Expand Down Expand Up @@ -456,7 +459,7 @@ def test_combine_categories(label_cols, cols, expected):
pd.testing.assert_frame_equal(reconstructed_df, result_label_df)


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES)
def test_aggregate_arraytype(
array_type, metric: AggType, request: pytest.FixtureRequest
):
Expand Down
11 changes: 9 additions & 2 deletions tests/test_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,12 @@ def test_seurat_v3_bad_chunking(adata, array_type, flavor):
],
)
@pytest.mark.parametrize(
"array_type", [p for p in ARRAY_TYPES if "dask" not in p.id or "1d_chunked" in p.id]
"array_type",
[
p
for p in ARRAY_TYPES
if "dask" not in p.id or ("1d_chunked" in p.id and "csr" in p.id)
],
)
@pytest.mark.parametrize("batch_key", [None, "batch"])
def test_subset_inplace_consistency(flavor, array_type, batch_key):
Expand Down Expand Up @@ -728,7 +733,9 @@ def test_subset_inplace_consistency(flavor, array_type, batch_key):
],
)
@pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"])
@pytest.mark.parametrize("to_dask", [p for p in ARRAY_TYPES if "1d_chunked" in p.id])
@pytest.mark.parametrize(
"to_dask", [p for p in ARRAY_TYPES if "1d_chunked" in p.id and "csr" in p.id]
)
def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask):
# current blob produces singularities in loess....maybe a bad sign of the data?
if "seurat_v3" in flavor:
Expand Down
48 changes: 34 additions & 14 deletions tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,24 @@
[-1.50180389, 5.56886849, 1.64034442, 2.24476032, -0.05109001],
])


ARRAY_TYPES = [
# These are array types which are expected to work with the current PCA implementation.
VALID_ARRAY_TYPES = [
param_with(
at,
marks=[needs.dask_ml] if at.id == "dask_array_dense-1d_chunked" else [],
)
for at in ARRAY_TYPES_ALL
if at.id not in {"dask_array_dense", "dask_array_sparse"}
if at.id
not in {
"dask_array_dense",
"dask_array_sparse",
"dask_array_sparse-1d_chunked-csc_array",
"dask_array_sparse-1d_chunked-csc_matrix",
}
]


@pytest.fixture(params=ARRAY_TYPES)
@pytest.fixture(params=VALID_ARRAY_TYPES)
def array_type(request: pytest.FixtureRequest) -> ArrayType:
return request.param

Expand All @@ -93,10 +99,14 @@ def gen_pca_params(
xfail_reason = "dask without 1d chunking scheme not supported"
yield None, None, xfail_reason
return
if id == "dask_array_sparse-1d_chunked" and not zero_center:
if "dask_array_sparse-1d_chunked" in id and not zero_center:
xfail_reason = "Sparse-in-dask with zero_center=False not implemented yet"
yield None, None, xfail_reason
return
if "dask_array_sparse-1d_chunked-csc" in id:
xfail_reason = "Sparse-in-dask with csc blocks not implemented yet"
yield None, None, xfail_reason
return
if svd_solver_type is None:
yield None, None, None
return
Expand Down Expand Up @@ -137,7 +147,12 @@ def possible_solvers(
svd_solvers = {"auto", "full", "tsqr", "randomized", "covariance_eigh"}
case (dc, False) if id == "dask_array_dense-1d_chunked":
svd_solvers = {"tsqr", "randomized"}
case (dc, True) if id == "dask_array_sparse-1d_chunked":
case (dc, True) if (
# See https://github.com/scverse/scanpy/blob/216b21d91312b899e939db9636d9ab20e7c29d77/src/testing/scanpy/_pytest/params.py#L88-L103
# for why we need two checks (i.e., before and after allowing CSC matrices)
"dask_array_sparse-1d_chunked-csr" in id
or id == "dask_array_sparse-1d_chunked"
):
svd_solvers = {"covariance_eigh"}
case (type() as dc, True) if issubclass(dc, CSBase):
svd_solvers = {"arpack"} | SKLEARN_ADDITIONAL
Expand All @@ -148,7 +163,7 @@ def possible_solvers(
case (helpers.asarray, False):
svd_solvers = {"arpack", "randomized"}
case _:
pytest.fail(f"Unknown {array_type=} ({zero_center=})")
pytest.fail(f"Unknown {array_type=} ({zero_center=}) ({id=})")

if svd_solver_type == "invalid":
svd_solvers = all_svd_solvers - svd_solvers
Expand Down Expand Up @@ -178,7 +193,7 @@ def possible_solvers(
f"{svd_solver or svd_solver_type}-{'xfail' if xfail_reason else warn_pat_expected}"
),
)
for array_type in ARRAY_TYPES
for array_type in VALID_ARRAY_TYPES
for zero_center in [True, False]
for svd_solver_type in [None, "valid", "invalid"]
for svd_solver, warn_pat_expected, xfail_reason in gen_pca_params(
Expand Down Expand Up @@ -542,10 +557,13 @@ def test_pca_rep(rep: Literal["layer", "obsm"]) -> None:
@pytest.mark.parametrize(
"other_array_type",
[
lambda x: x.toarray(),
*(at.values[0] for at in ARRAY_TYPES if "1d_chunked" in at.id),
pytest.param(lambda x: x.toarray(), id="dense"),
*(
pytest.param(at.values[0], id=at.id)
for at in VALID_ARRAY_TYPES
if "1d_chunked" in at.id
),
],
ids=["dense-mem", "sparse-dask", "dense-dask"],
)
def test_covariance_eigh_impls(other_array_type):
warnings.filterwarnings("error")
Expand Down Expand Up @@ -590,8 +608,8 @@ def test_sparse_dask_input_errors(msg_re: str, op: Callable[[DaskArray], DaskArr
adata_sparse.X = op(
next(
at.values[0]
for at in ARRAY_TYPES
if at.id == "dask_array_sparse-1d_chunked"
for at in VALID_ARRAY_TYPES
if "dask_array_sparse-1d_chunked" in at.id
)(adata_sparse.X)
)

Expand All @@ -612,7 +630,9 @@ def test_sparse_dask_input_errors(msg_re: str, op: Callable[[DaskArray], DaskArr
def test_cov_sparse_dask(dtype, dtype_arg, rtol):
x_arr = A_list.astype(dtype)
x = next(
at.values[0] for at in ARRAY_TYPES if at.id == "dask_array_sparse-1d_chunked"
at.values[0]
for at in VALID_ARRAY_TYPES
if "dask_array_sparse-1d_chunked" in at.id
)(x_arr)
cov, gram, mean = _cov_sparse_dask(x, return_gram=True, dtype=dtype_arg)
np.testing.assert_allclose(mean, np.mean(x_arr, axis=0))
Expand Down
7 changes: 6 additions & 1 deletion tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,12 @@ def test_sample_copy_backed_error(tmp_path):

@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("max_value", [None, 1.0], ids=["no_clip", "clip"])
def test_scale_matrix_types(array_type, zero_center, max_value):
def test_scale_matrix_types(
*,
array_type: Callable,
zero_center: bool,
max_value: float | None,
):
adata = pbmc68k_reduced()
adata.X = adata.raw.X
adata_casted = adata.copy()
Expand Down
Loading
Loading