Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ test = [
"pytest-randomly",
"pytest-memray",
"pytest-mock",
"filelock",
"matplotlib",
"scikit-learn",
"openpyxl",
Expand Down
35 changes: 35 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
from dask.base import normalize_seq
else:
from dask.tokenize import normalize_seq
from filelock import FileLock
from scipy import sparse

import anndata as ad
from anndata.tests.helpers import subset_func # noqa: F401

if TYPE_CHECKING:
from collections.abc import Generator
from types import EllipsisType


Expand Down Expand Up @@ -75,6 +77,39 @@ def equivalent_ellipsis_index(
return ellipsis_index_with_equivalent[1]


@pytest.fixture(scope="session")
def local_cluster_addr(
tmp_path_factory: pytest.TempPathFactory, worker_id: str
) -> Generator[str, None, None]:
# Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
import dask.distributed as dd

def make_cluster() -> dd.LocalCluster:
return dd.LocalCluster(n_workers=1, threads_per_worker=1)

if worker_id == "master":
with make_cluster() as cluster:
yield cluster.scheduler_address
return

# get the temp directory shared by all workers
root_tmp_dir = tmp_path_factory.getbasetemp().parent

fn = root_tmp_dir / "dask_scheduler_address.txt"
lock = FileLock(str(fn) + ".lock")
lock.acquire() # can’t use context manager, because we need to release the lock before yielding
address = fn.read_text() if fn.is_file() else None
if address:
lock.release()
yield address
return

with make_cluster() as cluster:
fn.write_text(cluster.scheduler_address)
lock.release()
yield cluster.scheduler_address


#####################
# Dask tokenization #
#####################
Expand Down
9 changes: 3 additions & 6 deletions tests/lazy/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,11 @@ def test_concat_to_memory_var(


def test_concat_data_with_cluster_to_memory(
adata_remote: AnnData, join: Join_T, *, load_annotation_index: bool
):
adata_remote: AnnData, join: Join_T, local_cluster_addr: str
) -> None:
import dask.distributed as dd

with (
dd.LocalCluster(n_workers=1, threads_per_worker=1) as cluster,
dd.Client(cluster),
):
with dd.Client(local_cluster_addr):
ad.concat([adata_remote, adata_remote], join=join).to_memory()


Expand Down
19 changes: 14 additions & 5 deletions tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import pytest
Expand All @@ -22,6 +24,11 @@
gen_adata,
)

if TYPE_CHECKING:
from pathlib import Path
from typing import Literal


pytest.importorskip("dask.array")


Expand Down Expand Up @@ -103,18 +110,20 @@ def test_dask_write(adata, tmp_path, diskfmt):
assert isinstance(orig.varm["a"], DaskArray)


def test_dask_distributed_write(adata, tmp_path, diskfmt):
def test_dask_distributed_write(
adata: AnnData,
tmp_path: Path,
diskfmt: Literal["h5ad", "zarr"],
local_cluster_addr: str,
) -> None:
import dask.array as da
import dask.distributed as dd
import numpy as np

pth = tmp_path / f"test_write.{diskfmt}"
g = as_group(pth, mode="w")

with (
dd.LocalCluster(n_workers=1, threads_per_worker=1, processes=False) as cluster,
dd.Client(cluster),
):
with dd.Client(local_cluster_addr):
M, N = adata.X.shape
adata.obsm["a"] = da.random.random((M, 10))
adata.obsm["b"] = da.random.random((M, 10))
Expand Down
12 changes: 6 additions & 6 deletions tests/test_io_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import re
from pathlib import Path
from typing import TYPE_CHECKING

import h5py
Expand Down Expand Up @@ -77,7 +78,7 @@ def store(request, tmp_path) -> H5Group | ZarrGroup:


@pytest.fixture(params=sparse_formats)
def sparse_format(request):
def sparse_format(request: pytest.FixtureRequest) -> Literal["csr", "csc"]:
return request.param


Expand Down Expand Up @@ -346,18 +347,17 @@ def test_read_lazy_subsets_nd_dask(store, n_dims, chunks):
assert_equal(X_from_disk[index], X_dask_from_disk[index])


def test_read_lazy_h5_cluster(sparse_format, tmp_path):
def test_read_lazy_h5_cluster(
sparse_format: Literal["csr", "csc"], tmp_path: Path, local_cluster_addr: str
) -> None:
import dask.distributed as dd

with h5py.File(tmp_path / "test.h5", "w") as file:
store = file["/"]
arr_store = create_sparse_store(sparse_format, store)
X_dask_from_disk = read_elem_lazy(arr_store["X"])
X_from_disk = read_elem(arr_store["X"])
with (
dd.LocalCluster(n_workers=1, threads_per_worker=1) as cluster,
dd.Client(cluster) as _client,
):
with dd.Client(local_cluster_addr):
assert_equal(X_from_disk, X_dask_from_disk)


Expand Down