Skip to content
50 changes: 32 additions & 18 deletions lib/zephyr/src/zephyr/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def __init__(self):
self._is_last_stage: bool = False
self._initialized: bool = False
self._pipeline_running: bool = False
# O(1) count of workers in READY or BUSY state, maintained by _set_worker_state.
self._alive_workers: int = 0

# Lock for accessing coordinator state from background thread
self._lock = threading.Lock()
Expand Down Expand Up @@ -417,6 +419,20 @@ def set_worker_group(self, worker_group: Any) -> None:
"""Set the worker ActorGroup so the coordinator can detect permanent worker death."""
self._worker_group = worker_group

def _set_worker_state(self, worker_id: str, new_state: WorkerState) -> None:
"""Transition a worker's state and keep _alive_workers in sync.

Must be called with self._lock held.
"""
old_state = self._worker_states.get(worker_id)
was_alive = old_state in {WorkerState.READY, WorkerState.BUSY}
is_alive = new_state in {WorkerState.READY, WorkerState.BUSY}
if was_alive and not is_alive:
self._alive_workers -= 1
elif not was_alive and is_alive:
self._alive_workers += 1
self._worker_states[worker_id] = new_state

def register_worker(self, worker_id: str, worker_handle: ActorHandle) -> None:
"""Called by workers when they come online to register with coordinator.

Expand All @@ -427,7 +443,7 @@ def register_worker(self, worker_id: str, worker_handle: ActorHandle) -> None:
if worker_id in self._worker_handles:
logger.info("Worker %s re-registering (likely reconstructed), updating handle", worker_id)
self._worker_handles[worker_id] = worker_handle
self._worker_states[worker_id] = WorkerState.READY
self._set_worker_state(worker_id, WorkerState.READY)
self._last_seen[worker_id] = time.monotonic()
# NOTE: if there was a task assigned to the worker, there's a race condition between marking
# the worker as unhealthy via heartbeat and re-registration. If we do not requeue we may silently
Expand All @@ -436,7 +452,7 @@ def register_worker(self, worker_id: str, worker_handle: ActorHandle) -> None:
return

self._worker_handles[worker_id] = worker_handle
self._worker_states[worker_id] = WorkerState.READY
self._set_worker_state(worker_id, WorkerState.READY)
self._last_seen[worker_id] = time.monotonic()

logger.info("Worker %s registered, total: %d", worker_id, len(self._worker_handles))
Expand Down Expand Up @@ -488,10 +504,10 @@ def _has_active_execution(self) -> bool:

def _log_status(self) -> None:
with self._lock:
states = list(self._worker_states.values())
alive = self._alive_workers
total_workers = len(self._worker_handles)
retried = {idx: att for idx, att in self._task_attempts.items() if att > 0}
alive = sum(1 for s in states if s in {WorkerState.READY, WorkerState.BUSY})
dead = sum(1 for s in states if s in {WorkerState.FAILED, WorkerState.DEAD})
dead = total_workers - alive
logger.info(
"[%s] [%s] %d/%d complete, %d in-flight, %d queued, %d/%d workers alive, %d dead",
self._execution_id,
Expand All @@ -501,7 +517,7 @@ def _log_status(self) -> None:
len(self._in_flight),
len(self._task_queue),
alive,
len(self._worker_handles),
total_workers,
dead,
)
if retried:
Expand Down Expand Up @@ -589,7 +605,7 @@ def _check_worker_heartbeats(self, timeout: float = 120.0) -> None:
for worker_id, last in list(self._last_seen.items()):
if now - last > timeout and self._worker_states.get(worker_id) not in {WorkerState.FAILED, WorkerState.DEAD}:
logger.warning(f"Zephyr worker {worker_id} failed to heartbeat within timeout ({now - last:.1f}s)")
self._worker_states[worker_id] = WorkerState.FAILED
self._set_worker_state(worker_id, WorkerState.FAILED)
self._maybe_requeue_worker_task(worker_id)

def pull_task(self, worker_id: str) -> tuple[ShardTask, int, dict] | str | None:
Expand All @@ -602,10 +618,10 @@ def pull_task(self, worker_id: str) -> tuple[ShardTask, int, dict] | str | None:
"""
with self._lock:
self._last_seen[worker_id] = time.monotonic()
self._worker_states[worker_id] = WorkerState.READY
self._set_worker_state(worker_id, WorkerState.READY)

if self._shutdown_event.is_set():
self._worker_states[worker_id] = WorkerState.DEAD
self._set_worker_state(worker_id, WorkerState.DEAD)
return "SHUTDOWN"

if self._fatal_error:
Expand All @@ -618,14 +634,14 @@ def pull_task(self, worker_id: str) -> tuple[ShardTask, int, dict] | str | None:
# restarts the worker which re-registers and picks it up.
# _check_worker_group() detects permanent worker-job death
# as a failsafe so we never deadlock.
self._worker_states[worker_id] = WorkerState.DEAD
self._set_worker_state(worker_id, WorkerState.DEAD)
return "SHUTDOWN"
return None

