Skip to content

Commit 8bb8547

Browse files
change dask default client
1 parent 9d5b35c commit 8bb8547

File tree

4 files changed

+92
-2
lines changed

4 files changed

+92
-2
lines changed

src/pseudopeople/interface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
configure_logging_to_terminal,
2424
get_engine_from_string,
2525
get_state_abbreviation,
26+
set_up_dask_client,
2627
)
2728

2829

@@ -111,7 +112,7 @@ def _generate_dataset(
111112
noised_dataset = pd.concat(noised_datasets_list, ignore_index=True)
112113

113114
noised_dataset = coerce_dtypes(noised_dataset, dataset_schema)
114-
else:
115+
else: # dask
115116
try:
116117
from distributed.client import default_client
117118

@@ -126,6 +127,8 @@ def _generate_dataset(
126127
import dask
127128
import dask.dataframe as dd
128129

130+
set_up_dask_client()
131+
129132
# Our work depends on the particulars of how dtypes work, and is only
130133
# built to work with NumPy dtypes, so we turn off the Dask default behavior
131134
# of using PyArrow dtypes.

src/pseudopeople/utilities.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,27 @@ def get_engine_from_string(engine: str) -> Engine:
303303
DataFrame = pd.DataFrame # type: ignore [misc]
304304

305305

306+
def set_up_dask_client() -> None:
307+
"""Sets up a Dask client if one is not already running."""
308+
from dask.distributed import get_client
309+
310+
# Determine whether or not a Dask client is already running. If not,
311+
# create a new one.
312+
try:
313+
client = get_client()
314+
except ValueError:
315+
# No Dask client is running so we create one.
316+
from dask.distributed import LocalCluster
317+
from dask.system import CPU_COUNT
318+
319+
# extract the memory limit from the environment variable
320+
cluster = LocalCluster(
321+
n_workers=CPU_COUNT,
322+
threads_per_worker=1,
323+
)
324+
client = cluster.get_client()
325+
326+
306327
##########################
307328
# Data utility functions #
308329
##########################

tests/unit/test_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
1+
import os
2+
13
import numpy as np
24
import pandas as pd
5+
import psutil
36
import pytest
7+
from dask.distributed import LocalCluster, get_client
8+
from dask.system import CPU_COUNT
49

510
from pseudopeople.dataset import Dataset
611
from pseudopeople.noise_functions import _corrupt_tokens
712
from pseudopeople.schema_entities import DATASET_SCHEMAS, DtypeNames
813
from pseudopeople.utilities import (
914
get_hash,
1015
get_index_to_noise,
16+
set_up_dask_client,
1117
to_string_as_integer,
1218
two_d_array_choice,
1319
vectorized_choice,
1420
)
1521
from tests.conftest import FuzzyChecker
22+
from tests.utilities import is_on_slurm
1623

1724

1825
@pytest.fixture()
@@ -268,3 +275,50 @@ def test_two_d_array_choice(
268275
target_proportion=1 / num_choices,
269276
name_additional=f"team {team} for sport {sport}",
270277
)
278+
279+
280+
def test_set_up_dask_client_default() -> None:
281+
282+
# There should be no dask client yet
283+
with pytest.raises(ValueError):
284+
client = get_client()
285+
286+
set_up_dask_client()
287+
client = get_client()
288+
workers = client.scheduler_info()["workers"]
289+
assert len(workers) == CPU_COUNT
290+
assert all(worker["nthreads"] == 1 for worker in workers.values())
291+
if is_on_slurm():
292+
try:
293+
available_memory = float(os.environ["SLURM_MEM_PER_NODE"]) / 1024
294+
except KeyError:
295+
raise RuntimeError(
296+
"You are on Slurm but SLURM_MEM_PER_NODE is not set. "
297+
"It is likely that you are SSHed onto a node (perhaps using VSCode). "
298+
"In this case, dask will assign the total memory of the node to each "
299+
"worker instead of the allocated memory from the srun call. "
300+
"Pseudopeople should only be used on Slurm directly on the node "
301+
"assigned via an srun (both for pytests as well as actual work)."
302+
)
303+
else:
304+
available_memory = psutil.virtual_memory().total / (1024 ** 3)
305+
assert sum(worker["memory_limit"] / 1024**3 for worker in workers.values()) == available_memory
306+
307+
308+
def test_set_up_dask_client_custom() -> None:
309+
memory_limit = 1 # gb
310+
n_workers = 3
311+
cluster = LocalCluster(
312+
name="custom",
313+
n_workers=n_workers,
314+
threads_per_worker=2,
315+
memory_limit=memory_limit * 1024**3,
316+
)
317+
client = cluster.get_client()
318+
set_up_dask_client()
319+
client = get_client()
320+
assert client.cluster.name == "custom"
321+
workers = client.scheduler_info()["workers"]
322+
assert len(workers) == 3
323+
assert all(worker["nthreads"] == 2 for worker in workers.values())
324+
assert sum(worker["memory_limit"] / 1024**3 for worker in workers.values()) == memory_limit * n_workers

tests/utilities.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
import math
4+
import shutil
45
from collections.abc import Callable
5-
from functools import partial
66
from typing import Any
77

88
import numpy as np
@@ -190,3 +190,15 @@ def get_single_noise_type_config(
190190
] = new_probability
191191

192192
return config_dict
193+
194+
195+
def is_on_slurm() -> bool:
196+
"""Returns True if the current environment is a SLURM cluster.
197+
198+
Notes
199+
-----
200+
This function simply checks for the presence of the `sbatch` command to _infer_
201+
if SLURM is installed. It does _not_ check if SLURM is currently active or
202+
managing jobs.
203+
"""
204+
return shutil.which("sbatch") is not None

0 commit comments

Comments
 (0)