Skip to content

Commit a4a6546

Browse files
committed
(fix): use dask cluster as a session param
1 parent 098639e commit a4a6546

1 file changed

Lines changed: 13 additions & 33 deletions

File tree

tests/conftest.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
from dask.base import normalize_seq
1414
else:
1515
from dask.tokenize import normalize_seq
16-
from filelock import FileLock
16+
17+
import dask.distributed as dd
1718
from scipy import sparse
1819

1920
import anndata as ad
2021
from anndata.tests.helpers import subset_func # noqa: F401
2122

2223
if TYPE_CHECKING:
23-
from collections.abc import Generator
2424
from types import EllipsisType
2525

2626

@@ -77,37 +77,17 @@ def equivalent_ellipsis_index(
7777
return ellipsis_index_with_equivalent[1]
7878

7979

80-
@pytest.fixture(scope="session")
81-
def local_cluster_addr(
82-
tmp_path_factory: pytest.TempPathFactory, worker_id: str
83-
) -> Generator[str, None, None]:
84-
# Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
85-
import dask.distributed as dd
86-
87-
def make_cluster() -> dd.LocalCluster:
88-
return dd.LocalCluster(n_workers=1, threads_per_worker=1)
89-
90-
if worker_id == "master":
91-
with make_cluster() as cluster:
92-
yield cluster.scheduler_address
93-
return
94-
95-
# get the temp directory shared by all workers
96-
root_tmp_dir = tmp_path_factory.getbasetemp().parent
97-
98-
fn = root_tmp_dir / "dask_scheduler_address.txt"
99-
lock = FileLock(str(fn) + ".lock")
100-
lock.acquire() # can’t use context manager, because we need to release the lock before yielding
101-
address = fn.read_text() if fn.is_file() else None
102-
if address:
103-
lock.release()
104-
yield address
105-
return
106-
107-
with make_cluster() as cluster:
108-
fn.write_text(cluster.scheduler_address)
109-
lock.release()
110-
yield cluster.scheduler_address
80+
@pytest.fixture(
81+
scope="session",
82+
params=[
83+
pytest.param(
84+
dd.LocalCluster(n_workers=1, threads_per_worker=1),
85+
# marks=pytest.mark.xdist_group("dask"), # if we ever want to do this setup only once
86+
)
87+
],
88+
)
89+
def local_cluster_addr(request):
90+
return request.param
11191

11292

11393
#####################

0 commit comments

Comments
 (0)