Skip to content

Commit 8508109

Browse files
authored
(fix): manage dask cluster from controller (#1929)
1 parent 6786336 commit 8508109

3 files changed

Lines changed: 47 additions & 37 deletions

File tree

hatch.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ overrides.matrix.deps.python = [
2929
{ if = [ "min" ], value = "3.11" },
3030
{ if = [ "stable", "pre" ], value = "3.13" },
3131
]
32+
overrides.matrix.deps.features = [
33+
{ if = [ "stable", "pre" ], value = "test-full" },
34+
]
3235

3336
[[envs.hatch-test.matrix]]
3437
deps = [ "stable", "pre", "min" ]

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ test = [
9191
"pytest-randomly",
9292
"pytest-memray",
9393
"pytest-mock",
94-
"filelock",
9594
"matplotlib",
9695
"scikit-learn",
9796
"openpyxl",

tests/conftest.py

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
import atexit
34
from functools import partial
4-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, cast
56

67
import dask
78
import joblib
@@ -13,16 +14,56 @@
1314
from dask.base import normalize_seq
1415
else:
1516
from dask.tokenize import normalize_seq
16-
from filelock import FileLock
1717
from scipy import sparse
1818

1919
import anndata as ad
2020
from anndata.tests.helpers import subset_func # noqa: F401
2121

2222
if TYPE_CHECKING:
23-
from collections.abc import Generator
2423
from types import EllipsisType
2524

25+
from xdist.workermanage import WorkerController
26+
27+
28+
_dask_cluster_addr: str
29+
30+
31+
def pytest_configure(config: pytest.Config) -> None:
32+
# We use this hook because it is run on sequential sessions,
33+
# and both pytest-xdist’s controller and workers
34+
global _dask_cluster_addr # noqa: PLW0603
35+
36+
if "--collect-only" in config.args:
37+
return # no need to do work
38+
39+
_dask_cluster_addr = _get_cluster_address(config)
40+
41+
42+
def pytest_configure_node(node: WorkerController) -> None:
43+
# send the cluster address to the workers
44+
node.workerinput["dask_cluster_addr"] = _dask_cluster_addr
45+
46+
47+
def _get_cluster_address(config: pytest.Config) -> str:
48+
"""Start the dask cluster or (in a pytest-xdist worker) get its address."""
49+
# If we’re on a worker, we can use the data sent in `pytest_configure_node` above
50+
if workerinput := cast("dict[str, str]", getattr(config, "workerinput", {})):
51+
return workerinput["dask_cluster_addr"]
52+
53+
# if we’re on the controller or running sequentially, we start the cluster
54+
import dask.distributed as dd
55+
56+
clust = dd.LocalCluster(n_workers=1, threads_per_worker=1)
57+
clust.__enter__()
58+
atexit.register(clust.close)
59+
return clust.scheduler_address
60+
61+
62+
@pytest.fixture(scope="session")
63+
def local_cluster_addr() -> str:
64+
"""Get the dask cluster address"""
65+
return _dask_cluster_addr
66+
2667

2768
@pytest.fixture
2869
def backing_h5ad(tmp_path):
@@ -77,39 +118,6 @@ def equivalent_ellipsis_index(
77118
return ellipsis_index_with_equivalent[1]
78119

79120

80-
@pytest.fixture(scope="session")
81-
def local_cluster_addr(
82-
tmp_path_factory: pytest.TempPathFactory, worker_id: str
83-
) -> Generator[str, None, None]:
84-
# Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
85-
import dask.distributed as dd
86-
87-
def make_cluster() -> dd.LocalCluster:
88-
return dd.LocalCluster(n_workers=1, threads_per_worker=1)
89-
90-
if worker_id == "master":
91-
with make_cluster() as cluster:
92-
yield cluster.scheduler_address
93-
return
94-
95-
# get the temp directory shared by all workers
96-
root_tmp_dir = tmp_path_factory.getbasetemp().parent
97-
98-
fn = root_tmp_dir / "dask_scheduler_address.txt"
99-
lock = FileLock(str(fn) + ".lock")
100-
lock.acquire() # can’t use context manager, because we need to release the lock before yielding
101-
address = fn.read_text() if fn.is_file() else None
102-
if address:
103-
lock.release()
104-
yield address
105-
return
106-
107-
with make_cluster() as cluster:
108-
fn.write_text(cluster.scheduler_address)
109-
lock.release()
110-
yield cluster.scheduler_address
111-
112-
113121
#####################
114122
# Dask tokenization #
115123
#####################

0 commit comments

Comments
 (0)