@@ -380,11 +380,6 @@ def __init__(self):
380380
381381 # Lock for accessing coordinator state from background thread
382382 self ._lock = threading .Lock ()
383- # Signaled on every state change that _wait_for_stage cares about
384- # (shard completed, task requeued, fatal error, worker registered).
385- # Allows _wait_for_stage to wake immediately instead of spinning up
386- # to its max 1-second poll interval.
387- self ._stage_event = threading .Event ()
388383
389384 actor_ctx = current_actor ()
390385 self ._name = f"{ actor_ctx .group_name } "
@@ -438,14 +433,13 @@ def register_worker(self, worker_id: str, worker_handle: ActorHandle) -> None:
438433 # the worker as unhealthy via heartbeat and re-registration. If we do not requeue we may silently
439434 # lose tasks.
440435 self ._maybe_requeue_worker_task (worker_id )
441- else :
442- self ._worker_handles [worker_id ] = worker_handle
443- self ._worker_states [worker_id ] = WorkerState .READY
444- self ._last_seen [worker_id ] = time .monotonic ()
436+ return
437+
438+ self ._worker_handles [worker_id ] = worker_handle
439+ self ._worker_states [worker_id ] = WorkerState .READY
440+ self ._last_seen [worker_id ] = time .monotonic ()
445441
446- logger .info ("Worker %s registered, total: %d" , worker_id , len (self ._worker_handles ))
447- # Wake _wait_for_stage so it re-evaluates the alive-worker count.
448- self ._stage_event .set ()
442+ logger .info ("Worker %s registered, total: %d" , worker_id , len (self ._worker_handles ))
449443
450444 def _coordinator_loop (self ) -> None :
451445 """Background loop for heartbeat checking and worker job monitoring."""
@@ -682,9 +676,6 @@ def report_result(
682676 # Zero the in-flight counters but keep the generation watermark
683677 # so late heartbeats from this task are rejected.
684678 self ._worker_counters [worker_id ] = CounterSnapshot .empty (counter_snapshot .generation )
685- # Wake _wait_for_stage immediately so it advances to the next stage
686- # without waiting out the polling backoff interval.
687- self ._stage_event .set ()
688679
689680 def report_error (self , worker_id : str , shard_idx : int , error_info : str ) -> None :
690681 """Worker reports a task failure. Re-queues up to MAX_SHARD_FAILURES."""
@@ -693,8 +684,6 @@ def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
693684 self ._assert_in_flight_consistent (worker_id , shard_idx )
694685 aborted = self ._record_shard_failure (worker_id , ShardFailureKind .TASK , error_info )
695686 self ._worker_states [worker_id ] = WorkerState .DEAD if aborted else WorkerState .READY
696- # Wake _wait_for_stage so it re-evaluates the alive-worker count and fatal error.
697- self ._stage_event .set ()
698687
699688 def heartbeat (self , worker_id : str , counter_snapshot : CounterSnapshot | None = None ) -> None :
700689 self ._last_seen [worker_id ] = time .monotonic ()
@@ -760,8 +749,6 @@ def abort(self, reason: str) -> None:
760749 if self ._fatal_error is None :
761750 logger .error ("Coordinator aborted: %s" , reason )
762751 self ._fatal_error = reason
763- # Wake _wait_for_stage so it raises ZephyrWorkerError immediately.
764- self ._stage_event .set ()
765752
766753 def _start_stage (self , stage_name : str , tasks : list [ShardTask ], is_last_stage : bool = False ) -> None :
767754 """Load a new stage's tasks into the queue."""
@@ -781,9 +768,6 @@ def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: b
781768 # Only reset in-flight worker snapshots; completed snapshots
782769 # accumulate across stages for full pipeline visibility.
783770 self ._worker_counters = {}
784- # Clear after releasing the lock so _wait_for_stage can't miss
785- # a signal set between lock release and the clear call.
786- self ._stage_event .clear ()
787771
788772 def _wait_for_stage (self ) -> None :
789773 """Block until current stage completes or error occurs."""
@@ -792,6 +776,7 @@ def _wait_for_stage(self) -> None:
792776 start_time = time .monotonic ()
793777 all_dead_since : float | None = None
794778 no_workers_timeout = self ._no_workers_timeout
779+ stage_done = threading .Event ()
795780
796781 while True :
797782 with self ._lock :
@@ -835,13 +820,7 @@ def _wait_for_stage(self) -> None:
835820 last_log_completed = completed
836821 backoff .reset ()
837822
838- # Wait for a state-change signal (shard completed, error, worker registered)
839- # with a timeout as backstop for the alive-worker check and log lines.
840- # _stage_event is set by report_result/report_error/abort/register_worker,
841- # so the stage transition is detected within microseconds of the last
842- # shard completing rather than after up to the full backoff interval.
843- self ._stage_event .wait (timeout = backoff .next_interval ())
844- self ._stage_event .clear ()
823+ stage_done .wait (timeout = backoff .next_interval ())
845824
846825 def _collect_results (self ) -> dict [int , TaskResult ]:
847826 """Return results for the completed stage."""
0 commit comments