Skip to content

Commit 82d2359

Browse files
[zephyr] Replace Parquet shuffle with zstd-chunk format (#4782)
## Summary - Replaces Parquet-based scatter/reduce shuffle with flat zstd frames + byte-range sidecar. Drops Arrow from the shuffle data plane. - Adds memory-bounded external-sort fan-in and byte-budgeted pass-1 spill batch so skewed and large-item shuffles don't OOM the worker. - Net **−165 lines** across `shuffle.py` / `external_sort.py` / `plan.py` / `execution.py`. ## Format Each scatter source writes a single binary file: a concatenation of zstd frames. Each frame is one sorted chunk, containing repeated `pickle.dump(sub_batch)` calls into a single zstd stream (sub-batch size = 1024 items by default). A JSON `.scatter_meta` sidecar maps `target_shard → [(offset, length)]`. Sidecars aggregate into one `scatter_metadata` manifest per stage (wire format unchanged from before). On read, `ScatterFileIterator` fetches each chunk via one `cat_file` range GET and streams sub-batches with `pickle.load`. Per-iterator memory is bounded by `sub_batch_size * avg_item_bytes + chunk_compressed_bytes` — independent of chunk row count. Gone: - Segment-rotation for schema evolution (`_ensure_writer`, `seg_idx`, `pa.unify_schemas`) - Arrow-vs-pickle envelope peek (`use_pickle_envelope`, `pa.RecordBatch.from_pylist`) - Row-group statistics + predicate pushdown (`equality_predicates`, `iter_parquet_row_groups`) - PyArrow dataset memory-leak workaround (`_get_scatter_read_fs`, block-size budgeting) - `pyarrow` on the shuffle data plane ## External-sort scaling Two knobs now scale with the workload instead of being hardcoded: - `compute_fan_in(per_iter_bytes, mem_limit)` — pass-1 fan-in floored at 4, capped at `EXTERNAL_SORT_FAN_IN=500`, otherwise sized to fit 50% of worker memory given `max_chunk_rows * avg_item_bytes` per open chunk. - `compute_write_batch_size(avg_item_bytes)` — pass-1 `pending` buffer sized to ~64 MB of items (capped at 10k). Prior fixed `_WRITE_BATCH_SIZE=10_000` could OOM at 10 GB on 1 MB items. `_merge_sorted_chunks` reads `shard.max_chunk_rows` and `shard.avg_item_bytes` from the manifest and passes both values through `external_sort_merge`. ## Benchmarks on marin-dev (4 workers × 8 GB RAM) | Workload | Baseline (Parquet) | New | Hot-worker peak mem | |---|---|---|---| | Uniform 10 GB, 250 B items | 736 s | **392 s** | 551 MB | | Uniform 10 GB, 1 MB items | — | **352 s** | 621 MB | | Skew 90% 10 GB, 250 B items | **OOM** | 800 s | 3.09 GB | | Skew 90% 10 GB, 1 MB items | **OOM** | 1349 s | 7.18 GB | | Skew 90% 50 GB, 250 B items | **OOM** | 7796 s* | 3.58 GB | | Skew 90% 50 GB, 1 MB items | **OOM** | 2996 s | 7.33 GB | *Includes ~35 min lost to a mid-run coordinator preemption + automatic pipeline retry. Uniform throughput 1.88× faster than Parquet at 10 GB small. Every skewed case baseline OOMs on now completes with memory bounded below the worker limit. ## Tests - `lib/zephyr/tests/test_shuffle.py` rewritten for the new API (13 tests). - `lib/zephyr/tests/test_groupby.py` pickle-roundtrip test updated. - `lib/zephyr/tests/benchmark_shuffle.py` new — synthetic 10-50 GB shuffle with `--hot-shard-frac`/`--hot-key-pool`. - `test_shuffle.py` (13), `test_groupby.py` (23), `test_execution.py` (40) all pass locally. ## Test plan - [x] Unit tests pass locally - [x] Uniform shuffle (10 GB, small + large items) on marin-dev - [x] Skewed shuffle (10 GB + 50 GB, small + large items) on marin-dev - [ ] Datakit ferry on marin — deferred (not necessary given direct shuffle benchmarks) 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Rafal Wojdyla <ravwojdyla@gmail.com>
1 parent e78e1dc commit 82d2359

File tree

8 files changed

+520
-642
lines changed

8 files changed

+520
-642
lines changed

lib/zephyr/AGENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Lazy dataset processing library. Start with the shared instructions in `/AGENTS.
1515
- `src/zephyr/plan.py``compute_plan`, `PhysicalPlan`, operation fusion
1616
- `src/zephyr/readers.py``load_jsonl`, `load_parquet`, `load_vortex`, `InputFileSpec`
1717
- `src/zephyr/writers.py``write_jsonl_file`, `write_parquet_file`, `write_vortex_file`, Levanter cache writer
18-
- `src/zephyr/shuffle.py` — scatter pipeline internals (`ScatterParquetIterator`, `ScatterShard`, hash-routing, combiner, Parquet envelope)
18+
- `src/zephyr/shuffle.py` — scatter pipeline internals (`ScatterFileIterator`, `ScatterShard`, hash-routing, combiner, zstd-chunk file format with byte-range sidecar)
1919
- `src/zephyr/expr.py``Expr`, `col`, `lit` for filter expressions
2020
- `src/zephyr/external_sort.py``external_sort_merge` k-way merge of sorted runs
2121
- `src/zephyr/counters.py``increment` / `get_counters` per-worker counter API (`CounterSnapshot` lives in `execution.py`)

lib/zephyr/src/zephyr/execution.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from typing import Any, Protocol
3636

3737
import cloudpickle
38-
import pyarrow as pa
3938
from rigging.filesystem import open_url, url_to_fs
4039
from fray.v2 import ActorConfig, ActorFuture, ActorHandle, Client, ResourceConfig
4140
from fray.v2.client import JobHandle
@@ -112,10 +111,11 @@ def read(self) -> list:
112111
from zephyr.shuffle import ( # noqa: E402
113112
ListShard,
114113
MemChunk,
115-
ScatterShard, # noqa: F401 — re-exported for plan.py and external callers
114+
ScatterReader, # noqa: F401 — re-exported for plan.py and external callers
115+
ScatterShard, # noqa: F401 — backward-compat alias for ScatterReader
116+
ScatterWriter, # noqa: F401 — re-exported for external callers
116117
_build_scatter_shard_from_manifest, # noqa: F401 — re-exported for plan.py
117-
_make_envelope,
118-
_write_parquet_scatter,
118+
_write_scatter,
119119
_write_scatter_manifest,
120120
_SCATTER_MANIFEST_NAME,
121121
)
@@ -232,36 +232,22 @@ def _write_stage_output(
232232
TaskResult with a ListShard.
233233
"""
234234
if scatter_op is not None:
235-
# Peek first item to test Arrow serializability
236235
first_item = next(stage_gen, None)
237236
if first_item is None:
238237
return TaskResult(shard=ListShard(refs=[]))
239238

240239
full_gen = itertools.chain([first_item], stage_gen)
241240

242-
use_pickle_envelope = False
243-
try:
244-
test_envelope = _make_envelope([first_item], 0, 0)
245-
pa.RecordBatch.from_pylist(test_envelope)
246-
logger.info("Using Parquet for scatter serialization for shard %d", source_shard)
247-
except Exception:
248-
use_pickle_envelope = True
249-
logger.info(
250-
"Using Parquet with pickle envelope for scatter serialization for shard %d",
251-
source_shard,
252-
)
253-
254241
num_output_shards = scatter_op.num_output_shards if scatter_op.num_output_shards > 0 else total_shards
255-
parquet_path = f"{stage_dir}/shard-{shard_idx:04d}.parquet"
256-
shard = _write_parquet_scatter(
242+
data_path = f"{stage_dir}/shard-{shard_idx:04d}.shuffle"
243+
shard = _write_scatter(
257244
full_gen,
258245
source_shard,
259-
parquet_path,
246+
data_path,
260247
key_fn=scatter_op.key_fn,
261248
num_output_shards=num_output_shards,
262249
sort_fn=scatter_op.sort_fn,
263250
combiner_fn=scatter_op.combiner_fn,
264-
pickled=use_pickle_envelope,
265251
)
266252
return TaskResult(shard=shard)
267253

lib/zephyr/src/zephyr/external_sort.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
``EXTERNAL_SORT_FAN_IN``, to avoid opening O(k) scanners simultaneously and
88
exhausting worker memory.
99
10-
Pass 1: batch the k iterators into groups of EXTERNAL_SORT_FAN_IN, merge each
11-
group with heapq.merge, and spill items to a run file under
12-
``{external_sort_dir}/run-{i:04d}.spill`` via :class:`SpillWriter`.
10+
Pass 1: batch the k iterators into groups of ``fan_in`` (defaulting to
11+
``EXTERNAL_SORT_FAN_IN`` but typically lowered via :func:`compute_fan_in` to
12+
fit the worker's memory budget), merge each group with ``heapq.merge``, and
13+
spill items to a run file under ``{external_sort_dir}/run-{i:04d}.spill`` via
14+
:class:`SpillWriter`.
1315
1416
Pass 2: heapq.merge over the (much smaller) set of run file iterators. Each
1517
iterator streams chunks from its spill file via :class:`SpillReader`; the read
@@ -31,20 +33,64 @@
3133

3234
logger = logging.getLogger(__name__)
3335

34-
# Maximum simultaneous chunk iterators per pass-1 batch.
36+
# Hard cap on simultaneous chunk iterators per pass-1 batch. Used as the
37+
# default when the caller cannot estimate per-iterator memory; otherwise
38+
# ``compute_fan_in`` lowers it to fit within the worker's memory budget.
3539
EXTERNAL_SORT_FAN_IN = 500
3640

37-
# Items buffered before handing to the SpillWriter. Larger values amortize
38-
# per-chunk overhead in the spill format.
41+
# Fraction of worker memory budgeted for the open chunk iterators during a
42+
# pass-1 merge batch.
43+
_FAN_IN_MEMORY_FRACTION = 0.5
44+
45+
# Floor on fan-in. Below 2, pass-1 just rewrites each chunk to its own run
46+
# file with no merging — pass-2 still produces correct output but the extra
47+
# round-trip is wasteful, so we keep at least a small merge fan-in.
48+
_FAN_IN_FLOOR = 4
49+
50+
# Default item count per write into the SpillWriter in pass-1. Large enough
51+
# for good compression + low per-call overhead. For large items (e.g. 1 MB
52+
# each) the caller should pass a smaller ``write_batch_size`` via
53+
# :func:`compute_write_batch_size` so the in-memory ``pending`` buffer stays
54+
# bounded by bytes rather than count.
3955
_WRITE_BATCH_SIZE = 10_000
4056

57+
# Target bytes for the in-memory pass-1 spill buffer.
58+
_WRITE_BATCH_TARGET_BYTES = 64 * 1024 * 1024
59+
4160
# Target bytes per spill chunk in pass-1 runs.
4261
_ROW_GROUP_BYTES = 8 * 1024 * 1024
4362

4463
# Fraction of container memory budgeted for pass-2 read buffers.
4564
_READ_MEMORY_FRACTION = 0.25
4665

4766

67+
def compute_fan_in(per_iterator_bytes: int, memory_limit: int) -> int:
68+
"""Pick a pass-1 fan-in that fits within the memory budget.
69+
70+
``per_iterator_bytes`` is the caller's estimate of memory held per open
71+
chunk iterator (typically compressed chunk bytes plus a small decoded
72+
buffer). Returns at least ``_FAN_IN_FLOOR`` and at most
73+
``EXTERNAL_SORT_FAN_IN``.
74+
"""
75+
if per_iterator_bytes <= 0 or memory_limit <= 0:
76+
return EXTERNAL_SORT_FAN_IN
77+
budget = int(memory_limit * _FAN_IN_MEMORY_FRACTION)
78+
fan_in = budget // per_iterator_bytes
79+
fan_in = max(_FAN_IN_FLOOR, fan_in)
80+
return min(fan_in, EXTERNAL_SORT_FAN_IN)
81+
82+
83+
def compute_write_batch_size(avg_item_bytes: float) -> int:
84+
"""Pick a pass-1 pending-buffer size sized to a byte budget.
85+
86+
Caps at the ``_WRITE_BATCH_SIZE`` default when items are small.
87+
"""
88+
if avg_item_bytes <= 0:
89+
return _WRITE_BATCH_SIZE
90+
by_bytes = int(_WRITE_BATCH_TARGET_BYTES // avg_item_bytes)
91+
return max(1, min(by_bytes, _WRITE_BATCH_SIZE))
92+
93+
4894
def _safe_read_batch_size(n_runs: int, sample_run_path: str) -> int:
4995
"""Compute a pass-2 read batch size that fits within the memory budget.
5096
@@ -87,16 +133,25 @@ def external_sort_merge(
87133
chunk_iterators_gen: Iterator[Iterator], # lazy — consumed in batches
88134
merge_key: Callable,
89135
external_sort_dir: str,
136+
fan_in: int = EXTERNAL_SORT_FAN_IN,
137+
write_batch_size: int = _WRITE_BATCH_SIZE,
90138
) -> Iterator:
91139
"""Merge ``chunk_iterators_gen`` via a two-pass external sort.
92140
93141
Args:
94142
chunk_iterators_gen: Lazy iterator of sorted iterators (one per scatter chunk).
95-
Consumed in batches of EXTERNAL_SORT_FAN_IN to avoid opening all file
143+
Consumed in batches of ``fan_in`` to avoid opening all file
96144
handles simultaneously.
97145
merge_key: Key function passed to heapq.merge.
98146
external_sort_dir: GCS prefix for spill files, e.g.
99147
``gs://bucket/.../stage1-external-sort/shard-0042``.
148+
fan_in: Maximum number of chunk iterators to merge in one pass-1
149+
batch. Defaults to ``EXTERNAL_SORT_FAN_IN``; callers should pass
150+
a value computed by :func:`compute_fan_in` to bound memory.
151+
write_batch_size: Item count threshold for the pass-1 ``pending``
152+
buffer. Callers should pass a value from
153+
:func:`compute_write_batch_size` to keep the buffer bounded by
154+
bytes rather than item count.
100155
101156
Yields:
102157
Items in merged sort order.
@@ -109,8 +164,10 @@ def external_sort_merge(
109164
spill_fs, spill_dir = url_to_fs(external_sort_dir)
110165
spill_fs.makedirs(spill_dir, exist_ok=True)
111166

167+
logger.info("External sort: pass-1 fan_in=%d, write_batch_size=%d", fan_in, write_batch_size)
168+
112169
while True:
113-
batch = list(islice(chunk_iterators_gen, EXTERNAL_SORT_FAN_IN))
170+
batch = list(islice(chunk_iterators_gen, fan_in))
114171
if not batch:
115172
break
116173
run_path = f"{external_sort_dir}/run-{batch_idx:04d}.spill"
@@ -119,7 +176,7 @@ def external_sort_merge(
119176
with SpillWriter(run_path, row_group_bytes=_ROW_GROUP_BYTES) as writer:
120177
for item in heapq.merge(*batch, key=merge_key):
121178
pending.append(item)
122-
if len(pending) >= _WRITE_BATCH_SIZE:
179+
if len(pending) >= write_batch_size:
123180
writer.write(pending)
124181
item_count += len(pending)
125182
pending = []

lib/zephyr/src/zephyr/plan.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from iris.env_resources import TaskResources as _TaskResources
2626
from rigging.filesystem import url_to_fs
2727

28-
from zephyr.external_sort import EXTERNAL_SORT_FAN_IN, external_sort_merge
28+
from zephyr.external_sort import external_sort_merge
2929

3030
from zephyr.dataset import (
3131
Dataset,
@@ -64,7 +64,7 @@ class Shard(Protocol):
6464
6565
Implementations:
6666
- ListShard: backed by iterable references (source data, non-scatter)
67-
- ScatterShard: backed by scatter Parquet files with predicate pushdown
67+
- ScatterShard: backed by scatter zstd-chunk files with byte-range sidecar
6868
"""
6969

7070
def __iter__(self) -> Iterator: ...
@@ -635,7 +635,7 @@ def merge_key(item):
635635

636636
# Check if external sort is needed BEFORE materializing all iterators.
637637
# ScatterShard can decide using manifest stats (no file opens needed).
638-
from zephyr.shuffle import ScatterShard
638+
from zephyr.shuffle import ScatterShard # ScatterShard is an alias for ScatterReader
639639

640640
use_external = (
641641
external_sort_dir is not None
@@ -644,21 +644,35 @@ def merge_key(item):
644644
)
645645

646646
if use_external:
647+
from zephyr.external_sort import compute_fan_in, compute_write_batch_size
648+
649+
memory_limit = _TaskResources.from_environment().memory_bytes
650+
# Per-iterator memory ~= compressed bytes for one chunk held by
651+
# cat_file. Use the actual max compressed chunk size from the sidecar.
652+
per_iter_bytes = shard.max_compressed_chunk_bytes
653+
fan_in = compute_fan_in(per_iter_bytes, memory_limit)
654+
write_batch_size = compute_write_batch_size(shard.avg_item_bytes)
647655
logger.info(
648-
"External sort triggered for shard with %d iterators, spilling to %s",
656+
"External sort triggered for shard with %d iterators, "
657+
"fan_in=%d (per_iter≈%dKB), write_batch_size=%d, spilling to %s",
649658
sum(it.chunk_count for it in shard.iterators),
659+
fan_in,
660+
per_iter_bytes // 1024,
661+
write_batch_size,
650662
external_sort_dir,
651663
)
652664
# Pass lazy generator — external_sort_merge consumes in batches without opening all files
653-
merged_stream = external_sort_merge(shard.get_iterators(), merge_key, external_sort_dir)
665+
merged_stream = external_sort_merge(
666+
shard.get_iterators(),
667+
merge_key,
668+
external_sort_dir,
669+
fan_in=fan_in,
670+
write_batch_size=write_batch_size,
671+
)
654672
else:
655673
chunk_iterators = list(shard.get_iterators())
656674
logger.info(f"Merging {len(chunk_iterators):,} sorted chunk iterators")
657-
if external_sort_dir is not None and len(chunk_iterators) > EXTERNAL_SORT_FAN_IN:
658-
# Fallback: stats unavailable, use fan_in threshold
659-
merged_stream = external_sort_merge(iter(chunk_iterators), merge_key, external_sort_dir)
660-
else:
661-
merged_stream = heapq.merge(*chunk_iterators, key=merge_key)
675+
merged_stream = heapq.merge(*chunk_iterators, key=merge_key)
662676
yield from groupby(merged_stream, key=key_fn)
663677

664678

0 commit comments

Comments
 (0)