Skip to content
128 changes: 116 additions & 12 deletions lib/zephyr/src/zephyr/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
chunk's compressed bytes (typically a few MB). This bound is essential for
skewed shuffles where one reducer pulls disproportionate data and the
external-sort fan-in opens hundreds of chunk iterators at once.

Write-side memory is bounded by a byte budget (``_SCATTER_WRITE_BUFFER_BYTES``)
rather than a fixed row count. When the estimated total bytes across all
shard buffers exceeds the budget, the largest buffer is flushed. This prevents
OOM on skewed or large-item workloads where a row-count limit provides no
reliable bound.
"""

from __future__ import annotations
Expand All @@ -41,11 +47,12 @@
import cloudpickle
import msgspec
import zstandard as zstd
from iris.env_resources import TaskResources
from rigging.filesystem import open_url, url_to_fs
from rigging.timing import log_time

from zephyr.plan import deterministic_hash
from zephyr.writers import INTERMEDIATE_CHUNK_SIZE, ensure_parent_dir
from zephyr.writers import ensure_parent_dir

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -91,8 +98,17 @@ def get_iterators(self) -> Iterator[Iterator]:
# ScatterReader. Sidecars are small msgpack files (a few KB) and reads are
# GCS GET-bound, so a modest pool keeps latency low without thrashing.
_SIDECAR_READ_CONCURRENCY = 32
# Number of items sampled from the first flush to estimate avg_item_bytes.
# Items sampled on the first flush to establish an avg_item_bytes baseline.
_SCATTER_SAMPLE_SIZE = 100
# Items sampled on each subsequent flush to track item-size drift cheaply.
_SCATTER_ONGOING_SAMPLE_SIZE = 10
# How often (in items written) to re-sample one item's pickle size and update
# the EMA estimate in write(). This is independent of flush-time sampling and
# ensures the estimate tracks drift even when no flush has fired yet.
_ESTIMATE_WRITE_SAMPLE_INTERVAL = 10
# EMA weight given to each new observation. 0.3 converges to a 2x step-change
# in item size within ~3 samples while staying stable under small fluctuations.
_ESTIMATE_EMA_ALPHA = 0.3
# Fraction of total memory budgeted for read-side decompression buffers.
_SCATTER_READ_BUFFER_FRACTION = 0.25

Expand All @@ -101,6 +117,23 @@ def get_iterators(self) -> Iterator[Iterator]:
# dispatch overhead), smaller = lower per-iterator read memory.
_SUB_BATCH_SIZE = 1024

# Fraction of cgroup memory allocated to scatter write buffers.
_SCATTER_WRITE_BUFFER_FRACTION = 0.25
# Static fallback used when the cgroup memory limit cannot be determined.
_SCATTER_WRITE_BUFFER_BYTES_FALLBACK = 256 * 1024 * 1024 # 256 MB


def _default_scatter_write_buffer_bytes() -> int:
"""Return the scatter write buffer budget based on the cgroup memory limit.

Uses 25% of the container memory limit so the budget scales with the
worker size. Falls back to 256 MB when the limit cannot be read.
"""
memory = TaskResources.from_environment().memory_bytes
if memory > 0:
return int(memory * _SCATTER_WRITE_BUFFER_FRACTION)
return _SCATTER_WRITE_BUFFER_BYTES_FALLBACK


# ---------------------------------------------------------------------------
# Sidecar / manifest helpers
Expand Down Expand Up @@ -447,6 +480,10 @@ class ScatterWriter:
Items are routed to target shards by ``key_fn``, buffered, optionally
combined and sorted, then flushed as zstd frames. A JSON sidecar is
written on close.

Flushing is byte-budget-based: when the estimated total bytes across all
shard buffers exceeds ``buffer_limit_bytes``, the largest buffer is flushed.
This bounds peak RSS regardless of item count or output shard count.
"""

