Skip to content

Commit b1c1029

Browse files
committed
[zephyr] Unify envelope abstraction, remove pickle intermediates
- Rename scatter columns to _zephyr_* prefix (no compat aliases) - Replace _make_envelope/_make_pickle_envelope with unified make_envelope_batch() and unwrap_items() helpers - Replace PickleDiskChunk with ParquetDiskChunk for intermediates - Remove pickle-based external sort; rename external_sort_merge_arrow to external_sort_merge, extract _read_parquet_run helper - Remove _merge_sorted_chunks (Python heapq path); _arrow_reduce_gen handles both flat and pickled envelopes via unwrap_items - Remove has_sort_key field (always true), ZEPHYR_META_COLUMNS (unused), _ITEM_BYTES_FALLBACK (unreachable) Net: -292 lines, one serialization format, one reduce path.
1 parent b84390d commit b1c1029

8 files changed

Lines changed: 274 additions & 581 deletions

File tree

lib/zephyr/src/zephyr/execution.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import itertools
1818
import logging
1919
import os
20-
import pickle
2120
import re
2221
from datetime import datetime, timezone
2322
import threading
@@ -57,45 +56,65 @@
5756
logger = logging.getLogger(__name__)
5857

5958

59+
_PARQUET_CHUNK_VALUE_COL = "_zephyr_value"
60+
61+
6062
@dataclass(frozen=True)
61-
class PickleDiskChunk:
62-
"""Reference to a pickle chunk stored on disk.
63+
class ParquetDiskChunk:
64+
"""Reference to a Parquet chunk stored on disk.
6365
6466
Each write goes to a UUID-unique path to avoid collisions when multiple
6567
workers race on the same shard. No coordinator-side rename is needed;
6668
the winning result's paths are used directly and the entire execution
6769
directory is cleaned up after the pipeline completes.
70+
71+
Items that are dicts are stored as Arrow columns directly. Non-dict items
72+
(scalars, frozensets, etc.) are wrapped in a ``_zephyr_value`` column via
73+
cloudpickle so that arbitrary Python objects can round-trip through Parquet.
6874
"""
6975

7076
path: str
7177
count: int
78+
wrapped: bool = False
7279

7380
def __iter__(self) -> Iterator:
7481
return iter(self.read())
7582

7683
@classmethod
77-
def write(cls, path: str, data: list) -> PickleDiskChunk:
78-
"""Write *data* to a UUID-unique path derived from *path*.
84+
def write(cls, path: str, data: list) -> ParquetDiskChunk:
85+
"""Write *data* as a Parquet file at a UUID-unique path derived from *path*."""
86+
import pyarrow.parquet as pq
7987

