88import pytest
99from anndata import OldFormatWarning , read_zarr
1010
11- from scanpy ._compat import DaskArray , ZappyArray
11+ from scanpy ._compat import DaskArray
1212from scanpy .preprocessing import (
1313 filter_cells ,
1414 filter_genes ,
2525HERE = Path (__file__ ).parent / Path ("_data/" )
2626input_file = Path (HERE , "10x-10k-subset.zarr" )
2727
28- DIST_TYPES = (DaskArray , ZappyArray )
2928
30-
31- pytestmark = [needs .zarr ]
29+ pytestmark = [needs .zarr , needs .dask ]
3230
3331
3432@pytest .fixture
35- def adata (request : pytest . FixtureRequest ) -> AnnData :
33+ def adata () -> AnnData :
3634 with warnings .catch_warnings ():
3735 warnings .filterwarnings ("ignore" , category = OldFormatWarning )
3836 warnings .filterwarnings ("ignore" , r"Variable names are not unique" , UserWarning )
@@ -42,37 +40,25 @@ def adata(request: pytest.FixtureRequest) -> AnnData:
4240 return a
4341
4442
45- @pytest .fixture (
46- params = [
47- pytest .param ("direct" , marks = [needs .zappy ]),
48- pytest .param ("dask" , marks = [needs .dask ]),
49- ]
50- )
51- def adata_dist (request : pytest .FixtureRequest ) -> AnnData :
43+ @pytest .fixture
44+ def adata_dist () -> AnnData :
45+ import dask .array as da
46+
5247 # regular anndata except for X, which we replace farther down
5348 with warnings .catch_warnings ():
5449 warnings .filterwarnings ("ignore" , category = OldFormatWarning )
5550 warnings .filterwarnings ("ignore" , r"Variable names are not unique" , UserWarning )
5651 a = read_zarr (input_file )
5752 a .var_names_make_unique ()
58- a .uns ["dist-mode" ] = request .param
5953 input_file_x = f"{ input_file } /X"
60- if request .param == "direct" :
61- import zappy .direct
62-
63- a .X = zappy .direct .from_zarr (input_file_x )
64- return a
65-
66- assert request .param == "dask"
67- import dask .array as da
6854
6955 a .X = da .from_zarr (input_file_x )
7056 return a
7157
7258
7359def test_log1p (adata : AnnData , adata_dist : AnnData ):
7460 log1p (adata_dist )
75- assert isinstance (adata_dist .X , DIST_TYPES )
61+ assert isinstance (adata_dist .X , DaskArray )
7662 result = materialize_as_ndarray (adata_dist .X )
7763 log1p (adata )
7864 assert result .shape == adata .shape
@@ -87,7 +73,7 @@ def test_normalize_per_cell(
8773 reason = "normalize_per_cell deprecated and broken for Dask"
8874 request .applymarker (pytest .mark .xfail (reason = reason ))
8975 normalize_per_cell (adata_dist )
90- assert isinstance (adata_dist .X , DIST_TYPES )
76+ assert isinstance (adata_dist .X , DaskArray )
9177 result = materialize_as_ndarray (adata_dist .X )
9278 normalize_per_cell (adata )
9379 assert result .shape == adata .shape
@@ -97,7 +83,7 @@ def test_normalize_per_cell(
9783@pytest .mark .filterwarnings ("ignore:Some cells have zero counts:UserWarning" )
9884def test_normalize_total (adata : AnnData , adata_dist : AnnData ) -> None :
9985 normalize_total (adata_dist )
100- assert isinstance (adata_dist .X , DIST_TYPES )
86+ assert isinstance (adata_dist .X , DaskArray )
10187 result = materialize_as_ndarray (adata_dist .X )
10288 normalize_total (adata )
10389 assert result .shape == adata .shape
@@ -106,8 +92,8 @@ def test_normalize_total(adata: AnnData, adata_dist: AnnData) -> None:
10692
10793def test_filter_cells_array (adata : AnnData , adata_dist : AnnData ):
10894 cell_subset_dist , number_per_cell_dist = filter_cells (adata_dist .X , min_genes = 3 )
109- assert isinstance (cell_subset_dist , DIST_TYPES )
110- assert isinstance (number_per_cell_dist , DIST_TYPES )
95+ assert isinstance (cell_subset_dist , DaskArray )
96+ assert isinstance (number_per_cell_dist , DaskArray )
11197
11298 cell_subset , number_per_cell = filter_cells (adata .X , min_genes = 3 )
11399 npt .assert_allclose (materialize_as_ndarray (cell_subset_dist ), cell_subset )
@@ -116,7 +102,7 @@ def test_filter_cells_array(adata: AnnData, adata_dist: AnnData):
116102
117103def test_filter_cells (adata : AnnData , adata_dist : AnnData ):
118104 filter_cells (adata_dist , min_genes = 3 )
119- assert isinstance (adata_dist .X , DIST_TYPES )
105+ assert isinstance (adata_dist .X , DaskArray )
120106 result = materialize_as_ndarray (adata_dist .X )
121107 filter_cells (adata , min_genes = 3 )
122108
@@ -127,8 +113,8 @@ def test_filter_cells(adata: AnnData, adata_dist: AnnData):
127113
128114def test_filter_genes_array (adata : AnnData , adata_dist : AnnData ):
129115 gene_subset_dist , number_per_gene_dist = filter_genes (adata_dist .X , min_cells = 2 )
130- assert isinstance (gene_subset_dist , DIST_TYPES )
131- assert isinstance (number_per_gene_dist , DIST_TYPES )
116+ assert isinstance (gene_subset_dist , DaskArray )
117+ assert isinstance (number_per_gene_dist , DaskArray )
132118
133119 gene_subset , number_per_gene = filter_genes (adata .X , min_cells = 2 )
134120 npt .assert_allclose (materialize_as_ndarray (gene_subset_dist ), gene_subset )
@@ -137,35 +123,18 @@ def test_filter_genes_array(adata: AnnData, adata_dist: AnnData):
137123
138124def test_filter_genes (adata : AnnData , adata_dist : AnnData ):
139125 filter_genes (adata_dist , min_cells = 2 )
140- assert isinstance (adata_dist .X , DIST_TYPES )
126+ assert isinstance (adata_dist .X , DaskArray )
141127 result = materialize_as_ndarray (adata_dist .X )
142128 filter_genes (adata , min_cells = 2 )
143129 assert result .shape == adata .shape
144130 npt .assert_allclose (result , adata .X )
145131
146132
147- @pytest .mark .filterwarnings ("ignore::anndata.OldFormatWarning" )
148- def test_write_zarr (adata : AnnData , adata_dist : AnnData ):
149- import zarr
150-
133+ def test_write_zarr (adata : AnnData , adata_dist : AnnData , tmp_path : Path ) -> None :
151134 log1p (adata_dist )
152- assert isinstance (adata_dist .X , DIST_TYPES )
153- temp_store = zarr .TempStore ()
154- chunks = adata_dist .X .chunks
155- if isinstance (chunks [0 ], tuple ):
156- chunks = (chunks [0 ][0 ],) + chunks [1 ]
157-
158- # write metadata using regular anndata
159- adata .write_zarr (temp_store , chunks = chunks )
160- if adata_dist .uns ["dist-mode" ] == "dask" :
161- adata_dist .X .to_zarr (temp_store .dir_path ("X" ), overwrite = True )
162- elif adata_dist .uns ["dist-mode" ] == "direct" :
163- adata_dist .X .to_zarr (temp_store .dir_path ("X" ), chunks = chunks )
164- else :
165- pytest .fail ("add branch for new dist-mode" )
166-
167- # read back as zarr directly and check it is the same as adata.X
168- adata_log1p = read_zarr (temp_store )
135+ assert isinstance (adata_dist .X , DaskArray )
136+ adata_dist .write_zarr (tmp_path / "test.zarr" )
137+ adata_log1p = read_zarr (tmp_path / "test.zarr" )
169138
170139 log1p (adata )
171140 npt .assert_allclose (adata_log1p .X , adata .X )
0 commit comments