[zephyr] Arrow-native scatter/reduce: 1.1x reduce speedup#4178
[zephyr] Arrow-native scatter/reduce: 1.1x reduce speedup#4178
Conversation
|
Claude finished @rjpower's task in 1m 53s —— View job Review: Arrow-native scatter/reduceReplaces Python-centric scatter/reduce with Arrow columnar operations. Scatter accumulates micro-batches with a persisted No specification found (no linked issue, design doc, or Findings:
Tests cover sort-key propagation, Arrow vs Python equivalence, combiner path, external sort roundtrip, and cleanup. Test coverage looks adequate for the new paths. |
|
🤖 Specification (>500 LOC PR) Problem: Scatter/reduce serializes items as Python dicts, sorts via heapq.merge (O(N log K) Python comparisons), and external sort spills via pickle+zstd. The reduce path is the bottleneck: 13s for 500K items. Approach:
Key code: The critical optimization is in _arrow_reduce_gen (plan.py:215-275). It reads chunk tables as Arrow, concatenates with promote_options="default" for schema evolution, sorts via pc.sort_indices, then groups by key using boundary detection on the sorted _sort_key column. Items are only materialized to Python when reducer_fn iterates. Tests: 8 new tests in test_shuffle.py covering Arrow merge correctness vs Python path, sort key column presence, external sort roundtrip/cleanup, zero-memory fallback, combiner interaction. All 46 existing tests pass. New benchmark_scatter_reduce.py for isolated measurement. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d6fa272b44
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| table = pq.read_table(path) | ||
| yield from table.to_pylist() |
There was a problem hiding this comment.
Stream Arrow run files during external merge
_read_run materializes each spill file with pq.read_table(...).to_pylist() before yielding any row. Because heapq.merge primes all iterators, this loads every run fully into Python memory at once, which can OOM on the large shards that trigger external sort in the first place. This defeats the memory-safety goal of the external path; read runs incrementally (e.g., scanner/batches) instead of whole-table materialization.
Useful? React with 👍 / 👎.
| ) | ||
|
|
||
| if use_external: | ||
| sort_keys: list[tuple[str, str]] = [(_ZEPHYR_SHUFFLE_SORT_KEY_COL, "ascending")] |
There was a problem hiding this comment.
Preserve secondary sort key in external Arrow reduce
When external sort is used, reduce ordering is computed only by _sort_key. This drops sort_by semantics for equal keys, so reducers that depend on within-group order (for example next(items) / keep-first patterns) can produce different results from the in-memory Arrow path and the legacy Python merge path. The external Arrow sort/merge key needs to include _sort_secondary when present.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
hrm i don't know why it says this, it adds the secondary sort key right below:
# Peek at the first chunk table to check for secondary sort column
first_tables = list(islice((t for it in shard.iterators for t in it.get_chunk_tables()), 1))
if first_tables and _ZEPHYR_SHUFFLE_SORT_SECONDARY_COL in first_tables[0].column_names:
sort_keys.append((_ZEPHYR_SHUFFLE_SORT_SECONDARY_COL, "ascending"))
| batch_tables = list(islice(chunk_tables_gen, EXTERNAL_SORT_FAN_IN)) | ||
| if not batch_tables: | ||
| break | ||
| combined = pa.concat_tables(batch_tables) |
There was a problem hiding this comment.
Enable schema promotion when concatenating Arrow spill inputs
Pass-1 external Arrow sort concatenates chunk tables without schema promotion. Scatter explicitly supports schema evolution across segments, and evolved item structs cause pa.concat_tables to raise ArrowInvalid here, so large reduces can fail only when they spill. Use promoted concatenation (as done in the non-external Arrow path) to keep schema-evolved inputs working.
Useful? React with 👍 / 👎.
| list(items), | ||
| self.shard_idx, | ||
| self.chunk_idx, | ||
| list(keys), |
There was a problem hiding this comment.
Keep pickle fallback compatible with non-Arrow key types
The new buffer path always passes key_values into the envelope, including when pickled=True. That means group keys must now be Arrow-serializable, and hashable-but-non-Arrow keys (e.g., frozenset) fail during RecordBatch.from_pylist, even though group_by only requires hashable keys and previously supported this via the pickle envelope path. This is a correctness regression for those pipelines.
Useful? React with 👍 / 👎.
7e9b79f to
175a1f6
Compare
|
So in practice, ~10% faster: Both runs completed successfully. Results are nearly identical:
Both warn The arrow run (222210) was ~8 minutes faster, mainly from a faster p0 stage (49m55s vs 57m13s). |
|
@rjpower ah would have definitely preferred the 10x 🤞
Makes sense.
Yep - btw any other diff in counters? Asking to spot check the counters. There's a bunch of changes in this PR, I wonder if we could narrow it down to the part that gives us the 10%, or at least the majority of it? wdyt? |
|
We can drop the rust changes, those have a pretty minimal effect. The rest is pretty related: Where the wins come fromshuffle.py — Scatter write path (-13.5% CPU share, main wall-clock win)
The new path accumulates plan.py — Reduce path (-1.4% CPU, structural improvement)
fuzzy.py — Bucket key fix (enables Arrow path)The bucket value is now emitted as signed int64 ( minhash_ops.rs — Minor Rust optimizations (~same MinHash %, slightly faster absolute)
Pipeline stage timingBoth runs: 12 pipelines (p0 = MinHash LSH scatter, p1–p10 = CC iterations, p11 = final dedup tagging). CC did not converge in 10 iterations for either run.
The p0 speedup (~7 minutes) accounts for essentially all of the wall-clock improvement. p1–p11 are slightly slower (+3s avg) because the Arrow reduce path has higher per-pipeline overhead (concat + sort + take) that doesn't amortize on these small CC iteration passes. @claude take care of this, remove the rust/** changes pls |
|
Claude finished @rjpower's task in 35s —— View job Removing rust/** changes
Reverted |
|
Looks like the counters are lost after the run finishes? We should probably write them at least to the coordinator log. Here's the CC trajectory: looks identical for the first step and then it's non-deterministic after that according to Claude, so expected to see some variation. |
Yep - agreed. I would like maybe some integration with the artifact/executor as well, but step by step :) |
ravwojdyla
left a comment
There was a problem hiding this comment.
I think this makes sense, my only worry is that it adds a fairly complicated special branch (to already complicated code). I wouldn't be surprised if we kept the previous logic and squeeze something on to order of 10% from there. wdyt?
Either way, if we push forward on this, I think we need to confirm this works on larger scale. I'm happy to run that experiment if needed.
| external_sort_dir=sort_dir, | ||
| ) | ||
| ) | ||
| import os |
There was a problem hiding this comment.
man I don't get it - why do LLMs like local import so much?
There was a problem hiding this comment.
i don't know either, right, it's weird
| serialization/deserialization in pass 1. Pass 2 still uses Python heapq | ||
| for simplicity. | ||
| """ | ||
| from zephyr.writers import ensure_parent_dir |
| for batch in pf.iter_batches(): | ||
| yield from batch.to_pylist() | ||
|
|
||
| run_iters = [_read_run(p) for p in run_paths] |
There was a problem hiding this comment.
nit: this logic is exactly repeated from above, consider to extract.
| seg_paths.append(seg_file) | ||
| ensure_parent_dir(seg_file) | ||
| writer = pq.ParquetWriter(seg_file, schema) | ||
| writer = pq.ParquetWriter(seg_file, schema, compression="zstd", compression_level=1) |
There was a problem hiding this comment.
How much do we gain just by using faster compression?
There was a problem hiding this comment.
this was added after the initial runs, so we'll see, i don't think it's a big gain over snappy though
There was a problem hiding this comment.
if no gain - should we just use default?
| scanner = dataset.scanner( | ||
| columns=columns, | ||
| filter=( | ||
| (pc.field(_ZEPHYR_SHUFFLE_SHARD_IDX_COL) == self.shard_idx) | ||
| & (pc.field(_ZEPHYR_SHUFFLE_CHUNK_IDX_COL) == chunk_idx) | ||
| ), | ||
| batch_size=batch_size, | ||
| use_threads=False, | ||
| ) | ||
| batches = list(scanner.to_batches()) | ||
| if batches: | ||
| yield pa.Table.from_batches(batches) |
There was a problem hiding this comment.
we don't need a lazy scanner if we materialize all batches anyway?
There was a problem hiding this comment.
we can remove it if desired it's more an option if we e.g. wanted to bound the memory for the external sort
There was a problem hiding this comment.
I would vote to remove it, if it's nice-to-have and not on the critical path.
| file_entries.append(entry) | ||
|
|
||
| # has_sort_key is True only if ALL entries with data for this shard have it | ||
| has_sort_key = bool(file_entries) and all(entry.get("has_sort_key", False) for entry in file_entries) |
There was a problem hiding this comment.
Q: is this necessary for arrow sort to work?
| _ZEPHYR_SHUFFLE_SORT_KEY_COL = "_sort_key" | ||
| _ZEPHYR_SHUFFLE_SORT_SECONDARY_COL = "_sort_secondary" |
There was a problem hiding this comment.
There's no need to use _ prefix, in fact that inconsistent with the other field names. These are envelope fields so we can call them whatever we like.
There was a problem hiding this comment.
i'm fine either way -- but don't we want to keep the parquet flat if possible?
There was a problem hiding this comment.
I assume there's maybe some overhead - but we would need to benchmark that. I would assume at this time this doesn't matter, and there are much lower hanging fruits out there.
| ) | ||
|
|
||
| if use_external: | ||
| sort_keys: list[tuple[str, str]] = [(_ZEPHYR_SHUFFLE_SORT_KEY_COL, "ascending")] |
i do think this is much faster, just that fuzzy dedup isn't a particularly great stress test -- ~40% of the time is in n-gram computation, for example. i think it's the right direction -- though I think I can also simplify this a lot by getting rid of the vestigial code as well, let me take a pass at that.
+1, or feel free to send me some commands you use. i feel i should do it if I was the one to annoy us with this sidequest. |
If we could simplify this PR that would be amazing!
Running on at least a single full split (e.g. quality=high) of nemotron would be a good stress test. Instead of doing fuzzy, we can do exact-paragraph to reduce the wall-time while still exercise reasonable scale shuffle. We can use the |
e3d9b3c to
b1c1029
Compare
👍 yeah i can run that for exact |
4772d73 to
8b8b9bf
Compare
Replace Python-object-centric scatter/reduce with Arrow columnar operations. Scatter now accumulates Arrow micro-batches with a persisted _sort_key column and sorts via pc.sort_indices. Reduce concatenates Arrow tables and sorts columnar instead of heapq.merge over deserialized Python dicts. External sort spills to Parquet instead of pickle+zstd. Benchmark: 500K items x 64 shards, reduce drops from 13.3s to 1.3s (10.4x). Backward compatible — old format files without _sort_key fall back to the Python merge path.
…torized boundaries - Stream Arrow run files via iter_batches() instead of whole-table read to avoid OOM on large shards that trigger external sort - Add promote_options="default" to concat_tables in external sort for schema-evolved items - Include _sort_secondary in external sort keys and merge key so within-group order is preserved - Vectorize _find_group_boundaries using pc.not_equal diff instead of per-element .as_py() calls — matters for high-cardinality keys - Store has_sort as a field on _ShardBuffer set at construction instead of fragile sorts[0] detection - Add assertion guard in _arrow_reduce_gen for pickled shards
- Rename scatter columns to _zephyr_* prefix (no compat aliases) - Replace _make_envelope/_make_pickle_envelope with unified make_envelope_batch() and unwrap_items() helpers - Replace PickleDiskChunk with ParquetDiskChunk for intermediates - Remove pickle-based external sort; rename external_sort_merge_arrow to external_sort_merge, extract _read_parquet_run helper - Remove _merge_sorted_chunks (Python heapq path); _arrow_reduce_gen handles both flat and pickled envelopes via unwrap_items - Remove has_sort_key field (always true), ZEPHYR_META_COLUMNS (unused), _ITEM_BYTES_FALLBACK (unreachable) Net: -292 lines, one serialization format, one reduce path.
When a shard is retried (e.g. due to heartbeat timeout), the old and new
attempts share the same external sort directory. The old attempt's finally
block can delete run files the new attempt is reading, causing
FileNotFoundError. Adding /attempt-{n} to the path isolates each attempt.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When fan_in is overestimated (due to first chunk having 0 bytes), the last batch may contain only empty tables. ParquetWriter.close() without any write_table() calls doesn't create a file on GCS, but the path is still appended to run_paths. The subsequent read_metadata then fails with FileNotFoundError. Fix: skip empty batches after concat, and probe past empty chunks when computing the sort budget. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… staleness Writes go through pyarrow's native GcsFileSystem (via ParquetWriter), but reads used fsspec/gcsfs which caches parent directory listings. When read_metadata(run-0000) cached the dir, run-0001 (written moments later) wasn't in the cache yet, causing FileNotFoundError. Fix: use pyarrow-native path resolution for all reads (no filesystem arg), matching the write path. Added GCS integration test that reproduces the bug.
…ernal sort SpillWriter: byte-budgeted ParquetWriter wrapper with background write thread for GCS upload overlap. Uses zstd-1 compression (was snappy for spills). Two modes: write_table() accumulates and auto-flushes row groups, write_row_group() writes immediately (for scatter where each chunk must be a separate row group). TableAccumulator: byte-budgeted Arrow table batching, replacing the row-count-based _MERGE_OUTPUT_BATCH_SIZE in the k-way merge output. Refactors: - external_sort: _write_spill_file uses SpillWriter (was binary-search row-group sizing), merge output uses TableAccumulator - shuffle: _write_parquet_scatter uses SpillWriter (was manual ParquetWriter + pending_chunk + _flush_pending)
8b8b9bf to
8572a27
Compare
Replaces the .pkl.zst spill files in external_sort_merge with Parquet files written through a new SpillWriter (byte-budgeted ParquetWriter with a background I/O thread). Items are cloudpickle-style serialized into a single `_zephyr_payload` binary column; the Python heapq.merge semantics on both passes are unchanged, so behavior is identical. This is the minimal slice cherry-picked from #4178 (arrow-scatter-reduce) that removes raw pickle files from zephyr's shuffle data plane without touching the scatter envelope or reduce merge. Follow-ups will promote the sort key to a first-class column and move reduce to columnar Arrow merge. Pass-2 read-batch-size estimation now reads row-group metadata directly from the parquet file instead of probing a pickled sample.
Replaces the .pkl.zst spill files in external_sort_merge with Parquet files written through a new SpillWriter (byte-budgeted ParquetWriter with a background I/O thread). Items are cloudpickle-style serialized into a single `_zephyr_payload` binary column; the Python heapq.merge semantics on both passes are unchanged, so behavior is identical. This is the minimal slice cherry-picked from #4178 (arrow-scatter-reduce) that removes raw pickle files from zephyr's shuffle data plane without touching the scatter envelope or reduce merge. Follow-ups will promote the sort key to a first-class column and move reduce to columnar Arrow merge. Pass-2 read-batch-size estimation now reads row-group metadata directly from the parquet file instead of probing a pickled sample.
## Summary - Replaces `.pkl.zst` spill files in `external_sort_merge` with Parquet files written via a new `SpillWriter` (byte-budgeted `pq.ParquetWriter` with a background I/O thread). - Items are pickled into a single `_zephyr_payload` binary column. Python `heapq.merge` semantics on both passes are unchanged, so behavior is identical — this is a format swap only. - Pass 2 reads spills back with `pq.ParquetFile.iter_batches` and unpickles one row group at a time to feed the heap merge. - Pass-2 read-batch-size estimation now reads row-group metadata directly from the parquet file instead of probing a pickled sample. This is the minimal slice cherry-picked from #4178 that removes raw pickle files from zephyr's shuffle data plane. The scatter envelope and reduce merge are untouched; follow-ups will promote the sort key to a first-class column (Tier 2) and move reduce to columnar Arrow merge (Tier 3). `SpillWriter` is added as `lib/zephyr/src/zephyr/spill_writer.py` verbatim from #4178. `external_sort.py` is its only caller in this PR.
|
This pull request has been inactive for 23 days and is marked as stale. |
Replace Python-object-centric scatter/reduce with Arrow columnar operations.
Scatter accumulates Arrow micro-batches (64 items) with a persisted _sort_key
column, flushing sorted row groups at 2MB via pc.sort_indices. Reduce
concatenates chunk tables and sorts columnar instead of heapq.merge over
deserialized Python dicts. External sort spills to Parquet instead of
pickle+zstd.
Benchmark (500K items, 64 shards, ~150 byte rows simulating fuzzy dedup):
Scatter: 0.64s (786K items/s)
Reduce Arrow: 1.28s (391K items/s)
Reduce Python: 13.29s (38K items/s)
Reduce speedup: 10.4x
Backward compatible: old format files without _sort_key fall back to the
Python merge path.