Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 59 additions & 33 deletions lib/zephyr/src/zephyr/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,19 @@

logger = logging.getLogger(__name__)

# Maximum number of times a single shard can fail before aborting the pipeline.
# Covers both explicit task errors (report_error) and implicit worker deaths
# (heartbeat timeout / OOM).
# Max explicit task errors (report_error) per shard before aborting. Preemption
# requeues (re-registration, heartbeat timeout) do not count — they retry
# unbounded. `_check_worker_group` backstops if workers fully exhaust Iris retries.
MAX_SHARD_FAILURES = 3


class ShardFailureKind(enum.StrEnum):
"""TASK failures count toward MAX_SHARD_FAILURES; INFRA failures (preemption) do not."""

TASK = enum.auto()
INFRA = enum.auto()


@dataclass(frozen=True)
class PickleDiskChunk:
"""Reference to a pickle chunk stored on disk.
Expand Down Expand Up @@ -347,7 +354,10 @@ def __init__(self):
self._completed_shards: int = 0
self._retries: int = 0
self._in_flight: dict[str, tuple[ShardTask, int]] = {}
# _task_attempts: monotonic generation for stale-result rejection (bumps on every
# requeue). _task_error_attempts: TASK-only counter, bounded by MAX_SHARD_FAILURES.
self._task_attempts: dict[int, int] = {}
self._task_error_attempts: dict[int, int] = {}
self._fatal_error: str | None = None
self._shard_errors: dict[int, list[str]] = {}
self._chunk_prefix: str = ""
Expand Down Expand Up @@ -498,15 +508,15 @@ def _log_status(self) -> None:
attempts_histogram = dict(sorted(Counter(retried.values()).items()))
logger.warning("[%s] Shards retried (attempts: shard count): %s", self._execution_id, attempts_histogram)

def _record_shard_failure(self, worker_id: str, error_info: str | None = None) -> bool:
"""Record a failure for the worker's in-flight shard. Must be called with lock held.

Pops the task from _in_flight, zeros the worker's counter snapshot,
records the error, increments the attempt counter, and either re-queues
the task or sets _fatal_error when MAX_SHARD_FAILURES is reached.
def _record_shard_failure(
self,
worker_id: str,
kind: ShardFailureKind,
error_info: str | None = None,
) -> bool:
"""Requeue the worker's in-flight shard; abort only if TASK errors hit MAX_SHARD_FAILURES.

Returns True if the shard was aborted (fatal), False otherwise
(including when there was no in-flight task).
Must be called with lock held. Returns True if the pipeline was aborted.
"""
task_and_attempt = self._in_flight.pop(worker_id, None)

Expand All @@ -525,38 +535,53 @@ def _record_shard_failure(self, worker_id: str, error_info: str | None = None) -
if error_info is not None:
self._shard_errors.setdefault(shard_idx, []).append(error_info)

# Bump generation regardless of kind so report_result rejects stale attempts.
self._task_attempts[shard_idx] += 1
attempts = self._task_attempts[shard_idx]

if attempts >= MAX_SHARD_FAILURES:
errors = self._shard_errors.get(shard_idx, [])
error_detail = f"\nLast error:\n{errors[-1]}" if errors else ""
logger.error(
"Shard %d has failed %d times (max %d), aborting pipeline.",
if kind is ShardFailureKind.TASK:
self._task_error_attempts[shard_idx] += 1
error_attempts = self._task_error_attempts[shard_idx]
if error_attempts >= MAX_SHARD_FAILURES:
errors = self._shard_errors.get(shard_idx, [])
error_detail = f"\nLast error:\n{errors[-1]}" if errors else ""
logger.error(
"Shard %d has failed %d times (max %d), last failure on worker %s, aborting pipeline.",
shard_idx,
error_attempts,
MAX_SHARD_FAILURES,
worker_id,
)
self._fatal_error = (
f"Shard {shard_idx} failed {error_attempts} times "
f"(max {MAX_SHARD_FAILURES}), last failure on worker {worker_id}.{error_detail}"
)
return True

logger.warning(
"Shard %d failed on worker %s (task error %d/%d), re-queuing.",
shard_idx,
attempts,
worker_id,
error_attempts,
MAX_SHARD_FAILURES,
)
self._fatal_error = (
f"Shard {shard_idx} failed {attempts} times " f"(max {MAX_SHARD_FAILURES}).{error_detail}"
else:
logger.warning(
"Shard %d requeued from worker %s due to infra failure (preemption/heartbeat); "
"infra retries are unbounded. Total generation: %d, task errors so far: %d/%d.",
shard_idx,
worker_id,
self._task_attempts[shard_idx],
self._task_error_attempts[shard_idx],
MAX_SHARD_FAILURES,
)
return True

logger.warning(
"Shard %d failed on worker %s (attempt %d/%d), re-queuing.",
shard_idx,
worker_id,
attempts,
MAX_SHARD_FAILURES,
)

self._task_queue.append(task)
self._retries += 1
return False

def _maybe_requeue_worker_task(self, worker_id: str) -> None:
"""If the worker has a task in-flight, re-queue it unless the shard has
exceeded MAX_SHARD_FAILURES, in which case abort the pipeline."""
self._record_shard_failure(worker_id)
"""Requeue the worker's in-flight task as an INFRA failure (preemption/heartbeat)."""
self._record_shard_failure(worker_id, ShardFailureKind.INFRA)

def _check_worker_heartbeats(self, timeout: float = 120.0) -> None:
"""Internal heartbeat check (called with lock held)."""
Expand Down Expand Up @@ -657,7 +682,7 @@ def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
with self._lock:
self._last_seen[worker_id] = time.monotonic()
self._assert_in_flight_consistent(worker_id, shard_idx)
aborted = self._record_shard_failure(worker_id, error_info)
aborted = self._record_shard_failure(worker_id, ShardFailureKind.TASK, error_info)
self._worker_states[worker_id] = WorkerState.DEAD if aborted else WorkerState.READY

def heartbeat(self, worker_id: str, counter_snapshot: CounterSnapshot | None = None) -> None:
Expand Down Expand Up @@ -736,6 +761,7 @@ def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: b
self._retries = 0
self._in_flight = {}
self._task_attempts = {task.shard_idx: 0 for task in tasks}
self._task_error_attempts = {task.shard_idx: 0 for task in tasks}
self._shard_errors = {}
self._fatal_error = None
self._is_last_stage = is_last_stage
Expand Down
Loading
Loading