|
4 | 4 | """Scatter/shuffle support for Zephyr pipelines. |
5 | 5 |
|
6 | 6 | Each source-shard's scatter output is a single binary file containing a |
7 | | -sequence of zstd-compressed frames. Within one chunk's zstd frame, items |
8 | | -are written in sub-batches of ``_SUB_BATCH_SIZE`` — each sub-batch is a |
9 | | -single ``pickle.dump(list_of_items)`` into the zstd stream. This amortises |
10 | | -per-item pickle/zstd dispatch over a sub-batch while still letting the |
11 | | -reader stream sub-batches lazily without materialising the full chunk. |
| 7 | +sequence of zstd-compressed frames. Each frame starts with a one-byte format |
| 8 | +tag: ``\\x01`` for Arrow IPC and ``\\x00`` for cloudpickle sub-batches. |
| 9 | +
|
| 10 | +Arrow IPC is used when the chunk items are Arrow-compatible (plain dicts with |
| 11 | +primitive or string values). For items that Arrow cannot represent (e.g. |
| 12 | +frozenset, custom classes), the writer falls back to cloudpickle sub-batches |
| 13 | +of ``_SUB_BATCH_SIZE`` items per ``pickle.dump`` call. |
12 | 14 |
|
13 | 15 | A msgpack sidecar (``.scatter_meta``) maps ``target_shard -> [(offset, length)]`` |
14 | 16 | byte ranges into the data file, plus per-shard ``max_chunk_rows`` and a global |
|
17 | 19 | which reducers consume to build :class:`ScatterReader` instances. |
18 | 20 |
|
19 | 21 | On read, each chunk is fetched with a single ``cat_file`` range GET (one |
20 | | -HTTP request, no per-chunk file handle), then streamed via |
21 | | -``pickle.load`` on a length-bounded zstd reader. Per-iterator memory stays |
22 | | -near-constant: one buffered item plus the zstd decoder state plus the |
23 | | -chunk's compressed bytes (typically a few MB). This bound is essential for |
24 | | -skewed shuffles where one reducer pulls disproportionate data and the |
25 | | -external-sort fan-in opens hundreds of chunk iterators at once. |
| 22 | +HTTP request, no per-chunk file handle). The format tag is inspected and the |
| 23 | +payload is dispatched to Arrow IPC or pickle deserialization accordingly. |
| 24 | +Per-iterator memory stays near-constant: bounded by the chunk's compressed |
| 25 | +bytes plus decompressed Arrow buffers or pickle state. |
26 | 26 | """ |
27 | 27 |
|
28 | 28 | from __future__ import annotations |
|
40 | 40 |
|
41 | 41 | import cloudpickle |
42 | 42 | import msgspec |
| 43 | +import pyarrow as pa |
43 | 44 | import zstandard as zstd |
44 | 45 | from rigging.filesystem import open_url, url_to_fs |
45 | 46 | from rigging.timing import log_time |
@@ -101,6 +102,11 @@ def get_iterators(self) -> Iterator[Iterator]: |
101 | 102 | # dispatch overhead), smaller = lower per-iterator read memory. |
102 | 103 | _SUB_BATCH_SIZE = 1024 |
103 | 104 |
|
| 105 | +# One-byte format tags written at the start of every chunk frame. |
| 106 | +# Arrow IPC is used when items are Arrow-compatible; pickle is the fallback. |
| 107 | +_FRAME_FORMAT_PICKLE = b"\x00" |
| 108 | +_FRAME_FORMAT_ARROW = b"\x01" |
| 109 | + |
104 | 110 |
|
105 | 111 | # --------------------------------------------------------------------------- |
106 | 112 | # Sidecar / manifest helpers |
@@ -249,19 +255,25 @@ def get_chunk_iterators(self) -> Iterator[Iterator]: |
249 | 255 | def _iter_chunk(fs: Any, fs_path: str, offset: int, length: int) -> Iterator: |
250 | 256 | """Fetch one chunk's compressed bytes via cat_file and stream items. |
251 | 257 |
|
252 | | - Each chunk is a zstd frame containing a sequence of pickled sub-batches |
253 | | - (lists of up to ``_SUB_BATCH_SIZE`` items). The reader streams one |
254 | | - sub-batch at a time, so per-iterator memory is bounded by the |
255 | | - sub-batch size plus the chunk's compressed bytes. |
| 258 | + Reads the one-byte format tag to dispatch to Arrow IPC or pickle |
| 259 | + deserialization. Arrow chunks are decompressed in one shot and converted |
| 260 | + via ``table.to_pylist()``. Pickle chunks are streamed sub-batch by |
| 261 | + sub-batch via ``pickle.load``. |
256 | 262 | """ |
257 | 263 | blob = fs.cat_file(fs_path, start=offset, end=offset + length) |
258 | | - with zstd.ZstdDecompressor().stream_reader(io.BytesIO(blob)) as reader: |
259 | | - while True: |
260 | | - try: |
261 | | - sub_batch = pickle.load(reader) |
262 | | - except EOFError: |
263 | | - return |
264 | | - yield from sub_batch |
| 264 | + fmt, payload = blob[0:1], blob[1:] |
| 265 | + if fmt == _FRAME_FORMAT_ARROW: |
| 266 | + ipc_bytes = zstd.ZstdDecompressor().decompress(payload) |
| 267 | + reader = pa.ipc.open_stream(pa.py_buffer(ipc_bytes)) |
| 268 | + yield from reader.read_all().to_pylist() |
| 269 | + else: |
| 270 | + with zstd.ZstdDecompressor().stream_reader(io.BytesIO(payload)) as reader: |
| 271 | + while True: |
| 272 | + try: |
| 273 | + sub_batch = pickle.load(reader) |
| 274 | + except EOFError: |
| 275 | + return |
| 276 | + yield from sub_batch |
265 | 277 |
|
266 | 278 |
|
267 | 279 | # --------------------------------------------------------------------------- |
@@ -426,19 +438,28 @@ def _apply_combiner(buffer: list, key_fn: Callable, combiner_fn: Callable) -> li |
426 | 438 |
|
427 | 439 |
|
428 | 440 | def _write_chunk_frame(items: list) -> bytes: |
429 | | - """Encode a list of items as one zstd frame of pickled sub-batches. |
| 441 | + """Encode a list of items as one zstd-compressed frame. |
430 | 442 |
|
431 | | - Items are split into sub-batches of ``_SUB_BATCH_SIZE`` and each |
432 | | - sub-batch is written as a single ``cloudpickle.dump(sublist)`` into the |
433 | | - same zstd stream. This batches per-call dispatch overhead while |
434 | | - keeping per-iterator read memory bounded by the sub-batch size. |
| 443 | + Tries Arrow IPC first; falls back to cloudpickle sub-batches for types |
| 444 | + Arrow cannot represent (e.g. frozenset, custom classes). The first byte |
| 445 | + of the returned bytes is the format tag (``_FRAME_FORMAT_ARROW`` or |
| 446 | + ``_FRAME_FORMAT_PICKLE``). |
435 | 447 | """ |
436 | | - raw = io.BytesIO() |
437 | | - cctx = zstd.ZstdCompressor(level=_ZSTD_COMPRESS_LEVEL) |
438 | | - with cctx.stream_writer(raw, closefd=False) as zf: |
439 | | - for i in range(0, len(items), _SUB_BATCH_SIZE): |
440 | | - cloudpickle.dump(items[i : i + _SUB_BATCH_SIZE], zf, protocol=pickle.HIGHEST_PROTOCOL) |
441 | | - return raw.getvalue() |
| 448 | + try: |
| 449 | + table = pa.Table.from_pylist(items) |
| 450 | + sink = pa.BufferOutputStream() |
| 451 | + with pa.ipc.new_stream(sink, table.schema) as writer: |
| 452 | + writer.write_table(table) |
| 453 | + cctx = zstd.ZstdCompressor(level=_ZSTD_COMPRESS_LEVEL) |
| 454 | + return _FRAME_FORMAT_ARROW + cctx.compress(sink.getvalue().to_pybytes()) |
| 455 | + except Exception: |
| 456 | + logger.debug("_write_chunk_frame: Arrow IPC not applicable, using pickle") |
| 457 | + raw = io.BytesIO() |
| 458 | + cctx = zstd.ZstdCompressor(level=_ZSTD_COMPRESS_LEVEL) |
| 459 | + with cctx.stream_writer(raw, closefd=False) as zf: |
| 460 | + for i in range(0, len(items), _SUB_BATCH_SIZE): |
| 461 | + cloudpickle.dump(items[i : i + _SUB_BATCH_SIZE], zf, protocol=pickle.HIGHEST_PROTOCOL) |
| 462 | + return _FRAME_FORMAT_PICKLE + raw.getvalue() |
442 | 463 |
|
443 | 464 |
|
444 | 465 | class ScatterWriter: |
|
0 commit comments