Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ test = [
"scanpy[test-min]",
# optional storage and processing modes
"scanpy[dask]",
"zappy",
"zarr<3",
"zarr>=2.18.7",
# additional tested algorithms
"scanpy[scrublet]",
"scanpy[leiden]",
Expand Down
8 changes: 0 additions & 8 deletions src/scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"CSRBase",
"DaskArray",
"SpBase",
"ZappyArray",
"_numba_threading_layer",
"deprecated",
"fullname",
Expand Down Expand Up @@ -58,13 +57,6 @@
DaskArray.__module__ = "dask.array"


if find_spec("zappy") or TYPE_CHECKING:
from zappy.base import ZappyArray
else:
ZappyArray = type("ZappyArray", (), {})
ZappyArray.__module__ = "zappy.base"


def fullname(typ: type) -> str:
module = typ.__module__
name = typ.__qualname__
Expand Down
4 changes: 1 addition & 3 deletions src/scanpy/preprocessing/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
if TYPE_CHECKING:
from numpy.typing import ArrayLike

from .._compat import ZappyArray


@overload
def materialize_as_ndarray(a: ArrayLike) -> np.ndarray: ...
Expand All @@ -33,7 +31,7 @@ def materialize_as_ndarray(


def materialize_as_ndarray(
a: DaskArray | ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
a: DaskArray | ArrayLike | tuple[ArrayLike | DaskArray, ...],
) -> tuple[np.ndarray] | np.ndarray:
"""Compute distributed arrays and convert them to numpy ndarrays."""
if isinstance(a, DaskArray):
Expand Down
1 change: 0 additions & 1 deletion src/testing/scanpy/_pytest/marks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def _generate_next_value_(
skimage = "scikit-image"
skmisc = "scikit-misc"
zarr = auto()
zappy = auto()
# external
bbknn = auto()
harmony = "harmonyTS"
Expand Down
69 changes: 19 additions & 50 deletions tests/test_preprocessing_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from anndata import OldFormatWarning, read_zarr

from scanpy._compat import DaskArray, ZappyArray
from scanpy._compat import DaskArray
from scanpy.preprocessing import (
filter_cells,
filter_genes,
Expand All @@ -25,14 +25,12 @@
HERE = Path(__file__).parent / Path("_data/")
input_file = Path(HERE, "10x-10k-subset.zarr")

DIST_TYPES = (DaskArray, ZappyArray)


pytestmark = [needs.zarr]


@pytest.fixture
def adata(request: pytest.FixtureRequest) -> AnnData:
def adata() -> AnnData:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=OldFormatWarning)
warnings.filterwarnings("ignore", r"Variable names are not unique", UserWarning)
Expand All @@ -42,37 +40,25 @@ def adata(request: pytest.FixtureRequest) -> AnnData:
return a


@pytest.fixture(
params=[
pytest.param("direct", marks=[needs.zappy]),
pytest.param("dask", marks=[needs.dask]),
]
)
def adata_dist(request: pytest.FixtureRequest) -> AnnData:
@pytest.fixture
def adata_dist() -> AnnData:
import dask.array as da

# regular anndata except for X, which we replace farther down
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=OldFormatWarning)
warnings.filterwarnings("ignore", r"Variable names are not unique", UserWarning)
a = read_zarr(input_file)
a.var_names_make_unique()
a.uns["dist-mode"] = request.param
input_file_x = f"{input_file}/X"
if request.param == "direct":
import zappy.direct

a.X = zappy.direct.from_zarr(input_file_x)
return a

assert request.param == "dask"
import dask.array as da

a.X = da.from_zarr(input_file_x)
return a


def test_log1p(adata: AnnData, adata_dist: AnnData):
log1p(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
assert isinstance(adata_dist.X, DaskArray)
result = materialize_as_ndarray(adata_dist.X)
log1p(adata)
assert result.shape == adata.shape
Expand All @@ -87,7 +73,7 @@ def test_normalize_per_cell(
reason = "normalize_per_cell deprecated and broken for Dask"
request.applymarker(pytest.mark.xfail(reason=reason))
normalize_per_cell(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
assert isinstance(adata_dist.X, DaskArray)
result = materialize_as_ndarray(adata_dist.X)
normalize_per_cell(adata)
assert result.shape == adata.shape
Expand All @@ -97,7 +83,7 @@ def test_normalize_per_cell(
@pytest.mark.filterwarnings("ignore:Some cells have zero counts:UserWarning")
def test_normalize_total(adata: AnnData, adata_dist: AnnData) -> None:
normalize_total(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
assert isinstance(adata_dist.X, DaskArray)
result = materialize_as_ndarray(adata_dist.X)
normalize_total(adata)
assert result.shape == adata.shape
Expand All @@ -106,8 +92,8 @@ def test_normalize_total(adata: AnnData, adata_dist: AnnData) -> None:

def test_filter_cells_array(adata: AnnData, adata_dist: AnnData):
cell_subset_dist, number_per_cell_dist = filter_cells(adata_dist.X, min_genes=3)
assert isinstance(cell_subset_dist, DIST_TYPES)
assert isinstance(number_per_cell_dist, DIST_TYPES)
assert isinstance(cell_subset_dist, DaskArray)
assert isinstance(number_per_cell_dist, DaskArray)

cell_subset, number_per_cell = filter_cells(adata.X, min_genes=3)
npt.assert_allclose(materialize_as_ndarray(cell_subset_dist), cell_subset)
Expand All @@ -116,7 +102,7 @@ def test_filter_cells_array(adata: AnnData, adata_dist: AnnData):

def test_filter_cells(adata: AnnData, adata_dist: AnnData):
filter_cells(adata_dist, min_genes=3)
assert isinstance(adata_dist.X, DIST_TYPES)
assert isinstance(adata_dist.X, DaskArray)
result = materialize_as_ndarray(adata_dist.X)
filter_cells(adata, min_genes=3)

Expand All @@ -127,8 +113,8 @@ def test_filter_cells(adata: AnnData, adata_dist: AnnData):

def test_filter_genes_array(adata: AnnData, adata_dist: AnnData):
gene_subset_dist, number_per_gene_dist = filter_genes(adata_dist.X, min_cells=2)
assert isinstance(gene_subset_dist, DIST_TYPES)
assert isinstance(number_per_gene_dist, DIST_TYPES)
assert isinstance(gene_subset_dist, DaskArray)
assert isinstance(number_per_gene_dist, DaskArray)

gene_subset, number_per_gene = filter_genes(adata.X, min_cells=2)
npt.assert_allclose(materialize_as_ndarray(gene_subset_dist), gene_subset)
Expand All @@ -137,35 +123,18 @@ def test_filter_genes_array(adata: AnnData, adata_dist: AnnData):

def test_filter_genes(adata: AnnData, adata_dist: AnnData):
filter_genes(adata_dist, min_cells=2)
assert isinstance(adata_dist.X, DIST_TYPES)
assert isinstance(adata_dist.X, DaskArray)
result = materialize_as_ndarray(adata_dist.X)
filter_genes(adata, min_cells=2)
assert result.shape == adata.shape
npt.assert_allclose(result, adata.X)


@pytest.mark.filterwarnings("ignore::anndata.OldFormatWarning")
def test_write_zarr(adata: AnnData, adata_dist: AnnData):
import zarr

def test_write_zarr(adata: AnnData, adata_dist: AnnData, tmp_path: Path) -> None:
log1p(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
temp_store = zarr.TempStore()
chunks = adata_dist.X.chunks
if isinstance(chunks[0], tuple):
chunks = (chunks[0][0],) + chunks[1]

# write metadata using regular anndata
adata.write_zarr(temp_store, chunks=chunks)
if adata_dist.uns["dist-mode"] == "dask":
adata_dist.X.to_zarr(temp_store.dir_path("X"), overwrite=True)
elif adata_dist.uns["dist-mode"] == "direct":
adata_dist.X.to_zarr(temp_store.dir_path("X"), chunks=chunks)
else:
pytest.fail("add branch for new dist-mode")

# read back as zarr directly and check it is the same as adata.X
adata_log1p = read_zarr(temp_store)
assert isinstance(adata_dist.X, DaskArray)
adata_dist.write_zarr(tmp_path / "test.zarr")
adata_log1p = read_zarr(tmp_path / "test.zarr")

log1p(adata)
npt.assert_allclose(adata_log1p.X, adata.X)
Loading