Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions lib/zephyr/src/zephyr/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ def read(self) -> list:
ScatterReader, # noqa: F401 — re-exported for plan.py and external callers
ScatterShard, # noqa: F401 — backward-compat alias for ScatterReader
ScatterWriter, # noqa: F401 — re-exported for external callers
_build_scatter_shard_from_manifest, # noqa: F401 — re-exported for plan.py
_write_scatter,
_write_scatter_manifest,
_SCATTER_MANIFEST_NAME,
)

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1296,25 +1293,23 @@ def _regroup_result_refs(
"""Regroup worker output refs by output shard index without loading data.

Non-scatter: each worker's ListShard maps to its own index (identity).
Scatter: writes a consolidated scatter manifest combining all sidecar
metadata into a single file, then gives each reducer a shard containing
just the manifest path.
Scatter: passes the list of scatter data-file paths to every reducer.
Each reducer reads the per-mapper ``.scatter_meta`` sidecars in parallel
to build its own ``ScatterReader`` without coordinator-side consolidation.
"""
num_output = max(max(result_refs.keys(), default=0) + 1, input_shard_count)
if output_shard_count is not None:
num_output = max(num_output, output_shard_count)

if is_scatter:
# Collect all scatter file paths from all workers
# Collect all scatter file paths from all workers. The coordinator
# does NOT read the sidecars or write a consolidated manifest —
# reducers do their own parallel sidecar reads.
all_paths: list[str] = []
for result in result_refs.values():
all_paths.extend(result.shard)

# Write consolidated manifest and point reducers at it
manifest_path = f"{scatter_manifest_dir}/{_SCATTER_MANIFEST_NAME}"
_write_scatter_manifest(all_paths, manifest_path)
shared_refs = MemChunk(items=[manifest_path])

shared_refs = MemChunk(items=all_paths)
return [ListShard(refs=[shared_refs]) for _ in range(num_output)]

# Non-scatter: each result's shard maps to its own index
Expand Down
14 changes: 7 additions & 7 deletions lib/zephyr/src/zephyr/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,16 +841,16 @@ def run_stage(
return

elif isinstance(op, Reduce):
# Build ScatterShard from scatter manifest if needed,
# then merge sorted chunks and reduce per key.
from zephyr.execution import ScatterShard, _build_scatter_shard_from_manifest
# Build ScatterReader directly from per-mapper sidecars, then
# merge sorted chunks and reduce per key.
from zephyr.execution import ScatterShard

shard = ctx.shard
if not isinstance(shard, ScatterShard):
# Shard contains a single manifest path — read it to build ScatterShard
paths = list(shard)
assert len(paths) == 1, f"Expected single scatter manifest path, got {len(paths)}"
shard = _build_scatter_shard_from_manifest(paths[0], ctx.shard_idx)
# Shard contains every mapper's scatter-data path — reducer
# reads all sidecars in parallel and filters for its target.
scatter_paths = list(shard)
shard = ScatterShard.from_sidecars(scatter_paths, ctx.shard_idx)
stream = _reduce_gen(
shard, op.key_fn, op.reducer_fn, sort_fn=op.sort_fn, external_sort_dir=external_sort_dir
)
Expand Down
71 changes: 31 additions & 40 deletions lib/zephyr/src/zephyr/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,12 @@ def get_iterators(self) -> Iterator[Iterator]:
# ---------------------------------------------------------------------------

_SCATTER_META_SUFFIX = ".scatter_meta"
_SCATTER_MANIFEST_NAME = "scatter_metadata"
_SCATTER_DATA_SUFFIX = ".shuffle"

_SCATTER_META_READ_CONCURRENCY = 256
# Number of parallel sidecar reads each reducer issues when building its
# ScatterReader. Sidecars are small JSON files (a few KB) and reads are
# GCS GET-bound, so a modest pool keeps latency low without thrashing.
_SIDECAR_READ_CONCURRENCY = 32
# Number of items sampled from the first flush to estimate avg_item_bytes.
_SCATTER_SAMPLE_SIZE = 100
# Fraction of total memory budgeted for read-side decompression buffers.
Expand Down Expand Up @@ -118,9 +120,8 @@ def _write_scatter_meta(data_path: str, sidecar: dict) -> None:
f.write(payload)


# Per-worker caches for sidecar + manifest reads.
# Per-worker cache for sidecar reads.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these caches? Doesn't this belong in the ScatterReader?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need a cache let's put it in the worker itself, e.g. maybe a scatter reader can be reused.

_scatter_meta_cache: dict[str, dict] = {}
_scatter_manifest_cache: dict[str, list[dict]] = {}


def _read_scatter_meta(data_path: str) -> dict:
Expand All @@ -131,35 +132,22 @@ def _read_scatter_meta(data_path: str) -> dict:
return _scatter_meta_cache[meta_path]


def _read_scatter_manifest(manifest_path: str) -> list[dict]:
if manifest_path not in _scatter_manifest_cache:
with open_url(manifest_path, "r") as f:
_scatter_manifest_cache[manifest_path] = json.loads(f.read())
return _scatter_manifest_cache[manifest_path]
def _read_sidecars_parallel(scatter_paths: list[str]) -> list[tuple[str, dict]]:
"""Read every ``.scatter_meta`` sidecar concurrently, preserving input order.


def _write_scatter_manifest(scatter_paths: list[str], output_path: str) -> None:
"""Aggregate ``.scatter_meta`` sidecars into a single manifest.

Sidecar reads run in parallel since each is an independent GCS GET.
Each reducer calls this to build its ``ScatterReader`` directly from the
per-mapper sidecars, without going through a coordinator-written manifest.
"""

def _read_entry(path: str) -> tuple[str, dict]:
meta = _read_scatter_meta(path)
return path, {"path": path, **meta}
return path, _read_scatter_meta(path)

results: dict[str, dict] = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=_SCATTER_META_READ_CONCURRENCY) as pool:
for path, entry in pool.map(_read_entry, scatter_paths):
results[path] = entry

entries = [results[path] for path in scatter_paths]
with concurrent.futures.ThreadPoolExecutor(max_workers=_SIDECAR_READ_CONCURRENCY) as pool:
for path, meta in pool.map(_read_entry, scatter_paths):
results[path] = meta

ensure_parent_dir(output_path)
payload = json.dumps(entries)
with log_time(f"Writing scatter manifest ({len(entries)} files) to {output_path}"):
with open_url(output_path, "w") as f:
f.write(payload)
return [(path, results[path]) for path in scatter_paths]


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -246,34 +234,42 @@ def __init__(
self.avg_item_bytes: float = avg_item_bytes

@classmethod
def from_manifest(cls, manifest_path: str, target_shard: int) -> ScatterReader:
"""Build a ScatterReader for one target shard from the consolidated manifest."""
entries = _read_scatter_manifest(manifest_path)
def from_sidecars(cls, scatter_paths: list[str], target_shard: int) -> ScatterReader:
"""Build a ScatterReader by reading per-mapper sidecars directly.

Each reducer reads every mapper's ``.scatter_meta`` sidecar in parallel
and filters for its own ``target_shard``. No coordinator-written manifest
is needed, which eliminates a serialization bottleneck when there are
thousands of mappers.
"""
shard_key = str(target_shard)

iterators: list[ScatterFileIterator] = []
max_rows = 0
weighted_bytes = 0.0
total_chunks_for_avg = 0

with log_time(f"Building ScatterReader for target shard {target_shard} from manifest ({len(entries)} files)"):
for entry in entries:
shards = entry.get("shards", {})
with log_time(
f"Building ScatterReader for target shard {target_shard} "
f"from {len(scatter_paths)} sidecars (concurrency={_SIDECAR_READ_CONCURRENCY})"
):
for path, meta in _read_sidecars_parallel(scatter_paths):
shards = meta.get("shards", {})
ranges = shards.get(shard_key)
if not ranges:
continue

iterators.append(
ScatterFileIterator(
path=entry["path"],
path=path,
chunks=tuple((int(off), int(length)) for off, length in ranges),
)
)

per_shard_max = entry.get("max_chunk_rows", {})
per_shard_max = meta.get("max_chunk_rows", {})
max_rows = max(max_rows, per_shard_max.get(shard_key, 0))

ab = entry.get("avg_item_bytes", 0.0)
ab = meta.get("avg_item_bytes", 0.0)
if ab > 0:
count = len(ranges)
weighted_bytes += ab * count
Expand Down Expand Up @@ -324,11 +320,6 @@ def needs_external_sort(self, memory_limit: int, memory_fraction: float = 0.5) -
ScatterShard = ScatterReader


def _build_scatter_shard_from_manifest(manifest_path: str, target_shard: int) -> ScatterReader:
"""Build a ScatterReader for one target shard from the consolidated manifest."""
return ScatterReader.from_manifest(manifest_path, target_shard)


# ---------------------------------------------------------------------------
# Combiner / sort helper
# ---------------------------------------------------------------------------
Expand Down
40 changes: 18 additions & 22 deletions lib/zephyr/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
from zephyr.shuffle import (
ScatterFileIterator,
ScatterShard,
_build_scatter_shard_from_manifest,
_write_chunk_frame,
_write_scatter,
_write_scatter_manifest,
)


Expand All @@ -27,7 +25,7 @@ def _target(key, num_shards):


def _build_shard(tmp_path, items, num_output_shards=4, source_shard=0):
"""Write a scatter file and manifest; return (manifest_path, data_paths)."""
"""Write a scatter file + sidecar; return scatter_paths for direct reducer reads."""
data_path = str(tmp_path / f"shard-{source_shard:04d}.shuffle")
list_shard = _write_scatter(
iter(items),
Expand All @@ -36,10 +34,8 @@ def _build_shard(tmp_path, items, num_output_shards=4, source_shard=0):
key_fn=_key,
num_output_shards=num_output_shards,
)
data_paths = list(list_shard)
manifest_path = str(tmp_path / "scatter_metadata")
_write_scatter_manifest(data_paths, manifest_path)
return manifest_path, data_paths
scatter_paths = list(list_shard)
return scatter_paths


# ---------------------------------------------------------------------------
Expand All @@ -51,11 +47,11 @@ def test_scatter_roundtrip(tmp_path):
"""All items written via scatter are recovered when reading all shards."""
num_shards = 4
items = [{"k": i % 4, "v": i} for i in range(40)]
manifest_path, _ = _build_shard(tmp_path, items, num_output_shards=num_shards)
scatter_paths = _build_shard(tmp_path, items, num_output_shards=num_shards)

recovered = []
for shard_idx in range(num_shards):
shard = _build_scatter_shard_from_manifest(manifest_path, shard_idx)
shard = ScatterShard.from_sidecars(scatter_paths, shard_idx)
recovered.extend(list(shard))

assert sorted(recovered, key=lambda x: x["v"]) == sorted(items, key=lambda x: x["v"])
Expand All @@ -65,10 +61,10 @@ def test_scatter_each_shard_gets_correct_items(tmp_path):
"""Items are routed to shards by deterministic_hash(key) % num_shards."""
num_shards = 4
items = [{"k": i % 4, "v": i} for i in range(40)]
manifest_path, _ = _build_shard(tmp_path, items, num_output_shards=num_shards)
scatter_paths = _build_shard(tmp_path, items, num_output_shards=num_shards)

for shard_idx in range(num_shards):
shard = _build_scatter_shard_from_manifest(manifest_path, shard_idx)
shard = ScatterShard.from_sidecars(scatter_paths, shard_idx)
recovered = sorted(list(shard), key=lambda x: x["v"])
expected = sorted([x for x in items if _target(x["k"], num_shards) == shard_idx], key=lambda x: x["v"])
assert recovered == expected, f"shard {shard_idx} mismatch"
Expand All @@ -77,10 +73,10 @@ def test_scatter_each_shard_gets_correct_items(tmp_path):
def test_scatter_roundtrip_sorted_chunks(tmp_path):
"""Each chunk iterator from get_iterators() yields items sorted by key."""
items = [{"k": i % 2, "v": i} for i in range(20)]
manifest_path, _ = _build_shard(tmp_path, items, num_output_shards=2)
scatter_paths = _build_shard(tmp_path, items, num_output_shards=2)

for shard_idx in range(2):
shard = _build_scatter_shard_from_manifest(manifest_path, shard_idx)
shard = ScatterShard.from_sidecars(scatter_paths, shard_idx)
for chunk_iter in shard.get_iterators():
chunk = list(chunk_iter)
keys = [_key(x) for x in chunk]
Expand All @@ -98,10 +94,10 @@ def test_max_chunk_rows_per_shard(tmp_path):
items = [{"k": 3, "v": i} for i in range(500)]
items += [{"k": 0, "v": i + 1000} for i in range(2)]

manifest_path, _ = _build_shard(tmp_path, items, num_output_shards=num_shards)
scatter_paths = _build_shard(tmp_path, items, num_output_shards=num_shards)

shard0 = _build_scatter_shard_from_manifest(manifest_path, 0)
shard1 = _build_scatter_shard_from_manifest(manifest_path, 1)
shard0 = ScatterShard.from_sidecars(scatter_paths, 0)
shard1 = ScatterShard.from_sidecars(scatter_paths, 1)

assert shard0.max_chunk_rows == 500
assert shard1.max_chunk_rows == 2, (
Expand All @@ -127,8 +123,8 @@ def test_needs_external_sort_triggers():

def test_needs_external_sort_below_threshold(tmp_path):
items = [{"k": 0, "v": i} for i in range(5)]
manifest_path, _ = _build_shard(tmp_path, items, num_output_shards=1)
shard = _build_scatter_shard_from_manifest(manifest_path, 0)
scatter_paths = _build_shard(tmp_path, items, num_output_shards=1)
shard = ScatterShard.from_sidecars(scatter_paths, 0)
assert not shard.needs_external_sort(memory_limit=32 * 1024**3)


Expand All @@ -144,8 +140,8 @@ def test_needs_external_sort_empty_shard():

def test_avg_item_bytes_written(tmp_path):
items = [{"k": 0, "v": i} for i in range(20)]
manifest_path, _ = _build_shard(tmp_path, items, num_output_shards=1)
shard = _build_scatter_shard_from_manifest(manifest_path, 0)
scatter_paths = _build_shard(tmp_path, items, num_output_shards=1)
shard = ScatterShard.from_sidecars(scatter_paths, 0)
assert shard.avg_item_bytes > 0


Expand All @@ -162,11 +158,11 @@ def test_scatter_handles_arbitrary_python_objects(tmp_path):
{"k": 1, "v": None},
{"k": 1, "v": frozenset([6])},
]
manifest_path, _ = _build_shard(tmp_path, items, num_output_shards=2)
scatter_paths = _build_shard(tmp_path, items, num_output_shards=2)

recovered = []
for shard_idx in range(2):
shard = _build_scatter_shard_from_manifest(manifest_path, shard_idx)
shard = ScatterShard.from_sidecars(scatter_paths, shard_idx)
recovered.extend(list(shard))

def _ord(x):
Expand Down
Loading