Skip to content

Commit 4cf9baf

Browse files
authored
Deprecate default value for Client.wait_for_workers (#6942)
1 parent 16748b7 commit 4cf9baf

2 files changed

Lines changed: 48 additions & 3 deletions

File tree

distributed/client.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,7 +1329,9 @@ async def _update_scheduler_info(self):
13291329
except OSError:
13301330
logger.debug("Not able to query scheduler for identity")
13311331

1332-
async def _wait_for_workers(self, n_workers=0, timeout=None):
1332+
async def _wait_for_workers(
1333+
self, n_workers: int, timeout: float | None = None
1334+
) -> None:
13331335
info = await self.scheduler.identity()
13341336
self._scheduler_identity = SchedulerInfo(info)
13351337
if timeout:
@@ -1346,7 +1348,7 @@ def running_workers(info):
13461348
]
13471349
)
13481350

1349-
while n_workers and running_workers(info) < n_workers:
1351+
while running_workers(info) < n_workers:
13501352
if deadline and time() > deadline:
13511353
raise TimeoutError(
13521354
"Only %d/%d workers arrived after %s"
@@ -1356,7 +1358,11 @@ def running_workers(info):
13561358
info = await self.scheduler.identity()
13571359
self._scheduler_identity = SchedulerInfo(info)
13581360

1359-
def wait_for_workers(self, n_workers=0, timeout=None):
1361+
def wait_for_workers(
1362+
self,
1363+
n_workers: int | str = no_default,
1364+
timeout: float | None = None,
1365+
) -> None:
13601366
"""Blocking call to wait for n workers before continuing
13611367
13621368
Parameters
@@ -1367,6 +1373,16 @@ def wait_for_workers(self, n_workers=0, timeout=None):
13671373
Time in seconds after which to raise a
13681374
``dask.distributed.TimeoutError``
13691375
"""
1376+
if n_workers is no_default:
1377+
warnings.warn(
1378+
"Please specify the `n_workers` argument when using `Client.wait_for_workers`. Not specifying `n_workers` will no longer be supported in future versions.",
1379+
FutureWarning,
1380+
)
1381+
n_workers = 0
1382+
elif not isinstance(n_workers, int) or n_workers < 1:
1383+
raise ValueError(
1384+
f"`n_workers` must be a positive integer. Instead got {n_workers}."
1385+
)
13701386
return self.sync(self._wait_for_workers, n_workers, timeout=timeout)
13711387

13721388
def _heartbeat(self):

distributed/tests/test_client.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7547,3 +7547,32 @@ def __init__(self, *args, **kwargs):
75477547
(DeprecationWarning, "The io_loop property is deprecated"),
75487548
(DeprecationWarning, "setting the loop property is deprecated"),
75497549
]
7550+
7551+
7552+
@gen_cluster(client=True, nthreads=[])
7553+
async def test_wait_for_workers_no_default(c, s):
7554+
with pytest.warns(
7555+
FutureWarning,
7556+
match="specify the `n_workers` argument when using `Client.wait_for_workers`",
7557+
):
7558+
await c.wait_for_workers()
7559+
7560+
7561+
@pytest.mark.parametrize(
7562+
"value, exception",
7563+
[
7564+
(None, ValueError),
7565+
(0, ValueError),
7566+
(1.0, ValueError),
7567+
(1, None),
7568+
(2, None),
7569+
],
7570+
)
7571+
@gen_cluster(client=True)
7572+
async def test_wait_for_workers_n_workers_value_check(c, s, a, b, value, exception):
7573+
if exception:
7574+
ctx = pytest.raises(exception)
7575+
else:
7576+
ctx = nullcontext()
7577+
with ctx:
7578+
await c.wait_for_workers(value)

0 commit comments

Comments
 (0)