80-
The UUID suffix avoids collisions when multiple workers race on
81-
the same shard. The resulting path is used directly for reads —
82-
no rename step is required.
83-
"""
8488
from zephyr.writers import unique_temp_path
8589

8690
ensure_parent_dir(path)
8791
data = list(data)
8892
count = len(data)
89-
9093
unique_path = unique_temp_path(path)
91-
with open_url(unique_path, "wb") as f:
92-
pickle.dump(data, f)
93-
return cls(path=unique_path, count=count)
94+
95+
wrapped = False
96+
if not data or not isinstance(data[0], dict):
97+
wrapped = True
98+
else:
99+
try:
100+
table = pa.Table.from_pylist(data)
101+
except (pa.ArrowInvalid, pa.ArrowTypeError, pa.ArrowNotImplementedError):
102+
wrapped = True
103+
104+
if wrapped:
105+
table = pa.table({_PARQUET_CHUNK_VALUE_COL: [cloudpickle.dumps(item) for item in data]})
106+
pq.write_table(table, unique_path, compression="zstd")
107+
return cls(path=unique_path, count=count, wrapped=wrapped)
94108

95109
def read(self) -> list:
96-
"""Load chunk data from disk."""
97-
with open_url(self.path, "rb") as f:
98-
return pickle.load(f)
110+
import pickle
111+
112+
import pyarrow.parquet as pq
113+
114+
table = pq.read_table(self.path)
115+
if _PARQUET_CHUNK_VALUE_COL in table.column_names:
116+
return [pickle.loads(b) for b in table.column(_PARQUET_CHUNK_VALUE_COL).to_pylist()]
117+
return table.to_pylist()
99118

100119

101120
# ---------------------------------------------------------------------------
@@ -108,7 +127,7 @@ def read(self) -> list:
108127
ScatterParquetIterator, # noqa: F401 — re-exported for external callers
109128
ScatterShard, # noqa: F401 — re-exported for plan.py and external callers
110129
_build_scatter_shard_from_manifest, # noqa: F401 — re-exported for plan.py
111-
_make_envelope,
130+
make_envelope_batch,
112131
_write_parquet_scatter,
113132
_write_scatter_manifest,
114133
_SCATTER_MANIFEST_NAME,
@@ -124,7 +143,7 @@ class TaskResult:
124143
"""Result of a single worker task.
125144
126145
Always contains a ListShard. For non-scatter stages, refs are
127-
PickleDiskChunks. For scatter stages, refs contain file paths
146+
ParquetDiskChunks. For scatter stages, refs contain file paths
128147
(the actual metadata lives in ``.scatter_meta`` sidecar files
129148
read lazily by reducers).
130149
"""
@@ -160,16 +179,12 @@ def _cleanup_execution(prefix: str, execution_id: str) -> None:
160179
logger.info(f"Cleaned up execution directory {exec_dir} in {elapsed:.1f}s")
161180

162181

163-
def _write_pickle_chunks(
182+
def _write_parquet_chunks(
164183
items: Iterator,
165184
source_shard: int,
166185
chunk_path_fn: Callable[[int], str],
167186
) -> ListShard:
168-
"""Batch a plain item stream into pickle chunk files.
169-
170-
Returns a ListShard containing PickleDiskChunk references.
171-
"""
172-
# TODO: make chunk_size configurable per writer
187+
"""Batch a plain item stream into Parquet chunk files."""
173188
chunk_size = 100_000
174189
chunks: list[Iterable] = []
175190
batch: list = []
@@ -178,20 +193,20 @@ def _write_pickle_chunks(
178193
for item in items:
179194
batch.append(item)
180195
if chunk_size > 0 and len(batch) >= chunk_size:
181-
chunk_ref = PickleDiskChunk.write(chunk_path_fn(pidx), batch)
196+
chunk_ref = ParquetDiskChunk.write(chunk_path_fn(pidx), batch)
182197
chunks.append(chunk_ref)
183198
pidx += 1
184199
batch = []
185200
if pidx % 10 == 0:
186201
logger.info(
187-
"[shard %d] Wrote %d pickle chunks so far (latest: %d items)",
202+
"[shard %d] Wrote %d parquet chunks so far (latest: %d items)",
188203
source_shard,
189204
pidx,
190205
chunk_ref.count,
191206
)
192207

193208
if batch:
194-
chunks.append(PickleDiskChunk.write(chunk_path_fn(pidx), batch))
209+
chunks.append(ParquetDiskChunk.write(chunk_path_fn(pidx), batch))
195210

196211
return ListShard(refs=chunks)
197212

@@ -210,7 +225,7 @@ def _write_stage_output(
210225
wrapping and ``.scatter_meta`` sidecars. Returns TaskResult with compact
211226
scatter metadata.
212227
213-
For non-scatter stages, batches items into pickle chunk files. Returns
228+
For non-scatter stages, batches items into Parquet chunk files. Returns
214229
TaskResult with a ListShard.
215230
"""
216231
if scatter_op is not None:
@@ -224,8 +239,7 @@ def _write_stage_output(
224239
use_pickle_envelope = False
225240
try:
226241
test_key = scatter_op.key_fn(first_item)
227-
test_envelope = _make_envelope([first_item], 0, 0, key_values=[test_key])
228-
pa.RecordBatch.from_pylist(test_envelope)
242+
make_envelope_batch([first_item], 0, 0, key_values=[test_key], sort_values=None, pickled=False)
229243
logger.info("Using Parquet for scatter serialization for shard %d", source_shard)
230244
except Exception:
231245
use_pickle_envelope = True
@@ -249,9 +263,9 @@ def _write_stage_output(
249263
return TaskResult(shard=shard)
250264

251265
def chunk_path_fn(idx: int) -> str:
252-
return f"{stage_dir}/shard-{shard_idx:04d}/chunk-{idx:04d}.pkl"
266+
return f"{stage_dir}/shard-{shard_idx:04d}/chunk-{idx:04d}.parquet"
253267

254-
return TaskResult(shard=_write_pickle_chunks(stage_gen, source_shard, chunk_path_fn))
268+
return TaskResult(shard=_write_parquet_chunks(stage_gen, source_shard, chunk_path_fn))
255269

256270

257271
class WorkerState(enum.Enum):

0 commit comments

Comments
 (0)