Skip to content

Commit 3bcb643

Browse files
authored
Revert "(fix): manage dask cluster from controller (#1929)"
This reverts commit 8508109.
1 parent 8508109 commit 3bcb643

3 files changed

Lines changed: 37 additions & 47 deletions

File tree

hatch.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ overrides.matrix.deps.python = [
2929
{ if = [ "min" ], value = "3.11" },
3030
{ if = [ "stable", "pre" ], value = "3.13" },
3131
]
32-
overrides.matrix.deps.features = [
33-
{ if = [ "stable", "pre" ], value = "test-full" },
34-
]
3532

3633
[[envs.hatch-test.matrix]]
3734
deps = [ "stable", "pre", "min" ]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ test = [
9191
"pytest-randomly",
9292
"pytest-memray",
9393
"pytest-mock",
94+
"filelock",
9495
"matplotlib",
9596
"scikit-learn",
9697
"openpyxl",

tests/conftest.py

Lines changed: 36 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

3-
import atexit
43
from functools import partial
5-
from typing import TYPE_CHECKING, cast
4+
from typing import TYPE_CHECKING
65

76
import dask
87
import joblib
@@ -14,56 +13,16 @@
1413
from dask.base import normalize_seq
1514
else:
1615
from dask.tokenize import normalize_seq
16+
from filelock import FileLock
1717
from scipy import sparse
1818

1919
import anndata as ad
2020
from anndata.tests.helpers import subset_func # noqa: F401
2121

2222
if TYPE_CHECKING:
23+
from collections.abc import Generator
2324
from types import EllipsisType
2425

25-
from xdist.workermanage import WorkerController
26-
27-
28-
_dask_cluster_addr: str
29-
30-
31-
def pytest_configure(config: pytest.Config) -> None:
32-
# We use this hook because it is run on sequential sessions,
33-
# and both pytest-xdist’s controller and workers
34-
global _dask_cluster_addr # noqa: PLW0603
35-
36-
if "--collect-only" in config.args:
37-
return # no need to do work
38-
39-
_dask_cluster_addr = _get_cluster_address(config)
40-
41-
42-
def pytest_configure_node(node: WorkerController) -> None:
43-
# send the cluster address to the workers
44-
node.workerinput["dask_cluster_addr"] = _dask_cluster_addr
45-
46-
47-
def _get_cluster_address(config: pytest.Config) -> str:
48-
"""Start the dask cluster or (in a pytest-xdist worker) get its address."""
49-
# If we’re on a worker, we can use the data sent in `pytest_configure_node` above
50-
if workerinput := cast("dict[str, str]", getattr(config, "workerinput", {})):
51-
return workerinput["dask_cluster_addr"]
52-
53-
# if we’re on the controller or running sequentially, we start the cluster
54-
import dask.distributed as dd
55-
56-
clust = dd.LocalCluster(n_workers=1, threads_per_worker=1)
57-
clust.__enter__()
58-
atexit.register(clust.close)
59-
return clust.scheduler_address
60-
61-
62-
@pytest.fixture(scope="session")
63-
def local_cluster_addr() -> str:
64-
"""Get the dask cluster address"""
65-
return _dask_cluster_addr
66-
6726

6827
@pytest.fixture
6928
def backing_h5ad(tmp_path):
@@ -118,6 +77,39 @@ def equivalent_ellipsis_index(
11877
return ellipsis_index_with_equivalent[1]
11978

12079

80+
@pytest.fixture(scope="session")
81+
def local_cluster_addr(
82+
tmp_path_factory: pytest.TempPathFactory, worker_id: str
83+
) -> Generator[str, None, None]:
84+
# Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
85+
import dask.distributed as dd
86+
87+
def make_cluster() -> dd.LocalCluster:
88+
return dd.LocalCluster(n_workers=1, threads_per_worker=1)
89+
90+
if worker_id == "master":
91+
with make_cluster() as cluster:
92+
yield cluster.scheduler_address
93+
return
94+
95+
# get the temp directory shared by all workers
96+
root_tmp_dir = tmp_path_factory.getbasetemp().parent
97+
98+
fn = root_tmp_dir / "dask_scheduler_address.txt"
99+
lock = FileLock(str(fn) + ".lock")
100+
lock.acquire() # can’t use context manager, because we need to release the lock before yielding
101+
address = fn.read_text() if fn.is_file() else None
102+
if address:
103+
lock.release()
104+
yield address
105+
return
106+
107+
with make_cluster() as cluster:
108+
fn.write_text(cluster.scheduler_address)
109+
lock.release()
110+
yield cluster.scheduler_address
111+
112+
121113
#####################
122114
# Dask tokenization #
123115
#####################

0 commit comments

Comments
 (0)