@@ -397,25 +397,45 @@ async def _correct_state_internal(self) -> None:
397397
398398 def _update_worker_status (self , op , msg ):
399399 if op == "remove" :
400- name = self .scheduler_info ["workers" ][msg ]["name" ]
400+ removed_worker_name = self .scheduler_info ["workers" ][msg ]["name" ]
401401
402+ # Closure to handle removal of a worker from the cluster
402403 def f ():
403- if (
404- name in self .workers
405- and msg not in self .scheduler_info ["workers" ]
406- and not any (
407- d ["name" ] == name
408- for d in self .scheduler_info ["workers" ].values ()
409- )
410- ):
411- self ._futures .add (asyncio .ensure_future (self .workers [name ].close ()))
412- del self .workers [name ]
413-
414- delay = parse_timedelta (
415- dask .config .get ("distributed.deploy.lost-worker-timeout" )
416- )
417-
404+ # Check if worker is truly gone from scheduler
405+ active_workers = {d ["name" ] for d in self .scheduler_info .get ("workers" , {}).values ()}
406+ if removed_worker_name in active_workers :
407+ return
408+
409+ # Build mapping from individual worker names to their worker spec names
410+ # - For non-grouped workers: worker name == spec name (1:1)
411+ # - For grouped workers: multiple workers map to one spec entry
412+ worker_to_spec = {}
413+ for worker_spec_name , spec in self .worker_spec .items ():
414+ if "group" not in spec :
415+ worker_to_spec [worker_spec_name ] = worker_spec_name
416+ else :
417+ grouped_workers = {
418+ str (worker_spec_name ) + suffix : worker_spec_name
419+ for suffix in spec ["group" ]
420+ }
421+ worker_to_spec .update (grouped_workers )
422+
423+ # Find and remove the worker spec entry
424+ # Note: For grouped workers, we remove the entire spec when ANY worker dies.
425+ # This assumes that partial failure means the whole group is compromised
426+ # (e.g., in HPC systems, if one process in a multi-process job fails, the
427+ # entire job allocation is typically lost).
428+ worker_spec_name = worker_to_spec .get (removed_worker_name )
429+ if worker_spec_name and worker_spec_name in self .worker_spec :
430+ # Close and remove the worker object
431+ if worker_spec_name in self .workers :
432+ self ._futures .add (asyncio .ensure_future (self .workers [worker_spec_name ].close ()))
433+ del self .workers [worker_spec_name ]
434+ del self .worker_spec [worker_spec_name ]
435+
436+ delay = parse_timedelta (dask .config .get ("distributed.deploy.lost-worker-timeout" ))
418437 asyncio .get_running_loop ().call_later (delay , f )
438+
419439 super ()._update_worker_status (op , msg )
420440
421441 def __await__ (self : Self ) -> Generator [Any , Any , Self ]:
0 commit comments