Skip to content

Commit b2b4c13

Browse files
committed
zephyr: fix frozen size estimate — EMA update in write() not just flush()
The previous flush-time EMA was a closed loop: if the estimate was too low no flush fired, so the EMA never ran, and the estimate stayed low. Skewed datasets (small items early, large items later) could accumulate unbounded memory without any flush triggering. Fix: sample one item's pickle size every 10 writes and apply EMA directly in write(), independent of whether any flush has occurred. The flush-time sample (100 items first flush, 10 items ongoing) still runs for higher-quality multi-item measurements when flushes do happen. Adds test that confirms mid-write flushes fire when large items arrive after a run of small items.
1 parent 9396701 commit b2b4c13

2 files changed

Lines changed: 74 additions & 13 deletions

File tree

lib/zephyr/src/zephyr/shuffle.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,17 @@ def get_iterators(self) -> Iterator[Iterator]:
9898
# ScatterReader. Sidecars are small msgpack files (a few KB) and reads are
9999
# GCS GET-bound, so a modest pool keeps latency low without thrashing.
100100
_SIDECAR_READ_CONCURRENCY = 32
101-
# Number of items sampled from the first flush to estimate avg_item_bytes.
101+
# Items sampled on the first flush to establish an avg_item_bytes baseline.
102102
_SCATTER_SAMPLE_SIZE = 100
103+
# Items sampled on each subsequent flush to track item-size drift cheaply.
104+
_SCATTER_ONGOING_SAMPLE_SIZE = 10
105+
# How often (in items written) to re-sample one item's pickle size and update
106+
# the EMA estimate in write(). This is independent of flush-time sampling and
107+
# ensures the estimate tracks drift even when no flush has fired yet.
108+
_ESTIMATE_WRITE_SAMPLE_INTERVAL = 10
109+
# EMA weight given to each new observation. 0.3 converges to a 2x step-change
110+
# in item size within ~3 samples while staying stable under small fluctuations.
111+
_ESTIMATE_EMA_ALPHA = 0.3
103112
# Fraction of total memory budgeted for read-side decompression buffers.
104113
_SCATTER_READ_BUFFER_FRACTION = 0.25
105114

@@ -532,12 +541,21 @@ def _flush(self, target: int, buf: list) -> None:
532541
buf = _apply_combiner(buf, self._key_fn, self._combiner_fn)
533542
buf.sort(key=self._sort_key)
534543

535-
if not self._sampled_avg and buf:
536-
sample = buf[: min(len(buf), _SCATTER_SAMPLE_SIZE)]
537-
total_bytes = sum(len(pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL)) for item in sample)
538-
self._avg_item_bytes = total_bytes / len(sample)
544+
if buf:
545+
# Sample a subset of the buffer to update the byte-size estimate.
546+
# First flush: larger sample for a good baseline. Subsequent flushes:
547+
# smaller sample to track drift cheaply via EMA. This prevents OOM
548+
# when early items are small but later items are large — the estimate
549+
# stays current rather than being frozen at the first-flush value.
550+
n = _SCATTER_SAMPLE_SIZE if not self._sampled_avg else _SCATTER_ONGOING_SAMPLE_SIZE
551+
sample = buf[: min(len(buf), n)]
552+
observed = sum(len(pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL)) for item in sample) / len(sample)
553+
if not self._sampled_avg:
554+
self._avg_item_bytes = observed
555+
self._sampled_avg = True
556+
else:
557+
self._avg_item_bytes = (1 - _ESTIMATE_EMA_ALPHA) * self._avg_item_bytes + _ESTIMATE_EMA_ALPHA * observed
539558
self._item_bytes_estimate = self._avg_item_bytes
540-
self._sampled_avg = True
541559

542560
frame = _write_chunk_frame(buf)
543561
offset = self._out.tell()
@@ -557,13 +575,21 @@ def _flush(self, target: int, buf: list) -> None:
557575

558576
def write(self, item: Any) -> None:
559577
"""Route a single item to its target shard buffer, flushing when over budget."""
560-
if self._total_buffer_rows == 0:
561-
# Calibrate from the first item before any batching occurs. A
562-
# hardcoded default (e.g. 512 B) can be orders of magnitude off for
563-
# large documents, allowing millions of rows to accumulate before the
564-
# first flush fires. One real measurement is far safer.
565-
self._item_bytes_estimate = float(len(pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL)))
566-
self._first_item_bytes = self._item_bytes_estimate
578+
if self._total_buffer_rows % _ESTIMATE_WRITE_SAMPLE_INTERVAL == 0:
579+
# Periodically measure a single item's serialised size and apply EMA.
580+
# This runs in write() — not just in _flush() — so the estimate tracks
581+
# size drift even when no flush has fired yet (the flush EMA is a
582+
# closed loop: if the estimate is too low no flush fires, so it never
583+
# updates). Interval-based sampling amortises the pickle.dumps cost
584+
# to 1-in-10 items while still catching step-changes within a few rows.
585+
observed = float(len(pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL)))
586+
if self._total_buffer_rows == 0:
587+
self._item_bytes_estimate = observed
588+
self._first_item_bytes = observed
589+
else:
590+
self._item_bytes_estimate = (
591+
1 - _ESTIMATE_EMA_ALPHA
592+
) * self._item_bytes_estimate + _ESTIMATE_EMA_ALPHA * observed
567593

568594
key = self._key_fn(item)
569595
target = deterministic_hash(key) % self._num_output_shards

lib/zephyr/tests/test_shuffle.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,41 @@ def test_scatter_byte_budget_flushes_mid_write(tmp_path):
200200
assert total_chunks > 2, f"expected >2 chunks with 1-byte budget, got {total_chunks}"
201201

202202

203+
def test_scatter_estimate_tracks_skewed_items(tmp_path):
204+
"""Estimate updates after each flush so large late items still trigger budget flushes."""
205+
num_shards = 1
206+
data_path = str(tmp_path / "shard-0000.shuffle")
207+
208+
# Start with tiny items, then switch to large items. With a frozen estimate
209+
# the budget check would never fire for the large items. With EMA updates it
210+
# should: _item_bytes_estimate rises and eventually exceeds budget / rows.
211+
small_items = [{"k": 0, "v": "x"} for _ in range(50)]
212+
large_items = [{"k": 0, "v": "y" * 50_000} for _ in range(10)]
213+
214+
# Budget large enough that small items alone never flush, but one large
215+
# item should push the estimate over threshold quickly.
216+
budget = 10_000 # 10 KB — well under 10 * 50 KB large items
217+
writer = ScatterWriter(
218+
data_path=data_path,
219+
key_fn=_key,
220+
num_output_shards=num_shards,
221+
buffer_limit_bytes=budget,
222+
)
223+
for item in small_items + large_items:
224+
writer.write(item)
225+
writer.close()
226+
227+
# All items must survive the skewed flush pattern.
228+
scatter_paths = [data_path]
229+
recovered = list(ScatterReader.from_sidecars(scatter_paths, 0))
230+
all_items = small_items + large_items
231+
assert sorted(recovered, key=lambda x: x["v"]) == sorted(all_items, key=lambda x: x["v"])
232+
233+
# The estimate must have been updated: mid-write flushes should have fired
234+
# for the large items (not just at close).
235+
assert writer._mid_write_flushes > 0, "expected mid-write flushes for large items"
236+
237+
203238
def test_scatter_byte_budget_preserves_all_items(tmp_path):
204239
"""Items are not lost or duplicated when byte-budget flushes fire mid-write."""
205240
num_shards = 3

0 commit comments

Comments
 (0)