Skip to content

Commit e042c48

Browse files
committed
[zephyr] Fix PR review feedback: streaming reads, secondary sort, vectorized boundaries
- Stream Arrow run files via iter_batches() instead of whole-table read to avoid OOM on large shards that trigger external sort - Add promote_options="default" to concat_tables in external sort for schema-evolved items - Include _sort_secondary in external sort keys and merge key so within-group order is preserved - Vectorize _find_group_boundaries using pc.not_equal diff instead of per-element .as_py() calls — matters for high-cardinality keys - Store has_sort as a field on _ShardBuffer set at construction instead of fragile sorts[0] detection - Add assertion guard in _arrow_reduce_gen for pickled shards
1 parent 05deaf3 commit e042c48

3 files changed

Lines changed: 37 additions & 12 deletions

File tree

lib/zephyr/src/zephyr/external_sort.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def external_sort_merge_arrow(
193193
batch_tables = list(islice(chunk_tables_gen, EXTERNAL_SORT_FAN_IN))
194194
if not batch_tables:
195195
break
196-
combined = pa.concat_tables(batch_tables)
196+
combined = pa.concat_tables(batch_tables, promote_options="default")
197197
indices = pc.sort_indices(combined, sort_keys=sort_keys)
198198
sorted_table = combined.take(indices)
199199

@@ -214,8 +214,9 @@ def external_sort_merge_arrow(
214214
return
215215

216216
def _read_run(path: str) -> Iterator:
217-
table = pq.read_table(path)
218-
yield from table.to_pylist()
217+
pf = pq.ParquetFile(path)
218+
for batch in pf.iter_batches():
219+
yield from batch.to_pylist()
219220

220221
run_iters = [_read_run(p) for p in run_paths]
221222
try:

lib/zephyr/src/zephyr/plan.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,27 @@ def _flatmap_gen(stream: Iterator, fn: Callable) -> Iterator:
179179

180180

181181
def _find_group_boundaries(key_col: pa.ChunkedArray) -> Iterator[tuple[int, int, Any]]:
182-
"""Yield (start, end, key_value) for each contiguous group in a sorted key column."""
182+
"""Yield (start, end, key_value) for each contiguous group in a sorted key column.
183+
184+
Uses Arrow compute to find boundaries vectorized instead of per-element
185+
Python scalar extraction, which matters for high-cardinality keys.
186+
"""
183187
arr = key_col.combine_chunks()
184188
n = len(arr)
185189
if n == 0:
186190
return
191+
if n == 1:
192+
yield (0, 1, arr[0].as_py())
193+
return
194+
195+
# Vectorized boundary detection: compare adjacent elements
196+
ne_mask = pc.not_equal(arr[:-1], arr[1:])
197+
boundary_indices = pc.filter(pa.array(range(1, n), type=pa.int64()), ne_mask).to_pylist()
198+
187199
prev = 0
188-
for i in range(1, n):
189-
if arr[i].as_py() != arr[prev].as_py():
190-
yield (prev, i, arr[prev].as_py())
191-
prev = i
200+
for idx in boundary_indices:
201+
yield (prev, idx, arr[prev].as_py())
202+
prev = idx
192203
yield (prev, n, arr[prev].as_py())
193204

194205

@@ -223,8 +234,13 @@ def _arrow_reduce_gen(
223234
ScatterShard,
224235
_ZEPHYR_SHUFFLE_ITEM_COL,
225236
_ZEPHYR_SHUFFLE_SORT_KEY_COL,
237+
_ZEPHYR_SHUFFLE_SORT_SECONDARY_COL,
226238
)
227239

240+
assert not any(
241+
it.is_pickled for it in shard.iterators
242+
), "_arrow_reduce_gen requires non-pickled items; use _reduce_gen which routes pickled shards to the Python path"
243+
228244
use_external = (
229245
external_sort_dir is not None
230246
and isinstance(shard, ScatterShard)
@@ -233,6 +249,10 @@ def _arrow_reduce_gen(
233249

234250
if use_external:
235251
sort_keys: list[tuple[str, str]] = [(_ZEPHYR_SHUFFLE_SORT_KEY_COL, "ascending")]
252+
# Peek at the first chunk table to check for secondary sort column
253+
first_tables = list(islice((t for it in shard.iterators for t in it.get_chunk_tables()), 1))
254+
if first_tables and _ZEPHYR_SHUFFLE_SORT_SECONDARY_COL in first_tables[0].column_names:
255+
sort_keys.append((_ZEPHYR_SHUFFLE_SORT_SECONDARY_COL, "ascending"))
236256
logger.info(
237257
"Arrow external sort triggered for shard with %d iterators, spilling to %s",
238258
sum(it.chunk_count for it in shard.iterators),
@@ -243,7 +263,11 @@ def _chunk_tables() -> Iterator[pa.Table]:
243263
for it in shard.iterators:
244264
yield from it.get_chunk_tables()
245265

266+
has_secondary = len(sort_keys) > 1
267+
246268
def _merge_key(row: dict) -> Any:
269+
if has_secondary:
270+
return (row[_ZEPHYR_SHUFFLE_SORT_KEY_COL], row.get(_ZEPHYR_SHUFFLE_SORT_SECONDARY_COL))
247271
return row[_ZEPHYR_SHUFFLE_SORT_KEY_COL]
248272

249273
merged_stream = external_sort_merge_arrow(

lib/zephyr/src/zephyr/shuffle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ class _ShardBuffer:
518518

519519
shard_idx: int
520520
pickled: bool = False
521+
has_sort: bool = False
521522
pending: list[tuple[Any, Any, Any | None]] = field(default_factory=list)
522523
tables: list[pa.RecordBatch] = field(default_factory=list)
523524
nbytes: int = 0
@@ -534,14 +535,13 @@ def _flush_micro(self) -> None:
534535
if not self.pending:
535536
return
536537
items, keys, sorts = zip(*self.pending, strict=True)
537-
has_sort = sorts[0] is not None
538538
envelope_fn = _make_pickle_envelope if self.pickled else _make_envelope
539539
enveloped = envelope_fn(
540540
list(items),
541541
self.shard_idx,
542542
self.chunk_idx,
543543
list(keys),
544-
list(sorts) if has_sort else None,
544+
list(sorts) if self.has_sort else None,
545545
)
546546
batch = pa.RecordBatch.from_pylist(enveloped, schema=self.schema)
547547
if self.schema is None:
@@ -629,7 +629,7 @@ def _write_parquet_scatter(
629629

630630
def _get_buffer(target: int) -> _ShardBuffer:
631631
if target not in buffers:
632-
buffers[target] = _ShardBuffer(shard_idx=target, pickled=pickled)
632+
buffers[target] = _ShardBuffer(shard_idx=target, pickled=pickled, has_sort=sort_fn is not None)
633633
return buffers[target]
634634

635635
def _flush_pending() -> None:
@@ -694,7 +694,7 @@ def _flush_buffer(buf: _ShardBuffer) -> None:
694694
py_items = [pickle.loads(b) for b in py_items]
695695
combined = _apply_combiner(py_items, key_fn, combiner_fn)
696696
# Re-create a fresh buffer from combined items
697-
combined_buf = _ShardBuffer(shard_idx=buf.shard_idx, pickled=pickled)
697+
combined_buf = _ShardBuffer(shard_idx=buf.shard_idx, pickled=pickled, has_sort=sort_fn is not None)
698698
combined_buf.chunk_idx = buf.chunk_idx
699699
for item in combined:
700700
k = key_fn(item)

0 commit comments

Comments
 (0)