Skip to content

Commit 193b50a

Browse files
authored
Backport PR #1924 on branch 0.11.x (Use same local Dask cluster in tests) (#1926)
1 parent 8b420a7 commit 193b50a

4 files changed

Lines changed: 56 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ test = [
9191
"pytest-memray",
9292
"pytest-mock",
9393
"zarr<3",
94+
"filelock",
9495
"matplotlib",
9596
"scikit-learn",
9697
"openpyxl",

tests/conftest.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
from dask.base import normalize_seq
1414
else:
1515
from dask.tokenize import normalize_seq
16+
from filelock import FileLock
1617
from scipy import sparse
1718

1819
import anndata as ad
1920
from anndata.tests.helpers import subset_func # noqa: F401
2021

2122
if TYPE_CHECKING:
23+
from collections.abc import Generator
2224
from types import EllipsisType
2325

2426

@@ -75,6 +77,39 @@ def equivalent_ellipsis_index(
7577
return ellipsis_index_with_equivalent[1]
7678

7779

80+
@pytest.fixture(scope="session")
81+
def local_cluster_addr(
82+
tmp_path_factory: pytest.TempPathFactory, worker_id: str
83+
) -> Generator[str, None, None]:
84+
# Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
85+
import dask.distributed as dd
86+
87+
def make_cluster() -> dd.LocalCluster:
88+
return dd.LocalCluster(n_workers=1, threads_per_worker=1)
89+
90+
if worker_id == "master":
91+
with make_cluster() as cluster:
92+
yield cluster.scheduler_address
93+
return
94+
95+
# get the temp directory shared by all workers
96+
root_tmp_dir = tmp_path_factory.getbasetemp().parent
97+
98+
fn = root_tmp_dir / "dask_scheduler_address.txt"
99+
lock = FileLock(str(fn) + ".lock")
100+
lock.acquire() # can’t use context manager, because we need to release the lock before yielding
101+
address = fn.read_text() if fn.is_file() else None
102+
if address:
103+
lock.release()
104+
yield address
105+
return
106+
107+
with make_cluster() as cluster:
108+
fn.write_text(cluster.scheduler_address)
109+
lock.release()
110+
yield cluster.scheduler_address
111+
112+
78113
#####################
79114
# Dask tokenization #
80115
#####################

tests/test_dask.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from __future__ import annotations
66

7+
from typing import TYPE_CHECKING
8+
79
import numpy as np
810
import pandas as pd
911
import pytest
@@ -22,6 +24,11 @@
2224
gen_adata,
2325
)
2426

27+
if TYPE_CHECKING:
28+
from pathlib import Path
29+
from typing import Literal
30+
31+
2532
pytest.importorskip("dask.array")
2633

2734

@@ -103,18 +110,20 @@ def test_dask_write(adata, tmp_path, diskfmt):
103110
assert isinstance(orig.varm["a"], DaskArray)
104111

105112

106-
def test_dask_distributed_write(adata, tmp_path, diskfmt):
113+
def test_dask_distributed_write(
114+
adata: AnnData,
115+
tmp_path: Path,
116+
diskfmt: Literal["h5ad", "zarr"],
117+
local_cluster_addr: str,
118+
) -> None:
107119
import dask.array as da
108120
import dask.distributed as dd
109121
import numpy as np
110122

111123
pth = tmp_path / f"test_write.{diskfmt}"
112124
g = as_group(pth, mode="w")
113125

114-
with (
115-
dd.LocalCluster(n_workers=1, threads_per_worker=1, processes=False) as cluster,
116-
dd.Client(cluster),
117-
):
126+
with dd.Client(local_cluster_addr):
118127
M, N = adata.X.shape
119128
adata.obsm["a"] = da.random.random((M, 10))
120129
adata.obsm["b"] = da.random.random((M, 10))

tests/test_io_elementwise.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import re
8+
from pathlib import Path
89
from typing import TYPE_CHECKING
910

1011
import h5py
@@ -76,7 +77,7 @@ def store(request, tmp_path) -> H5Group | ZarrGroup:
7677

7778

7879
@pytest.fixture(params=sparse_formats)
79-
def sparse_format(request):
80+
def sparse_format(request: pytest.FixtureRequest) -> Literal["csr", "csc"]:
8081
return request.param
8182

8283

@@ -344,18 +345,17 @@ def test_read_lazy_subsets_nd_dask(store, n_dims, chunks):
344345
assert_equal(X_from_disk[index], X_dask_from_disk[index])
345346

346347

347-
def test_read_lazy_h5_cluster(sparse_format, tmp_path):
348+
def test_read_lazy_h5_cluster(
349+
sparse_format: Literal["csr", "csc"], tmp_path: Path, local_cluster_addr: str
350+
) -> None:
348351
import dask.distributed as dd
349352

350353
with h5py.File(tmp_path / "test.h5", "w") as file:
351354
store = file["/"]
352355
arr_store = create_sparse_store(sparse_format, store)
353356
X_dask_from_disk = read_elem_as_dask(arr_store["X"])
354357
X_from_disk = read_elem(arr_store["X"])
355-
with (
356-
dd.LocalCluster(n_workers=1, threads_per_worker=1) as cluster,
357-
dd.Client(cluster) as _client,
358-
):
358+
with dd.Client(local_cluster_addr):
359359
assert_equal(X_from_disk, X_dask_from_disk)
360360

361361

0 commit comments

Comments
 (0)