Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions lib/zephyr/src/zephyr/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,9 +1323,15 @@ def _regroup_result_refs(
Each reducer reads the per-mapper ``.scatter_meta`` sidecars in parallel
to build its own ``ScatterReader`` without coordinator-side consolidation.
"""
num_output = max(max(result_refs.keys(), default=0) + 1, input_shard_count)
if output_shard_count is not None:
num_output = max(num_output, output_shard_count)
if is_scatter and output_shard_count is not None:
# Scatter routes records into exactly ``output_shard_count`` buckets via
# ``hash(key) % output_shard_count``; spawning more reduce tasks than that
# produces empty output files for shard indices that no record hashes to.
num_output = output_shard_count
else:
num_output = max(max(result_refs.keys(), default=0) + 1, input_shard_count)
if output_shard_count is not None:
num_output = max(num_output, output_shard_count)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior here is a little weird for non-scatter cases. It seems like for non-scatter we should just ignore output shards? If a user wants a different shard count we should be going through a reshard?

Copy link
Copy Markdown
Contributor

@ravwojdyla ravwojdyla Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point - that's cleaner and more direct - will update


if is_scatter:
# Collect all scatter file paths from all workers. The coordinator
Expand Down
33 changes: 33 additions & 0 deletions lib/zephyr/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,39 @@ def test_deduplicate_with_num_output_shards(zephyr_ctx):
assert ids == [0, 1, 2]


def test_group_by_num_output_shards_smaller_than_input(zephyr_ctx, tmp_path):
"""``num_output_shards`` is authoritative even when the input has more shards.

Regression test for marin#5162: scatter writes records into ``num_output_shards``
buckets (``hash(key) % num_output_shards``), but the reduce stage previously
spawned ``max(input_shards, num_output_shards)`` tasks. The "extra" reduce
tasks ran on a ``shard_idx`` that scatter never wrote to and emitted empty
output files.
"""
output_dir = tmp_path / "out"
output_pattern = str(output_dir / "data-{shard:05d}-of-{total:05d}.parquet")

ds = (
Dataset.from_list([{"id": i, "val": i} for i in range(60)])
.reshard(10)
.group_by(
key=lambda x: x["id"],
reducer=lambda k, items: next(iter(items)),
num_output_shards=3,
)
.write_parquet(output_pattern)
)

output_files = zephyr_ctx.execute(ds).results

assert len(output_files) == 3, (
f"expected 3 output files (= num_output_shards), got {len(output_files)}; "
"extra files come from reduce tasks that ran on padding shards"
)
for p in output_files:
assert "of-00003" in p, f"unexpected total in {p}"


def test_group_by_with_hash_key_large(zephyr_ctx, large_document_dataset):
"""Test group_by with MD5 hash on larger dataset, counting duplicates."""

Expand Down
Loading