Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ test = [
"pytest-randomly",
"pytest-memray",
"pytest-mock",
"filelock",
"matplotlib",
"scikit-learn",
"openpyxl",
Expand All @@ -102,7 +101,7 @@ test = [
"dask[distributed]",
"awkward>=2.3",
"pyarrow",
"anndata[dask]",
"anndata[dask,lazy]",
Comment thread
flying-sheep marked this conversation as resolved.
Outdated
]
dev-test = [ "pytest-xdist[psutil]" ] # local test speedups
gpu = [ "cupy" ]
Expand Down
80 changes: 44 additions & 36 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import atexit
from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import dask
import joblib
Expand All @@ -13,16 +14,56 @@
from dask.base import normalize_seq
else:
from dask.tokenize import normalize_seq
from filelock import FileLock
from scipy import sparse

import anndata as ad
from anndata.tests.helpers import subset_func # noqa: F401

if TYPE_CHECKING:
from collections.abc import Generator
from types import EllipsisType

from xdist.workermanage import WorkerController


_dask_cluster_addr: str


def pytest_configure(config: pytest.Config) -> None:
# We use this hook because it is run on sequential sessions,
# and both pytest-xdist’s controller and workers
global _dask_cluster_addr # noqa: PLW0603

if "--collect-only" in config.args:
return # no need to do work

_dask_cluster_addr = _get_cluster_address(config)


def pytest_configure_node(node: WorkerController) -> None:
# send the cluster address to the workers
node.workerinput["dask_cluster_addr"] = _dask_cluster_addr


def _get_cluster_address(config: pytest.Config) -> str:
"""Start the dask cluster or (in a pytest-xdist worker) get its address."""
# If we’re on a worker, we can use the data sent in `pytest_configure_node` above
if workerinput := cast("dict[str, str]", getattr(config, "workerinput", {})):
return workerinput["dask_cluster_addr"]

# if we’re on the controller or running sequentially, we start the cluster
import dask.distributed as dd

clust = dd.LocalCluster(n_workers=1, threads_per_worker=1)
clust.__enter__()
atexit.register(clust.close)
return clust.scheduler_address


@pytest.fixture(scope="session")
def local_cluster_addr() -> str:
"""Get the dask cluster address"""
return _dask_cluster_addr


@pytest.fixture
def backing_h5ad(tmp_path):
Expand Down Expand Up @@ -77,39 +118,6 @@ def equivalent_ellipsis_index(
return ellipsis_index_with_equivalent[1]


@pytest.fixture(scope="session")
def local_cluster_addr(
tmp_path_factory: pytest.TempPathFactory, worker_id: str
) -> Generator[str, None, None]:
# Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
import dask.distributed as dd

def make_cluster() -> dd.LocalCluster:
return dd.LocalCluster(n_workers=1, threads_per_worker=1)

if worker_id == "master":
with make_cluster() as cluster:
yield cluster.scheduler_address
return

# get the temp directory shared by all workers
root_tmp_dir = tmp_path_factory.getbasetemp().parent

fn = root_tmp_dir / "dask_scheduler_address.txt"
lock = FileLock(str(fn) + ".lock")
lock.acquire() # can’t use context manager, because we need to release the lock before yielding
address = fn.read_text() if fn.is_file() else None
if address:
lock.release()
yield address
return

with make_cluster() as cluster:
fn.write_text(cluster.scheduler_address)
lock.release()
yield cluster.scheduler_address


#####################
# Dask tokenization #
#####################
Expand Down