@@ -377,11 +377,14 @@ def __init__(self):
377377 self ._is_last_stage : bool = False
378378 self ._initialized : bool = False
379379 self ._pipeline_running : bool = False
380- # O(1) count of workers in READY or BUSY state, maintained by _set_worker_state.
381- self ._alive_workers : int = 0
382380
383381 # Lock for accessing coordinator state from background thread
384382 self ._lock = threading .Lock ()
383+ # Signaled on every state change that _wait_for_stage cares about
384+ # (shard completed, task requeued, fatal error, worker registered).
385+ # Allows _wait_for_stage to wake immediately instead of spinning up
386+ # to its max 1-second poll interval.
387+ self ._stage_event = threading .Event ()
385388
386389 actor_ctx = current_actor ()
387390 self ._name = f"{ actor_ctx .group_name } "
@@ -419,20 +422,6 @@ def set_worker_group(self, worker_group: Any) -> None:
419422 """Set the worker ActorGroup so the coordinator can detect permanent worker death."""
420423 self ._worker_group = worker_group
421424
422- def _set_worker_state (self , worker_id : str , new_state : WorkerState ) -> None :
423- """Transition a worker's state and keep _alive_workers in sync.
424-
425- Must be called with self._lock held.
426- """
427- old_state = self ._worker_states .get (worker_id )
428- was_alive = old_state in {WorkerState .READY , WorkerState .BUSY }
429- is_alive = new_state in {WorkerState .READY , WorkerState .BUSY }
430- if was_alive and not is_alive :
431- self ._alive_workers -= 1
432- elif not was_alive and is_alive :
433- self ._alive_workers += 1
434- self ._worker_states [worker_id ] = new_state
435-
436425 def register_worker (self , worker_id : str , worker_handle : ActorHandle ) -> None :
437426 """Called by workers when they come online to register with coordinator.
438427
@@ -443,19 +432,20 @@ def register_worker(self, worker_id: str, worker_handle: ActorHandle) -> None:
443432 if worker_id in self ._worker_handles :
444433 logger .info ("Worker %s re-registering (likely reconstructed), updating handle" , worker_id )
445434 self ._worker_handles [worker_id ] = worker_handle
446- self ._set_worker_state ( worker_id , WorkerState .READY )
435+ self ._worker_states [ worker_id ] = WorkerState .READY
447436 self ._last_seen [worker_id ] = time .monotonic ()
448437 # NOTE: if there was a task assigned to the worker, there's a race condition between marking
449438 # the worker as unhealthy via heartbeat and re-registration. If we do not requeue we may silently
450439 # lose tasks.
451440 self ._maybe_requeue_worker_task (worker_id )
452- return
453-
454- self ._worker_handles [worker_id ] = worker_handle
455- self ._set_worker_state (worker_id , WorkerState .READY )
456- self ._last_seen [worker_id ] = time .monotonic ()
441+ else :
442+ self ._worker_handles [worker_id ] = worker_handle
443+ self ._worker_states [worker_id ] = WorkerState .READY
444+ self ._last_seen [worker_id ] = time .monotonic ()
457445
458- logger .info ("Worker %s registered, total: %d" , worker_id , len (self ._worker_handles ))
446+ logger .info ("Worker %s registered, total: %d" , worker_id , len (self ._worker_handles ))
447+ # Wake _wait_for_stage so it re-evaluates the alive-worker count.
448+ self ._stage_event .set ()
459449
460450 def _coordinator_loop (self ) -> None :
461451 """Background loop for heartbeat checking and worker job monitoring."""
@@ -504,10 +494,10 @@ def _has_active_execution(self) -> bool:
504494
505495 def _log_status (self ) -> None :
506496 with self ._lock :
507- alive = self ._alive_workers
508- total_workers = len (self ._worker_handles )
497+ states = list (self ._worker_states .values ())
509498 retried = {idx : att for idx , att in self ._task_attempts .items () if att > 0 }
510- dead = total_workers - alive
499+ alive = sum (1 for s in states if s in {WorkerState .READY , WorkerState .BUSY })
500+ dead = sum (1 for s in states if s in {WorkerState .FAILED , WorkerState .DEAD })
511501 logger .info (
512502 "[%s] [%s] %d/%d complete, %d in-flight, %d queued, %d/%d workers alive, %d dead" ,
513503 self ._execution_id ,
@@ -517,7 +507,7 @@ def _log_status(self) -> None:
517507 len (self ._in_flight ),
518508 len (self ._task_queue ),
519509 alive ,
520- total_workers ,
510+ len ( self . _worker_handles ) ,
521511 dead ,
522512 )
523513 if retried :
@@ -605,7 +595,7 @@ def _check_worker_heartbeats(self, timeout: float = 120.0) -> None:
605595 for worker_id , last in list (self ._last_seen .items ()):
606596 if now - last > timeout and self ._worker_states .get (worker_id ) not in {WorkerState .FAILED , WorkerState .DEAD }:
607597 logger .warning (f"Zephyr worker { worker_id } failed to heartbeat within timeout ({ now - last :.1f} s)" )
608- self ._set_worker_state ( worker_id , WorkerState .FAILED )
598+ self ._worker_states [ worker_id ] = WorkerState .FAILED
609599 self ._maybe_requeue_worker_task (worker_id )
610600
611601 def pull_task (self , worker_id : str ) -> tuple [ShardTask , int , dict ] | str | None :
@@ -618,10 +608,10 @@ def pull_task(self, worker_id: str) -> tuple[ShardTask, int, dict] | str | None:
618608 """
619609 with self ._lock :
620610 self ._last_seen [worker_id ] = time .monotonic ()
621- self ._set_worker_state ( worker_id , WorkerState .READY )
611+ self ._worker_states [ worker_id ] = WorkerState .READY
622612
623613 if self ._shutdown_event .is_set ():
624- self ._set_worker_state ( worker_id , WorkerState .DEAD )
614+ self ._worker_states [ worker_id ] = WorkerState .DEAD
625615 return "SHUTDOWN"
626616
627617 if self ._fatal_error :
@@ -634,14 +624,14 @@ def pull_task(self, worker_id: str) -> tuple[ShardTask, int, dict] | str | None:
634624 # restarts the worker which re-registers and picks it up.
635625 # _check_worker_group() detects permanent worker-job death
636626 # as a failsafe so we never deadlock.
637- self ._set_worker_state ( worker_id , WorkerState .DEAD )
627+ self ._worker_states [ worker_id ] = WorkerState .DEAD
638628 return "SHUTDOWN"
639629 return None
640630
641631 task = self ._task_queue .popleft ()
642632 attempt = self ._task_attempts [task .shard_idx ]
643633 self ._in_flight [worker_id ] = (task , attempt )
644- self ._set_worker_state ( worker_id , WorkerState .BUSY )
634+ self ._worker_states [ worker_id ] = WorkerState .BUSY
645635
646636 config = {
647637 "chunk_prefix" : self ._chunk_prefix ,
@@ -687,19 +677,24 @@ def report_result(
687677 self ._results [shard_idx ] = result
688678 self ._completed_shards += 1
689679 self ._in_flight .pop (worker_id , None )
690- self ._set_worker_state ( worker_id , WorkerState .READY )
680+ self ._worker_states [ worker_id ] = WorkerState .READY
691681 self ._completed_counters .append (counter_snapshot )
692682 # Zero the in-flight counters but keep the generation watermark
693683 # so late heartbeats from this task are rejected.
694684 self ._worker_counters [worker_id ] = CounterSnapshot .empty (counter_snapshot .generation )
685+ # Wake _wait_for_stage immediately so it advances to the next stage
686+ # without waiting out the polling backoff interval.
687+ self ._stage_event .set ()
695688
696689 def report_error (self , worker_id : str , shard_idx : int , error_info : str ) -> None :
697690 """Worker reports a task failure. Re-queues up to MAX_SHARD_FAILURES."""
698691 with self ._lock :
699692 self ._last_seen [worker_id ] = time .monotonic ()
700693 self ._assert_in_flight_consistent (worker_id , shard_idx )
701694 aborted = self ._record_shard_failure (worker_id , ShardFailureKind .TASK , error_info )
702- self ._set_worker_state (worker_id , WorkerState .DEAD if aborted else WorkerState .READY )
695+ self ._worker_states [worker_id ] = WorkerState .DEAD if aborted else WorkerState .READY
696+ # Wake _wait_for_stage so it re-evaluates the alive-worker count and fatal error.
697+ self ._stage_event .set ()
703698
704699 def heartbeat (self , worker_id : str , counter_snapshot : CounterSnapshot | None = None ) -> None :
705700 self ._last_seen [worker_id ] = time .monotonic ()
@@ -765,6 +760,8 @@ def abort(self, reason: str) -> None:
765760 if self ._fatal_error is None :
766761 logger .error ("Coordinator aborted: %s" , reason )
767762 self ._fatal_error = reason
763+ # Wake _wait_for_stage so it raises ZephyrWorkerError immediately.
764+ self ._stage_event .set ()
768765
769766 def _start_stage (self , stage_name : str , tasks : list [ShardTask ], is_last_stage : bool = False ) -> None :
770767 """Load a new stage's tasks into the queue."""
@@ -784,6 +781,9 @@ def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: b
784781 # Only reset in-flight worker snapshots; completed snapshots
785782 # accumulate across stages for full pipeline visibility.
786783 self ._worker_counters = {}
784+ # Clear after releasing the lock so _wait_for_stage can't miss
785+ # a signal set between lock release and the clear call.
786+ self ._stage_event .clear ()
787787
788788 def _wait_for_stage (self ) -> None :
789789 """Block until current stage completes or error occurs."""
@@ -792,7 +792,6 @@ def _wait_for_stage(self) -> None:
792792 start_time = time .monotonic ()
793793 all_dead_since : float | None = None
794794 no_workers_timeout = self ._no_workers_timeout
795- stage_done = threading .Event ()
796795
797796 while True :
798797 with self ._lock :
@@ -805,9 +804,11 @@ def _wait_for_stage(self) -> None:
805804 if completed >= total :
806805 return
807806
808- # _alive_workers is kept in sync by _set_worker_state on every
809- # state transition, avoiding a per-wakeup O(n_workers) scan.
810- alive_workers = self ._alive_workers
807+ # Count alive workers (READY or BUSY), not just total registered.
808+ # Dead/failed workers stay in _worker_handles but can't make progress.
809+ alive_workers = sum (
810+ 1 for s in self ._worker_states .values () if s in {WorkerState .READY , WorkerState .BUSY }
811+ )
811812
812813 if alive_workers == 0 :
813814 now = time .monotonic ()
@@ -834,7 +835,13 @@ def _wait_for_stage(self) -> None:
834835 last_log_completed = completed
835836 backoff .reset ()
836837
837- stage_done .wait (timeout = backoff .next_interval ())
838+ # Wait for a state-change signal (shard completed, error, worker registered)
839+ # with a timeout as backstop for the alive-worker check and log lines.
840+ # _stage_event is set by report_result/report_error/abort/register_worker,
841+ # so the stage transition is detected within microseconds of the last
842+ # shard completing rather than after up to the full backoff interval.
843+ self ._stage_event .wait (timeout = backoff .next_interval ())
844+ self ._stage_event .clear ()
838845
839846 def _collect_results (self ) -> dict [int , TaskResult ]:
840847 """Return results for the completed stage."""
0 commit comments