Skip to content

Commit 6fa046b

Browse files
committed
ensure workers do not kill on restart
1 parent 66ced13 commit 6fa046b

13 files changed

+381
-373
lines changed

distributed/comm/tcp.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,9 @@ async def _handle_stream(self, stream, address):
665665
try:
666666
await self.on_connection(comm)
667667
except CommClosedError:
668-
logger.info("Connection from %s closed before handshake completed", address)
668+
logger.debug(
669+
"Connection from %s closed before handshake completed", address
670+
)
669671
return
670672

671673
await self.comm_handler(comm)

distributed/deploy/spec.py

+43-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from collections.abc import Awaitable, Generator
1111
from contextlib import suppress
1212
from inspect import isawaitable
13+
from time import time
1314
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
1415

1516
from tornado import gen
@@ -389,28 +390,64 @@ async def _correct_state_internal(self) -> None:
389390
# proper teardown.
390391
await asyncio.gather(*worker_futs)
391392

392-
def _update_worker_status(self, op, msg):
393+
def _update_worker_status(self, op, worker_addr):
393394
if op == "remove":
394-
name = self.scheduler_info["workers"][msg]["name"]
395+
worker_info = self.scheduler_info["workers"][worker_addr].copy()
396+
name = worker_info["name"]
397+
398+
from distributed import Nanny, Worker
395399

396400
def f():
401+
# FIXME: SpecCluster is tracking workers by `name`` which are
402+
# not necessarily unique.
403+
# Clusters with Nannies (default) are susceptible to falsely
404+
# removing the Nannies on restart due to this logic since the
405+
# restart emits a op==remove signal on the worker address but
406+
# the SpecCluster only tracks the names, i.e. after
407+
# `lost-worker-timeout` the Nanny is still around and this logic
408+
# could trigger a false close. The below code should handle this
409+
# but it would be cleaner if the cluster tracked by address
410+
# instead of name just like the scheduler does
397411
if (
398412
name in self.workers
399-
and msg not in self.scheduler_info["workers"]
413+
and worker_addr not in self.scheduler_info["workers"]
400414
and not any(
401415
d["name"] == name
402416
for d in self.scheduler_info["workers"].values()
403417
)
404418
):
405-
self._futures.add(asyncio.ensure_future(self.workers[name].close()))
406-
del self.workers[name]
419+
w = self.workers[name]
420+
421+
async def remove_worker():
422+
await w.close(reason=f"lost-worker-timeout-{time()}")
423+
self.workers.pop(name, None)
424+
425+
if (
426+
worker_info["type"] == "Worker"
427+
and (isinstance(w, Nanny) and w.worker_address == worker_addr)
428+
or (isinstance(w, Worker) and w.address == worker_addr)
429+
):
430+
self._futures.add(
431+
asyncio.create_task(
432+
remove_worker(),
433+
name="remove-worker-lost-worker-timeout",
434+
)
435+
)
436+
elif worker_info["type"] == "Nanny":
437+
# This should never happen
438+
logger.critical(
439+
"Unespected signal encountered. WorkerStatusPlugin "
440+
"emitted a op==remove signal for a Nanny which "
441+
"should not happen. This might cause a lingering "
442+
"Nanny process."
443+
)
407444

408445
delay = parse_timedelta(
409446
dask.config.get("distributed.deploy.lost-worker-timeout")
410447
)
411448

412449
asyncio.get_running_loop().call_later(delay, f)
413-
super()._update_worker_status(op, msg)
450+
super()._update_worker_status(op, worker_addr)
414451

415452
def __await__(self: Self) -> Generator[Any, Any, Self]:
416453
async def _() -> Self:

distributed/deploy/tests/test_local.py

+15
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pytest
1212
from tornado.httpclient import AsyncHTTPClient
1313

14+
import dask
1415
from dask.system import CPU_COUNT
1516

1617
from distributed import Client, LocalCluster, Nanny, Worker, get_client
@@ -1285,3 +1286,17 @@ def test_localcluster_get_client(loop):
12851286
with Client(cluster) as client2:
12861287
assert client1 != client2
12871288
assert client2 == cluster.get_client()
1289+
1290+
1291+
@pytest.mark.slow()
1292+
def test_localcluster_restart(loop):
1293+
with (
1294+
dask.config.set({"distributed.deploy.lost-worker-timeout": "0.5s"}),
1295+
LocalCluster(asynchronous=False, dashboard_address=":0", loop=loop) as cluster,
1296+
cluster.get_client() as client,
1297+
):
1298+
nworkers = len(client.run(lambda: None))
1299+
for _ in range(10):
1300+
assert len(client.run(lambda: None)) == nworkers
1301+
client.restart()
1302+
assert len(client.run(lambda: None)) == nworkers

0 commit comments

Comments
 (0)