Skip to content

Commit 358402d

Browse files
authored
Refactor timeouts in start cluster (#9062)
1 parent bf6d37f commit 358402d

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

distributed/tests/test_tls_functional.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77

88
import asyncio
99

10-
import pytest
1110
from tlz import merge
1211

1312
from distributed import Client, Nanny, Queue, Scheduler, Worker, wait, worker_client
14-
from distributed.compatibility import LINUX
1513
from distributed.core import Status
1614
from distributed.metrics import time
1715
from distributed.utils_test import (
@@ -91,7 +89,6 @@ async def test_scatter(c, s, a, b):
9189
assert yy == [20]
9290

9391

94-
@pytest.mark.skipif(LINUX, reason="https://github.com/dask/distributed/issues/9052")
9592
@gen_tls_cluster(client=True, Worker=Nanny)
9693
async def test_nanny(c, s, a, b):
9794
assert s.address.startswith("tls://")
@@ -191,7 +188,6 @@ def mysum():
191188
assert result == 30 * 29
192189

193190

194-
@pytest.mark.skipif(LINUX, reason="https://github.com/dask/distributed/issues/9052")
195191
@gen_tls_cluster(client=True, Worker=Nanny)
196192
async def test_retire_workers(c, s, a, b):
197193
assert set(s.workers) == {a.worker_address, b.worker_address}

distributed/utils_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ async def start_cluster(
761761
Worker: type[ServerNode] = Worker,
762762
scheduler_kwargs: dict[str, Any] | None = None,
763763
worker_kwargs: dict[str, Any] | None = None,
764+
timeout: float = _TEST_TIMEOUT // 4,
764765
) -> tuple[Scheduler, list[ServerNode]]:
765766
scheduler_kwargs = scheduler_kwargs or {}
766767
worker_kwargs = worker_kwargs or {}
@@ -797,13 +798,15 @@ async def start_cluster(
797798
or any(comm.comm is None for comm in s.stream_comms.values())
798799
):
799800
await asyncio.sleep(0.01)
800-
if time() > start + 30:
801+
if time() > start + timeout:
801802
await asyncio.gather(*(w.close(timeout=1) for w in workers))
802803
await s.close()
803804
check_invalid_worker_transitions(s)
804805
check_invalid_task_states(s)
805806
check_worker_fail_hard(s)
806-
raise TimeoutError("Cluster creation timeout")
807+
raise TimeoutError(
808+
"Cluster creation timeout. Workers did not come up and register in time."
809+
)
807810
return s, workers
808811

809812

@@ -969,26 +972,25 @@ async def _cluster_factory():
969972
workers = []
970973
s = None
971974
try:
972-
for _ in range(60):
975+
while True:
973976
try:
977+
if not deadline.remaining:
978+
raise TimeoutError("Timeout on cluster creation")
974979
s, ws = await start_cluster(
975980
nthreads,
976981
scheduler,
977982
security=security,
978983
Worker=Worker,
979984
scheduler_kwargs=scheduler_kwargs,
980-
worker_kwargs=merge(
981-
{"death_timeout": min(15, int(deadline.remaining))},
982-
worker_kwargs,
983-
),
985+
worker_kwargs=worker_kwargs,
986+
timeout=timeout // 4,
984987
)
985988
except Exception as e:
986989
logger.error(
987990
"Failed to start gen_cluster: "
988991
f"{e.__class__.__name__}: {e}; retrying",
989992
exc_info=True,
990993
)
991-
await asyncio.sleep(1)
992994
else:
993995
workers[:] = ws
994996
break

0 commit comments

Comments
 (0)