Skip to content

Commit 4688654

Browse files
Sbachmei/mic 6020/fix breaking test (#513)
* fix test_set_up_dask_clean_existing_cluster * ignore mypy error * check cluster prior to setting up as well * improve test
1 parent d360c83 commit 4688654

File tree

2 files changed

+63
-21
lines changed

2 files changed

+63
-21
lines changed

tests/integration/test_interface.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,11 @@ def test_generate_dataset_with_bad_year(
497497
)
498498
if engine == "dask":
499499
df.compute()
500+
501+
502+
@pytest.mark.skip(reason="TODO: mic-6023")
503+
def test_dask_cluster_is_used():
504+
"""Set up a dask cluster manually, then run psp, then make sure that the
505+
resulting dataframe was computed from that same cluster
506+
"""
507+
...

tests/unit/test_interface.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_set_up_dask_client_default() -> None:
107107
# Shut down a client if it exists
108108
try:
109109
client = get_client()
110-
client.shutdown()
110+
client.shutdown() # type: ignore[no-untyped-call]
111111
except ValueError:
112112
pass
113113
finally:
@@ -116,43 +116,77 @@ def test_set_up_dask_client_default() -> None:
116116
client = get_client()
117117

118118
set_up_dask_client()
119-
client = get_client()
120-
assert isinstance(client.cluster, LocalCluster)
121-
assert client.cluster.name == "pseudopeople_dask_cluster"
122-
workers = client.scheduler_info()["workers"] # type: ignore[no-untyped-call]
123-
assert len(workers) == CPU_COUNT
124-
assert all(worker["nthreads"] == 1 for worker in workers.values())
119+
125120
if is_on_slurm():
126121
try:
127122
available_memory = float(os.environ["SLURM_MEM_PER_NODE"]) / 1024
128123
except KeyError:
129124
raise RuntimeError(
130-
"You are on Slurm but SLURM_MEM_PER_NODE is not set. "
131-
"It is likely that you are SSHed onto a node (perhaps using VSCode). "
125+
"NOTE: This RuntimeError is expected if you are using VSCode on the cluster!\n\n"
126+
"You are on Slurm but SLURM_MEM_PER_NODE is not set; it is likely "
127+
"that you are SSHed onto a node (perhaps using VSCode?). "
132128
"In this case, dask will assign the total memory of the node to the "
133129
"cluster instead of the allocated memory from the srun call. "
134130
"Pseudopeople should only be used on Slurm directly on the node "
135131
"assigned via an srun (both for pytests as well as actual work)."
136132
)
137133
else:
138-
available_memory = psutil.virtual_memory().total / (1024 ** 3)
139-
assert np.isclose(sum(worker["memory_limit"] / 1024**3 for worker in workers.values()), available_memory, rtol=0.01)
134+
available_memory = psutil.virtual_memory().total / (1024**3)
135+
136+
_check_cluster_attrs(
137+
cluster_name="pseudopeople_dask_cluster",
138+
memory_limit=available_memory,
139+
n_workers=CPU_COUNT,
140+
threads_per_worker=1,
141+
)
140142

141143

142-
def test_set_up_dask_client_custom() -> None:
144+
def test_set_up_dask_client_existing_cluster() -> None:
145+
cluster_name = "custom"
143146
memory_limit = 1 # gb
144147
n_workers = 3
148+
threads_per_worker = 2
149+
150+
# Manually create a cluster
145151
cluster = LocalCluster( # type: ignore[no-untyped-call]
146-
name="custom",
152+
name=cluster_name,
147153
n_workers=n_workers,
148-
threads_per_worker=2,
154+
threads_per_worker=threads_per_worker,
149155
memory_limit=memory_limit * 1024**3,
150156
)
151-
client = cluster.get_client() # type: ignore[no-untyped-call]
157+
cluster.get_client() # type: ignore[no-untyped-call]
158+
_check_cluster_attrs(
159+
cluster_name=cluster_name,
160+
memory_limit=memory_limit * n_workers,
161+
n_workers=n_workers,
162+
threads_per_worker=threads_per_worker
163+
)
164+
165+
# Call the dask client setup function
152166
set_up_dask_client()
153-
client = get_client()
154-
assert client.cluster.name == "custom"
155-
workers = client.scheduler_info()["workers"]
156-
assert len(workers) == 3
157-
assert all(worker["nthreads"] == 2 for worker in workers.values())
158-
assert sum(worker["memory_limit"] / 1024**3 for worker in workers.values()) == memory_limit * n_workers
167+
168+
# Make sure that the cluster hasn't been changed
169+
assert get_client().cluster == cluster
170+
_check_cluster_attrs(
171+
cluster_name=cluster_name,
172+
memory_limit=memory_limit * n_workers,
173+
n_workers=n_workers,
174+
threads_per_worker=threads_per_worker
175+
)
176+
177+
####################
178+
# Helper Functions #
179+
####################
180+
181+
def _check_cluster_attrs(cluster_name: str, memory_limit: int | float, n_workers: int, threads_per_worker: int) -> None:
182+
cluster = get_client().cluster
183+
assert isinstance(cluster, LocalCluster)
184+
assert cluster.name == cluster_name
185+
workers = cluster.scheduler_info["workers"]
186+
assert len(workers) == n_workers
187+
assert all(worker["nthreads"] == threads_per_worker for worker in workers.values())
188+
assert np.isclose(
189+
sum(worker["memory_limit"] / 1024**3 for worker in workers.values()),
190+
memory_limit,
191+
rtol=0.01,
192+
)

0 commit comments

Comments
 (0)