Skip to content

Commit 816ba2f

Browse files
committed
revert: remove _alive_workers counter from this branch
1 parent 285d467 commit 816ba2f

1 file changed

Lines changed: 46 additions & 39 deletions

File tree

lib/zephyr/src/zephyr/execution.py

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,14 @@ def __init__(self):
377377
self._is_last_stage: bool = False
378378
self._initialized: bool = False
379379
self._pipeline_running: bool = False
380-
# O(1) count of workers in READY or BUSY state, maintained by _set_worker_state.
381-
self._alive_workers: int = 0
382380

383381
# Lock for accessing coordinator state from background thread
384382
self._lock = threading.Lock()
383+
# Signaled on every state change that _wait_for_stage cares about
384+
# (shard completed, task requeued, fatal error, worker registered).
385+
# Allows _wait_for_stage to wake immediately instead of spinning up
386+
# to its max 1-second poll interval.
387+
self._stage_event = threading.Event()
385388

386389
actor_ctx = current_actor()
387390
self._name = f"{actor_ctx.group_name}"
@@ -419,20 +422,6 @@ def set_worker_group(self, worker_group: Any) -> None:
419422
"""Set the worker ActorGroup so the coordinator can detect permanent worker death."""
420423
self._worker_group = worker_group
421424

422-
def _set_worker_state(self, worker_id: str, new_state: WorkerState) -> None:
423-
"""Transition a worker's state and keep _alive_workers in sync.
424-
425-
Must be called with self._lock held.
426-
"""
427-
old_state = self._worker_states.get(worker_id)
428-
was_alive = old_state in {WorkerState.READY, WorkerState.BUSY}
429-
is_alive = new_state in {WorkerState.READY, WorkerState.BUSY}
430-
if was_alive and not is_alive:
431-
self._alive_workers -= 1
432-
elif not was_alive and is_alive:
433-
self._alive_workers += 1
434-
self._worker_states[worker_id] = new_state
435-
436425
def register_worker(self, worker_id: str, worker_handle: ActorHandle) -> None:
437426
"""Called by workers when they come online to register with coordinator.
438427
@@ -443,19 +432,20 @@ def register_worker(self, worker_id: str, worker_handle: ActorHandle) -> None:
443432
if worker_id in self._worker_handles:
444433
logger.info("Worker %s re-registering (likely reconstructed), updating handle", worker_id)
445434
self._worker_handles[worker_id] = worker_handle
446-
self._set_worker_state(worker_id, WorkerState.READY)
435+
self._worker_states[worker_id] = WorkerState.READY
447436
self._last_seen[worker_id] = time.monotonic()
448437
# NOTE: if there was a task assigned to the worker, there's a race condition between marking
449438
# the worker as unhealthy via heartbeat and re-registration. If we do not requeue we may silently
450439
# lose tasks.
451440
self._maybe_requeue_worker_task(worker_id)
452-
return
453-
454-
self._worker_handles[worker_id] = worker_handle
455-
self._set_worker_state(worker_id, WorkerState.READY)
456-
self._last_seen[worker_id] = time.monotonic()
441+
else:
442+
self._worker_handles[worker_id] = worker_handle
443+
self._worker_states[worker_id] = WorkerState.READY
444+
self._last_seen[worker_id] = time.monotonic()
457445

458-
logger.info("Worker %s registered, total: %d", worker_id, len(self._worker_handles))
446+
logger.info("Worker %s registered, total: %d", worker_id, len(self._worker_handles))
447+
# Wake _wait_for_stage so it re-evaluates the alive-worker count.
448+
self._stage_event.set()
459449

460450
def _coordinator_loop(self) -> None:
461451
"""Background loop for heartbeat checking and worker job monitoring."""
@@ -504,10 +494,10 @@ def _has_active_execution(self) -> bool:
504494

505495
def _log_status(self) -> None:
506496
with self._lock:
507-
alive = self._alive_workers
508-
total_workers = len(self._worker_handles)
497+
states = list(self._worker_states.values())
509498
retried = {idx: att for idx, att in self._task_attempts.items() if att > 0}
510-
dead = total_workers - alive
499+
alive = sum(1 for s in states if s in {WorkerState.READY, WorkerState.BUSY})
500+
dead = sum(1 for s in states if s in {WorkerState.FAILED, WorkerState.DEAD})
511501
logger.info(
512502
"[%s] [%s] %d/%d complete, %d in-flight, %d queued, %d/%d workers alive, %d dead",
513503
self._execution_id,
@@ -517,7 +507,7 @@ def _log_status(self) -> None:
517507
len(self._in_flight),
518508
len(self._task_queue),
519509
alive,
520-
total_workers,
510+
len(self._worker_handles),
521511
dead,
522512
)
523513
if retried:
@@ -605,7 +595,7 @@ def _check_worker_heartbeats(self, timeout: float = 120.0) -> None:
605595
for worker_id, last in list(self._last_seen.items()):
606596
if now - last > timeout and self._worker_states.get(worker_id) not in {WorkerState.FAILED, WorkerState.DEAD}:
607597
logger.warning(f"Zephyr worker {worker_id} failed to heartbeat within timeout ({now - last:.1f}s)")
608-
self._set_worker_state(worker_id, WorkerState.FAILED)
598+
self._worker_states[worker_id] = WorkerState.FAILED
609599
self._maybe_requeue_worker_task(worker_id)
610600

611601
def pull_task(self, worker_id: str) -> tuple[ShardTask, int, dict] | str | None:
@@ -618,10 +608,10 @@ def pull_task(self, worker_id: str) -> tuple[ShardTask, int, dict] | str | None:
618608
"""
619609
with self._lock:
620610
self._last_seen[worker_id] = time.monotonic()
621-
self._set_worker_state(worker_id, WorkerState.READY)
611+
self._worker_states[worker_id] = WorkerState.READY
622612

