Skip to content

Commit 3502bb1

Browse files
authored
Use same local cluster (#1924)
1 parent 78a6d6b commit 3502bb1

5 files changed

Lines changed: 59 additions & 17 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ test = [
9191
"pytest-randomly",
9292
"pytest-memray",
9393
"pytest-mock",
94+
"filelock",
9495
"matplotlib",
9596
"scikit-learn",
9697
"openpyxl",

tests/conftest.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
from dask.base import normalize_seq
1414
else:
1515
from dask.tokenize import normalize_seq
16+
from filelock import FileLock
1617
from scipy import sparse
1718

1819
import anndata as ad
1920
from anndata.tests.helpers import subset_func # noqa: F401
2021

2122
if TYPE_CHECKING:
23+
from collections.abc import Generator
2224
from types import EllipsisType
2325

2426

@@ -75,6 +77,39 @@ def equivalent_ellipsis_index(
7577
return ellipsis_index_with_equivalent[1]
7678

7779

80+
@pytest.fixture(scope="session")
81+
def local_cluster_addr(
82+
tmp_path_factory: pytest.TempPathFactory, worker_id: str
83+
) -> Generator[str, None, None]:
84+
# Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
85+
import dask.distributed as dd
86+
87+
def make_cluster() -> dd.LocalCluster:
88+
return dd.LocalCluster(n_workers=1, threads_per_worker=1)
89+
90+
if worker_id == "master":
91+
with make_cluster() as cluster:
92+
yield cluster.scheduler_address
93+
return
94+
95+
# get the temp directory shared by all workers
96+
root_tmp_dir = tmp_path_factory.getbasetemp().parent
97+
98+
fn = root_tmp_dir / "dask_scheduler_address.txt"
99+
lock = FileLock(str(fn) + ".lock")
100+
lock.acquire() # can’t use context manager, because we need to release the lock before yielding
101+
address = fn.read_text() if fn.is_file() else None
102+
if address:
103+
lock.release()
104+
yield address
105+
return
106+
107+
with make_cluster() as cluster:
108+
fn.write_text(cluster.scheduler_address)
109+
lock.release()
110+
yield cluster.scheduler_address
111+
112+
78113
#####################
79114
# Dask tokenization #
80115
#####################

tests/lazy/test_concat.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,11 @@ def test_concat_to_memory_var(
215215

216216

217217
def test_concat_data_with_cluster_to_memory(
218-
adata_remote: AnnData, join: Join_T, *, load_annotation_index: bool
219-
):
218+
adata_remote: AnnData, join: Join_T, local_cluster_addr: str
219+
) -> None:
220220
import dask.distributed as dd
221221

222-
with (
223-
dd.LocalCluster(n_workers=1, threads_per_worker=1) as cluster,
224-
dd.Client(cluster),
225-
):
222+
with dd.Client(local_cluster_addr):
226223
ad.concat([adata_remote, adata_remote], join=join).to_memory()
227224

228225

tests/test_dask.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from __future__ import annotations
66

7+
from typing import TYPE_CHECKING
8+
79
import numpy as np
810
import pandas as pd
911
import pytest
@@ -22,6 +24,11 @@
2224
gen_adata,
2325
)
2426

27+
if TYPE_CHECKING:
28+
from pathlib import Path
29+
from typing import Literal
30+
31+
2532
pytest.importorskip("dask.array")
2633

2734

@@ -103,18 +110,20 @@ def test_dask_write(adata, tmp_path, diskfmt):
103110
assert isinstance(orig.varm["a"], DaskArray)
104111

105112

106-
def test_dask_distributed_write(adata, tmp_path, diskfmt):
113+
def test_dask_distributed_write(
114+
adata: AnnData,
115+
tmp_path: Path,
116+
diskfmt: Literal["h5ad", "zarr"],
117+
local_cluster_addr: str,
118+
) -> None:
107119
import dask.array as da
108120
import dask.distributed as dd
109121
import numpy as np
110122

111123
pth = tmp_path / f"test_write.{diskfmt}"
112124
g = as_group(pth, mode="w")
113125

114-
with (
115-
dd.LocalCluster(n_workers=1, threads_per_worker=1, processes=False) as cluster,
116-
dd.Client(cluster),
117-
):
126+
with dd.Client(local_cluster_addr):
118127
M, N = adata.X.shape
119128
adata.obsm["a"] = da.random.random((M, 10))
120129
adata.obsm["b"] = da.random.random((M, 10))

tests/test_io_elementwise.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import re
8+
from pathlib import Path
89
from typing import TYPE_CHECKING
910

1011
import h5py
@@ -77,7 +78,7 @@ def store(request, tmp_path) -> H5Group | ZarrGroup:
7778

7879

7980
@pytest.fixture(params=sparse_formats)
80-
def sparse_format(request):
81+
def sparse_format(request: pytest.FixtureRequest) -> Literal["csr", "csc"]:
8182
return request.param
8283

8384

@@ -346,18 +347,17 @@ def test_read_lazy_subsets_nd_dask(store, n_dims, chunks):
346347
assert_equal(X_from_disk[index], X_dask_from_disk[index])
347348

348349

349-
def test_read_lazy_h5_cluster(sparse_format, tmp_path):
350+
def test_read_lazy_h5_cluster(
351+
sparse_format: Literal["csr", "csc"], tmp_path: Path, local_cluster_addr: str
352+
) -> None:
350353
import dask.distributed as dd
351354

352355
with h5py.File(tmp_path / "test.h5", "w") as file:
353356
store = file["/"]
354357
arr_store = create_sparse_store(sparse_format, store)
355358
X_dask_from_disk = read_elem_lazy(arr_store["X"])
356359
X_from_disk = read_elem(arr_store["X"])
357-
with (
358-
dd.LocalCluster(n_workers=1, threads_per_worker=1) as cluster,
359-
dd.Client(cluster) as _client,
360-
):
360+
with dd.Client(local_cluster_addr):
361361
assert_equal(X_from_disk, X_dask_from_disk)
362362

363363

0 commit comments

Comments
 (0)