|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import atexit |
4 | 3 | from functools import partial |
5 | | -from typing import TYPE_CHECKING, cast |
| 4 | +from typing import TYPE_CHECKING |
6 | 5 |
|
7 | 6 | import dask |
8 | 7 | import joblib |
|
14 | 13 | from dask.base import normalize_seq |
15 | 14 | else: |
16 | 15 | from dask.tokenize import normalize_seq |
| 16 | +from filelock import FileLock |
17 | 17 | from scipy import sparse |
18 | 18 |
|
19 | 19 | import anndata as ad |
20 | 20 | from anndata.tests.helpers import subset_func # noqa: F401 |
21 | 21 |
|
22 | 22 | if TYPE_CHECKING: |
| 23 | + from collections.abc import Generator |
23 | 24 | from types import EllipsisType |
24 | 25 |
|
25 | | - from xdist.workermanage import WorkerController |
26 | | - |
27 | | - |
28 | | -_dask_cluster_addr: str |
29 | | - |
30 | | - |
31 | | -def pytest_configure(config: pytest.Config) -> None: |
32 | | - # We use this hook because it is run on sequential sessions, |
33 | | - # and both pytest-xdist’s controller and workers |
34 | | - global _dask_cluster_addr # noqa: PLW0603 |
35 | | - |
36 | | - if "--collect-only" in config.args: |
37 | | - return # no need to do work |
38 | | - |
39 | | - _dask_cluster_addr = _get_cluster_address(config) |
40 | | - |
41 | | - |
42 | | -def pytest_configure_node(node: WorkerController) -> None: |
43 | | - # send the cluster address to the workers |
44 | | - node.workerinput["dask_cluster_addr"] = _dask_cluster_addr |
45 | | - |
46 | | - |
47 | | -def _get_cluster_address(config: pytest.Config) -> str: |
48 | | - """Start the dask cluster or (in a pytest-xdist worker) get its address.""" |
49 | | - # If we’re on a worker, we can use the data sent in `pytest_configure_node` above |
50 | | - if workerinput := cast("dict[str, str]", getattr(config, "workerinput", {})): |
51 | | - return workerinput["dask_cluster_addr"] |
52 | | - |
53 | | - # if we’re on the controller or running sequentially, we start the cluster |
54 | | - import dask.distributed as dd |
55 | | - |
56 | | - clust = dd.LocalCluster(n_workers=1, threads_per_worker=1) |
57 | | - clust.__enter__() |
58 | | - atexit.register(clust.close) |
59 | | - return clust.scheduler_address |
60 | | - |
61 | | - |
62 | | -@pytest.fixture(scope="session") |
63 | | -def local_cluster_addr() -> str: |
64 | | - """Get the dask cluster address""" |
65 | | - return _dask_cluster_addr |
66 | | - |
67 | 26 |
|
68 | 27 | @pytest.fixture |
69 | 28 | def backing_h5ad(tmp_path): |
@@ -118,6 +77,39 @@ def equivalent_ellipsis_index( |
118 | 77 | return ellipsis_index_with_equivalent[1] |
119 | 78 |
|
120 | 79 |
|
| 80 | +@pytest.fixture(scope="session") |
| 81 | +def local_cluster_addr( |
| 82 | + tmp_path_factory: pytest.TempPathFactory, worker_id: str |
| 83 | +) -> Generator[str, None, None]: |
| 84 | + # Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once |
| 85 | + import dask.distributed as dd |
| 86 | + |
| 87 | + def make_cluster() -> dd.LocalCluster: |
| 88 | + return dd.LocalCluster(n_workers=1, threads_per_worker=1) |
| 89 | + |
| 90 | + if worker_id == "master": |
| 91 | + with make_cluster() as cluster: |
| 92 | + yield cluster.scheduler_address |
| 93 | + return |
| 94 | + |
| 95 | + # get the temp directory shared by all workers |
| 96 | + root_tmp_dir = tmp_path_factory.getbasetemp().parent |
| 97 | + |
| 98 | + fn = root_tmp_dir / "dask_scheduler_address.txt" |
| 99 | + lock = FileLock(str(fn) + ".lock") |
| 100 | + lock.acquire() # can’t use context manager, because we need to release the lock before yielding |
| 101 | + address = fn.read_text() if fn.is_file() else None |
| 102 | + if address: |
| 103 | + lock.release() |
| 104 | + yield address |
| 105 | + return |
| 106 | + |
| 107 | + with make_cluster() as cluster: |
| 108 | + fn.write_text(cluster.scheduler_address) |
| 109 | + lock.release() |
| 110 | + yield cluster.scheduler_address |
| 111 | + |
| 112 | + |
121 | 113 | ##################### |
122 | 114 | # Dask tokenization # |
123 | 115 | ##################### |
|
0 commit comments