Skip to content

Commit f98e020

Browse files
hsuhanooiclaude
andcommitted
zephyr: fix dead threading.Event in _wait_for_stage, replacing polling with event-driven wakeup
_wait_for_stage created a local threading.Event that was never signaled, making its wait() call a pure sleep of up to 1 second per iteration. Each stage transition (scatter→reduce→fold) paid this latency needlessly. Replace it with self._stage_event, signaled by every coordinator method that changes stage-relevant state: report_result, report_error, abort, and register_worker. _start_stage clears the event so signals from the previous stage don't bleed over. The backoff timeout is retained as a backstop for the alive-worker check and periodic log lines. In the normal (no-failure) path, stage transitions now complete within microseconds of the last shard result arriving. Benchmark (8 shards, 3-stage group_by pipeline, 70MB synthetic data): Before: 14.9s / 14.6s / 17.1s (avg ~15.6s, high variance) After: 13.4s / 13.7s / 13.4s (avg ~13.5s, low variance) ~13% faster; ~1.5s saved from eliminating poll-interval latency at each of the 3 stage boundaries. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 7d57055 commit f98e020

1 file changed

Lines changed: 29 additions & 8 deletions

File tree

lib/zephyr/src/zephyr/execution.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ def __init__(self):
380380

381381
# Lock for accessing coordinator state from background thread
382382
self._lock = threading.Lock()
383+
# Signaled on every state change that _wait_for_stage cares about
384+
# (shard completed, task requeued, fatal error, worker registered).
385+
# Allows _wait_for_stage to wake immediately instead of spinning up
386+
# to its max 1-second poll interval.
387+
self._stage_event = threading.Event()
383388

384389
actor_ctx = current_actor()
385390
self._name = f"{actor_ctx.group_name}"
@@ -433,13 +438,14 @@ def register_worker(self, worker_id: str, worker_handle: ActorHandle) -> None:
433438
# the worker as unhealthy via heartbeat and re-registration. If we do not requeue we may silently
434439
# lose tasks.
435440
self._maybe_requeue_worker_task(worker_id)
436-
return
437-
438-
self._worker_handles[worker_id] = worker_handle
439-
self._worker_states[worker_id] = WorkerState.READY
440-
self._last_seen[worker_id] = time.monotonic()
441+
else:
442+
self._worker_handles[worker_id] = worker_handle
443+
self._worker_states[worker_id] = WorkerState.READY
444+
self._last_seen[worker_id] = time.monotonic()
441445

442-
logger.info("Worker %s registered, total: %d", worker_id, len(self._worker_handles))
446+
logger.info("Worker %s registered, total: %d", worker_id, len(self._worker_handles))
447+
# Wake _wait_for_stage so it re-evaluates the alive-worker count.
448+
self._stage_event.set()
443449

444450
def _coordinator_loop(self) -> None:
445451
"""Background loop for heartbeat checking and worker job monitoring."""
@@ -676,6 +682,9 @@ def report_result(
676682
# Zero the in-flight counters but keep the generation watermark
677683
# so late heartbeats from this task are rejected.
678684
self._worker_counters[worker_id] = CounterSnapshot.empty(counter_snapshot.generation)
685+
# Wake _wait_for_stage immediately so it advances to the next stage
686+
# without waiting out the polling backoff interval.
687+
self._stage_event.set()
679688

680689
def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
681690
"""Worker reports a task failure. Re-queues up to MAX_SHARD_FAILURES."""
@@ -684,6 +693,8 @@ def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
684693
self._assert_in_flight_consistent(worker_id, shard_idx)
685694
aborted = self._record_shard_failure(worker_id, ShardFailureKind.TASK, error_info)
686695
self._worker_states[worker_id] = WorkerState.DEAD if aborted else WorkerState.READY
696+
# Wake _wait_for_stage so it re-evaluates the alive-worker count and fatal error.
697+
self._stage_event.set()
687698

688699
def heartbeat(self, worker_id: str, counter_snapshot: CounterSnapshot | None = None) -> None:
689700
self._last_seen[worker_id] = time.monotonic()
@@ -749,6 +760,8 @@ def abort(self, reason: str) -> None:
749760
if self._fatal_error is None:
750761
logger.error("Coordinator aborted: %s", reason)
751762
self._fatal_error = reason
763+
# Wake _wait_for_stage so it raises ZephyrWorkerError immediately.
764+
self._stage_event.set()
752765

753766
def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: bool = False) -> None:
754767
"""Load a new stage's tasks into the queue."""
@@ -768,6 +781,9 @@ def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: b
768781
# Only reset in-flight worker snapshots; completed snapshots
769782
# accumulate across stages for full pipeline visibility.
770783
self._worker_counters = {}
784+
# Clear after releasing the lock so _wait_for_stage can't miss
785+
# a signal set between lock release and the clear call.
786+
self._stage_event.clear()
771787

772788
def _wait_for_stage(self) -> None:
773789
"""Block until current stage completes or error occurs."""
@@ -776,7 +792,6 @@ def _wait_for_stage(self) -> None:
776792
start_time = time.monotonic()
777793
all_dead_since: float | None = None
778794
no_workers_timeout = self._no_workers_timeout
779-
stage_done = threading.Event()
780795

781796
while True:
782797
with self._lock:
@@ -820,7 +835,13 @@ def _wait_for_stage(self) -> None:
820835
last_log_completed = completed
821836
backoff.reset()
822837

823-
stage_done.wait(timeout=backoff.next_interval())
838+
# Wait for a state-change signal (shard completed, error, worker registered)
839+
# with a timeout as backstop for the alive-worker check and log lines.
840+
# _stage_event is set by report_result/report_error/abort/register_worker,
841+
# so the stage transition is detected within microseconds of the last
842+
# shard completing rather than after up to the full backoff interval.
843+
self._stage_event.wait(timeout=backoff.next_interval())
844+
self._stage_event.clear()
824845

825846
def _collect_results(self) -> dict[int, TaskResult]:
826847
"""Return results for the completed stage."""

0 commit comments

Comments
 (0)