Skip to content

Commit ac090af

Browse files
hsuhanooiclaude
andcommitted
[zephyr] Use Arrow IPC for scatter chunks, fall back to pickle
Replace the cloudpickle-only chunk format with Arrow IPC when items are Arrow-compatible (plain dicts with primitive/string values). Chunks that cannot be Arrow-encoded (frozenset, custom classes) fall back to cloudpickle sub-batches. A one-byte format tag at the start of each frame distinguishes the two paths on read. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 263f57b commit ac090af

2 files changed

Lines changed: 88 additions & 33 deletions

File tree

lib/zephyr/src/zephyr/shuffle.py

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
"""Scatter/shuffle support for Zephyr pipelines.
55
66
Each source-shard's scatter output is a single binary file containing a
7-
sequence of zstd-compressed frames. Within one chunk's zstd frame, items
8-
are written in sub-batches of ``_SUB_BATCH_SIZE`` — each sub-batch is a
9-
single ``pickle.dump(list_of_items)`` into the zstd stream. This amortises
10-
per-item pickle/zstd dispatch over a sub-batch while still letting the
11-
reader stream sub-batches lazily without materialising the full chunk.
7+
sequence of zstd-compressed frames. Each frame starts with a one-byte format
8+
tag: ``\\x01`` for Arrow IPC and ``\\x00`` for cloudpickle sub-batches.
9+
10+
Arrow IPC is used when the chunk items are Arrow-compatible (plain dicts with
11+
primitive or string values). For items that Arrow cannot represent (e.g.
12+
frozenset, custom classes), the writer falls back to cloudpickle sub-batches
13+
of ``_SUB_BATCH_SIZE`` items per ``pickle.dump`` call.
1214
1315
A msgpack sidecar (``.scatter_meta``) maps ``target_shard -> [(offset, length)]``
1416
byte ranges into the data file, plus per-shard ``max_chunk_rows`` and a global
@@ -17,12 +19,10 @@
1719
which reducers consume to build :class:`ScatterReader` instances.
1820
1921
On read, each chunk is fetched with a single ``cat_file`` range GET (one
20-
HTTP request, no per-chunk file handle), then streamed via
21-
``pickle.load`` on a length-bounded zstd reader. Per-iterator memory stays
22-
near-constant: one buffered item plus the zstd decoder state plus the
23-
chunk's compressed bytes (typically a few MB). This bound is essential for
24-
skewed shuffles where one reducer pulls disproportionate data and the
25-
external-sort fan-in opens hundreds of chunk iterators at once.
22+
HTTP request, no per-chunk file handle). The format tag is inspected and the
23+
payload is dispatched to Arrow IPC or pickle deserialization accordingly.
24+
Per-iterator memory stays near-constant: bounded by the chunk's compressed
25+
bytes plus decompressed Arrow buffers or pickle state.
2626
"""
2727

2828
from __future__ import annotations
@@ -40,6 +40,7 @@
4040

4141
import cloudpickle
4242
import msgspec
43+
import pyarrow as pa
4344
import zstandard as zstd
4445
from rigging.filesystem import open_url, url_to_fs
4546
from rigging.timing import log_time
@@ -101,6 +102,11 @@ def get_iterators(self) -> Iterator[Iterator]:
101102
# dispatch overhead), smaller = lower per-iterator read memory.
102103
_SUB_BATCH_SIZE = 1024
103104

105+
# One-byte format tags written at the start of every chunk frame.
106+
# Arrow IPC is used when items are Arrow-compatible; pickle is the fallback.
107+
_FRAME_FORMAT_PICKLE = b"\x00"
108+
_FRAME_FORMAT_ARROW = b"\x01"
109+
104110

105111
# ---------------------------------------------------------------------------
106112
# Sidecar / manifest helpers
@@ -249,19 +255,25 @@ def get_chunk_iterators(self) -> Iterator[Iterator]:
249255
def _iter_chunk(fs: Any, fs_path: str, offset: int, length: int) -> Iterator:
250256
"""Fetch one chunk's compressed bytes via cat_file and stream items.
251257
252-
Each chunk is a zstd frame containing a sequence of pickled sub-batches
253-
(lists of up to ``_SUB_BATCH_SIZE`` items). The reader streams one
254-
sub-batch at a time, so per-iterator memory is bounded by the
255-
sub-batch size plus the chunk's compressed bytes.
258+
Reads the one-byte format tag to dispatch to Arrow IPC or pickle
259+
deserialization. Arrow chunks are decompressed in one shot and converted
260+
via ``table.to_pylist()``. Pickle chunks are streamed sub-batch by
261+
sub-batch via ``pickle.load``.
256262
"""
257263
blob = fs.cat_file(fs_path, start=offset, end=offset + length)
258-
with zstd.ZstdDecompressor().stream_reader(io.BytesIO(blob)) as reader:
259-
while True:
260-
try:
261-
sub_batch = pickle.load(reader)
262-
except EOFError:
263-
return
264-
yield from sub_batch
264+
fmt, payload = blob[0:1], blob[1:]
265+
if fmt == _FRAME_FORMAT_ARROW:
266+
ipc_bytes = zstd.ZstdDecompressor().decompress(payload)
267+
reader = pa.ipc.open_stream(pa.py_buffer(ipc_bytes))
268+
yield from reader.read_all().to_pylist()
269+
else:
270+
with zstd.ZstdDecompressor().stream_reader(io.BytesIO(payload)) as reader:
271+
while True:
272+
try:
273+
sub_batch = pickle.load(reader)
274+
except EOFError:
275+
return
276+
yield from sub_batch
265277

