Skip to content

Commit f74cda9

Browse files
ravwojdyla-agentravwojdylaclaudegithub-actions[bot]
authored
zephyr: pack pickle inside parquet shuffle (#3656)
## Summary - When scatter items aren't Arrow-serializable, serialize them via `pickle.dumps()` into a binary `pickled` column in Parquet instead of falling back to per-chunk `.pkl` files - Eliminates the M×RxC file blowup in pickle mode while preserving single-file-per-shard compactness and predicate pushdown - Adds `is_pickled` field to `ParquetDiskChunk` for transparent deserialization on read Fixes #3640 ## Test plan - [x] Existing `test_group_by_non_vortex_serializable` passes on all 3 backends (local, iris, ray) — exercises the full pickle-in-parquet scatter+reduce path with `frozenset` items - [x] New `test_parquet_disk_chunk_pickle_roundtrip` — unit test for pickle envelope write/read - [x] All 49 groupby tests pass, all 61 execution tests pass - [x] Pre-commit clean 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Rafal Wojdyla <ravwojdyla@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Rafal Wojdyla <ravwojdyla@users.noreply.github.com>
1 parent 925472f commit f74cda9

File tree

2 files changed

+71
-31
lines changed

2 files changed

+71
-31
lines changed

lib/zephyr/src/zephyr/execution.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __iter__(self) -> Iterator: ...
6363
_ZEPHYR_SHUFFLE_SHARD_IDX_COL = "shard_idx"
6464
_ZEPHYR_SHUFFLE_CHUNK_IDX_COL = "chunk_idx"
6565
_ZEPHYR_SHUFFLE_ITEM_COL = "item"
66+
_ZEPHYR_SHUFFLE_PICKLED_COL = "pickled"
6667

6768

6869
@dataclass(frozen=True)
@@ -114,34 +115,38 @@ class ParquetDiskChunk:
114115
for different (shard_idx, chunk_idx) pairs. Each chunk is pre-sorted
115116
by key, preserving the invariant needed for k-way merge in Reduce.
116117
117-
Items are stored wrapped in an envelope struct with routing metadata::
118+
Items are stored in one of two envelope formats:
118119
119-
{"shard_idx": int, "chunk_idx": int, "item": <user_data>}
120+
* **Native** (``is_pickled=False``): ``{"shard_idx", "chunk_idx", "item": <data>}``
121+
* **Pickle** (``is_pickled=True``): ``{"shard_idx", "chunk_idx", "pickled": <bytes>}``
120122
121-
The ``read`` method filters by shard/chunk and unwraps the ``item`` field.
122-
Predicate pushdown in Parquet skips irrelevant row groups, so each
123-
reducer reads only its own data efficiently.
123+
The pickle envelope is used when items are not Arrow-serializable.
124124
"""
125125

126126
path: str
127127
filter_shard: int
128128
filter_chunk: int
129129
count: int
130+
is_pickled: bool = False
130131

131132
def __iter__(self) -> Iterator:
132133
return iter(self.read())
133134

134135
def read(self) -> list:
135136
"""Load filtered chunk data from a Parquet file, unwrapping envelope."""
137+
col = _ZEPHYR_SHUFFLE_PICKLED_COL if self.is_pickled else _ZEPHYR_SHUFFLE_ITEM_COL
136138
table = pq.read_table(
137139
self.path,
138-
columns=[_ZEPHYR_SHUFFLE_ITEM_COL],
140+
columns=[col],
139141
filters=(
140142
(pc.field(_ZEPHYR_SHUFFLE_SHARD_IDX_COL) == self.filter_shard)
141143
& (pc.field(_ZEPHYR_SHUFFLE_CHUNK_IDX_COL) == self.filter_chunk)
142144
),
143145
)
144-
return table.column(_ZEPHYR_SHUFFLE_ITEM_COL).to_pylist()
146+
items = table.column(col).to_pylist()
147+
if self.is_pickled:
148+
return [pickle.loads(b) for b in items]
149+
return items
145150

146151

147152
@dataclass
@@ -225,6 +230,18 @@ def _make_envelope(items: list, target_shard: int, chunk_idx: int) -> list[dict]
225230
]
226231

227232

233+
def _make_pickle_envelope(items: list, target_shard: int, chunk_idx: int) -> list[dict]:
234+
"""Wrap items as pickle-serialized bytes for Arrow-incompatible types."""
235+
return [
236+
{
237+
_ZEPHYR_SHUFFLE_SHARD_IDX_COL: target_shard,
238+
_ZEPHYR_SHUFFLE_CHUNK_IDX_COL: chunk_idx,
239+
_ZEPHYR_SHUFFLE_PICKLED_COL: cloudpickle.dumps(item),
240+
}
241+
for item in items
242+
]
243+
244+
228245
def _segment_path(base_path: str, seg_idx: int) -> str:
229246
"""Return the file path for a given segment index.
230247
@@ -245,13 +262,17 @@ def _write_parquet_scatter(
245262
stage_gen: Iterator[StageResultChunk],
246263
source_shard: int,
247264
parquet_path: str,
265+
pickled: bool = False,
248266
) -> list[ResultChunk]:
249267
"""Stream scatter chunks into Parquet files as row groups.
250268
251269
Writes batches to a Parquet file until a schema mismatch is detected
252270
(e.g. a field evolves from null to a concrete type). On mismatch the
253271
current file is closed, the schema is unified via ``pa.unify_schemas``,
254272
and a new segment file is opened with the evolved schema.
273+
274+
When ``pickled=True``, items are serialized via pickle into a binary
275+
``pickled`` column instead of being stored natively in the ``item`` column.
255276
"""
256277
chunk_results: list[_ChunkMetadata] = []
257278
per_shard_chunk_cnt: dict[int, int] = defaultdict(int)
@@ -285,7 +306,8 @@ def _flush_pending():
285306
target_shard = result.target_shard
286307
shard_chunk_idx = per_shard_chunk_cnt[target_shard]
287308
per_shard_chunk_cnt[target_shard] += 1
288-
envelope = _make_envelope(chunk_items, target_shard, shard_chunk_idx)
309+
envelope_fn = _make_pickle_envelope if pickled else _make_envelope
310+
envelope = envelope_fn(chunk_items, target_shard, shard_chunk_idx)
289311
chunk_arrow = pa.RecordBatch.from_pylist(envelope)
290312

291313
if schema is None:
@@ -328,7 +350,11 @@ def _flush_pending():
328350
source_shard=source_shard,
329351
target_shard=rec.target_shard,
330352
data=ParquetDiskChunk(
331-
path=rec.path, filter_shard=rec.target_shard, filter_chunk=rec.chunk_idx, count=rec.cnt
353+
path=rec.path,
354+
filter_shard=rec.target_shard,
355+
filter_chunk=rec.chunk_idx,
356+
count=rec.cnt,
357+
is_pickled=pickled,
332358
),
333359
)
334360
for rec in chunk_results
@@ -387,33 +413,26 @@ def _write_stage_chunks(
387413

388414
first_items = list(first_result.chunk)
389415

390-
# Test Arrow serializability on the first chunk to decide parquet vs pickle
391-
use_parquet = False
416+
# Prepend the already-consumed first result back into the stream
417+
first_with_materialized_chunk = dataclasses.replace(first_result, chunk=first_items)
418+
full_gen = itertools.chain([first_with_materialized_chunk], stage_gen)
419+
392420
if is_scatter:
421+
# Test Arrow serializability on the first chunk to decide native vs pickle envelope
422+
use_pickle_envelope = False
393423
try:
394424
test_envelope = _make_envelope(first_items, 0, 0)
395425
pa.RecordBatch.from_pylist(test_envelope)
396-
use_parquet = True
397426
logger.info("Using Parquet for scatter serialization for shard %d", source_shard)
398427
except Exception:
399-
sample_rows = str(test_envelope[:5]) if len(test_envelope) > 5 else str(test_envelope)
400-
if len(sample_rows) > 1000:
401-
sample_rows = sample_rows[:1000] + "...(truncated)"
402-
logger.warning(
403-
"Arrow scatter serialization failed for shard %d; "
404-
"falling back to pickle. Performance will be degraded. Sample rows: %s",
428+
use_pickle_envelope = True
429+
logger.info(
430+
"Using Parquet with pickle envelope for scatter serialization for shard %d",
405431
source_shard,
406-
sample_rows,
407-
exc_info=True,
408432
)
409433

410-
# Prepend the already-consumed first result back into the stream
411-
first_with_materialized_chunk = dataclasses.replace(first_result, chunk=first_items)
412-
full_gen = itertools.chain([first_with_materialized_chunk], stage_gen)
413-
414-
if use_parquet:
415434
parquet_path = f"{stage_dir}/shard-{shard_idx:04d}.parquet"
416-
return _write_parquet_scatter(full_gen, source_shard, parquet_path)
435+
return _write_parquet_scatter(full_gen, source_shard, parquet_path, pickled=use_pickle_envelope)
417436

418437
def chunk_path_fn(idx: int) -> str:
419438
return f"{stage_dir}/shard-{shard_idx:04d}/chunk-{idx:04d}.pkl"

lib/zephyr/tests/test_groupby.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,16 +330,16 @@ def reducer(key, items):
330330

331331

332332
def test_group_by_non_vortex_serializable(zephyr_ctx):
333-
"""Shuffle with items that Vortex/Arrow cannot serialize falls back to pickle.
333+
"""Shuffle with items that Vortex/Arrow cannot serialize uses pickle-in-parquet.
334334
335-
Uses SimpleNamespace (not a dict) so Arrow conversion fails and the pickle
336-
fallback is exercised. SimpleNamespace is a stdlib type importable by any
337-
worker process, avoiding module-resolution issues with test-local classes.
335+
Uses frozenset (not Arrow-serializable) so the pickle envelope path is
336+
exercised. Items are serialized via cloudpickle into a binary ``__pickle__``
337+
column inside Parquet, avoiding the N*M pickle file blowup.
338338
"""
339339

340340
from zephyr.writers import infer_arrow_schema
341341

342-
# NOTE: confirm frozenset is not arrow-serializable type to trigger the fallback path
342+
# NOTE: confirm frozenset is not arrow-serializable type to trigger the pickle envelope path
343343
with pytest.raises(pa.lib.ArrowInvalid, match="Could not convert frozenset"):
344344
infer_arrow_schema([{"foo": frozenset([1, 2, 3])}])
345345

@@ -361,6 +361,27 @@ def test_group_by_non_vortex_serializable(zephyr_ctx):
361361
assert results[1] == {"key": "b", "value": frozenset([2])}
362362

363363

364+
def test_parquet_disk_chunk_pickle_roundtrip(tmp_path):
365+
"""ParquetDiskChunk with is_pickled=True round-trips non-Arrow-serializable items."""
366+
import pyarrow.parquet as pq
367+
368+
from zephyr.execution import (
369+
ParquetDiskChunk,
370+
_make_pickle_envelope,
371+
)
372+
373+
items = [frozenset([1, 2]), frozenset([3, 4, 5])]
374+
envelope = _make_pickle_envelope(items, target_shard=0, chunk_idx=0)
375+
batch = pa.RecordBatch.from_pylist(envelope)
376+
377+
path = str(tmp_path / "test.parquet")
378+
pq.write_table(pa.Table.from_batches([batch]), path)
379+
380+
chunk = ParquetDiskChunk(path=path, filter_shard=0, filter_chunk=0, count=2, is_pickled=True)
381+
result = chunk.read()
382+
assert result == items
383+
384+
364385
def test_group_by_schema_evolution(zephyr_ctx):
365386
"""Schema evolution: a field that is null in some chunks gains a type in others."""
366387
data = []

0 commit comments

Comments
 (0)