Skip to content

Commit 8b8b9bf

Browse files
committed
[zephyr] Add SpillWriter and TableAccumulator, refactor scatter + external sort
SpillWriter: byte-budgeted ParquetWriter wrapper with background write thread for GCS upload overlap. Uses zstd-1 compression (was snappy for spills). Two modes: write_table() accumulates and auto-flushes row groups, write_row_group() writes immediately (for scatter where each chunk must be a separate row group). TableAccumulator: byte-budgeted Arrow table batching, replacing the row-count-based _MERGE_OUTPUT_BATCH_SIZE in the k-way merge output. Refactors: - external_sort: _write_spill_file uses SpillWriter (was binary-search row-group sizing), merge output uses TableAccumulator - shuffle: _write_parquet_scatter uses SpillWriter (was manual ParquetWriter + pending_chunk + _flush_pending)
1 parent a967602 commit 8b8b9bf

10 files changed

Lines changed: 377 additions & 92 deletions

File tree

lib/iris/src/iris/actor/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def __init__(
5959
resolver: Resolver,
6060
name: str,
6161
call_timeout: float | None = None,
62-
max_call_attempts: int = 5,
63-
backoff: ExponentialBackoff = ExponentialBackoff(initial=0.1, maximum=10.0, factor=2.0, jitter=0.25),
62+
max_call_attempts: int = 10,
63+
backoff: ExponentialBackoff = ExponentialBackoff(initial=0.5, maximum=10.0, factor=2.0, jitter=0.25),
6464
):
6565
"""Initialize the actor client.
6666

lib/iris/src/iris/client/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def submit(
555555
coscheduling: CoschedulingConfig | None = None,
556556
replicas: int = 1,
557557
max_retries_failure: int = 0,
558-
max_retries_preemption: int = 100,
558+
max_retries_preemption: int = 10_000,
559559
timeout: Duration | None = None,
560560
user: str | None = None,
561561
reservation: list[ReservationEntry] | None = None,

lib/iris/src/iris/cluster/controller/scaling_group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ class AvailabilityState:
9898
until: Timestamp | None = None
9999

100100

101-
DEFAULT_SCALE_UP_RATE_LIMIT = 5 # per minute
102-
DEFAULT_SCALE_DOWN_RATE_LIMIT = 5 # per minute
101+
DEFAULT_SCALE_UP_RATE_LIMIT = 32 # per minute
102+
DEFAULT_SCALE_DOWN_RATE_LIMIT = 32 # per minute
103103
DEFAULT_SCALE_UP_COOLDOWN = Duration.from_minutes(1)
104104
DEFAULT_BACKOFF_INITIAL = Duration.from_minutes(5)
105105
DEFAULT_BACKOFF_MAX = Duration.from_minutes(15)

lib/iris/src/iris/cluster/providers/gcp/bootstrap.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,11 @@ def replace_var(match: re.Match) -> str:
136136
sudo systemctl start docker || true
137137
138138
# Tune network stack for high-connection workloads (#3066).
139-
# Expands ephemeral port range and allows reuse of TIME_WAIT sockets.
139+
# Expands ephemeral port range, allows reuse of TIME_WAIT sockets,
140+
# and raises listen backlog for actor servers handling 1000s of workers.
140141
sudo sysctl -w net.ipv4.ip_local_port_range="1024 65535"
141142
sudo sysctl -w net.ipv4.tcp_tw_reuse=1
143+
sudo sysctl -w net.core.somaxconn=4096
142144
143145
# Create cache directory
144146
sudo mkdir -p {{ cache_dir }}

lib/iris/src/iris/cluster/runtime/docker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,8 @@ def _docker_create(
618618
"create",
619619
"--ulimit",
620620
"core=0:0",
621+
"--ulimit",
622+
"nofile=65536:524288",
621623
"-w",
622624
config.workdir,
623625
]

lib/zephyr/src/zephyr/external_sort.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import pyarrow.parquet as pq
3131
from iris.env_resources import TaskResources as _TaskResources
3232

33+
from zephyr.spill_writer import SpillWriter, TableAccumulator
34+
3335
logger = logging.getLogger(__name__)
3436

3537
# Fraction of worker memory available for sort (pass 1 and pass 2 are
@@ -43,9 +45,6 @@
4345
# memory during merge, so this controls per-run memory footprint.
4446
_SPILL_ROW_GROUP_TARGET_BYTES = 8 * 1024 * 1024 # 8 MB
4547

46-
# Output batch size yielded from the merge.
47-
_MERGE_OUTPUT_BATCH_SIZE = 100_000
48-
4948

5049
@dataclass
5150
class _SortBudget:
@@ -79,24 +78,8 @@ def _compute_budget(chunk_bytes: int) -> _SortBudget:
7978

8079
def _write_spill_file(table: pa.Table, path: str) -> None:
8180
"""Write a sorted table as a Parquet file with byte-budgeted row groups."""
82-
writer = pq.ParquetWriter(path, table.schema)
83-
offset = 0
84-
n = len(table)
85-
while offset < n:
86-
# Grow the row group until we hit the byte target.
87-
# Double the slice size each probe to keep overhead O(log n).
88-
lo = offset + 1
89-
hi = n
90-
while lo < hi:
91-
mid = (lo + hi + 1) // 2
92-
if table.slice(offset, mid - offset).nbytes <= _SPILL_ROW_GROUP_TARGET_BYTES:
93-
lo = mid
94-
else:
95-
hi = mid - 1
96-
rg_end = lo
97-
writer.write_table(table.slice(offset, rg_end - offset))
98-
offset = rg_end
99-
writer.close()
81+
with SpillWriter(path, table.schema, row_group_bytes=_SPILL_ROW_GROUP_TARGET_BYTES) as w:
82+
w.write_table(table)
10083

10184

10285
def _promote_to_large_string(table: pa.Table) -> pa.Table:
@@ -215,8 +198,7 @@ def _streaming_k_way_merge(
215198
for src in sources:
216199
heapq.heappush(heap, _MergeEntry(src.current_sort_value(), src.idx, src))
217200

218-
output_chunks: list[pa.Table] = []
219-
output_rows = 0
201+
accumulator = TableAccumulator(_SPILL_ROW_GROUP_TARGET_BYTES)
220202

221203
while heap:
222204
entry = heapq.heappop(heap)
@@ -229,19 +211,17 @@ def _streaming_k_way_merge(
229211
take_count = winner.remaining()
230212

231213
chunk = winner.take(take_count)
232-
output_chunks.append(chunk)
233-
output_rows += len(chunk)
234214

235215
if winner.has_data:
236216
heapq.heappush(heap, _MergeEntry(winner.current_sort_value(), winner.idx, winner))
237217

238-
if output_rows >= _MERGE_OUTPUT_BATCH_SIZE:
239-
yield pa.concat_tables(output_chunks, promote_options="default")
240-
output_chunks.clear()
241-
output_rows = 0
218+
merged = accumulator.add(chunk)
219+
if merged is not None:
220+
yield merged
242221

243-
if output_chunks:
244-
yield pa.concat_tables(output_chunks, promote_options="default")
222+
remaining = accumulator.flush()
223+
if remaining is not None:
224+
yield remaining
245225

246226

247227
def external_sort_merge(

lib/zephyr/src/zephyr/plan.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from iris.env_resources import TaskResources as _TaskResources
2828
from rigging.filesystem import url_to_fs
2929

30-
from zephyr.external_sort import external_sort_merge
30+
from zephyr.external_sort import _promote_to_large_string, external_sort_merge
3131

3232
from zephyr.dataset import (
3333
Dataset,
@@ -211,7 +211,7 @@ def _arrow_merge_sorted_chunks(shard: Any) -> pa.Table:
211211
all_tables.append(table)
212212
if not all_tables:
213213
return pa.table({})
214-
combined = pa.concat_tables(all_tables, promote_options="default")
214+
combined = pa.concat_tables([_promote_to_large_string(t) for t in all_tables], promote_options="default")
215215
sort_keys: list[tuple[str, str]] = [(_ZEPHYR_SORT_KEY, "ascending")]
216216
if _ZEPHYR_SORT_SECONDARY in combined.column_names:
217217
sort_keys.append((_ZEPHYR_SORT_SECONDARY, "ascending"))
@@ -259,22 +259,42 @@ def _chunk_tables() -> Iterator[pa.Table]:
259259
for it in shard.iterators:
260260
yield from it.get_chunk_tables()
261261

262-
all_tables = list(external_sort_merge(_chunk_tables(), sort_keys, external_sort_dir))
263-
if not all_tables:
264-
return
265-
sorted_table = pa.concat_tables(all_tables, promote_options="default")
262+
# Stream through the merge, grouping by sort key across batch boundaries.
263+
# Only one batch + one group's accumulated rows are in memory at a time.
264+
is_gen = inspect.isgeneratorfunction(reducer_fn)
265+
current_key = None
266+
current_group_tables: list[pa.Table] = []
266267

267-
key_col = sorted_table.column(_ZEPHYR_SORT_KEY)
268-
pickled = _ZEPHYR_PAYLOAD in sorted_table.column_names
268+
for batch_table in external_sort_merge(_chunk_tables(), sort_keys, external_sort_dir):
269+
pickled = _ZEPHYR_PAYLOAD in batch_table.column_names
270+
key_col = batch_table.column(_ZEPHYR_SORT_KEY)
269271

270-
is_gen = inspect.isgeneratorfunction(reducer_fn)
271-
for start, end, key_value in _find_group_boundaries(key_col):
272-
group_table = sorted_table.slice(start, end - start)
272+
for start, end, key_value in _find_group_boundaries(key_col):
273+
group_slice = batch_table.slice(start, end - start)
274+
275+
if current_key is None:
276+
current_key = key_value
277+
current_group_tables = [group_slice]
278+
elif key_value == current_key:
279+
current_group_tables.append(group_slice)
280+
else:
281+
group_table = pa.concat_tables(current_group_tables, promote_options="default")
282+
group_items = unwrap_items(group_table, pickled)
283+
if is_gen:
284+
yield from reducer_fn(current_key, iter(group_items))
285+
else:
286+
yield reducer_fn(current_key, iter(group_items))
287+
current_key = key_value
288+
current_group_tables = [group_slice]
289+
290+
if current_group_tables:
291+
group_table = pa.concat_tables(current_group_tables, promote_options="default")
292+
pickled = _ZEPHYR_PAYLOAD in group_table.column_names
273293
group_items = unwrap_items(group_table, pickled)
274294
if is_gen:
275-
yield from reducer_fn(key_value, iter(group_items))
295+
yield from reducer_fn(current_key, iter(group_items))
276296
else:
277-
yield reducer_fn(key_value, iter(group_items))
297+
yield reducer_fn(current_key, iter(group_items))
278298
return
279299

280300
sorted_table = _arrow_merge_sorted_chunks(shard)

lib/zephyr/src/zephyr/shuffle.py

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
import pyarrow as pa
2929
import pyarrow.compute as pc
3030
import pyarrow.dataset as pad
31-
import pyarrow.parquet as pq
3231
from iris.env_resources import TaskResources as _TaskResources
3332
from rigging.filesystem import open_url, url_to_fs
3433
from rigging.timing import log_time
3534

3635
from zephyr.plan import deterministic_hash
36+
from zephyr.spill_writer import SpillWriter
3737
from zephyr.writers import ensure_parent_dir
3838

3939
logger = logging.getLogger(__name__)
@@ -571,12 +571,7 @@ def _write_parquet_scatter(
571571
seg_idx = 0
572572
seg_paths: list[str] = []
573573
schema: pa.Schema | None = None
574-
writer: pq.ParquetWriter | None = None
575-
seg_file = ""
576-
577-
pending_chunk: pa.RecordBatch | None = None
578-
pending_target: int = -1
579-
pending_cnt: int = 0
574+
spill_writer: SpillWriter | None = None
580575

581576
avg_item_bytes: float = 0.0
582577
_sampled_avg = False
@@ -586,57 +581,36 @@ def _get_buffer(target: int) -> _ShardBuffer:
586581
buffers[target] = _ShardBuffer(shard_idx=target, pickled=pickled, has_sort=sort_fn is not None)
587582
return buffers[target]
588583

589-
def _flush_pending() -> None:
590-
nonlocal n_chunks_flushed, pending_chunk
591-
if pending_chunk is None:
592-
return
593-
writer.write_batch(pending_chunk)
594-
seg_shard_counts[seg_idx][pending_target] = seg_shard_counts[seg_idx].get(pending_target, 0) + 1
595-
n_chunks_flushed += 1
596-
pending_chunk = None
597-
if n_chunks_flushed % 10 == 0:
598-
logger.info(
599-
"[shard %d segment %d] Wrote %d parquet chunks so far (latest chunk size: %d items)",
600-
source_shard,
601-
seg_idx,
602-
n_chunks_flushed,
603-
pending_cnt,
604-
)
605-
606584
def _ensure_writer(chunk_schema: pa.Schema) -> pa.Schema:
607-
nonlocal schema, writer, seg_file, seg_idx
585+
nonlocal schema, spill_writer, seg_idx
608586
if schema is None:
609587
schema = chunk_schema
610588
seg_file = _segment_path(parquet_path, seg_idx)
611589
seg_paths.append(seg_file)
612590
ensure_parent_dir(seg_file)
613-
writer = pq.ParquetWriter(seg_file, schema, compression="zstd", compression_level=1)
591+
spill_writer = SpillWriter(seg_file, schema)
614592
elif chunk_schema != schema:
615-
_flush_pending()
616-
writer.close()
593+
spill_writer.close()
617594
schema = pa.unify_schemas([schema, chunk_schema])
618595
seg_idx += 1
619596
for buf in buffers.values():
620597
buf.chunk_idx = 0
621598
seg_file = _segment_path(parquet_path, seg_idx)
622599
seg_paths.append(seg_file)
623600
ensure_parent_dir(seg_file)
624-
writer = pq.ParquetWriter(seg_file, schema, compression="zstd", compression_level=1)
601+
spill_writer = SpillWriter(seg_file, schema)
625602
logger.info(
626603
"[shard %d] Schema evolved after %d chunks; starting segment %d",
627604
source_shard,
628605
n_chunks_flushed,
629606
seg_idx,
630607
)
631-
else:
632-
_flush_pending()
633608
return schema
634609

635610
def _flush_buffer(buf: _ShardBuffer) -> None:
636-
nonlocal pending_chunk, pending_target, pending_cnt, avg_item_bytes, _sampled_avg
611+
nonlocal n_chunks_flushed, avg_item_bytes, _sampled_avg
637612

638613
if combiner_fn is not None:
639-
# Combiner path: drain buffer to Python, apply combiner, re-sort in Arrow
640614
buf._flush_micro()
641615
if not buf.tables:
642616
return
@@ -664,9 +638,21 @@ def _flush_buffer(buf: _ShardBuffer) -> None:
664638
write_schema = _ensure_writer(batch.schema)
665639
if batch.schema != write_schema:
666640
batch = batch.cast(write_schema)
667-
pending_chunk = batch
668-
pending_target = buf.shard_idx
669-
pending_cnt = len(batch)
641+
642+
# Each sorted chunk is its own row group (distinct shard/chunk metadata).
643+
batch_table = pa.Table.from_batches([batch])
644+
spill_writer.write_row_group(batch_table)
645+
seg_shard_counts[seg_idx][buf.shard_idx] = seg_shard_counts[seg_idx].get(buf.shard_idx, 0) + 1
646+
n_chunks_flushed += 1
647+
648+
if n_chunks_flushed % 10 == 0:
649+
logger.info(
650+
"[shard %d segment %d] Wrote %d parquet chunks so far (latest chunk size: %d items)",
651+
source_shard,
652+
seg_idx,
653+
n_chunks_flushed,
654+
len(batch),
655+
)
670656

671657
if not _sampled_avg and len(batch) > 0:
672658
avg_item_bytes = batch.nbytes / len(batch)
@@ -682,16 +668,14 @@ def _flush_buffer(buf: _ShardBuffer) -> None:
682668
_flush_buffer(buf)
683669

684670
with log_time(f"Flushing remaining buffers for {parquet_path}"):
685-
_flush_pending()
686671
for target in sorted(buffers.keys()):
687672
buf = buffers[target]
688673
if buf.item_count == 0:
689674
continue
690675
_flush_buffer(buf)
691-
_flush_pending()
692676

693-
if writer is not None:
694-
writer.close()
677+
if spill_writer is not None:
678+
spill_writer.close()
695679

696680
per_shard_max_rows: dict[int, int] = {target: buf.max_rows for target, buf in buffers.items() if buf.max_rows > 0}
697681

0 commit comments

Comments
 (0)