Skip to content

Commit dd6e6fd

Browse files
authored
chore: add csc-in-dask tests (#3870)
1 parent 819ca8c commit dd6e6fd

7 files changed

Lines changed: 156 additions & 32 deletions

File tree

src/testing/scanpy/_pytest/params.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from functools import wraps
5+
from functools import partial, wraps
66
from importlib.metadata import version
77
from typing import TYPE_CHECKING
88

@@ -29,6 +29,24 @@
2929
reason="scipy cs{rc}_array not supported in anndata<0.11",
3030
)
3131

32+
anndata_test_utils_supports_typ_kwarg = Version(version("anndata")) >= Version("0.12.6")
33+
34+
35+
def gen_csr_csc_params_wrapper(
36+
func: Callable,
37+
format: Literal["csr", "csc"],
38+
matrix_or_array: Literal["matrix", "array"],
39+
):
40+
def wrapper(arr):
41+
if anndata_test_utils_supports_typ_kwarg:
42+
return _chunked_1d(
43+
partial(func, typ=getattr(sparse, f"{format}_{matrix_or_array}"))
44+
)(arr)
45+
return _chunked_1d(func)(arr)
46+
47+
wrapper.__name__ = f"{func.__name__}-1d_chunked-{format}_{matrix_or_array}"
48+
return wrapper
49+
3250

3351
def param_with(
3452
at: ParameterSet,
@@ -48,7 +66,11 @@ def _chunked_1d(
4866
@wraps(f)
4967
def wrapper(a: np.ndarray) -> DaskArray:
5068
da = f(a)
51-
return da.rechunk((da.chunksize[0], -1))
69+
return da.rechunk(
70+
(da.chunksize[0], -1)
71+
if not hasattr(da._meta, "format") or da._meta.format == "csr"
72+
else (-1, da.chunksize[1])
73+
)
5274

5375
wrapper.__name__ = f"{wrapper.__name__}-1d_chunked"
5476
return wrapper
@@ -75,10 +97,31 @@ def wrapper(a: np.ndarray) -> DaskArray:
7597
("dask", "sparse"): tuple(
7698
pytest.param(
7799
wrapper(as_sparse_dask_matrix),
78-
marks=[needs.dask],
100+
marks=[needs.dask, skip_csc_mark]
101+
if skip_csc_mark is not None
102+
else [needs.dask],
79103
id=f"dask_array_sparse{suffix}",
80104
)
81-
for wrapper, suffix in [(lambda x: x, ""), (_chunked_1d, "-1d_chunked")]
105+
for wrapper, suffix, skip_csc_mark in [
106+
(lambda x: x, "", None),
107+
*(
108+
(
109+
partial(
110+
gen_csr_csc_params_wrapper,
111+
format=format,
112+
matrix_or_array=matrix_or_array,
113+
),
114+
f"-1d_chunked-{format}_{matrix_or_array}",
115+
pytest.mark.skipif(
116+
not anndata_test_utils_supports_typ_kwarg and format == "csc",
117+
reason="anndata < 0.12.6 lacked the required kwargs to enable csc matrix test utils.",
118+
),
119+
)
120+
for format in ["csr", "csc"]
121+
# TODO: use `array` as well once anndata 0.13 drops
122+
for matrix_or_array in ["matrix"]
123+
),
124+
]
82125
),
83126
}
84127

tests/test_aggregated.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,16 @@
2222

2323
from scanpy._compat import CSRBase
2424

25-
ARRAY_TYPES = [
25+
VALID_ARRAY_TYPES = [
2626
at
2727
for at in ARRAY_TYPES_ALL
28-
if at.id not in {"dask_array_dense", "dask_array_sparse"}
28+
if at.id
29+
not in {
30+
"dask_array_dense",
31+
"dask_array_sparse",
32+
"dask_array_sparse-1d_chunked-csc_array",
33+
"dask_array_sparse-1d_chunked-csc_matrix",
34+
}
2935
]
3036

3137

@@ -118,7 +124,7 @@ def test_mask(axis):
118124
assert np.all(by_name["0"].layers["sum"] == 0)
119125

120126

121-
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
127+
@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES)
122128
def test_aggregate_vs_pandas(
123129
metric: AggType, array_type, request: pytest.FixtureRequest
124130
):
@@ -160,7 +166,7 @@ def test_aggregate_vs_pandas(
160166
pd.testing.assert_frame_equal(result_df, expected, check_dtype=False, atol=1e-5)
161167

162168

163-
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
169+
@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES)
164170
def test_aggregate_axis(array_type, metric, request: pytest.FixtureRequest):
165171
adata = pbmc3k_processed().raw.to_adata()
166172
adata = adata[
@@ -456,7 +462,7 @@ def test_combine_categories(label_cols, cols, expected):
456462
pd.testing.assert_frame_equal(reconstructed_df, result_label_df)
457463

458464

459-
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
465+
@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES)
460466
def test_aggregate_arraytype(
461467
array_type, metric: AggType, request: pytest.FixtureRequest
462468
):

tests/test_highly_variable_genes.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,12 @@ def test_seurat_v3_bad_chunking(adata, array_type, flavor):
649649
],
650650
)
651651
@pytest.mark.parametrize(
652-
"array_type", [p for p in ARRAY_TYPES if "dask" not in p.id or "1d_chunked" in p.id]
652+
"array_type",
653+
[
654+
p
655+
for p in ARRAY_TYPES
656+
if "dask" not in p.id or ("1d_chunked" in p.id and "csr" in p.id)
657+
],
653658
)
654659
@pytest.mark.parametrize("batch_key", [None, "batch"])
655660
def test_subset_inplace_consistency(flavor, array_type, batch_key):
@@ -728,7 +733,9 @@ def test_subset_inplace_consistency(flavor, array_type, batch_key):
728733
],
729734
)
730735
@pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"])
731-
@pytest.mark.parametrize("to_dask", [p for p in ARRAY_TYPES if "1d_chunked" in p.id])
736+
@pytest.mark.parametrize(
737+
"to_dask", [p for p in ARRAY_TYPES if "1d_chunked" in p.id and "csr" in p.id]
738+
)
732739
def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask):
733740
# current blob produces singularities in loess....maybe a bad sign of the data?
734741
if "seurat_v3" in flavor:

