Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 45 additions & 54 deletions lib/zephyr/src/zephyr/external_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,66 +8,66 @@
exhausting worker memory.

Pass 1: batch the k iterators into groups of EXTERNAL_SORT_FAN_IN, merge each
group with heapq.merge, and spill items in batches of ``_WRITE_BATCH_SIZE`` to
a zstd-compressed pickle run file under
``{external_sort_dir}/run-{i:04d}.pkl.zst``. Items are streamed to disk
rather than accumulated in a list, so peak memory per batch is bounded by the
number of open iterators rather than their total item count.
group with heapq.merge, and spill items to a run file under
``{external_sort_dir}/run-{i:04d}.spill`` via :class:`SpillWriter`.

Pass 2: heapq.merge over the (much smaller) set of run file iterators. Each
iterator reads one batch at a time and yields items one-by-one; the read batch
size is computed from the cgroup memory limit so that all concurrent batches
together stay within ``_READ_MEMORY_FRACTION`` of available memory.
iterator streams chunks from its spill file via :class:`SpillReader`; the read
batch size is computed from the cgroup memory limit so that all concurrent
batches together stay within ``_READ_MEMORY_FRACTION`` of available memory.

Run files are deleted after the final merge completes.
"""

import heapq
import logging
import pickle
from collections.abc import Callable, Iterator
from itertools import islice

import fsspec
import zstandard as zstd
from iris.env_resources import TaskResources
from rigging.filesystem import url_to_fs

from zephyr.spill import SpillReader, SpillWriter

logger = logging.getLogger(__name__)

# Maximum simultaneous chunk iterators per pass-1 batch.
EXTERNAL_SORT_FAN_IN = 500

# Items per pickle.dump in pass-1. Larger batches compress better (zstd
# dictionary spans the whole batch) and reduce per-call overhead.
# Items buffered before handing to the SpillWriter. Larger values amortize
# per-chunk overhead in the spill format.
_WRITE_BATCH_SIZE = 10_000

# Target bytes per spill chunk in pass-1 runs.
_ROW_GROUP_BYTES = 8 * 1024 * 1024

# Fraction of container memory budgeted for pass-2 read buffers.
_READ_MEMORY_FRACTION = 0.25


def _safe_read_batch_size(n_runs: int, sample_run_path: str) -> int:
"""Compute a pass-2 read batch size that fits within the memory budget.

Probes the first batch from ``sample_run_path`` to estimate in-memory
bytes per item, then divides the memory budget by ``n_runs * item_bytes``
so that all concurrent run-file buffers together stay within
Uses the spill's per-item byte estimate to divide the memory budget across
concurrent run-file buffers so they together stay within
``_READ_MEMORY_FRACTION`` of available container memory.
"""
dctx = zstd.ZstdDecompressor()
try:
with fsspec.open(sample_run_path, "rb") as raw_f:
with dctx.stream_reader(raw_f) as f:
sample_batch: list = pickle.load(f)
item_bytes_raw = SpillReader(sample_run_path).approx_item_bytes
except Exception:
logger.warning(
"Failed to read spill metadata from %s; falling back to default batch size",
sample_run_path,
exc_info=True,
)
return _WRITE_BATCH_SIZE

Comment thread
rjpower marked this conversation as resolved.
sample = sample_batch[:100]
if not sample:
if item_bytes_raw <= 0:
return _WRITE_BATCH_SIZE
# pickle size x 3 approximates Python object overhead (dicts are ~3x larger
# in memory than their serialised form).
item_bytes = max(64, len(pickle.dumps(sample)) // len(sample) * 3)

# Payload size x 3 approximates Python object overhead (dicts are ~3x
# larger in memory than their pickled form).
item_bytes = max(64, item_bytes_raw * 3)

available = TaskResources.from_environment().memory_bytes
budget = int(available * _READ_MEMORY_FRACTION)
Expand Down Expand Up @@ -101,28 +101,31 @@ def external_sort_merge(
Yields:
Items in merged sort order.
"""
cctx = zstd.ZstdCompressor(level=3)
run_paths: list[str] = []
batch_idx = 0

# SpillWriter does not auto-create parent directories, so ensure the spill
# dir exists up front.
spill_fs, spill_dir = url_to_fs(external_sort_dir)
spill_fs.makedirs(spill_dir, exist_ok=True)

