Skip to content

Commit 8ca1768

Browse files
committed
zephyr: address review — ScatterWriter/Reader classes, cloudpickle, exact memory budget
Review feedback from rjpower on #4782: - cloudpickle on write: `_write_chunk_frame` now uses `cloudpickle.dump` so lambdas, local classes, and dynamically-defined callables survive scatter serialization. Read side stays stdlib `pickle.load`. - ScatterWriter class: extracts `_write_scatter` into a class with `write(item)` / `close()` / context manager. Nonlocal state becomes instance attributes. `_write_scatter` remains as a thin wrapper. - ScatterReader class: collapses `ScatterShard` + `ScatterFileIterator` + `_build_scatter_shard_from_manifest`. `ScatterShard = ScatterReader` alias preserves backward compat for isinstance checks. `ScatterReader.from_manifest()` classmethod replaces the factory fn. - url_to_fs resolved once per file: `ScatterFileIterator.__post_init__` resolves (fs, fs_path) once; `_iter_chunk` receives them directly. - Memory budget from chunk compressed sizes: fan-in computation uses `max_compressed_chunk_bytes` (exact from sidecar) instead of `max_chunk_rows * avg_item_bytes` heuristic. `avg_item_bytes` kept only for `compute_write_batch_size`. - Removed stats-unavailable fallback in plan.py: stats are always written by the scatter writer; the dead-code fallback path is gone.
1 parent bf00f9e commit 8ca1768

File tree

3 files changed

+216
-140
lines changed

3 files changed

+216
-140
lines changed

lib/zephyr/src/zephyr/execution.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def read(self) -> list:
111111
from zephyr.shuffle import ( # noqa: E402
112112
ListShard,
113113
MemChunk,
114-
ScatterShard, # noqa: F401 — re-exported for plan.py and external callers
114+
ScatterReader, # noqa: F401 — re-exported for plan.py and external callers
115+
ScatterShard, # noqa: F401 — backward-compat alias for ScatterReader
116+
ScatterWriter, # noqa: F401 — re-exported for external callers
115117
_build_scatter_shard_from_manifest, # noqa: F401 — re-exported for plan.py
116118
_write_scatter,
117119
_write_scatter_manifest,

lib/zephyr/src/zephyr/plan.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from iris.env_resources import TaskResources as _TaskResources
2626
from rigging.filesystem import url_to_fs
2727

28-
from zephyr.external_sort import EXTERNAL_SORT_FAN_IN, external_sort_merge
28+
from zephyr.external_sort import external_sort_merge
2929

3030
from zephyr.dataset import (
3131
Dataset,
@@ -635,7 +635,7 @@ def merge_key(item):
635635

636636
# Check if external sort is needed BEFORE materializing all iterators.
637637
# ScatterShard can decide using manifest stats (no file opens needed).
638-
from zephyr.shuffle import ScatterShard
638+
from zephyr.shuffle import ScatterShard # ScatterShard is an alias for ScatterReader
639639

640640
use_external = (
641641
external_sort_dir is not None
@@ -648,10 +648,8 @@ def merge_key(item):
648648

649649
memory_limit = _TaskResources.from_environment().memory_bytes
650650
# Per-iterator memory ~= compressed bytes for one chunk held by
651-
# cat_file. Use uncompressed (max_chunk_rows * avg_item_bytes) as a
652-
# conservative upper bound — scatter writes ASCII-ish data with
653-
# mediocre zstd ratio.
654-
per_iter_bytes = int(shard.max_chunk_rows * shard.avg_item_bytes)
651+
# cat_file. Use the actual max compressed chunk size from the sidecar.
652+
per_iter_bytes = shard.max_compressed_chunk_bytes
655653
fan_in = compute_fan_in(per_iter_bytes, memory_limit)
656654
write_batch_size = compute_write_batch_size(shard.avg_item_bytes)
657655
logger.info(
@@ -674,11 +672,7 @@ def merge_key(item):
674672
else:
675673
chunk_iterators = list(shard.get_iterators())
676674
logger.info(f"Merging {len(chunk_iterators):,} sorted chunk iterators")
677-
if external_sort_dir is not None and len(chunk_iterators) > EXTERNAL_SORT_FAN_IN:
678-
# Fallback: stats unavailable, use the hard-cap fan-in.
679-
merged_stream = external_sort_merge(iter(chunk_iterators), merge_key, external_sort_dir)
680-
else:
681-
merged_stream = heapq.merge(*chunk_iterators, key=merge_key)
675+
merged_stream = heapq.merge(*chunk_iterators, key=merge_key)
682676
yield from groupby(merged_stream, key=key_fn)
683677

684678

0 commit comments

Comments
 (0)