5858 [- 1.50180389 , 5.56886849 , 1.64034442 , 2.24476032 , - 0.05109001 ],
5959])
6060
61-
62- ARRAY_TYPES = [
61+ # These are array types which are expected to work with the current PCA implementation.
62+ VALID_ARRAY_TYPES = [
6363 param_with (
6464 at ,
6565 marks = [needs .dask_ml ] if at .id == "dask_array_dense-1d_chunked" else [],
6666 )
6767 for at in ARRAY_TYPES_ALL
68- if at .id not in {"dask_array_dense" , "dask_array_sparse" }
68+ if at .id
69+ not in {
70+ "dask_array_dense" ,
71+ "dask_array_sparse" ,
72+ "dask_array_sparse-1d_chunked-csc_array" ,
73+ "dask_array_sparse-1d_chunked-csc_matrix" ,
74+ }
6975]
7076
7177
72- @pytest .fixture (params = ARRAY_TYPES )
78+ @pytest .fixture (params = VALID_ARRAY_TYPES )
7379def array_type (request : pytest .FixtureRequest ) -> ArrayType :
7480 return request .param
7581
@@ -93,10 +99,14 @@ def gen_pca_params(
9399 xfail_reason = "dask without 1d chunking scheme not supported"
94100 yield None , None , xfail_reason
95101 return
96- if id == "dask_array_sparse-1d_chunked" and not zero_center :
102+ if "dask_array_sparse-1d_chunked" in id and not zero_center :
97103 xfail_reason = "Sparse-in-dask with zero_center=False not implemented yet"
98104 yield None , None , xfail_reason
99105 return
106+ if "dask_array_sparse-1d_chunked-csc" in id :
107+ xfail_reason = "Sparse-in-dask with csc blocks not implemented yet"
108+ yield None , None , xfail_reason
109+ return
100110 if svd_solver_type is None :
101111 yield None , None , None
102112 return
@@ -137,7 +147,12 @@ def possible_solvers(
137147 svd_solvers = {"auto" , "full" , "tsqr" , "randomized" , "covariance_eigh" }
138148 case (dc , False ) if id == "dask_array_dense-1d_chunked" :
139149 svd_solvers = {"tsqr" , "randomized" }
140- case (dc , True ) if id == "dask_array_sparse-1d_chunked" :
150+ case (dc , True ) if (
151+ # See https://github.com/scverse/scanpy/blob/216b21d91312b899e939db9636d9ab20e7c29d77/src/testing/scanpy/_pytest/params.py#L88-L103
152+ # for why we need two checks (i.e., before and after allowing CSC matrices)
153+ "dask_array_sparse-1d_chunked-csr" in id
154+ or id == "dask_array_sparse-1d_chunked"
155+ ):
141156 svd_solvers = {"covariance_eigh" }
142157 case (type () as dc , True ) if issubclass (dc , CSBase ):
143158 svd_solvers = {"arpack" } | SKLEARN_ADDITIONAL
@@ -148,7 +163,7 @@ def possible_solvers(
148163 case (helpers .asarray , False ):
149164 svd_solvers = {"arpack" , "randomized" }
150165 case _:
151- pytest .fail (f"Unknown { array_type = } ({ zero_center = } )" )
166+ pytest .fail (f"Unknown { array_type = } ({ zero_center = } ) ( { id = } ) " )
152167
153168 if svd_solver_type == "invalid" :
154169 svd_solvers = all_svd_solvers - svd_solvers
@@ -178,7 +193,7 @@ def possible_solvers(
178193 f"{ svd_solver or svd_solver_type } -{ 'xfail' if xfail_reason else warn_pat_expected } "
179194 ),
180195 )
181- for array_type in ARRAY_TYPES
196+ for array_type in VALID_ARRAY_TYPES
182197 for zero_center in [True , False ]
183198 for svd_solver_type in [None , "valid" , "invalid" ]
184199 for svd_solver , warn_pat_expected , xfail_reason in gen_pca_params (
@@ -542,10 +557,13 @@ def test_pca_rep(rep: Literal["layer", "obsm"]) -> None:
542557@pytest .mark .parametrize (
543558 "other_array_type" ,
544559 [
545- lambda x : x .toarray (),
546- * (at .values [0 ] for at in ARRAY_TYPES if "1d_chunked" in at .id ),
560+ pytest .param (lambda x : x .toarray (), id = "dense" ),
561+ * (
562+ pytest .param (at .values [0 ], id = at .id )
563+ for at in VALID_ARRAY_TYPES
564+ if "1d_chunked" in at .id
565+ ),
547566 ],
548- ids = ["dense-mem" , "sparse-dask" , "dense-dask" ],
549567)
550568def test_covariance_eigh_impls (other_array_type ):
551569 warnings .filterwarnings ("error" )
@@ -590,8 +608,8 @@ def test_sparse_dask_input_errors(msg_re: str, op: Callable[[DaskArray], DaskArr
590608 adata_sparse .X = op (
591609 next (
592610 at .values [0 ]
593- for at in ARRAY_TYPES
594- if at . id == "dask_array_sparse-1d_chunked"
611+ for at in VALID_ARRAY_TYPES
612+ if "dask_array_sparse-1d_chunked" in at . id
595613 )(adata_sparse .X )
596614 )
597615
@@ -612,7 +630,9 @@ def test_sparse_dask_input_errors(msg_re: str, op: Callable[[DaskArray], DaskArr
612630def test_cov_sparse_dask (dtype , dtype_arg , rtol ):
613631 x_arr = A_list .astype (dtype )
614632 x = next (
615- at .values [0 ] for at in ARRAY_TYPES if at .id == "dask_array_sparse-1d_chunked"
633+ at .values [0 ]
634+ for at in VALID_ARRAY_TYPES
635+ if "dask_array_sparse-1d_chunked" in at .id
616636 )(x_arr )
617637 cov , gram , mean = _cov_sparse_dask (x , return_gram = True , dtype = dtype_arg )
618638 np .testing .assert_allclose (mean , np .mean (x_arr , axis = 0 ))
0 commit comments