|
| 1 | +import os |
1 | 2 | from pathlib import Path |
2 | 3 |
|
| 4 | +import numpy as np |
| 5 | +import psutil |
3 | 6 | import pytest |
4 | 7 | from _pytest.tmpdir import TempPathFactory |
| 8 | +from dask.distributed import LocalCluster, get_client |
| 9 | +from dask.system import CPU_COUNT |
5 | 10 | from packaging.version import parse |
6 | 11 | from pytest_mock import MockerFixture |
7 | 12 |
|
8 | 13 | from pseudopeople.constants.metadata import DatasetNames |
9 | 14 | from pseudopeople.exceptions import DataSourceError |
10 | 15 | from pseudopeople.interface import ( |
11 | 16 | _get_data_changelog_version, |
| 17 | + set_up_dask_client, |
12 | 18 | validate_source_compatibility, |
13 | 19 | ) |
14 | 20 | from pseudopeople.schema_entities import DATASET_SCHEMAS |
| 21 | +from tests.utilities import is_on_slurm |
15 | 22 |
|
16 | 23 | CENSUS = DATASET_SCHEMAS.get_dataset_schema(DatasetNames.CENSUS) |
17 | 24 |
|
@@ -94,3 +101,50 @@ def test_validate_source_compatibility_wrong_directory(tmp_path: Path) -> None: |
94 | 101 | bad_path.mkdir() |
95 | 102 | with pytest.raises(FileNotFoundError, match="Could not find 'decennial_census' in"): |
96 | 103 | validate_source_compatibility(bad_path, CENSUS) |
| 104 | + |
| 105 | + |
| 106 | +def test_set_up_dask_client_default() -> None: |
| 107 | + |
| 108 | + # There should be no dask client yet |
| 109 | + with pytest.raises(ValueError): |
| 110 | + client = get_client() |
| 111 | + |
| 112 | + set_up_dask_client() |
| 113 | + client = get_client() |
| 114 | + workers = client.scheduler_info()["workers"] |
| 115 | + assert len(workers) == CPU_COUNT |
| 116 | + assert all(worker["nthreads"] == 1 for worker in workers.values()) |
| 117 | + if is_on_slurm(): |
| 118 | + try: |
| 119 | + available_memory = float(os.environ["SLURM_MEM_PER_NODE"]) / 1024 |
| 120 | + except KeyError: |
| 121 | + raise RuntimeError( |
| 122 | + "You are on Slurm but SLURM_MEM_PER_NODE is not set. " |
| 123 | + "It is likely that you are SSHed onto a node (perhaps using VSCode). " |
| 124 | + "In this case, dask will assign the total memory of the node to each " |
| 125 | + "worker instead of the allocated memory from the srun call. " |
| 126 | + "Pseudopeople should only be used on Slurm directly on the node " |
| 127 | + "assigned via an srun (both for pytests as well as actual work)." |
| 128 | + ) |
| 129 | + else: |
| 130 | + available_memory = psutil.virtual_memory().total / (1024 ** 3) |
| 131 | + assert np.isclose(sum(worker["memory_limit"] / 1024**3 for worker in workers.values()), available_memory, rtol=0.01) |
| 132 | + |
| 133 | + |
| 134 | +def test_set_up_dask_client_custom() -> None: |
| 135 | + memory_limit = 1 # gb |
| 136 | + n_workers = 3 |
| 137 | + cluster = LocalCluster( |
| 138 | + name="custom", |
| 139 | + n_workers=n_workers, |
| 140 | + threads_per_worker=2, |
| 141 | + memory_limit=memory_limit * 1024**3, |
| 142 | + ) |
| 143 | + client = cluster.get_client() |
| 144 | + set_up_dask_client() |
| 145 | + client = get_client() |
| 146 | + assert client.cluster.name == "custom" |
| 147 | + workers = client.scheduler_info()["workers"] |
| 148 | + assert len(workers) == 3 |
| 149 | + assert all(worker["nthreads"] == 2 for worker in workers.values()) |
| 150 | + assert sum(worker["memory_limit"] / 1024**3 for worker in workers.values()) == memory_limit * n_workers |
0 commit comments