Skip to content

Commit bfe399b

Browse files
committed
[zephyr] Drop dead output_shard_count handling in non-scatter regroup
stage.output_shards is only ever set on Scatter (via group_by) and Reshard stages, and Reshard short-circuits before reaching _regroup_result_refs. The non-scatter branch's max-with-output_shard_count was dead code suggesting semantics that don't exist; resharding belongs to ReshardOp.
1 parent b5ee7ea commit bfe399b

1 file changed

Lines changed: 7 additions & 8 deletions

File tree

lib/zephyr/src/zephyr/execution.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,17 +1323,14 @@ def _regroup_result_refs(
13231323
Each reducer reads the per-mapper ``.scatter_meta`` sidecars in parallel
13241324
to build its own ``ScatterReader`` without coordinator-side consolidation.
13251325
"""
1326-
if is_scatter and output_shard_count is not None:
1326+
if is_scatter:
13271327
# Scatter routes records into exactly ``output_shard_count`` buckets via
13281328
# ``hash(key) % output_shard_count``; spawning more reduce tasks than that
13291329
# produces empty output files for shard indices that no record hashes to.
1330-
num_output = output_shard_count
1331-
else:
1332-
num_output = max(max(result_refs.keys(), default=0) + 1, input_shard_count)
1333-
if output_shard_count is not None:
1334-
num_output = max(num_output, output_shard_count)
1330+
# When output_shard_count is None (group_by auto-detect), inherit the
1331+
# input shard count.
1332+
num_output = output_shard_count if output_shard_count is not None else input_shard_count
13351333

1336-
if is_scatter:
13371334
# Collect all scatter file paths from all workers. The coordinator
13381335
# does NOT read the sidecars or write a consolidated manifest —
13391336
# reducers do their own parallel sidecar reads.
@@ -1344,7 +1341,9 @@ def _regroup_result_refs(
13441341
shared_refs = MemChunk(items=all_paths)
13451342
return [ListShard(refs=[shared_refs]) for _ in range(num_output)]
13461343

1347-
# Non-scatter: each result's shard maps to its own index
1344+
# Non-scatter: 1:1 mapping from input shard index to output. Resharding
1345+
# to a different shard count belongs to ReshardOp, not here.
1346+
num_output = max(max(result_refs.keys(), default=0) + 1, input_shard_count)
13481347
return [result_refs[idx].shard if idx in result_refs else ListShard(refs=[]) for idx in range(num_output)]
13491348

13501349

0 commit comments

Comments
 (0)