1717import itertools
1818import logging
1919import os
20- import pickle
2120import re
2221from datetime import datetime , timezone
2322import threading
5756logger = logging .getLogger (__name__ )
5857
5958
59+ _PARQUET_CHUNK_VALUE_COL = "_zephyr_value"
60+
61+
6062@dataclass (frozen = True )
61- class PickleDiskChunk :
62- """Reference to a pickle chunk stored on disk.
63+ class ParquetDiskChunk :
64+ """Reference to a Parquet chunk stored on disk.
6365
6466 Each write goes to a UUID-unique path to avoid collisions when multiple
6567 workers race on the same shard. No coordinator-side rename is needed;
6668 the winning result's paths are used directly and the entire execution
6769 directory is cleaned up after the pipeline completes.
70+
71+ Items that are dicts are stored as Arrow columns directly. Non-dict items
72+ (scalars, frozensets, etc.) are wrapped in a ``_zephyr_value`` column via
73+ cloudpickle so that arbitrary Python objects can round-trip through Parquet.
6874 """
6975
7076 path : str
7177 count : int
78+ wrapped : bool = False
7279
7380 def __iter__ (self ) -> Iterator :
7481 return iter (self .read ())
7582
7683 @classmethod
77- def write (cls , path : str , data : list ) -> PickleDiskChunk :
78- """Write *data* to a UUID-unique path derived from *path*.
84+ def write (cls , path : str , data : list ) -> ParquetDiskChunk :
85+ """Write *data* as a Parquet file at a UUID-unique path derived from *path*."""
86+ import pyarrow .parquet as pq
7987
80- The UUID suffix avoids collisions when multiple workers race on
81- the same shard. The resulting path is used directly for reads —
82- no rename step is required.
83- """
8488 from zephyr .writers import unique_temp_path
8589
8690 ensure_parent_dir (path )
8791 data = list (data )
8892 count = len (data )
89-
9093 unique_path = unique_temp_path (path )
91- with open_url (unique_path , "wb" ) as f :
92- pickle .dump (data , f )
93- return cls (path = unique_path , count = count )
94+
95+ wrapped = False
96+ if not data or not isinstance (data [0 ], dict ):
97+ wrapped = True
98+ else :
99+ try :
100+ table = pa .Table .from_pylist (data )
101+ except (pa .ArrowInvalid , pa .ArrowTypeError , pa .ArrowNotImplementedError ):
102+ wrapped = True
103+
104+ if wrapped :
105+ table = pa .table ({_PARQUET_CHUNK_VALUE_COL : [cloudpickle .dumps (item ) for item in data ]})
106+ pq .write_table (table , unique_path , compression = "zstd" )
107+ return cls (path = unique_path , count = count , wrapped = wrapped )
94108
95109 def read (self ) -> list :
96- """Load chunk data from disk."""
97- with open_url (self .path , "rb" ) as f :
98- return pickle .load (f )
110+ import pickle
111+
112+ import pyarrow .parquet as pq
113+
114+ table = pq .read_table (self .path )
115+ if _PARQUET_CHUNK_VALUE_COL in table .column_names :
116+ return [pickle .loads (b ) for b in table .column (_PARQUET_CHUNK_VALUE_COL ).to_pylist ()]
117+ return table .to_pylist ()
99118
100119
101120# ---------------------------------------------------------------------------
@@ -108,7 +127,7 @@ def read(self) -> list:
108127 ScatterParquetIterator , # noqa: F401 — re-exported for external callers
109128 ScatterShard , # noqa: F401 — re-exported for plan.py and external callers
110129 _build_scatter_shard_from_manifest , # noqa: F401 — re-exported for plan.py
111- _make_envelope ,
130+ make_envelope_batch ,
112131 _write_parquet_scatter ,
113132 _write_scatter_manifest ,
114133 _SCATTER_MANIFEST_NAME ,
@@ -124,7 +143,7 @@ class TaskResult:
124143 """Result of a single worker task.
125144
126145 Always contains a ListShard. For non-scatter stages, refs are
127- PickleDiskChunks . For scatter stages, refs contain file paths
146+ ParquetDiskChunks . For scatter stages, refs contain file paths
128147 (the actual metadata lives in ``.scatter_meta`` sidecar files
129148 read lazily by reducers).
130149 """
@@ -160,16 +179,12 @@ def _cleanup_execution(prefix: str, execution_id: str) -> None:
160179 logger .info (f"Cleaned up execution directory { exec_dir } in { elapsed :.1f} s" )
161180
162181
163- def _write_pickle_chunks (
182+ def _write_parquet_chunks (
164183 items : Iterator ,
165184 source_shard : int ,
166185 chunk_path_fn : Callable [[int ], str ],
167186) -> ListShard :
168- """Batch a plain item stream into pickle chunk files.
169-
170- Returns a ListShard containing PickleDiskChunk references.
171- """
172- # TODO: make chunk_size configurable per writer
187+ """Batch a plain item stream into Parquet chunk files."""
173188 chunk_size = 100_000
174189 chunks : list [Iterable ] = []
175190 batch : list = []
@@ -178,20 +193,20 @@ def _write_pickle_chunks(
178193 for item in items :
179194 batch .append (item )
180195 if chunk_size > 0 and len (batch ) >= chunk_size :
181- chunk_ref = PickleDiskChunk .write (chunk_path_fn (pidx ), batch )
196+ chunk_ref = ParquetDiskChunk .write (chunk_path_fn (pidx ), batch )
182197 chunks .append (chunk_ref )
183198 pidx += 1
184199 batch = []
185200 if pidx % 10 == 0 :
186201 logger .info (
187- "[shard %d] Wrote %d pickle chunks so far (latest: %d items)" ,
202+ "[shard %d] Wrote %d parquet chunks so far (latest: %d items)" ,
188203 source_shard ,
189204 pidx ,
190205 chunk_ref .count ,
191206 )
192207
193208 if batch :
194- chunks .append (PickleDiskChunk .write (chunk_path_fn (pidx ), batch ))
209+ chunks .append (ParquetDiskChunk .write (chunk_path_fn (pidx ), batch ))
195210
196211 return ListShard (refs = chunks )
197212
@@ -210,7 +225,7 @@ def _write_stage_output(
210225 wrapping and ``.scatter_meta`` sidecars. Returns TaskResult with compact
211226 scatter metadata.
212227
213- For non-scatter stages, batches items into pickle chunk files. Returns
228+ For non-scatter stages, batches items into Parquet chunk files. Returns
214229 TaskResult with a ListShard.
215230 """
216231 if scatter_op is not None :
@@ -224,8 +239,7 @@ def _write_stage_output(
224239 use_pickle_envelope = False
225240 try :
226241 test_key = scatter_op .key_fn (first_item )
227- test_envelope = _make_envelope ([first_item ], 0 , 0 , key_values = [test_key ])
228- pa .RecordBatch .from_pylist (test_envelope )
242+ make_envelope_batch ([first_item ], 0 , 0 , key_values = [test_key ], sort_values = None , pickled = False )
229243 logger .info ("Using Parquet for scatter serialization for shard %d" , source_shard )
230244 except Exception :
231245 use_pickle_envelope = True
@@ -249,9 +263,9 @@ def _write_stage_output(
249263 return TaskResult (shard = shard )
250264
251265 def chunk_path_fn (idx : int ) -> str :
252- return f"{ stage_dir } /shard-{ shard_idx :04d} /chunk-{ idx :04d} .pkl "
266+ return f"{ stage_dir } /shard-{ shard_idx :04d} /chunk-{ idx :04d} .parquet "
253267
254- return TaskResult (shard = _write_pickle_chunks (stage_gen , source_shard , chunk_path_fn ))
268+ return TaskResult (shard = _write_parquet_chunks (stage_gen , source_shard , chunk_path_fn ))
255269
256270
257271class WorkerState (enum .Enum ):
0 commit comments