266278

267279
# ---------------------------------------------------------------------------
@@ -426,19 +438,28 @@ def _apply_combiner(buffer: list, key_fn: Callable, combiner_fn: Callable) -> li
426438

427439

428440
def _write_chunk_frame(items: list) -> bytes:
429-
"""Encode a list of items as one zstd frame of pickled sub-batches.
441+
"""Encode a list of items as one zstd-compressed frame.
430442
431-
Items are split into sub-batches of ``_SUB_BATCH_SIZE`` and each
432-
sub-batch is written as a single ``cloudpickle.dump(sublist)`` into the
433-
same zstd stream. This batches per-call dispatch overhead while
434-
keeping per-iterator read memory bounded by the sub-batch size.
443+
Tries Arrow IPC first; falls back to cloudpickle sub-batches for types
444+
Arrow cannot represent (e.g. frozenset, custom classes). The first byte
445+
of the returned bytes is the format tag (``_FRAME_FORMAT_ARROW`` or
446+
``_FRAME_FORMAT_PICKLE``).
435447
"""
436-
raw = io.BytesIO()
437-
cctx = zstd.ZstdCompressor(level=_ZSTD_COMPRESS_LEVEL)
438-
with cctx.stream_writer(raw, closefd=False) as zf:
439-
for i in range(0, len(items), _SUB_BATCH_SIZE):
440-
cloudpickle.dump(items[i : i + _SUB_BATCH_SIZE], zf, protocol=pickle.HIGHEST_PROTOCOL)
441-
return raw.getvalue()
448+
try:
449+
table = pa.Table.from_pylist(items)
450+
sink = pa.BufferOutputStream()
451+
with pa.ipc.new_stream(sink, table.schema) as writer:
452+
writer.write_table(table)
453+
cctx = zstd.ZstdCompressor(level=_ZSTD_COMPRESS_LEVEL)
454+
return _FRAME_FORMAT_ARROW + cctx.compress(sink.getvalue().to_pybytes())
455+
except Exception:
456+
logger.debug("_write_chunk_frame: Arrow IPC not applicable, using pickle")
457+
raw = io.BytesIO()
458+
cctx = zstd.ZstdCompressor(level=_ZSTD_COMPRESS_LEVEL)
459+
with cctx.stream_writer(raw, closefd=False) as zf:
460+
for i in range(0, len(items), _SUB_BATCH_SIZE):
461+
cloudpickle.dump(items[i : i + _SUB_BATCH_SIZE], zf, protocol=pickle.HIGHEST_PROTOCOL)
462+
return _FRAME_FORMAT_PICKLE + raw.getvalue()
442463

443464

444465
class ScatterWriter:

lib/zephyr/tests/test_shuffle.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from zephyr.shuffle import (
1212
ScatterFileIterator,
1313
ScatterReader,
14+
_FRAME_FORMAT_ARROW,
15+
_FRAME_FORMAT_PICKLE,
1416
_write_chunk_frame,
1517
_write_scatter,
1618
)
@@ -194,6 +196,38 @@ def test_scatter_file_iterator_multiple_chunks(tmp_path):
194196
assert chunks == [chunk_a, chunk_b]
195197

196198

199+
# ---------------------------------------------------------------------------
200+
# Frame format: Arrow IPC vs pickle fallback
201+
# ---------------------------------------------------------------------------
202+
203+
204+
def test_write_chunk_frame_uses_arrow_for_plain_dicts():
205+
"""Plain dict items with primitive values use the Arrow IPC format tag."""
206+
items = [{"k": i, "v": float(i)} for i in range(10)]
207+
frame = _write_chunk_frame(items)
208+
assert frame[0:1] == _FRAME_FORMAT_ARROW, "expected Arrow IPC format tag for Arrow-compatible dicts"
209+
210+
211+
def test_write_chunk_frame_falls_back_to_pickle_for_frozensets():
212+
"""Items containing frozensets cannot be Arrow-encoded and use pickle."""
213+
items = [{"k": 0, "v": frozenset([1, 2, 3])}]
214+
frame = _write_chunk_frame(items)
215+
assert frame[0:1] == _FRAME_FORMAT_PICKLE, "expected pickle format tag for frozenset values"
216+
217+
218+
def test_arrow_roundtrip_end_to_end(tmp_path):
219+
"""Items written via the Arrow path round-trip correctly through scatter."""
220+
items = [{"k": i % 3, "v": i, "label": f"item-{i}"} for i in range(30)]
221+
scatter_paths = _build_shard(tmp_path, items, num_output_shards=3)
222+
223+
recovered = []
224+
for shard_idx in range(3):
225+
shard = ScatterReader.from_sidecars(scatter_paths, shard_idx)
226+
recovered.extend(list(shard))
227+
228+
assert sorted(recovered, key=lambda x: x["v"]) == sorted(items, key=lambda x: x["v"])
229+
230+
197231
# ---------------------------------------------------------------------------
198232
# external_sort_merge
199233
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)