task = self._task_queue.popleft()
attempt = self._task_attempts[task.shard_idx]
self._in_flight[worker_id] = (task, attempt)
self._worker_states[worker_id] = WorkerState.BUSY
self._set_worker_state(worker_id, WorkerState.BUSY)

config = {
"chunk_prefix": self._chunk_prefix,
Expand Down Expand Up @@ -671,7 +687,7 @@ def report_result(
self._results[shard_idx] = result
self._completed_shards += 1
self._in_flight.pop(worker_id, None)
self._worker_states[worker_id] = WorkerState.READY
self._set_worker_state(worker_id, WorkerState.READY)
self._completed_counters.append(counter_snapshot)
# Zero the in-flight counters but keep the generation watermark
# so late heartbeats from this task are rejected.
Expand All @@ -683,7 +699,7 @@ def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
self._last_seen[worker_id] = time.monotonic()
self._assert_in_flight_consistent(worker_id, shard_idx)
aborted = self._record_shard_failure(worker_id, ShardFailureKind.TASK, error_info)
self._worker_states[worker_id] = WorkerState.DEAD if aborted else WorkerState.READY
self._set_worker_state(worker_id, WorkerState.DEAD if aborted else WorkerState.READY)

def heartbeat(self, worker_id: str, counter_snapshot: CounterSnapshot | None = None) -> None:
self._last_seen[worker_id] = time.monotonic()
Expand Down Expand Up @@ -789,11 +805,9 @@ def _wait_for_stage(self) -> None:
if completed >= total:
return

# Count alive workers (READY or BUSY), not just total registered.
# Dead/failed workers stay in _worker_handles but can't make progress.
alive_workers = sum(
1 for s in self._worker_states.values() if s in {WorkerState.READY, WorkerState.BUSY}
)
# _alive_workers is kept in sync by _set_worker_state on every
# state transition, avoiding a per-wakeup O(n_workers) scan.
alive_workers = self._alive_workers

if alive_workers == 0:
now = time.monotonic()
Expand Down
47 changes: 41 additions & 6 deletions lib/zephyr/src/zephyr/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
chunk's compressed bytes (typically a few MB). This bound is essential for
skewed shuffles where one reducer pulls disproportionate data and the
external-sort fan-in opens hundreds of chunk iterators at once.

Write-side memory is bounded by a byte budget (``_SCATTER_WRITE_BUFFER_BYTES``)
rather than a fixed row count. When the estimated total bytes across all
shard buffers exceeds the budget, the largest buffer is flushed. This prevents
OOM on skewed or large-item workloads where a row-count limit provides no
reliable bound.
"""

from __future__ import annotations
Expand All @@ -45,7 +51,7 @@
from rigging.timing import log_time

from zephyr.plan import deterministic_hash
from zephyr.writers import INTERMEDIATE_CHUNK_SIZE, ensure_parent_dir
from zephyr.writers import ensure_parent_dir

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,6 +107,18 @@ def get_iterators(self) -> Iterator[Iterator]:
# dispatch overhead), smaller = lower per-iterator read memory.
_SUB_BATCH_SIZE = 1024

# Total byte budget for all shard buffers held by one ScatterWriter.
# When the estimate exceeds this, the largest buffer is flushed. Byte-based
# budgeting is safer than a fixed row count because item size varies widely
# (a 100K-row limit is fine for tiny dicts but OOMs on large documents).
_SCATTER_WRITE_BUFFER_BYTES = 256 * 1024 * 1024 # 256 MB

# Conservative bytes-per-item estimate used before the first flush samples
# actual sizes. 512 bytes covers a typical dict with a few string fields plus
# Python object overhead; the estimate is replaced by the measured average
# after the first sample is taken.
_INITIAL_ITEM_BYTES_ESTIMATE = 512


# ---------------------------------------------------------------------------
# Sidecar / manifest helpers
Expand Down Expand Up @@ -447,6 +465,10 @@ class ScatterWriter:
Items are routed to target shards by ``key_fn``, buffered, optionally
combined and sorted, then flushed as zstd frames. A JSON sidecar is
written on close.

Flushing is byte-budget-based: when the estimated total bytes across all
shard buffers exceeds ``buffer_limit_bytes``, the largest buffer is flushed.
This bounds peak RSS regardless of item count or output shard count.
"""

def __init__(
Expand All @@ -457,13 +479,14 @@ def __init__(
source_shard: int = 0,
sort_fn: Callable | None = None,
combiner_fn: Callable | None = None,
buffer_limit_bytes: int = _SCATTER_WRITE_BUFFER_BYTES,
) -> None:
self._data_path = data_path
self._key_fn = key_fn
self._num_output_shards = num_output_shards
self._source_shard = source_shard
self._combiner_fn = combiner_fn
self._chunk_size = INTERMEDIATE_CHUNK_SIZE
self._buffer_limit_bytes = buffer_limit_bytes

if sort_fn is not None:
captured_sort_fn = sort_fn
Expand All @@ -481,6 +504,10 @@ def _sort_key(item: Any) -> Any:
self._avg_item_bytes: float = 0.0
self._sampled_avg = False
self._n_chunks_written = 0
# Running total of rows across all shard buffers; used with
# _item_bytes_estimate to gate byte-budget flushes.
self._total_buffer_rows: int = 0
self._item_bytes_estimate: float = _INITIAL_ITEM_BYTES_ESTIMATE

ensure_parent_dir(data_path)
fs, fs_path = url_to_fs(data_path)
Expand All @@ -495,6 +522,7 @@ def _flush(self, target: int, buf: list) -> None:
sample = buf[: min(len(buf), _SCATTER_SAMPLE_SIZE)]
total_bytes = sum(len(pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL)) for item in sample)
self._avg_item_bytes = total_bytes / len(sample)
self._item_bytes_estimate = self._avg_item_bytes
Comment thread
hsuhanooi marked this conversation as resolved.
Outdated
self._sampled_avg = True

frame = _write_chunk_frame(buf)
Expand All @@ -514,13 +542,18 @@ def _flush(self, target: int, buf: list) -> None:
)

def write(self, item: Any) -> None:
"""Route a single item to its target shard buffer, flushing when full."""
"""Route a single item to its target shard buffer, flushing when over budget."""
key = self._key_fn(item)
target = deterministic_hash(key) % self._num_output_shards
self._buffers[target].append(item)
if self._chunk_size > 0 and len(self._buffers[target]) >= self._chunk_size:
self._flush(target, self._buffers[target])
self._buffers[target] = []
self._total_buffer_rows += 1

if self._total_buffer_rows * self._item_bytes_estimate > self._buffer_limit_bytes:
largest = max(self._buffers, key=lambda t: len(self._buffers[t]))
rows_flushed = len(self._buffers[largest])
self._flush(largest, self._buffers[largest])
self._buffers[largest] = []
self._total_buffer_rows -= rows_flushed

def close(self) -> ListShard:
"""Flush remaining buffers, write sidecar, return ListShard."""
Expand Down Expand Up @@ -557,6 +590,7 @@ def _write_scatter(
num_output_shards: int,
sort_fn: Callable | None = None,
combiner_fn: Callable | None = None,
buffer_limit_bytes: int = _SCATTER_WRITE_BUFFER_BYTES,
) -> ListShard:
"""Route items to target shards, buffer, sort, and append zstd chunks.

Expand All @@ -573,6 +607,7 @@ def _write_scatter(
source_shard=source_shard,
sort_fn=sort_fn,
combiner_fn=combiner_fn,
buffer_limit_bytes=buffer_limit_bytes,
)
for item in items:
writer.write(item)
Expand Down
47 changes: 47 additions & 0 deletions lib/zephyr/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from zephyr.shuffle import (
ScatterFileIterator,
ScatterReader,
ScatterWriter,
_write_chunk_frame,
_write_scatter,
)
Expand Down Expand Up @@ -171,6 +172,52 @@ def _ord(x):
assert sorted(recovered, key=_ord) == sorted(items, key=_ord)


# ---------------------------------------------------------------------------
# Byte-budget flushing
# ---------------------------------------------------------------------------


def test_scatter_byte_budget_flushes_mid_write(tmp_path):
"""A tiny byte budget forces flushes during write, not only at close."""
num_shards = 2
items = [{"k": i % num_shards, "v": i} for i in range(200)]
data_path = str(tmp_path / "shard-0000.shuffle")

# Budget of 1 byte forces a flush on every write after the first.
writer = ScatterWriter(
data_path=data_path,
key_fn=_key,
num_output_shards=num_shards,
buffer_limit_bytes=1,
)
for item in items:
writer.write(item)
writer.close()

# Multiple chunks must have been written (not just the close-time flush).
scatter_paths = [data_path]
total_chunks = sum(ScatterReader.from_sidecars(scatter_paths, s).total_chunks for s in range(num_shards))
assert total_chunks > 2, f"expected >2 chunks with 1-byte budget, got {total_chunks}"


def test_scatter_byte_budget_preserves_all_items(tmp_path):
"""Items are not lost or duplicated when byte-budget flushes fire mid-write."""
num_shards = 3
items = [{"k": i % num_shards, "v": i} for i in range(300)]
scatter_paths = _build_shard(
tmp_path,
items,
num_output_shards=num_shards,
)

recovered = []
for shard_idx in range(num_shards):
shard = ScatterReader.from_sidecars(scatter_paths, shard_idx)
recovered.extend(list(shard))

assert sorted(recovered, key=lambda x: x["v"]) == sorted(items, key=lambda x: x["v"])


# ---------------------------------------------------------------------------
# ScatterFileIterator low-level
# ---------------------------------------------------------------------------
Expand Down