tests/test_pca.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,24 @@
5858
[-1.50180389, 5.56886849, 1.64034442, 2.24476032, -0.05109001],
5959
])
6060

61-
62-
ARRAY_TYPES = [
61+
# These are array types which are expected to work with the current PCA implementation.
62+
VALID_ARRAY_TYPES = [
6363
param_with(
6464
at,
6565
marks=[needs.dask_ml] if at.id == "dask_array_dense-1d_chunked" else [],
6666
)
6767
for at in ARRAY_TYPES_ALL
68-
if at.id not in {"dask_array_dense", "dask_array_sparse"}
68+
if at.id
69+
not in {
70+
"dask_array_dense",
71+
"dask_array_sparse",
72+
"dask_array_sparse-1d_chunked-csc_array",
73+
"dask_array_sparse-1d_chunked-csc_matrix",
74+
}
6975
]
7076

7177

72-
@pytest.fixture(params=ARRAY_TYPES)
78+
@pytest.fixture(params=VALID_ARRAY_TYPES)
7379
def array_type(request: pytest.FixtureRequest) -> ArrayType:
7480
return request.param
7581

@@ -93,10 +99,14 @@ def gen_pca_params(
9399
xfail_reason = "dask without 1d chunking scheme not supported"
94100
yield None, None, xfail_reason
95101
return
96-
if id == "dask_array_sparse-1d_chunked" and not zero_center:
102+
if "dask_array_sparse-1d_chunked" in id and not zero_center:
97103
xfail_reason = "Sparse-in-dask with zero_center=False not implemented yet"
98104
yield None, None, xfail_reason
99105
return
106+
if "dask_array_sparse-1d_chunked-csc" in id:
107+
xfail_reason = "Sparse-in-dask with csc blocks not implemented yet"
108+
yield None, None, xfail_reason
109+
return
100110
if svd_solver_type is None:
101111
yield None, None, None
102112
return
@@ -137,7 +147,12 @@ def possible_solvers(
137147
svd_solvers = {"auto", "full", "tsqr", "randomized", "covariance_eigh"}
138148
case (dc, False) if id == "dask_array_dense-1d_chunked":
139149
svd_solvers = {"tsqr", "randomized"}
140-
case (dc, True) if id == "dask_array_sparse-1d_chunked":
150+
case (dc, True) if (
151+
# See https://github.com/scverse/scanpy/blob/216b21d91312b899e939db9636d9ab20e7c29d77/src/testing/scanpy/_pytest/params.py#L88-L103
152+
# for why we need two checks (i.e., before and after allowing CSC matrices)
153+
"dask_array_sparse-1d_chunked-csr" in id
154+
or id == "dask_array_sparse-1d_chunked"
155+
):
141156
svd_solvers = {"covariance_eigh"}
142157
case (type() as dc, True) if issubclass(dc, CSBase):
143158
svd_solvers = {"arpack"} | SKLEARN_ADDITIONAL
@@ -148,7 +163,7 @@ def possible_solvers(
148163
case (helpers.asarray, False):
149164
svd_solvers = {"arpack", "randomized"}
150165
case _:
151-
pytest.fail(f"Unknown {array_type=} ({zero_center=})")
166+
pytest.fail(f"Unknown {array_type=} ({zero_center=}) ({id=})")
152167

153168
if svd_solver_type == "invalid":
154169
svd_solvers = all_svd_solvers - svd_solvers
@@ -178,7 +193,7 @@ def possible_solvers(
178193
f"{svd_solver or svd_solver_type}-{'xfail' if xfail_reason else warn_pat_expected}"
179194
),
180195
)
181-
for array_type in ARRAY_TYPES
196+
for array_type in VALID_ARRAY_TYPES
182197
for zero_center in [True, False]
183198
for svd_solver_type in [None, "valid", "invalid"]
184199
for svd_solver, warn_pat_expected, xfail_reason in gen_pca_params(
@@ -542,10 +557,13 @@ def test_pca_rep(rep: Literal["layer", "obsm"]) -> None:
542557
@pytest.mark.parametrize(
543558
"other_array_type",
544559
[
545-
lambda x: x.toarray(),
546-
*(at.values[0] for at in ARRAY_TYPES if "1d_chunked" in at.id),
560+
pytest.param(lambda x: x.toarray(), id="dense"),
561+
*(
562+
pytest.param(at.values[0], id=at.id)
563+
for at in VALID_ARRAY_TYPES
564+
if "1d_chunked" in at.id
565+
),
547566
],
548-
ids=["dense-mem", "sparse-dask", "dense-dask"],
549567
)
550568
def test_covariance_eigh_impls(other_array_type):
551569
warnings.filterwarnings("error")
@@ -590,8 +608,8 @@ def test_sparse_dask_input_errors(msg_re: str, op: Callable[[DaskArray], DaskArr
590608
adata_sparse.X = op(
591609
next(
592610
at.values[0]
593-
for at in ARRAY_TYPES
594-
if at.id == "dask_array_sparse-1d_chunked"
611+
for at in VALID_ARRAY_TYPES
612+
if "dask_array_sparse-1d_chunked" in at.id
595613
)(adata_sparse.X)
596614
)
597615

@@ -612,7 +630,9 @@ def test_sparse_dask_input_errors(msg_re: str, op: Callable[[DaskArray], DaskArr
612630
def test_cov_sparse_dask(dtype, dtype_arg, rtol):
613631
x_arr = A_list.astype(dtype)
614632
x = next(
615-
at.values[0] for at in ARRAY_TYPES if at.id == "dask_array_sparse-1d_chunked"
633+
at.values[0]
634+
for at in VALID_ARRAY_TYPES
635+
if "dask_array_sparse-1d_chunked" in at.id
616636
)(x_arr)
617637
cov, gram, mean = _cov_sparse_dask(x, return_gram=True, dtype=dtype_arg)
618638
np.testing.assert_allclose(mean, np.mean(x_arr, axis=0))

tests/test_preprocessing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,12 @@ def test_sample_copy_backed_error(tmp_path):
280280

281281
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
282282
@pytest.mark.parametrize("max_value", [None, 1.0], ids=["no_clip", "clip"])
283-
def test_scale_matrix_types(array_type, zero_center, max_value):
283+
def test_scale_matrix_types(
284+
*,
285+
array_type: Callable,
286+
zero_center: bool,
287+
max_value: float | None,
288+
):
284289
adata = pbmc68k_reduced()
285290
adata.X = adata.raw.X
286291
adata_casted = adata.copy()

tests/test_qc_metrics.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from contextlib import nullcontext
4+
35
import numpy as np
46
import pandas as pd
57
import pytest
@@ -9,7 +11,7 @@
911
from scipy import sparse
1012

1113
import scanpy as sc
12-
from scanpy._compat import DaskArray
14+
from scanpy._compat import CSCBase, DaskArray
1315
from scanpy.preprocessing._qc import (
1416
describe_obs,
1517
describe_var,
@@ -83,8 +85,16 @@ def test_top_segments(request: pytest.FixtureRequest, array_type):
8385
reason = "DaskArray with feature axis chunking not yet supported"
8486
request.applymarker(pytest.mark.xfail(reason=reason))
8587
a = array_type(np.ones((300, 100)))
86-
with maybe_dask_process_context():
88+
is_csc_dask = isinstance(a, DaskArray) and isinstance(a._meta, CSCBase)
89+
with (
90+
maybe_dask_process_context(),
91+
pytest.raises(ValueError, match=r"DaskArray must have csr")
92+
if is_csc_dask
93+
else nullcontext(),
94+
):
8795
seg = top_segment_proportions(a, [50, 100])
96+
if is_csc_dask:
97+
return
8898
assert (seg[:, 0] == 0.5).all()
8999
assert (seg[:, 1] == 1.0).all()
90100

@@ -107,10 +117,22 @@ def test_top_proportions(request: pytest.FixtureRequest, array_type):
107117
# While many of these are trivial,
108118
# they’re also just making sure the metrics are there
109119
def test_qc_metrics(adata_prepared: AnnData):
110-
with maybe_dask_process_context():
120+
is_csc_dask = isinstance(adata_prepared.X, DaskArray) and isinstance(
121+
adata_prepared.X._meta, CSCBase
122+
)
123+
with (
124+
maybe_dask_process_context(),
125+
(
126+
pytest.raises(ValueError, match=r"DaskArray must have csr")
127+
if is_csc_dask
128+
else nullcontext()
129+
),
130+
):
111131
sc.pp.calculate_qc_metrics(
112132
adata_prepared, qc_vars=["mito", "negative"], inplace=True
113133
)
134+
if is_csc_dask:
135+
return
114136
x = (
115137
adata_prepared.X.compute()
116138
if isinstance(adata_prepared.X, DaskArray)
@@ -159,14 +181,26 @@ def test_qc_metrics(adata_prepared: AnnData):
159181

160182

161183
def test_qc_metrics_idempotent(adata_prepared: AnnData):
162-
with maybe_dask_process_context():
184+
is_csc_dask = isinstance(adata_prepared.X, DaskArray) and isinstance(
185+
adata_prepared.X._meta, CSCBase
186+
)
187+
with (
188+
maybe_dask_process_context(),
189+
(
190+
pytest.raises(ValueError, match=r"DaskArray must have csr")
191+
if is_csc_dask
192+
else nullcontext()
193+
),
194+
):
163195
sc.pp.calculate_qc_metrics(
164196
adata_prepared, qc_vars=["mito", "negative"], inplace=True
165197
)
166198
old_obs, old_var = adata_prepared.obs.copy(), adata_prepared.var.copy()
167199
sc.pp.calculate_qc_metrics(
168200
adata_prepared, qc_vars=["mito", "negative"], inplace=True
169201
)
202+
if is_csc_dask:
203+
return
170204
assert set(adata_prepared.obs.columns) == set(old_obs.columns)
171205
assert set(adata_prepared.var.columns) == set(old_var.columns)
172206
for col in adata_prepared.obs:
@@ -176,7 +210,15 @@ def test_qc_metrics_idempotent(adata_prepared: AnnData):
176210

177211

178212
def test_qc_metrics_no_log1p(adata_prepared: AnnData):
179-
with maybe_dask_process_context():
213+
with (
214+
maybe_dask_process_context(),
215+
(
216+
pytest.raises(ValueError, match=r"DaskArray must have csr")
217+
if isinstance(adata_prepared.X, DaskArray)
218+
and isinstance(adata_prepared.X._meta, CSCBase)
219+
else nullcontext()
220+
),
221+
):
180222
sc.pp.calculate_qc_metrics(
181223
adata_prepared, qc_vars=["mito", "negative"], log1p=False, inplace=True
182224
)

0 commit comments

Comments
 (0)