diff --git a/pyproject.toml b/pyproject.toml index a07fb52ef..c826b767e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ test = [ "pytest-randomly", "pytest-memray", "pytest-mock", + "filelock", "matplotlib", "scikit-learn", "openpyxl", diff --git a/tests/conftest.py b/tests/conftest.py index 683ca5ce7..b38ab4a65 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,12 +13,14 @@ from dask.base import normalize_seq else: from dask.tokenize import normalize_seq +from filelock import FileLock from scipy import sparse import anndata as ad from anndata.tests.helpers import subset_func # noqa: F401 if TYPE_CHECKING: + from collections.abc import Generator from types import EllipsisType @@ -75,6 +77,39 @@ def equivalent_ellipsis_index( return ellipsis_index_with_equivalent[1] +@pytest.fixture(scope="session") +def local_cluster_addr( + tmp_path_factory: pytest.TempPathFactory, worker_id: str +) -> Generator[str, None, None]: + # Adapted from https://pytest-xdist.readthedocs.io/en/latest/how-to.html#making-session-scoped-fixtures-execute-only-once + import dask.distributed as dd + + def make_cluster() -> dd.LocalCluster: + return dd.LocalCluster(n_workers=1, threads_per_worker=1) + + if worker_id == "master": + with make_cluster() as cluster: + yield cluster.scheduler_address + return + + # get the temp directory shared by all workers + root_tmp_dir = tmp_path_factory.getbasetemp().parent + + fn = root_tmp_dir / "dask_scheduler_address.txt" + lock = FileLock(str(fn) + ".lock") + lock.acquire() # can’t use context manager, because we need to release the lock before yielding + address = fn.read_text() if fn.is_file() else None + if address: + lock.release() + yield address + return + + with make_cluster() as cluster: + fn.write_text(cluster.scheduler_address) + lock.release() + yield cluster.scheduler_address + + ##################### # Dask tokenization # ##################### diff --git a/tests/lazy/test_concat.py b/tests/lazy/test_concat.py index d2c5efdb4..7825a6e91 100644 --- a/tests/lazy/test_concat.py +++ b/tests/lazy/test_concat.py @@ -215,14 +215,11 @@ def test_concat_to_memory_var( def test_concat_data_with_cluster_to_memory( - adata_remote: AnnData, join: Join_T, *, load_annotation_index: bool -): + adata_remote: AnnData, join: Join_T, local_cluster_addr: str +) -> None: import dask.distributed as dd - with ( - dd.LocalCluster(n_workers=1, threads_per_worker=1) as cluster, - dd.Client(cluster), - ): + with dd.Client(local_cluster_addr): ad.concat([adata_remote, adata_remote], join=join).to_memory() diff --git a/tests/test_dask.py b/tests/test_dask.py index 4865758e2..15af3bd39 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -4,6 +4,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import pandas as pd import pytest @@ -22,6 +24,11 @@ gen_adata, ) +if TYPE_CHECKING: + from pathlib import Path + from typing import Literal + + pytest.importorskip("dask.array") @@ -103,7 +110,12 @@ def test_dask_write(adata, tmp_path, diskfmt): assert isinstance(orig.varm["a"], DaskArray) -def test_dask_distributed_write(adata, tmp_path, diskfmt): +def test_dask_distributed_write( + adata: AnnData, + tmp_path: Path, + diskfmt: Literal["h5ad", "zarr"], + local_cluster_addr: str, +) -> None: import dask.array as da import dask.distributed as dd import numpy as np @@ -111,10 +123,7 @@ def test_dask_distributed_write(adata, tmp_path, diskfmt): pth = tmp_path / f"test_write.{diskfmt}" g = as_group(pth, mode="w") - with ( - dd.LocalCluster(n_workers=1, threads_per_worker=1, processes=False) as cluster, - dd.Client(cluster), - ): + with dd.Client(local_cluster_addr): M, N = adata.X.shape adata.obsm["a"] = da.random.random((M, 10)) adata.obsm["b"] = da.random.random((M, 10)) diff --git a/tests/test_io_elementwise.py b/tests/test_io_elementwise.py index 3991b46f7..92f6abdb1 100644 --- a/tests/test_io_elementwise.py +++ b/tests/test_io_elementwise.py @@ -5,6 +5,7 @@ from __future__ import annotations import re +from pathlib import Path from typing import TYPE_CHECKING import h5py @@ -77,7 +78,7 @@ def store(request, tmp_path) -> H5Group | ZarrGroup: @pytest.fixture(params=sparse_formats) -def sparse_format(request): +def sparse_format(request: pytest.FixtureRequest) -> Literal["csr", "csc"]: return request.param @@ -346,7 +347,9 @@ def test_read_lazy_subsets_nd_dask(store, n_dims, chunks): assert_equal(X_from_disk[index], X_dask_from_disk[index]) -def test_read_lazy_h5_cluster(sparse_format, tmp_path): +def test_read_lazy_h5_cluster( + sparse_format: Literal["csr", "csc"], tmp_path: Path, local_cluster_addr: str +) -> None: import dask.distributed as dd with h5py.File(tmp_path / "test.h5", "w") as file: @@ -354,10 +357,7 @@ def test_read_lazy_h5_cluster(sparse_format, tmp_path): arr_store = create_sparse_store(sparse_format, store) X_dask_from_disk = read_elem_lazy(arr_store["X"]) X_from_disk = read_elem(arr_store["X"]) - with ( - dd.LocalCluster(n_workers=1, threads_per_worker=1) as cluster, - dd.Client(cluster) as _client, - ): + with dd.Client(local_cluster_addr): assert_equal(X_from_disk, X_dask_from_disk)