def __init__(
Expand All @@ -457,13 +494,16 @@ def __init__(
source_shard: int = 0,
sort_fn: Callable | None = None,
combiner_fn: Callable | None = None,
buffer_limit_bytes: int | None = None,
) -> None:
self._data_path = data_path
self._key_fn = key_fn
self._num_output_shards = num_output_shards
self._source_shard = source_shard
self._combiner_fn = combiner_fn
self._chunk_size = INTERMEDIATE_CHUNK_SIZE
self._buffer_limit_bytes = (
buffer_limit_bytes if buffer_limit_bytes is not None else _default_scatter_write_buffer_bytes()
)

if sort_fn is not None:
captured_sort_fn = sort_fn
Expand All @@ -481,6 +521,16 @@ def _sort_key(item: Any) -> Any:
self._avg_item_bytes: float = 0.0
self._sampled_avg = False
self._n_chunks_written = 0
self._mid_write_flushes: int = 0
# Running total of rows across all shard buffers; used with
# _item_bytes_estimate to gate byte-budget flushes.
self._total_buffer_rows: int = 0
self._peak_buffer_rows: int = 0
# Estimate refined in two steps: (1) first-item pickle measurement in
# write(), (2) 100-item sample average in _flush(). Logging both lets
# operators see how representative the first item was.
self._item_bytes_estimate: float = 0.0 # set on first write()
self._first_item_bytes: float = 0.0 # logged at close for comparison

ensure_parent_dir(data_path)
fs, fs_path = url_to_fs(data_path)
Expand All @@ -491,11 +541,21 @@ def _flush(self, target: int, buf: list) -> None:
buf = _apply_combiner(buf, self._key_fn, self._combiner_fn)
buf.sort(key=self._sort_key)

if not self._sampled_avg and buf:
sample = buf[: min(len(buf), _SCATTER_SAMPLE_SIZE)]
total_bytes = sum(len(pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL)) for item in sample)
self._avg_item_bytes = total_bytes / len(sample)
self._sampled_avg = True
if buf:
# Sample a subset of the buffer to update the byte-size estimate.
# First flush: larger sample for a good baseline. Subsequent flushes:
# smaller sample to track drift cheaply via EMA. This prevents OOM
# when early items are small but later items are large — the estimate
# stays current rather than being frozen at the first-flush value.
n = _SCATTER_SAMPLE_SIZE if not self._sampled_avg else _SCATTER_ONGOING_SAMPLE_SIZE
sample = buf[: min(len(buf), n)]
observed = sum(len(pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL)) for item in sample) / len(sample)
if not self._sampled_avg:
self._avg_item_bytes = observed
self._sampled_avg = True
else:
self._avg_item_bytes = (1 - _ESTIMATE_EMA_ALPHA) * self._avg_item_bytes + _ESTIMATE_EMA_ALPHA * observed
self._item_bytes_estimate = self._avg_item_bytes

frame = _write_chunk_frame(buf)
offset = self._out.tell()
Expand All @@ -514,22 +574,64 @@ def _flush(self, target: int, buf: list) -> None:
)

def write(self, item: Any) -> None:
"""Route a single item to its target shard buffer, flushing when full."""
"""Route a single item to its target shard buffer, flushing when over budget."""
if self._total_buffer_rows % _ESTIMATE_WRITE_SAMPLE_INTERVAL == 0:
# Periodically measure a single item's serialised size and apply EMA.
# This runs in write() — not just in _flush() — so the estimate tracks
# size drift even when no flush has fired yet (the flush EMA is a
# closed loop: if the estimate is too low no flush fires, so it never
# updates). Interval-based sampling amortises the pickle.dumps cost
# to 1-in-10 items while still catching step-changes within a few rows.
observed = float(len(pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL)))
if self._total_buffer_rows == 0:
self._item_bytes_estimate = observed
self._first_item_bytes = observed
else:
self._item_bytes_estimate = (
1 - _ESTIMATE_EMA_ALPHA
) * self._item_bytes_estimate + _ESTIMATE_EMA_ALPHA * observed

