Skip to content

Commit 812cf80

Browse files
committed
revert: restore execution.py to origin/main (remove local revert of alive_workers)
1 parent 94d7d8b commit 812cf80

1 file changed

Lines changed: 8 additions & 29 deletions

File tree

lib/zephyr/src/zephyr/execution.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,6 @@ def __init__(self):
380380

381381
# Lock for accessing coordinator state from background thread
382382
self._lock = threading.Lock()
383-
# Signaled on every state change that _wait_for_stage cares about
384-
# (shard completed, task requeued, fatal error, worker registered).
385-
# Allows _wait_for_stage to wake immediately instead of spinning up
386-
# to its max 1-second poll interval.
387-
self._stage_event = threading.Event()
388383

389384
actor_ctx = current_actor()
390385
self._name = f"{actor_ctx.group_name}"
@@ -438,14 +433,13 @@ def register_worker(self, worker_id: str, worker_handle: ActorHandle) -> None:
438433
# the worker as unhealthy via heartbeat and re-registration. If we do not requeue we may silently
439434
# lose tasks.
440435
self._maybe_requeue_worker_task(worker_id)
441-
else:
442-
self._worker_handles[worker_id] = worker_handle
443-
self._worker_states[worker_id] = WorkerState.READY
444-
self._last_seen[worker_id] = time.monotonic()
436+
return
437+
438+
self._worker_handles[worker_id] = worker_handle
439+
self._worker_states[worker_id] = WorkerState.READY
440+
self._last_seen[worker_id] = time.monotonic()
445441

446-
logger.info("Worker %s registered, total: %d", worker_id, len(self._worker_handles))
447-
# Wake _wait_for_stage so it re-evaluates the alive-worker count.
448-
self._stage_event.set()
442+
logger.info("Worker %s registered, total: %d", worker_id, len(self._worker_handles))
449443

450444
def _coordinator_loop(self) -> None:
451445
"""Background loop for heartbeat checking and worker job monitoring."""
@@ -682,9 +676,6 @@ def report_result(
682676
# Zero the in-flight counters but keep the generation watermark
683677
# so late heartbeats from this task are rejected.
684678
self._worker_counters[worker_id] = CounterSnapshot.empty(counter_snapshot.generation)
685-
# Wake _wait_for_stage immediately so it advances to the next stage
686-
# without waiting out the polling backoff interval.
687-
self._stage_event.set()
688679

689680
def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
690681
"""Worker reports a task failure. Re-queues up to MAX_SHARD_FAILURES."""
@@ -693,8 +684,6 @@ def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
693684
self._assert_in_flight_consistent(worker_id, shard_idx)
694685
aborted = self._record_shard_failure(worker_id, ShardFailureKind.TASK, error_info)
695686
self._worker_states[worker_id] = WorkerState.DEAD if aborted else WorkerState.READY
696-
# Wake _wait_for_stage so it re-evaluates the alive-worker count and fatal error.
697-
self._stage_event.set()
698687

699688
def heartbeat(self, worker_id: str, counter_snapshot: CounterSnapshot | None = None) -> None:
700689
self._last_seen[worker_id] = time.monotonic()
@@ -760,8 +749,6 @@ def abort(self, reason: str) -> None:
760749
if self._fatal_error is None:
761750
logger.error("Coordinator aborted: %s", reason)
762751
self._fatal_error = reason
763-
# Wake _wait_for_stage so it raises ZephyrWorkerError immediately.
764-
self._stage_event.set()
765752

766753
def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: bool = False) -> None:
767754
"""Load a new stage's tasks into the queue."""
@@ -781,9 +768,6 @@ def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: b
781768
# Only reset in-flight worker snapshots; completed snapshots
782769
# accumulate across stages for full pipeline visibility.
783770
self._worker_counters = {}
784-
# Clear after releasing the lock so _wait_for_stage can't miss
785-
# a signal set between lock release and the clear call.
786-
self._stage_event.clear()
787771

788772
def _wait_for_stage(self) -> None:
789773
"""Block until current stage completes or error occurs."""
@@ -792,6 +776,7 @@ def _wait_for_stage(self) -> None:
792776
start_time = time.monotonic()
793777
all_dead_since: float | None = None
794778
no_workers_timeout = self._no_workers_timeout
779+
stage_done = threading.Event()
795780

796781
while True:
797782
with self._lock:
@@ -835,13 +820,7 @@ def _wait_for_stage(self) -> None:
835820
last_log_completed = completed
836821
backoff.reset()
837822

838-
# Wait for a state-change signal (shard completed, error, worker registered)
839-
# with a timeout as backstop for the alive-worker check and log lines.
840-
# _stage_event is set by report_result/report_error/abort/register_worker,
841-
# so the stage transition is detected within microseconds of the last
842-
# shard completing rather than after up to the full backoff interval.
843-
self._stage_event.wait(timeout=backoff.next_interval())
844-
self._stage_event.clear()
823+
stage_done.wait(timeout=backoff.next_interval())
845824

846825
def _collect_results(self) -> dict[int, TaskResult]:
847826
"""Return results for the completed stage."""

0 commit comments

Comments
 (0)