Skip to content

Commit 1709f35

Browse files
change default dask client
1 parent 9d5b35c commit 1709f35

File tree

4 files changed

+92
-2
lines changed

4 files changed

+92
-2
lines changed

src/pseudopeople/interface.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _generate_dataset(
111111
noised_dataset = pd.concat(noised_datasets_list, ignore_index=True)
112112

113113
noised_dataset = coerce_dtypes(noised_dataset, dataset_schema)
114-
else:
114+
else: # dask
115115
try:
116116
from distributed.client import default_client
117117

@@ -126,6 +126,8 @@ def _generate_dataset(
126126
import dask
127127
import dask.dataframe as dd
128128

129+
set_up_dask_client()
130+
129131
# Our work depends on the particulars of how dtypes work, and is only
130132
# built to work with NumPy dtypes, so we turn off the Dask default behavior
131133
# of using PyArrow dtypes.
@@ -931,3 +933,24 @@ def get_dataset_filepaths(source: Path, dataset_name: str) -> list[Path]:
931933
dataset_paths = [x for x in directory.glob(f"{dataset_name}*")]
932934
sorted_dataset_paths = sorted(dataset_paths)
933935
return sorted_dataset_paths
936+
937+
938+
def set_up_dask_client() -> None:
939+
"""Sets up a Dask client if one is not already running."""
940+
from dask.distributed import get_client
941+
942+
# Determine whether or not a Dask client is already running. If not,
943+
# create a new one.
944+
try:
945+
client = get_client()
946+
except ValueError:
947+
# No Dask client is running so we create one.
948+
from dask.distributed import LocalCluster
949+
from dask.system import CPU_COUNT
950+
951+
# extract the memory limit from the environment variable
952+
cluster = LocalCluster(
953+
n_workers=CPU_COUNT,
954+
threads_per_worker=1,
955+
)
956+
client = cluster.get_client()

tests/unit/test_interface.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1+
import os
12
from pathlib import Path
23

4+
import numpy as np
5+
import psutil
36
import pytest
47
from _pytest.tmpdir import TempPathFactory
8+
from dask.distributed import LocalCluster, get_client
9+
from dask.system import CPU_COUNT
510
from packaging.version import parse
611
from pytest_mock import MockerFixture
712

813
from pseudopeople.constants.metadata import DatasetNames
914
from pseudopeople.exceptions import DataSourceError
1015
from pseudopeople.interface import (
1116
_get_data_changelog_version,
17+
set_up_dask_client,
1218
validate_source_compatibility,
1319
)
1420
from pseudopeople.schema_entities import DATASET_SCHEMAS
21+
from tests.utilities import is_on_slurm
1522

1623
CENSUS = DATASET_SCHEMAS.get_dataset_schema(DatasetNames.CENSUS)
1724

@@ -94,3 +101,50 @@ def test_validate_source_compatibility_wrong_directory(tmp_path: Path) -> None:
94101
bad_path.mkdir()
95102
with pytest.raises(FileNotFoundError, match="Could not find 'decennial_census' in"):
96103
validate_source_compatibility(bad_path, CENSUS)
104+
105+
106+
def test_set_up_dask_client_default() -> None:
107+
108+
# There should be no dask client yet
109+
with pytest.raises(ValueError):
110+
client = get_client()
111+
112+
set_up_dask_client()
113+
client = get_client()
114+
workers = client.scheduler_info()["workers"]
115+
assert len(workers) == CPU_COUNT
116+
assert all(worker["nthreads"] == 1 for worker in workers.values())
117+
if is_on_slurm():
118+
try:
119+
available_memory = float(os.environ["SLURM_MEM_PER_NODE"]) / 1024
120+
except KeyError:
121+
raise RuntimeError(
122+
"You are on Slurm but SLURM_MEM_PER_NODE is not set. "
123+
"It is likely that you are SSHed onto a node (perhaps using VSCode). "
124+
"In this case, dask will assign the total memory of the node to each "
125+
"worker instead of the allocated memory from the srun call. "
126+
"Pseudopeople should only be used on Slurm directly on the node "
127+
"assigned via an srun (both for pytests as well as actual work)."
128+
)
129+
else:
130+
available_memory = psutil.virtual_memory().total / (1024 ** 3)
131+
assert np.isclose(sum(worker["memory_limit"] / 1024**3 for worker in workers.values()), available_memory, rtol=0.01)
132+
133+
134+
def test_set_up_dask_client_custom() -> None:
135+
memory_limit = 1 # gb
136+
n_workers = 3
137+
cluster = LocalCluster(
138+
name="custom",
139+
n_workers=n_workers,
140+
threads_per_worker=2,
141+
memory_limit=memory_limit * 1024**3,
142+
)
143+
client = cluster.get_client()
144+
set_up_dask_client()
145+
client = get_client()
146+
assert client.cluster.name == "custom"
147+
workers = client.scheduler_info()["workers"]
148+
assert len(workers) == 3
149+
assert all(worker["nthreads"] == 2 for worker in workers.values())
150+
assert sum(worker["memory_limit"] / 1024**3 for worker in workers.values()) == memory_limit * n_workers

tests/unit/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
import numpy as np
23
import pandas as pd
34
import pytest

tests/utilities.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
import math
4+
import shutil
45
from collections.abc import Callable
5-
from functools import partial
66
from typing import Any
77

88
import numpy as np
@@ -190,3 +190,15 @@ def get_single_noise_type_config(
190190
] = new_probability
191191

192192
return config_dict
193+
194+
195+
def is_on_slurm() -> bool:
196+
"""Returns True if the current environment is a SLURM cluster.
197+
198+
Notes
199+
-----
200+
This function simply checks for the presence of the `sbatch` command to _infer_
201+
if SLURM is installed. It does _not_ check if SLURM is currently active or
202+
managing jobs.
203+
"""
204+
return shutil.which("sbatch") is not None

0 commit comments

Comments
 (0)