key = self._key_fn(item)
target = deterministic_hash(key) % self._num_output_shards
self._buffers[target].append(item)
if self._chunk_size > 0 and len(self._buffers[target]) >= self._chunk_size:
self._flush(target, self._buffers[target])
self._buffers[target] = []
self._total_buffer_rows += 1
if self._total_buffer_rows > self._peak_buffer_rows:
self._peak_buffer_rows = self._total_buffer_rows

if self._total_buffer_rows * self._item_bytes_estimate > self._buffer_limit_bytes:
largest = max(self._buffers, key=lambda t: len(self._buffers[t]))
rows_flushed = len(self._buffers[largest])
self._flush(largest, self._buffers[largest])
self._buffers[largest] = []
self._total_buffer_rows -= rows_flushed
self._mid_write_flushes += 1

def close(self) -> ListShard:
"""Flush remaining buffers, write sidecar, return ListShard."""
close_flushes = 0
with log_time(f"Flushing remaining buffers for {self._data_path}"):
for target, buf in sorted(self._buffers.items()):
if buf:
self._flush(target, buf)
close_flushes += 1
self._out.close()

measured_avg = self._avg_item_bytes if self._sampled_avg else self._item_bytes_estimate
logger.info(
"[shard %d] scatter write done: %d mid-write flushes + %d at close = %d total; "
"first-item estimate=%.0f B, measured avg=%.0f B (%.1fx), "
"peak buffered=%d rows, budget=%d MB",
self._source_shard,
self._mid_write_flushes,
close_flushes,
self._mid_write_flushes + close_flushes,
self._first_item_bytes,
measured_avg,
measured_avg / self._first_item_bytes if self._first_item_bytes > 0 else 0.0,
self._peak_buffer_rows,
self._buffer_limit_bytes // (1024 * 1024),
)