while True:
batch = list(islice(chunk_iterators_gen, EXTERNAL_SORT_FAN_IN))
if not batch:
break
run_path = f"{external_sort_dir}/run-{batch_idx:04d}.pkl.zst"
run_path = f"{external_sort_dir}/run-{batch_idx:04d}.spill"
item_count = 0
pending: list = []
with fsspec.open(run_path, "wb") as raw_f:
with cctx.stream_writer(raw_f, closefd=False) as f:
for item in heapq.merge(*batch, key=merge_key):
pending.append(item)
if len(pending) >= _WRITE_BATCH_SIZE:
pickle.dump(pending, f, protocol=pickle.HIGHEST_PROTOCOL)
item_count += len(pending)
pending = []
if pending:
pickle.dump(pending, f, protocol=pickle.HIGHEST_PROTOCOL)
with SpillWriter(run_path, row_group_bytes=_ROW_GROUP_BYTES) as writer:
for item in heapq.merge(*batch, key=merge_key):
pending.append(item)
if len(pending) >= _WRITE_BATCH_SIZE:
writer.write(pending)
item_count += len(pending)
pending = []
if pending:
writer.write(pending)
item_count += len(pending)
run_paths.append(run_path)
logger.info(
"External sort: wrote run %d (%d items) to %s",
Expand All @@ -135,29 +138,17 @@ def external_sort_merge(
read_batch_size = _safe_read_batch_size(len(run_paths), run_paths[0]) if run_paths else _WRITE_BATCH_SIZE

def _read_run(path: str) -> Iterator:
with fsspec.open(path, "rb") as raw_f:
with zstd.ZstdDecompressor().stream_reader(raw_f) as f:
while True:
try:
items: list = pickle.load(f)
# Yield in read_batch_size chunks and delete consumed
# items in-place so memory is released progressively
# even while the generator is suspended in heapq.merge.
while items:
chunk = items[:read_batch_size]
del items[:read_batch_size]
yield from chunk
except EOFError:
break
reader = SpillReader(path, batch_size=read_batch_size)
for chunk in reader.iter_chunks():
yield from chunk

run_iters = [_read_run(p) for p in run_paths]
try:
yield from heapq.merge(*run_iters, key=merge_key)
finally:
fs, _ = fsspec.core.url_to_fs(external_sort_dir)
for path in run_paths:
try:
_, fs_path = url_to_fs(path)
fs.rm(fs_path)
rm_fs, rm_path = url_to_fs(path)
rm_fs.rm(rm_path)
except Exception:
pass
211 changes: 211 additions & 0 deletions lib/zephyr/src/zephyr/spill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""Opaque chunked row format for zephyr spill files.

SpillWriter and SpillReader hide the on-disk representation from callers.
Items are pickled into an opaque binary payload and written as chunks of a
chunked row format. Callers do not see the schema, serialization, or storage
format — they append items and read back items (or chunks of items) in the
same order.

Currently backed by Parquet with a single binary payload column, a background
I/O thread, and byte-budgeted row groups. The file format is an implementation
detail; do not rely on it outside this module.
"""

import logging
import pickle
from collections.abc import Iterable, Iterator
from typing import Any

import fsspec
import pyarrow as pa
import pyarrow.parquet as pq

from zephyr.writers import ThreadedBatchWriter

logger = logging.getLogger(__name__)

# Single binary payload column. Not part of the public API.
_PAYLOAD_COL = "_zephyr_payload"
_SCHEMA = pa.schema([pa.field(_PAYLOAD_COL, pa.binary())])


class _TableAccumulator:
"""Accumulates Arrow tables and yields merged results when a byte threshold is reached.

Byte-budgeted batching produces uniformly-sized output regardless of row
width, which matters for write performance and memory predictability.
"""

def __init__(self, byte_threshold: int) -> None:
self._byte_threshold = byte_threshold
self._tables: list[pa.Table] = []
self._nbytes: int = 0

def add(self, table: pa.Table) -> pa.Table | None:
self._tables.append(table)
self._nbytes += table.nbytes
if self._nbytes >= self._byte_threshold:
return self._take()
return None

def flush(self) -> pa.Table | None:
if not self._tables:
return None
return self._take()

def _take(self) -> pa.Table:
result = pa.concat_tables(self._tables, promote_options="default")
self._tables.clear()
self._nbytes = 0
return result


def _items_to_table(items: Iterable[Any]) -> pa.Table:
payloads = [pickle.dumps(item, protocol=pickle.HIGHEST_PROTOCOL) for item in items]
return pa.table({_PAYLOAD_COL: pa.array(payloads, type=pa.binary())})


class SpillWriter:
"""Writes items to an opaque chunked row-format spill file.

Use ``write`` to stream items; the writer accumulates a byte budget and
emits chunks when the budget is exceeded. Use ``write_chunk`` to commit
a batch of items as its own chunk immediately (no accumulation) — useful
when the caller wants each logical batch to round-trip as one chunk.

Writes are offloaded to a :class:`ThreadedBatchWriter` so one write can be
in-flight while the caller produces the next batch. Backpressure, error
propagation, and clean teardown on the exception path are delegated to it.
"""

def __init__(
self,
path: str,
*,
row_group_bytes: int = 8 * 1024 * 1024,
compression: str = "zstd",
compression_level: int = 1,
) -> None:
self._writer = pq.ParquetWriter(path, _SCHEMA, compression=compression, compression_level=compression_level)
self._accumulator = _TableAccumulator(row_group_bytes)

def _drain(tables: Iterable[pa.Table]) -> None:
for table in tables:
self._writer.write_table(table)

# maxsize=1: at most one chunk in-flight so memory stays bounded while
# the producer keeps working on the next batch.
self._threaded = ThreadedBatchWriter(_drain, maxsize=1)
self._closed = False

def write(self, items: Iterable[Any]) -> None:
"""Append items. Emits a chunk when the accumulated byte budget is exceeded."""
table = _items_to_table(items)
if len(table) == 0:
return
merged = self._accumulator.add(table)
if merged is not None:
self._threaded.submit(merged)

def write_chunk(self, items: Iterable[Any]) -> None:
"""Commit items as their own chunk immediately (no accumulation)."""
table = _items_to_table(items)
if len(table) == 0:
return
self._threaded.submit(table)

def close(self) -> None:
"""Flush remaining buffered items and wait for the background writer."""
if self._closed:
return
self._closed = True
try:
remaining = self._accumulator.flush()
if remaining is not None:
self._threaded.submit(remaining)
self._threaded.close()
finally:
self._writer.close()

def __enter__(self) -> "SpillWriter":
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if self._closed:
return
self._closed = True
try:
if exc_type is not None:
# Error path: skip final flush (partial file will never be read)
# and let ThreadedBatchWriter.__exit__ tear down the thread
# without blocking the caller.
self._threaded.__exit__(exc_type, exc_val, exc_tb)
else:
remaining = self._accumulator.flush()
if remaining is not None:
self._threaded.submit(remaining)
self._threaded.close()
finally:
self._writer.close()


class SpillReader:
"""Reads items from an opaque chunked row-format spill file.

Iteration yields items one at a time in write order. ``iter_chunks`` yields
lists of items grouped by the on-disk chunks; callers that want a specific
batch size can pass ``batch_size`` to re-batch.
"""

def __init__(self, path: str, *, batch_size: int | None = None) -> None:
self._path = path
self._batch_size = batch_size

@property
def path(self) -> str:
return self._path

@property
def num_rows(self) -> int:
with fsspec.open(self._path, "rb") as f:
return pq.ParquetFile(f).metadata.num_rows

@property
def approx_item_bytes(self) -> int:
"""Uncompressed payload bytes per item, read from file metadata.

Returns 0 for an empty spill. Useful as a memory-budgeting hint without
exposing the underlying format.
"""
with fsspec.open(self._path, "rb") as f:
md = pq.ParquetFile(f).metadata
if md.num_rows <= 0:
return 0
total = sum(md.row_group(i).column(0).total_uncompressed_size for i in range(md.num_row_groups))
return total // md.num_rows

def iter_chunks(self) -> Iterator[list[Any]]:
"""Yield chunks of items (lists).

Chunk boundaries follow the on-disk layout unless ``batch_size`` was
set on the reader, in which case items are re-batched to approximately
that size.
"""
with fsspec.open(self._path, "rb") as f:
pf = pq.ParquetFile(f)
if self._batch_size is None:
for i in range(pf.num_row_groups):
table = pf.read_row_group(i, columns=[_PAYLOAD_COL])
payloads = table.column(_PAYLOAD_COL).to_pylist()
yield [pickle.loads(p) for p in payloads]
else:
for record_batch in pf.iter_batches(batch_size=self._batch_size, columns=[_PAYLOAD_COL]):
payloads = record_batch.column(_PAYLOAD_COL).to_pylist()
yield [pickle.loads(p) for p in payloads]

def __iter__(self) -> Iterator[Any]:
for chunk in self.iter_chunks():
yield from chunk
Loading