623613
if self._shutdown_event.is_set():
624-
self._set_worker_state(worker_id, WorkerState.DEAD)
614+
self._worker_states[worker_id] = WorkerState.DEAD
625615
return "SHUTDOWN"
626616

627617
if self._fatal_error:
@@ -634,14 +624,14 @@ def pull_task(self, worker_id: str) -> tuple[ShardTask, int, dict] | str | None:
634624
# restarts the worker which re-registers and picks it up.
635625
# _check_worker_group() detects permanent worker-job death
636626
# as a failsafe so we never deadlock.
637-
self._set_worker_state(worker_id, WorkerState.DEAD)
627+
self._worker_states[worker_id] = WorkerState.DEAD
638628
return "SHUTDOWN"
639629
return None
640630

641631
task = self._task_queue.popleft()
642632
attempt = self._task_attempts[task.shard_idx]
643633
self._in_flight[worker_id] = (task, attempt)
644-
self._set_worker_state(worker_id, WorkerState.BUSY)
634+
self._worker_states[worker_id] = WorkerState.BUSY
645635

646636
config = {
647637
"chunk_prefix": self._chunk_prefix,
@@ -687,19 +677,24 @@ def report_result(
687677
self._results[shard_idx] = result
688678
self._completed_shards += 1
689679
self._in_flight.pop(worker_id, None)
690-
self._set_worker_state(worker_id, WorkerState.READY)
680+
self._worker_states[worker_id] = WorkerState.READY
691681
self._completed_counters.append(counter_snapshot)
692682
# Zero the in-flight counters but keep the generation watermark
693683
# so late heartbeats from this task are rejected.
694684
self._worker_counters[worker_id] = CounterSnapshot.empty(counter_snapshot.generation)
685+
# Wake _wait_for_stage immediately so it advances to the next stage
686+
# without waiting out the polling backoff interval.
687+
self._stage_event.set()
695688

696689
def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
697690
"""Worker reports a task failure. Re-queues up to MAX_SHARD_FAILURES."""
698691
with self._lock:
699692
self._last_seen[worker_id] = time.monotonic()
700693
self._assert_in_flight_consistent(worker_id, shard_idx)
701694
aborted = self._record_shard_failure(worker_id, ShardFailureKind.TASK, error_info)
702-
self._set_worker_state(worker_id, WorkerState.DEAD if aborted else WorkerState.READY)
695+
self._worker_states[worker_id] = WorkerState.DEAD if aborted else WorkerState.READY
696+
# Wake _wait_for_stage so it re-evaluates the alive-worker count and fatal error.
697+
self._stage_event.set()
703698

704699
def heartbeat(self, worker_id: str, counter_snapshot: CounterSnapshot | None = None) -> None:
705700
self._last_seen[worker_id] = time.monotonic()
@@ -765,6 +760,8 @@ def abort(self, reason: str) -> None:
765760
if self._fatal_error is None:
766761
logger.error("Coordinator aborted: %s", reason)
767762
self._fatal_error = reason
763+
# Wake _wait_for_stage so it raises ZephyrWorkerError immediately.
764+
self._stage_event.set()
768765

769766
def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: bool = False) -> None:
770767
"""Load a new stage's tasks into the queue."""
@@ -784,6 +781,9 @@ def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: b
784781
# Only reset in-flight worker snapshots; completed snapshots
785782
# accumulate across stages for full pipeline visibility.
786783
self._worker_counters = {}
784+
# Clear after releasing the lock so _wait_for_stage can't miss
785+
# a signal set between lock release and the clear call.
786+
self._stage_event.clear()
787787

788788
def _wait_for_stage(self) -> None:
789789
"""Block until current stage completes or error occurs."""
@@ -792,7 +792,6 @@ def _wait_for_stage(self) -> None:
792792
start_time = time.monotonic()
793793
all_dead_since: float | None = None
794794
no_workers_timeout = self._no_workers_timeout
795-
stage_done = threading.Event()
796795

797796
while True:
798797
with self._lock:
@@ -805,9 +804,11 @@ def _wait_for_stage(self) -> None:
805804
if completed >= total:
806805
return
807806

808-
# _alive_workers is kept in sync by _set_worker_state on every
809-
# state transition, avoiding a per-wakeup O(n_workers) scan.
810-
alive_workers = self._alive_workers
807+
# Count alive workers (READY or BUSY), not just total registered.
808+
# Dead/failed workers stay in _worker_handles but can't make progress.
809+
alive_workers = sum(
810+
1 for s in self._worker_states.values() if s in {WorkerState.READY, WorkerState.BUSY}
811+
)
811812

812813
if alive_workers == 0:
813814
now = time.monotonic()
@@ -834,7 +835,13 @@ def _wait_for_stage(self) -> None:
834835
last_log_completed = completed
835836
backoff.reset()
836837

837-
stage_done.wait(timeout=backoff.next_interval())
838+
# Wait for a state-change signal (shard completed, error, worker registered)
839+
# with a timeout as backstop for the alive-worker check and log lines.
840+
# _stage_event is set by report_result/report_error/abort/register_worker,
841+
# so the stage transition is detected within microseconds of the last
842+
# shard completing rather than after up to the full backoff interval.
843+
self._stage_event.wait(timeout=backoff.next_interval())
844+
self._stage_event.clear()
838845

839846
def _collect_results(self) -> dict[int, TaskResult]:
840847
"""Return results for the completed stage."""

0 commit comments

Comments
 (0)