|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import atexit |
3 | 4 | from functools import partial |
4 | | -from typing import TYPE_CHECKING |
| 5 | +from typing import TYPE_CHECKING, cast |
5 | 6 |
|
6 | 7 | import dask |
7 | 8 | import joblib |
|
13 | 14 | from dask.base import normalize_seq |
14 | 15 | else: |
15 | 16 | from dask.tokenize import normalize_seq |
16 | | -from filelock import FileLock |
17 | 17 | from scipy import sparse |
18 | 18 |
|
19 | 19 | import anndata as ad |
20 | 20 | from anndata.tests.helpers import subset_func # noqa: F401 |
21 | 21 |
|
22 | 22 | if TYPE_CHECKING: |
23 | | - from collections.abc import Generator |
24 | 23 | from types import EllipsisType |
25 | 24 |
|
| 25 | + from xdist.workermanage import WorkerController |
| 26 | + |
| 27 | + |
| 28 | +_dask_cluster_addr: str |
| 29 | + |
| 30 | + |
| 31 | +def pytest_configure(config: pytest.Config) -> None: |
| 32 | + # We use this hook because it is run on sequential sessions, |
| 33 | + # and both pytest-xdist’s controller and workers |
| 34 | + global _dask_cluster_addr # noqa: PLW0603 |
| 35 | + |
| 36 | + if "--collect-only" in config.args: |
| 37 | + return # no need to do work |
| 38 | + |
| 39 | + _dask_cluster_addr = _get_cluster_address(config) |
| 40 | + |
| 41 | + |
| 42 | +def pytest_configure_node(node: WorkerController) -> None: |
| 43 | + # send the cluster address to the workers |
| 44 | + node.workerinput["dask_cluster_addr"] = _dask_cluster_addr |
| 45 | + |
| 46 | + |
| 47 | +def _get_cluster_address(config: pytest.Config) -> str: |
| 48 | + """Start the dask cluster or (in a pytest-xdist worker) get its address.""" |
| 49 | + # If we’re on a worker, we can use the data sent in `pytest_configure_node` above |
| 50 | + if workerinput := cast("dict[str, str]", getattr(config, "workerinput", {})): |
| 51 | + return workerinput["dask_cluster_addr"] |
| 52 | + |
| 53 | + # if we’re on the controller or running sequentially, we start the cluster |
| 54 | + import dask.distributed as dd |
| 55 | + |
| 56 | + clust = dd.LocalCluster(n_workers=1, threads_per_worker=1) |
| 57 | + clust.__enter__() |
| 58 | + atexit.register(clust.close) |
| 59 | + return clust.scheduler_address |
| 60 | + |
| 61 | + |
| 62 | +@pytest.fixture(scope="session") |
| 63 | +def local_cluster_addr() -> str: |
| 64 | + """Get the dask cluster address""" |
| 65 | + return _dask_cluster_addr |
| 66 | + |
26 | 67 |
|
27 | 68 | @pytest.fixture |
28 | 69 | def backing_h5ad(tmp_path): |
@@ -77,39 +118,6 @@ def equivalent_ellipsis_index( |
77 | 118 | return ellipsis_index_with_equivalent[1] |
78 | 119 |
|
79 | 120 |
|
80 | | -@pytest.fixture(scope="session") |
81 | | -def local_cluster_addr( |
82 | | - tmp_path_factory: pytest.TempPathFactory, worker_id: str |
83 | | -) -> Generator[str, None, None]: |
84 | | - # Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once |
85 | | - import dask.distributed as dd |
86 | | - |
87 | | - def make_cluster() -> dd.LocalCluster: |
88 | | - return dd.LocalCluster(n_workers=1, threads_per_worker=1) |
89 | | - |
90 | | - if worker_id == "master": |
91 | | - with make_cluster() as cluster: |
92 | | - yield cluster.scheduler_address |
93 | | - return |
94 | | - |
95 | | - # get the temp directory shared by all workers |
96 | | - root_tmp_dir = tmp_path_factory.getbasetemp().parent |
97 | | - |
98 | | - fn = root_tmp_dir / "dask_scheduler_address.txt" |
99 | | - lock = FileLock(str(fn) + ".lock") |
100 | | - lock.acquire() # can’t use context manager, because we need to release the lock before yielding |
101 | | - address = fn.read_text() if fn.is_file() else None |
102 | | - if address: |
103 | | - lock.release() |
104 | | - yield address |
105 | | - return |
106 | | - |
107 | | - with make_cluster() as cluster: |
108 | | - fn.write_text(cluster.scheduler_address) |
109 | | - lock.release() |
110 | | - yield cluster.scheduler_address |
111 | | - |
112 | | - |
113 | 121 | ##################### |
114 | 122 | # Dask tokenization # |
115 | 123 | ##################### |
|
0 commit comments