Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions lib/zephyr/src/zephyr/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,11 @@ def __init__(self):

# Lock for accessing coordinator state from background thread
self._lock = threading.Lock()
# Signaled on every state change that _wait_for_stage cares about
# (shard completed, task requeued, fatal error, worker registered).
# Allows _wait_for_stage to wake immediately instead of spinning up
# to its max 1-second poll interval.
self._stage_event = threading.Event()

actor_ctx = current_actor()
self._name = f"{actor_ctx.group_name}"
Expand Down Expand Up @@ -433,13 +438,14 @@ def register_worker(self, worker_id: str, worker_handle: ActorHandle) -> None:
# the worker as unhealthy via heartbeat and re-registration. If we do not requeue we may silently
# lose tasks.
self._maybe_requeue_worker_task(worker_id)
return

self._worker_handles[worker_id] = worker_handle
self._worker_states[worker_id] = WorkerState.READY
self._last_seen[worker_id] = time.monotonic()
else:
self._worker_handles[worker_id] = worker_handle
self._worker_states[worker_id] = WorkerState.READY
self._last_seen[worker_id] = time.monotonic()

logger.info("Worker %s registered, total: %d", worker_id, len(self._worker_handles))
logger.info("Worker %s registered, total: %d", worker_id, len(self._worker_handles))
# Wake _wait_for_stage so it re-evaluates the alive-worker count.
self._stage_event.set()

def _coordinator_loop(self) -> None:
"""Background loop for heartbeat checking and worker job monitoring."""
Expand Down Expand Up @@ -676,6 +682,9 @@ def report_result(
# Zero the in-flight counters but keep the generation watermark
# so late heartbeats from this task are rejected.
self._worker_counters[worker_id] = CounterSnapshot.empty(counter_snapshot.generation)
# Wake _wait_for_stage immediately so it advances to the next stage
# without waiting out the polling backoff interval.
self._stage_event.set()

def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
"""Worker reports a task failure. Re-queues up to MAX_SHARD_FAILURES."""
Expand All @@ -684,6 +693,8 @@ def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
self._assert_in_flight_consistent(worker_id, shard_idx)
aborted = self._record_shard_failure(worker_id, ShardFailureKind.TASK, error_info)
self._worker_states[worker_id] = WorkerState.DEAD if aborted else WorkerState.READY
# Wake _wait_for_stage so it re-evaluates the alive-worker count and fatal error.
self._stage_event.set()

def heartbeat(self, worker_id: str, counter_snapshot: CounterSnapshot | None = None) -> None:
self._last_seen[worker_id] = time.monotonic()
Expand Down Expand Up @@ -749,6 +760,8 @@ def abort(self, reason: str) -> None:
if self._fatal_error is None:
logger.error("Coordinator aborted: %s", reason)
self._fatal_error = reason
# Wake _wait_for_stage so it raises ZephyrWorkerError immediately.
self._stage_event.set()

def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: bool = False) -> None:
"""Load a new stage's tasks into the queue."""
Expand All @@ -768,6 +781,9 @@ def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: b
# Only reset in-flight worker snapshots; completed snapshots
# accumulate across stages for full pipeline visibility.
self._worker_counters = {}
# Clear after releasing the lock so _wait_for_stage can't miss
# a signal set between lock release and the clear call.
self._stage_event.clear()

def _wait_for_stage(self) -> None:
"""Block until current stage completes or error occurs."""
Expand All @@ -776,7 +792,6 @@ def _wait_for_stage(self) -> None:
start_time = time.monotonic()
all_dead_since: float | None = None
no_workers_timeout = self._no_workers_timeout
stage_done = threading.Event()

while True:
with self._lock:
Expand Down Expand Up @@ -820,7 +835,13 @@ def _wait_for_stage(self) -> None:
last_log_completed = completed
backoff.reset()

stage_done.wait(timeout=backoff.next_interval())
# Wait for a state-change signal (shard completed, error, worker registered)
# with a timeout as backstop for the alive-worker check and log lines.
# _stage_event is set by report_result/report_error/abort/register_worker,
# so the stage transition is detected within microseconds of the last
# shard completing rather than after up to the full backoff interval.
self._stage_event.wait(timeout=backoff.next_interval())
self._stage_event.clear()

def _collect_results(self) -> dict[int, TaskResult]:
"""Return results for the completed stage."""
Expand Down