sidecar: dict = {
"shards": {str(k): v for k, v in self._shard_ranges.items()},
"max_chunk_rows": {str(k): v for k, v in self._per_shard_max_rows.items() if v > 0},
Expand Down Expand Up @@ -557,6 +659,7 @@ def _write_scatter(
num_output_shards: int,
sort_fn: Callable | None = None,
combiner_fn: Callable | None = None,
buffer_limit_bytes: int | None = None,
) -> ListShard:
"""Route items to target shards, buffer, sort, and append zstd chunks.

Expand All @@ -573,6 +676,7 @@ def _write_scatter(
source_shard=source_shard,
sort_fn=sort_fn,
combiner_fn=combiner_fn,
buffer_limit_bytes=buffer_limit_bytes,
)
for item in items:
writer.write(item)
Expand Down
118 changes: 118 additions & 0 deletions lib/zephyr/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from zephyr.shuffle import (
ScatterFileIterator,
ScatterReader,
ScatterWriter,
_write_chunk_frame,
_write_scatter,
)
Expand Down Expand Up @@ -171,6 +172,123 @@ def _ord(x):
assert sorted(recovered, key=_ord) == sorted(items, key=_ord)


# ---------------------------------------------------------------------------
# Byte-budget flushing
# ---------------------------------------------------------------------------


def test_scatter_byte_budget_flushes_mid_write(tmp_path):
"""A tiny byte budget forces flushes during write, not only at close."""
num_shards = 2
items = [{"k": i % num_shards, "v": i} for i in range(200)]
data_path = str(tmp_path / "shard-0000.shuffle")

# Budget of 1 byte forces a flush on every write after the first.
writer = ScatterWriter(
data_path=data_path,
key_fn=_key,
num_output_shards=num_shards,
buffer_limit_bytes=1,
)
for item in items:
writer.write(item)
writer.close()

# Multiple chunks must have been written (not just the close-time flush).
scatter_paths = [data_path]
total_chunks = sum(ScatterReader.from_sidecars(scatter_paths, s).total_chunks for s in range(num_shards))
assert total_chunks > 2, f"expected >2 chunks with 1-byte budget, got {total_chunks}"


def test_scatter_estimate_tracks_skewed_items(tmp_path):
"""Write-time EMA sampling catches large late items and triggers mid-write flushes."""
num_shards = 1
data_path = str(tmp_path / "shard-0000.shuffle")

# Start with tiny items, then switch to large items. With a frozen estimate
# the budget check would never fire for the large items. With EMA updates it
# should: _item_bytes_estimate rises and eventually exceeds budget / rows.
small_items = [{"k": 0, "v": "x"} for _ in range(50)]
large_items = [{"k": 0, "v": "y" * 50_000} for _ in range(10)]

# Budget large enough that small items alone never flush, but one large
# item should push the estimate over threshold quickly.
budget = 10_000 # 10 KB — well under 10 * 50 KB large items
writer = ScatterWriter(
data_path=data_path,
key_fn=_key,
num_output_shards=num_shards,
buffer_limit_bytes=budget,
)
for item in small_items + large_items:
writer.write(item)
writer.close()

# All items must survive the skewed flush pattern.
scatter_paths = [data_path]
recovered = list(ScatterReader.from_sidecars(scatter_paths, 0))
all_items = small_items + large_items
assert sorted(recovered, key=lambda x: x["v"]) == sorted(all_items, key=lambda x: x["v"])

# The estimate must have been updated: mid-write flushes should have fired
# for the large items (not just at close).
assert writer._mid_write_flushes > 0, "expected mid-write flushes for large items"


def test_scatter_estimate_adapts_to_gradual_drift(tmp_path):
"""Write-time EMA bounds peak buffered rows even when item sizes grow gradually."""
num_shards = 1
data_path = str(tmp_path / "shard-0000.shuffle")

# Items grow linearly from ~100 B to ~100 KB across 200 records.
# If all 200 were buffered at once the real RSS would be ~10 MB.
n_items = 200
items = [{"k": 0, "v": "x" * (100 + i * 500)} for i in range(n_items)]

# 500 KB budget. With a frozen first-item estimate (~110 B) the budget check
# would read 200 * 110 = 22 KB < 500 KB and never flush mid-write, letting
# all items accumulate. With EMA adaptation the estimate tracks the growing
# sizes and flushes before peak RSS reaches the budget.
budget = 500_000
writer = ScatterWriter(
data_path=data_path,
key_fn=_key,
num_output_shards=num_shards,
buffer_limit_bytes=budget,
)
for item in items:
writer.write(item)
writer.close()

scatter_paths = [data_path]
recovered = list(ScatterReader.from_sidecars(scatter_paths, 0))
assert sorted(recovered, key=lambda x: x["v"]) == sorted(items, key=lambda x: x["v"])

assert writer._mid_write_flushes > 0, "expected mid-write flushes as item sizes grew"
assert writer._peak_buffer_rows < n_items, (
f"peak_buffer_rows={writer._peak_buffer_rows} should be < {n_items}; "
"a frozen estimate lets all items accumulate before close()"
)


def test_scatter_byte_budget_preserves_all_items(tmp_path):
"""Items are not lost or duplicated when byte-budget flushes fire mid-write."""
num_shards = 3
items = [{"k": i % num_shards, "v": i} for i in range(300)]
scatter_paths = _build_shard(
tmp_path,
items,
num_output_shards=num_shards,
)

recovered = []
for shard_idx in range(num_shards):
shard = ScatterReader.from_sidecars(scatter_paths, shard_idx)
recovered.extend(list(shard))

assert sorted(recovered, key=lambda x: x["v"]) == sorted(items, key=lambda x: x["v"])


# ---------------------------------------------------------------------------
# ScatterFileIterator low-level
# ---------------------------------------------------------------------------
Expand Down
Loading