Skip to content

Commit 8f2b98a

Browse files
authored
[zephyr] External sort spill: Parquet instead of pickle+zstd (#4695)
## Summary - Replaces `.pkl.zst` spill files in `external_sort_merge` with Parquet files written via a new `SpillWriter` (byte-budgeted `pq.ParquetWriter` with a background I/O thread). - Items are pickled into a single `_zephyr_payload` binary column. Python `heapq.merge` semantics on both passes are unchanged, so behavior is identical — this is a format swap only. - Pass 2 reads spills back with `pq.ParquetFile.iter_batches` and unpickles one row group at a time to feed the heap merge. - Pass-2 read-batch-size estimation now reads row-group metadata directly from the parquet file instead of probing a pickled sample. This is the minimal slice cherry-picked from #4178 that removes raw pickle files from zephyr's shuffle data plane. The scatter envelope and reduce merge are untouched; follow-ups will promote the sort key to a first-class column (Tier 2) and move reduce to columnar Arrow merge (Tier 3). `SpillWriter` is added as `lib/zephyr/src/zephyr/spill_writer.py` verbatim from #4178. `external_sort.py` is its only caller in this PR.
1 parent b589918 commit 8f2b98a

File tree

2 files changed

+256
-54
lines changed

2 files changed

+256
-54
lines changed

lib/zephyr/src/zephyr/external_sort.py

Lines changed: 45 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,66 +8,66 @@
88
exhausting worker memory.
99
1010
Pass 1: batch the k iterators into groups of EXTERNAL_SORT_FAN_IN, merge each
11-
group with heapq.merge, and spill items in batches of ``_WRITE_BATCH_SIZE`` to
12-
a zstd-compressed pickle run file under
13-
``{external_sort_dir}/run-{i:04d}.pkl.zst``. Items are streamed to disk
14-
rather than accumulated in a list, so peak memory per batch is bounded by the
15-
number of open iterators rather than their total item count.
11+
group with heapq.merge, and spill items to a run file under
12+
``{external_sort_dir}/run-{i:04d}.spill`` via :class:`SpillWriter`.
1613
1714
Pass 2: heapq.merge over the (much smaller) set of run file iterators. Each
18-
iterator reads one batch at a time and yields items one-by-one; the read batch
19-
size is computed from the cgroup memory limit so that all concurrent batches
20-
together stay within ``_READ_MEMORY_FRACTION`` of available memory.
15+
iterator streams chunks from its spill file via :class:`SpillReader`; the read
16+
batch size is computed from the cgroup memory limit so that all concurrent
17+
batches together stay within ``_READ_MEMORY_FRACTION`` of available memory.
2118
2219
Run files are deleted after the final merge completes.
2320
"""
2421

2522
import heapq
2623
import logging
27-
import pickle
2824
from collections.abc import Callable, Iterator
2925
from itertools import islice
3026

31-
import fsspec
32-
import zstandard as zstd
3327
from iris.env_resources import TaskResources
3428
from rigging.filesystem import url_to_fs
3529

30+
from zephyr.spill import SpillReader, SpillWriter
31+
3632
logger = logging.getLogger(__name__)
3733

3834
# Maximum simultaneous chunk iterators per pass-1 batch.
3935
EXTERNAL_SORT_FAN_IN = 500
4036

41-
# Items per pickle.dump in pass-1. Larger batches compress better (zstd
42-
# dictionary spans the whole batch) and reduce per-call overhead.
37+
# Items buffered before handing to the SpillWriter. Larger values amortize
38+
# per-chunk overhead in the spill format.
4339
_WRITE_BATCH_SIZE = 10_000
4440

41+
# Target bytes per spill chunk in pass-1 runs.
42+
_ROW_GROUP_BYTES = 8 * 1024 * 1024
43+
4544
# Fraction of container memory budgeted for pass-2 read buffers.
4645
_READ_MEMORY_FRACTION = 0.25
4746

4847

4948
def _safe_read_batch_size(n_runs: int, sample_run_path: str) -> int:
5049
"""Compute a pass-2 read batch size that fits within the memory budget.
5150
52-
Probes the first batch from ``sample_run_path`` to estimate in-memory
53-
bytes per item, then divides the memory budget by ``n_runs * item_bytes``
54-
so that all concurrent run-file buffers together stay within
51+
Uses the spill's per-item byte estimate to divide the memory budget across
52+
concurrent run-file buffers so they together stay within
5553
``_READ_MEMORY_FRACTION`` of available container memory.
5654
"""
57-
dctx = zstd.ZstdDecompressor()
5855
try:
59-
with fsspec.open(sample_run_path, "rb") as raw_f:
60-
with dctx.stream_reader(raw_f) as f:
61-
sample_batch: list = pickle.load(f)
56+
item_bytes_raw = SpillReader(sample_run_path).approx_item_bytes
6257
except Exception:
58+
logger.warning(
59+
"Failed to read spill metadata from %s; falling back to default batch size",
60+
sample_run_path,
61+
exc_info=True,
62+
)
6363
return _WRITE_BATCH_SIZE
6464

65-
sample = sample_batch[:100]
66-
if not sample:
65+
if item_bytes_raw <= 0:
6766
return _WRITE_BATCH_SIZE
68-
# pickle size x 3 approximates Python object overhead (dicts are ~3x larger
69-
# in memory than their serialised form).
70-
item_bytes = max(64, len(pickle.dumps(sample)) // len(sample) * 3)
67+
68+
# Payload size x 3 approximates Python object overhead (dicts are ~3x
69+
# larger in memory than their pickled form).
70+
item_bytes = max(64, item_bytes_raw * 3)
7171

7272
available = TaskResources.from_environment().memory_bytes
7373
budget = int(available * _READ_MEMORY_FRACTION)
@@ -101,28 +101,31 @@ def external_sort_merge(
101101
Yields:
102102
Items in merged sort order.
103103
"""
104-
cctx = zstd.ZstdCompressor(level=3)
105104
run_paths: list[str] = []
106105
batch_idx = 0
107106

107+
# SpillWriter does not auto-create parent directories, so ensure the spill
108+
# dir exists up front.
109+
spill_fs, spill_dir = url_to_fs(external_sort_dir)
110+
spill_fs.makedirs(spill_dir, exist_ok=True)
111+
108112
while True:
109113
batch = list(islice(chunk_iterators_gen, EXTERNAL_SORT_FAN_IN))
110114
if not batch:
111115
break
112-
run_path = f"{external_sort_dir}/run-{batch_idx:04d}.pkl.zst"
116+
run_path = f"{external_sort_dir}/run-{batch_idx:04d}.spill"
113117
item_count = 0
114118
pending: list = []
115-
with fsspec.open(run_path, "wb") as raw_f:
116-
with cctx.stream_writer(raw_f, closefd=False) as f:
117-
for item in heapq.merge(*batch, key=merge_key):
118-
pending.append(item)
119-
if len(pending) >= _WRITE_BATCH_SIZE:
120-
pickle.dump(pending, f, protocol=pickle.HIGHEST_PROTOCOL)
121-
item_count += len(pending)
122-
pending = []
123-
if pending:
124-
pickle.dump(pending, f, protocol=pickle.HIGHEST_PROTOCOL)
119+
with SpillWriter(run_path, row_group_bytes=_ROW_GROUP_BYTES) as writer:
120+
for item in heapq.merge(*batch, key=merge_key):
121+
pending.append(item)
122+
if len(pending) >= _WRITE_BATCH_SIZE:
123+
writer.write(pending)
125124
item_count += len(pending)
125+
pending = []
126+
if pending:
127+
writer.write(pending)
128+
item_count += len(pending)
126129
run_paths.append(run_path)
127130
logger.info(
128131
"External sort: wrote run %d (%d items) to %s",
@@ -135,29 +138,17 @@ def external_sort_merge(
135138
read_batch_size = _safe_read_batch_size(len(run_paths), run_paths[0]) if run_paths else _WRITE_BATCH_SIZE
136139

137140
def _read_run(path: str) -> Iterator:
138-
with fsspec.open(path, "rb") as raw_f:
139-
with zstd.ZstdDecompressor().stream_reader(raw_f) as f:
140-
while True:
141-
try:
142-
items: list = pickle.load(f)
143-
# Yield in read_batch_size chunks and delete consumed
144-
# items in-place so memory is released progressively
145-
# even while the generator is suspended in heapq.merge.
146-
while items:
147-
chunk = items[:read_batch_size]
148-
del items[:read_batch_size]
149-
yield from chunk
150-
except EOFError:
151-
break
141+
reader = SpillReader(path, batch_size=read_batch_size)
142+
for chunk in reader.iter_chunks():
143+
yield from chunk
152144

153145
run_iters = [_read_run(p) for p in run_paths]
154146
try:
155147
yield from heapq.merge(*run_iters, key=merge_key)
156148
finally:
157-
fs, _ = fsspec.core.url_to_fs(external_sort_dir)
158149
for path in run_paths:
159150
try:
160-
_, fs_path = url_to_fs(path)
161-
fs.rm(fs_path)
151+
rm_fs, rm_path = url_to_fs(path)
152+
rm_fs.rm(rm_path)
162153
except Exception:
163154
pass

lib/zephyr/src/zephyr/spill.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Opaque chunked row format for zephyr spill files.
5+
6+
SpillWriter and SpillReader hide the on-disk representation from callers.
7+
Items are pickled into an opaque binary payload and written as chunks of a
8+
chunked row format. Callers do not see the schema, serialization, or storage
9+
format — they append items and read back items (or chunks of items) in the
10+
same order.
11+
12+
Currently backed by Parquet with a single binary payload column, a background
13+
I/O thread, and byte-budgeted row groups. The file format is an implementation
14+
detail; do not rely on it outside this module.
15+
"""
16+
17+
import logging
18+
import pickle
19+
from collections.abc import Iterable, Iterator
20+
from typing import Any
21+
22+
import fsspec
23+
import pyarrow as pa
24+
import pyarrow.parquet as pq
25+
26+
from zephyr.writers import ThreadedBatchWriter
27+
28+
logger = logging.getLogger(__name__)
29+
30+
# Single binary payload column. Not part of the public API.
31+
_PAYLOAD_COL = "_zephyr_payload"
32+
_SCHEMA = pa.schema([pa.field(_PAYLOAD_COL, pa.binary())])
33+
34+
35+
class _TableAccumulator:
36+
"""Accumulates Arrow tables and yields merged results when a byte threshold is reached.
37+
38+
Byte-budgeted batching produces uniformly-sized output regardless of row
39+
width, which matters for write performance and memory predictability.
40+
"""
41+
42+
def __init__(self, byte_threshold: int) -> None:
43+
self._byte_threshold = byte_threshold
44+
self._tables: list[pa.Table] = []
45+
self._nbytes: int = 0
46+
47+
def add(self, table: pa.Table) -> pa.Table | None:
48+
self._tables.append(table)
49+
self._nbytes += table.nbytes
50+
if self._nbytes >= self._byte_threshold:
51+
return self._take()
52+
return None
53+
54+
def flush(self) -> pa.Table | None:
55+
if not self._tables:
56+
return None
57+
return self._take()
58+
59+
def _take(self) -> pa.Table:
60+
result = pa.concat_tables(self._tables, promote_options="default")
61+
self._tables.clear()
62+
self._nbytes = 0
63+
return result
64+
65+
66+
def _items_to_table(items: Iterable[Any]) -> pa.Table:
67+
payloads = [pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL) for item in items]
68+
return pa.table({_PAYLOAD_COL: pa.array(payloads, type=pa.binary())})
69+
70+
71+
class SpillWriter:
72+
"""Writes items to an opaque chunked row-format spill file.
73+
74+
Use ``write`` to stream items; the writer accumulates a byte budget and
75+
emits chunks when the budget is exceeded. Use ``write_chunk`` to commit
76+
a batch of items as its own chunk immediately (no accumulation) — useful
77+
when the caller wants each logical batch to round-trip as one chunk.
78+
79+
Writes are offloaded to a :class:`ThreadedBatchWriter` so one write can be
80+
in-flight while the caller produces the next batch. Backpressure, error
81+
propagation, and clean teardown on the exception path are delegated to it.
82+
"""
83+
84+
def __init__(
85+
self,
86+
path: str,
87+
*,
88+
row_group_bytes: int = 8 * 1024 * 1024,
89+
compression: str = "zstd",
90+
compression_level: int = 1,
91+
) -> None:
92+
self._writer = pq.ParquetWriter(path, _SCHEMA, compression=compression, compression_level=compression_level)
93+
self._accumulator = _TableAccumulator(row_group_bytes)
94+
95+
def _drain(tables: Iterable[pa.Table]) -> None:
96+
for table in tables:
97+
self._writer.write_table(table)
98+
99+
# maxsize=1: at most one chunk in-flight so memory stays bounded while
100+
# the producer keeps working on the next batch.
101+
self._threaded = ThreadedBatchWriter(_drain, maxsize=1)
102+
self._closed = False
103+
104+
def write(self, items: Iterable[Any]) -> None:
105+
"""Append items. Emits a chunk when the accumulated byte budget is exceeded."""
106+
table = _items_to_table(items)
107+
if len(table) == 0:
108+
return
109+
merged = self._accumulator.add(table)
110+
if merged is not None:
111+
self._threaded.submit(merged)
112+
113+
def write_chunk(self, items: Iterable[Any]) -> None:
114+
"""Commit items as their own chunk immediately (no accumulation)."""
115+
table = _items_to_table(items)
116+
if len(table) == 0:
117+
return
118+
self._threaded.submit(table)
119+
120+
def close(self) -> None:
121+
"""Flush remaining buffered items and wait for the background writer."""
122+
if self._closed:
123+
return
124+
self._closed = True
125+
try:
126+
remaining = self._accumulator.flush()
127+
if remaining is not None:
128+
self._threaded.submit(remaining)
129+
self._threaded.close()
130+
finally:
131+
self._writer.close()
132+
133+
def __enter__(self) -> "SpillWriter":
134+
return self
135+
136+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
137+
if self._closed:
138+
return
139+
self._closed = True
140+
try:
141+
if exc_type is not None:
142+
# Error path: skip final flush (partial file will never be read)
143+
# and let ThreadedBatchWriter.__exit__ tear down the thread
144+
# without blocking the caller.
145+
self._threaded.__exit__(exc_type, exc_val, exc_tb)
146+
else:
147+
remaining = self._accumulator.flush()
148+
if remaining is not None:
149+
self._threaded.submit(remaining)
150+
self._threaded.close()
151+
finally:
152+
self._writer.close()
153+
154+
155+
class SpillReader:
156+
"""Reads items from an opaque chunked row-format spill file.
157+
158+
Iteration yields items one at a time in write order. ``iter_chunks`` yields
159+
lists of items grouped by the on-disk chunks; callers that want a specific
160+
batch size can pass ``batch_size`` to re-batch.
161+
"""
162+
163+
def __init__(self, path: str, *, batch_size: int | None = None) -> None:
164+
self._path = path
165+
self._batch_size = batch_size
166+
167+
@property
168+
def path(self) -> str:
169+
return self._path
170+
171+
@property
172+
def num_rows(self) -> int:
173+
with fsspec.open(self._path, "rb") as f:
174+
return pq.ParquetFile(f).metadata.num_rows
175+
176+
@property
177+
def approx_item_bytes(self) -> int:
178+
"""Uncompressed payload bytes per item, read from file metadata.
179+
180+
Returns 0 for an empty spill. Useful as a memory-budgeting hint without
181+
exposing the underlying format.
182+
"""
183+
with fsspec.open(self._path, "rb") as f:
184+
md = pq.ParquetFile(f).metadata
185+
if md.num_rows <= 0:
186+
return 0
187+
total = sum(md.row_group(i).column(0).total_uncompressed_size for i in range(md.num_row_groups))
188+
return total // md.num_rows
189+
190+
def iter_chunks(self) -> Iterator[list[Any]]:
191+
"""Yield chunks of items (lists).
192+
193+
Chunk boundaries follow the on-disk layout unless ``batch_size`` was
194+
set on the reader, in which case items are re-batched to approximately
195+
that size.
196+
"""
197+
with fsspec.open(self._path, "rb") as f:
198+
pf = pq.ParquetFile(f)
199+
if self._batch_size is None:
200+
for i in range(pf.num_row_groups):
201+
table = pf.read_row_group(i, columns=[_PAYLOAD_COL])
202+
payloads = table.column(_PAYLOAD_COL).to_pylist()
203+
yield [pickle.loads(p) for p in payloads]
204+
else:
205+
for record_batch in pf.iter_batches(batch_size=self._batch_size, columns=[_PAYLOAD_COL]):
206+
payloads = record_batch.column(_PAYLOAD_COL).to_pylist()
207+
yield [pickle.loads(p) for p in payloads]
208+
209+
def __iter__(self) -> Iterator[Any]:
210+
for chunk in self.iter_chunks():
211+
yield from chunk

0 commit comments

Comments
 (0)