diff --git a/hatch.toml b/hatch.toml index 12d567fc3..b5516931d 100644 --- a/hatch.toml +++ b/hatch.toml @@ -29,6 +29,9 @@ overrides.matrix.deps.python = [ { if = [ "min" ], value = "3.11" }, { if = [ "stable", "pre" ], value = "3.13" }, ] +overrides.matrix.deps.features = [ + { if = [ "stable", "pre" ], value = "test-full" }, +] [[envs.hatch-test.matrix]] deps = [ "stable", "pre", "min" ] diff --git a/pyproject.toml b/pyproject.toml index c826b767e..a07fb52ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,6 @@ test = [ "pytest-randomly", "pytest-memray", "pytest-mock", - "filelock", "matplotlib", "scikit-learn", "openpyxl", diff --git a/tests/conftest.py b/tests/conftest.py index b38ab4a65..65d1b1fb8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ from __future__ import annotations +import atexit from functools import partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import dask import joblib @@ -13,16 +14,56 @@ from dask.base import normalize_seq else: from dask.tokenize import normalize_seq -from filelock import FileLock from scipy import sparse import anndata as ad from anndata.tests.helpers import subset_func # noqa: F401 if TYPE_CHECKING: - from collections.abc import Generator from types import EllipsisType + from xdist.workermanage import WorkerController + + +_dask_cluster_addr: str + + +def pytest_configure(config: pytest.Config) -> None: + # We use this hook because it is run on sequential sessions, + # and both pytest-xdist’s controller and workers + global _dask_cluster_addr # noqa: PLW0603 + + if "--collect-only" in config.args: + return # no need to do work + + _dask_cluster_addr = _get_cluster_address(config) + + +def pytest_configure_node(node: WorkerController) -> None: + # send the cluster address to the workers + node.workerinput["dask_cluster_addr"] = _dask_cluster_addr + + +def _get_cluster_address(config: pytest.Config) -> str: + """Start the dask cluster or (in a pytest-xdist worker) get its address.""" + # If we’re on a worker, we can use the data sent in `pytest_configure_node` above + if workerinput := cast("dict[str, str]", getattr(config, "workerinput", {})): + return workerinput["dask_cluster_addr"] + + # if we’re on the controller or running sequentially, we start the cluster + import dask.distributed as dd + + clust = dd.LocalCluster(n_workers=1, threads_per_worker=1) + clust.__enter__() + atexit.register(clust.close) + return clust.scheduler_address + + +@pytest.fixture(scope="session") +def local_cluster_addr() -> str: + """Get the dask cluster address""" + return _dask_cluster_addr + @pytest.fixture def backing_h5ad(tmp_path): @@ -77,39 +118,6 @@ def equivalent_ellipsis_index( return ellipsis_index_with_equivalent[1] -@pytest.fixture(scope="session") -def local_cluster_addr( - tmp_path_factory: pytest.TempPathFactory, worker_id: str -) -> Generator[str, None, None]: - # Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once - import dask.distributed as dd - - def make_cluster() -> dd.LocalCluster: - return dd.LocalCluster(n_workers=1, threads_per_worker=1) - - if worker_id == "master": - with make_cluster() as cluster: - yield cluster.scheduler_address - return - - # get the temp directory shared by all workers - root_tmp_dir = tmp_path_factory.getbasetemp().parent - - fn = root_tmp_dir / "dask_scheduler_address.txt" - lock = FileLock(str(fn) + ".lock") - lock.acquire() # can’t use context manager, because we need to release the lock before yielding - address = fn.read_text() if fn.is_file() else None - if address: - lock.release() - yield address - return - - with make_cluster() as cluster: - fn.write_text(cluster.scheduler_address) - lock.release() - yield cluster.scheduler_address - - ##################### # Dask tokenization # #####################