Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ overrides.matrix.deps.python = [
{ if = [ "min" ], value = "3.11" },
{ if = [ "stable", "pre" ], value = "3.13" },
]
overrides.matrix.deps.features = [
{ if = [ "stable", "pre" ], value = "test-full" },
]

[[envs.hatch-test.matrix]]
deps = [ "stable", "pre", "min" ]
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ test = [
"pytest-randomly",
"pytest-memray",
"pytest-mock",
"filelock",
"matplotlib",
"scikit-learn",
"openpyxl",
Expand Down
80 changes: 44 additions & 36 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import atexit
from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import dask
import joblib
Expand All @@ -13,16 +14,56 @@
from dask.base import normalize_seq
else:
from dask.tokenize import normalize_seq
from filelock import FileLock
from scipy import sparse

import anndata as ad
from anndata.tests.helpers import subset_func # noqa: F401

if TYPE_CHECKING:
from collections.abc import Generator
from types import EllipsisType

from xdist.workermanage import WorkerController


_dask_cluster_addr: str


def pytest_configure(config: pytest.Config) -> None:
# We use this hook because it is run on sequential sessions,
# and both pytest-xdist’s controller and workers
global _dask_cluster_addr # noqa: PLW0603

if "--collect-only" in config.args:
return # no need to do work

_dask_cluster_addr = _get_cluster_address(config)


def pytest_configure_node(node: WorkerController) -> None:
# send the cluster address to the workers
node.workerinput["dask_cluster_addr"] = _dask_cluster_addr


def _get_cluster_address(config: pytest.Config) -> str:
"""Start the dask cluster or (in a pytest-xdist worker) get its address."""
# If we’re on a worker, we can use the data sent in `pytest_configure_node` above
if workerinput := cast("dict[str, str]", getattr(config, "workerinput", {})):
return workerinput["dask_cluster_addr"]

# if we’re on the controller or running sequentially, we start the cluster
import dask.distributed as dd

clust = dd.LocalCluster(n_workers=1, threads_per_worker=1)
clust.__enter__()
atexit.register(clust.close)
return clust.scheduler_address


@pytest.fixture(scope="session")
def local_cluster_addr() -> str:
"""Get the dask cluster address"""
return _dask_cluster_addr


@pytest.fixture
def backing_h5ad(tmp_path):
Expand Down Expand Up @@ -77,39 +118,6 @@ def equivalent_ellipsis_index(
return ellipsis_index_with_equivalent[1]


@pytest.fixture(scope="session")
def local_cluster_addr(
tmp_path_factory: pytest.TempPathFactory, worker_id: str
) -> Generator[str, None, None]:
# Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once
import dask.distributed as dd

def make_cluster() -> dd.LocalCluster:
return dd.LocalCluster(n_workers=1, threads_per_worker=1)

if worker_id == "master":
with make_cluster() as cluster:
yield cluster.scheduler_address
return

# get the temp directory shared by all workers
root_tmp_dir = tmp_path_factory.getbasetemp().parent

fn = root_tmp_dir / "dask_scheduler_address.txt"
lock = FileLock(str(fn) + ".lock")
lock.acquire() # can’t use context manager, because we need to release the lock before yielding
address = fn.read_text() if fn.is_file() else None
if address:
lock.release()
yield address
return

with make_cluster() as cluster:
fn.write_text(cluster.scheduler_address)
lock.release()
yield cluster.scheduler_address


#####################
# Dask tokenization #
#####################
Expand Down