@@ -63,6 +63,7 @@ def __iter__(self) -> Iterator: ...
6363_ZEPHYR_SHUFFLE_SHARD_IDX_COL = "shard_idx"
6464_ZEPHYR_SHUFFLE_CHUNK_IDX_COL = "chunk_idx"
6565_ZEPHYR_SHUFFLE_ITEM_COL = "item"
66+ _ZEPHYR_SHUFFLE_PICKLED_COL = "pickled"
6667
6768
6869@dataclass (frozen = True )
@@ -114,34 +115,38 @@ class ParquetDiskChunk:
114115 for different (shard_idx, chunk_idx) pairs. Each chunk is pre-sorted
115116 by key, preserving the invariant needed for k-way merge in Reduce.
116117
117- Items are stored wrapped in an envelope struct with routing metadata: :
118+ Items are stored in one of two envelope formats :
118119
119- {"shard_idx": int, "chunk_idx": int, "item": <user_data>}
120+ * **Native** (``is_pickled=False``): ``{"shard_idx", "chunk_idx", "item": <data>}``
121+ * **Pickle** (``is_pickled=True``): ``{"shard_idx", "chunk_idx", "pickled": <bytes>}``
120122
121- The ``read`` method filters by shard/chunk and unwraps the ``item`` field.
122- Predicate pushdown in Parquet skips irrelevant row groups, so each
123- reducer reads only its own data efficiently.
123+ The pickle envelope is used when items are not Arrow-serializable.
124124 """
125125
126126 path : str
127127 filter_shard : int
128128 filter_chunk : int
129129 count : int
130+ is_pickled : bool = False
130131
131132 def __iter__ (self ) -> Iterator :
132133 return iter (self .read ())
133134
134135 def read (self ) -> list :
135136 """Load filtered chunk data from a Parquet file, unwrapping envelope."""
137+ col = _ZEPHYR_SHUFFLE_PICKLED_COL if self .is_pickled else _ZEPHYR_SHUFFLE_ITEM_COL
136138 table = pq .read_table (
137139 self .path ,
138- columns = [_ZEPHYR_SHUFFLE_ITEM_COL ],
140+ columns = [col ],
139141 filters = (
140142 (pc .field (_ZEPHYR_SHUFFLE_SHARD_IDX_COL ) == self .filter_shard )
141143 & (pc .field (_ZEPHYR_SHUFFLE_CHUNK_IDX_COL ) == self .filter_chunk )
142144 ),
143145 )
144- return table .column (_ZEPHYR_SHUFFLE_ITEM_COL ).to_pylist ()
146+ items = table .column (col ).to_pylist ()
147+ if self .is_pickled :
148+ return [pickle .loads (b ) for b in items ]
149+ return items
145150
146151
147152@dataclass
@@ -225,6 +230,18 @@ def _make_envelope(items: list, target_shard: int, chunk_idx: int) -> list[dict]
225230 ]
226231
227232
233+ def _make_pickle_envelope (items : list , target_shard : int , chunk_idx : int ) -> list [dict ]:
234+ """Wrap items as pickle-serialized bytes for Arrow-incompatible types."""
235+ return [
236+ {
237+ _ZEPHYR_SHUFFLE_SHARD_IDX_COL : target_shard ,
238+ _ZEPHYR_SHUFFLE_CHUNK_IDX_COL : chunk_idx ,
239+ _ZEPHYR_SHUFFLE_PICKLED_COL : cloudpickle .dumps (item ),
240+ }
241+ for item in items
242+ ]
243+
244+
228245def _segment_path (base_path : str , seg_idx : int ) -> str :
229246 """Return the file path for a given segment index.
230247
@@ -245,13 +262,17 @@ def _write_parquet_scatter(
245262 stage_gen : Iterator [StageResultChunk ],
246263 source_shard : int ,
247264 parquet_path : str ,
265+ pickled : bool = False ,
248266) -> list [ResultChunk ]:
249267 """Stream scatter chunks into Parquet files as row groups.
250268
251269 Writes batches to a Parquet file until a schema mismatch is detected
252270 (e.g. a field evolves from null to a concrete type). On mismatch the
253271 current file is closed, the schema is unified via ``pa.unify_schemas``,
254272 and a new segment file is opened with the evolved schema.
273+
274+ When ``pickled=True``, items are serialized via pickle into a binary
275+ ``pickled`` column instead of being stored natively in the ``item`` column.
255276 """
256277 chunk_results : list [_ChunkMetadata ] = []
257278 per_shard_chunk_cnt : dict [int , int ] = defaultdict (int )
@@ -285,7 +306,8 @@ def _flush_pending():
285306 target_shard = result .target_shard
286307 shard_chunk_idx = per_shard_chunk_cnt [target_shard ]
287308 per_shard_chunk_cnt [target_shard ] += 1
288- envelope = _make_envelope (chunk_items , target_shard , shard_chunk_idx )
309+ envelope_fn = _make_pickle_envelope if pickled else _make_envelope
310+ envelope = envelope_fn (chunk_items , target_shard , shard_chunk_idx )
289311 chunk_arrow = pa .RecordBatch .from_pylist (envelope )
290312
291313 if schema is None :
@@ -328,7 +350,11 @@ def _flush_pending():
328350 source_shard = source_shard ,
329351 target_shard = rec .target_shard ,
330352 data = ParquetDiskChunk (
331- path = rec .path , filter_shard = rec .target_shard , filter_chunk = rec .chunk_idx , count = rec .cnt
353+ path = rec .path ,
354+ filter_shard = rec .target_shard ,
355+ filter_chunk = rec .chunk_idx ,
356+ count = rec .cnt ,
357+ is_pickled = pickled ,
332358 ),
333359 )
334360 for rec in chunk_results
@@ -387,33 +413,26 @@ def _write_stage_chunks(
387413
388414 first_items = list (first_result .chunk )
389415
390- # Test Arrow serializability on the first chunk to decide parquet vs pickle
391- use_parquet = False
416+ # Prepend the already-consumed first result back into the stream
417+ first_with_materialized_chunk = dataclasses .replace (first_result , chunk = first_items )
418+ full_gen = itertools .chain ([first_with_materialized_chunk ], stage_gen )
419+
392420 if is_scatter :
421+ # Test Arrow serializability on the first chunk to decide native vs pickle envelope
422+ use_pickle_envelope = False
393423 try :
394424 test_envelope = _make_envelope (first_items , 0 , 0 )
395425 pa .RecordBatch .from_pylist (test_envelope )
396- use_parquet = True
397426 logger .info ("Using Parquet for scatter serialization for shard %d" , source_shard )
398427 except Exception :
399- sample_rows = str (test_envelope [:5 ]) if len (test_envelope ) > 5 else str (test_envelope )
400- if len (sample_rows ) > 1000 :
401- sample_rows = sample_rows [:1000 ] + "...(truncated)"
402- logger .warning (
403- "Arrow scatter serialization failed for shard %d; "
404- "falling back to pickle. Performance will be degraded. Sample rows: %s" ,
428+ use_pickle_envelope = True
429+ logger .info (
430+ "Using Parquet with pickle envelope for scatter serialization for shard %d" ,
405431 source_shard ,
406- sample_rows ,
407- exc_info = True ,
408432 )
409433
410- # Prepend the already-consumed first result back into the stream
411- first_with_materialized_chunk = dataclasses .replace (first_result , chunk = first_items )
412- full_gen = itertools .chain ([first_with_materialized_chunk ], stage_gen )
413-
414- if use_parquet :
415434 parquet_path = f"{ stage_dir } /shard-{ shard_idx :04d} .parquet"
416- return _write_parquet_scatter (full_gen , source_shard , parquet_path )
435+ return _write_parquet_scatter (full_gen , source_shard , parquet_path , pickled = use_pickle_envelope )
417436
418437 def chunk_path_fn (idx : int ) -> str :
419438 return f"{ stage_dir } /shard-{ shard_idx :04d} /chunk-{ idx :04d} .pkl"
0 commit comments