Skip to content

Commit 04e4969

Browse files
committed
add missing support for grouped workers in SpecCluster._update_worker_status
1 parent 0f0adef commit 04e4969

File tree

1 file changed

+36
-16
lines changed

1 file changed

+36
-16
lines changed

distributed/deploy/spec.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -397,25 +397,45 @@ async def _correct_state_internal(self) -> None:
397397

398398
def _update_worker_status(self, op, msg):
399399
if op == "remove":
400-
name = self.scheduler_info["workers"][msg]["name"]
400+
removed_worker_name = self.scheduler_info["workers"][msg]["name"]
401401

402+
# Closure to handle removal of a worker from the cluster
402403
def f():
403-
if (
404-
name in self.workers
405-
and msg not in self.scheduler_info["workers"]
406-
and not any(
407-
d["name"] == name
408-
for d in self.scheduler_info["workers"].values()
409-
)
410-
):
411-
self._futures.add(asyncio.ensure_future(self.workers[name].close()))
412-
del self.workers[name]
413-
414-
delay = parse_timedelta(
415-
dask.config.get("distributed.deploy.lost-worker-timeout")
416-
)
417-
404+
# Check if worker is truly gone from scheduler
405+
active_workers = {d["name"] for d in self.scheduler_info.get("workers", {}).values()}
406+
if removed_worker_name in active_workers:
407+
return
408+
409+
# Build mapping from individual worker names to their worker spec names
410+
# - For non-grouped workers: worker name == spec name (1:1)
411+
# - For grouped workers: multiple workers map to one spec entry
412+
worker_to_spec = {}
413+
for worker_spec_name, spec in self.worker_spec.items():
414+
if "group" not in spec:
415+
worker_to_spec[worker_spec_name] = worker_spec_name
416+
else:
417+
grouped_workers = {
418+
str(worker_spec_name) + suffix: worker_spec_name
419+
for suffix in spec["group"]
420+
}
421+
worker_to_spec.update(grouped_workers)
422+
423+
# Find and remove the worker spec entry
424+
# Note: For grouped workers, we remove the entire spec when ANY worker dies.
425+
# This assumes that partial failure means the whole group is compromised
426+
# (e.g., in HPC systems, if one process in a multi-process job fails, the
427+
# entire job allocation is typically lost).
428+
worker_spec_name = worker_to_spec.get(removed_worker_name)
429+
if worker_spec_name and worker_spec_name in self.worker_spec:
430+
# Close and remove the worker object
431+
if worker_spec_name in self.workers:
432+
self._futures.add(asyncio.ensure_future(self.workers[worker_spec_name].close()))
433+
del self.workers[worker_spec_name]
434+
del self.worker_spec[worker_spec_name]
435+
436+
delay = parse_timedelta(dask.config.get("distributed.deploy.lost-worker-timeout"))
418437
asyncio.get_running_loop().call_later(delay, f)
438+
419439
super()._update_worker_status(op, msg)
420440

421441
def __await__(self: Self) -> Generator[Any, Any, Self]:

0 commit comments

Comments
 (0)