Skip to content

Commit fc4597c

Browse files
wdykasclaude
andauthored
Disag MR1: Add inference shard specs and pg-collection building (NVIDIA#5186)
Signed-off-by: wdykas <wdykas@nvidia.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
1 parent ae2efd5 commit fc4597c

8 files changed

Lines changed: 511 additions & 181 deletions

File tree

examples/rl/benchmark_refit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from megatron.training import get_args, get_model as get_training_model, print_rank_0
1515
from megatron.training.initialize import initialize_megatron
1616
from megatron.training.arguments import core_transformer_config_from_args
17-
from megatron.rl.parallel_utils import build_inference_pg_collection
17+
from megatron.core.inference.shards import build_inference_pg_collection
1818
from gpt_builders import gpt_builder
1919
from megatron.core.resharding.copy_services.nvshmem_copy_service import NVSHMEMCopyService
2020
from megatron.core.resharding.copy_services.nccl_copy_service import NCCLCopyService

megatron/core/inference/shards.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
3+
"""Framework-agnostic primitives for heterogeneous inference sharding: build a
4+
``ProcessGroupCollection`` per shard, each over a contiguous rank window at its
5+
own parallelism."""
6+
7+
from dataclasses import dataclass
8+
from typing import List, Optional, Sequence, Union
9+
10+
import torch.distributed as dist
11+
12+
from megatron.core import mpu
13+
from megatron.core.hyper_comm_grid import HyperCommGrid
14+
from megatron.core.inference.shards_spec import InferenceShardSpec, normalize_shard_specs
15+
from megatron.core.process_groups_config import ProcessGroupCollection
16+
17+
18+
def build_inference_pg_collection(
19+
world_size: int,
20+
tp_size: Optional[int] = None,
21+
pp_size: Optional[int] = None,
22+
cp_size: Optional[int] = None,
23+
ep_size: Optional[int] = None,
24+
expt_tp_size: Optional[int] = None,
25+
use_tp_pp_dp_mapping: bool = False,
26+
rank_offset: int = 0,
27+
) -> ProcessGroupCollection:
28+
"""Build a ProcessGroupCollection for one inference model.
29+
30+
Uses two HyperCommGrids matching mpu:
31+
- decoder_grid for dense/attention layers (tp, cp, dp, pp)
32+
- expert_grid for MoE expert layers (expt_tp, ep, expt_dp, pp)
33+
34+
Args:
35+
world_size: Number of ranks in this inference window.
36+
tp_size: Tensor model parallel size. Defaults to training's TP size.
37+
pp_size: Pipeline parallel size. Defaults to training's PP size.
38+
cp_size: Context parallel size. Defaults to training's CP size.
39+
ep_size: Expert parallel size. Defaults to training's EP size.
40+
expt_tp_size: Expert tensor parallel size. Defaults to training's
41+
expert TP size.
42+
use_tp_pp_dp_mapping: If True, use 'tp-pp-dp' order; otherwise
43+
'tp-dp-pp'.
44+
rank_offset: Starting global rank of the window. Use ``0`` for
45+
collocated inference (shares ranks with training); use a non-zero
46+
offset for non-collocated setups where inference ranks are disjoint
47+
from training ranks.
48+
49+
Returns:
50+
ProcessGroupCollection configured for the inference model. On ranks
51+
outside the ``[rank_offset, rank_offset + world_size)`` window every
52+
process-group field is a non-member sentinel returned by
53+
:func:`torch.distributed.new_subgroups_by_enumeration` — callers should
54+
not use that instance; see
55+
:func:`build_inference_pg_collections_for_shards` for the right way to
56+
filter.
57+
"""
58+
if tp_size is None:
59+
tp_size = mpu.get_tensor_model_parallel_world_size()
60+
if cp_size is None:
61+
cp_size = mpu.get_context_parallel_world_size()
62+
if pp_size is None:
63+
pp_size = mpu.get_pipeline_model_parallel_world_size()
64+
if ep_size is None:
65+
ep_size = mpu.get_expert_model_parallel_world_size()
66+
if expt_tp_size is None:
67+
expt_tp_size = mpu.get_expert_tensor_parallel_world_size()
68+
69+
# Dense layer DP size (world = tp * cp * dp * pp)
70+
dp_size = world_size // (tp_size * cp_size * pp_size)
71+
assert dp_size >= 1 and (tp_size * cp_size * dp_size * pp_size) == world_size, (
72+
f"World size ({world_size}) must be divisible by tp*cp*pp "
73+
f"({tp_size * cp_size * pp_size})"
74+
)
75+
76+
# Expert DP size (world = expt_tp * ep * expt_dp * pp)
77+
expt_dp_size = world_size // (expt_tp_size * ep_size * pp_size)
78+
assert expt_dp_size >= 1 and (expt_tp_size * ep_size * expt_dp_size * pp_size) == world_size, (
79+
f"World size ({world_size}) must be divisible by expt_tp*ep*pp "
80+
f"({expt_tp_size * ep_size * pp_size})"
81+
)
82+
83+
rank = dist.get_rank()
84+
85+
if use_tp_pp_dp_mapping:
86+
decoder_grid = HyperCommGrid(
87+
[tp_size, cp_size, pp_size, dp_size], ["tp", "cp", "pp", "dp"], rank_offset=rank_offset
88+
)
89+
else:
90+
decoder_grid = HyperCommGrid(
91+
[tp_size, cp_size, dp_size, pp_size], ["tp", "cp", "dp", "pp"], rank_offset=rank_offset
92+
)
93+
94+
tp_group = decoder_grid.create_pg("tp")
95+
cp_group = decoder_grid.create_pg("cp")
96+
pp_group = decoder_grid.create_pg("pp")
97+
dp_group = decoder_grid.create_pg("dp")
98+
mp_group = decoder_grid.create_pg(["tp", "pp"])
99+
tp_cp_group = decoder_grid.create_pg(["tp", "cp"])
100+
dp_cp_group = decoder_grid.create_pg(["cp", "dp"])
101+
tp_dp_cp_group = decoder_grid.create_pg(["tp", "cp", "dp"])
102+
103+
if use_tp_pp_dp_mapping:
104+
expert_grid = HyperCommGrid(
105+
[expt_tp_size, ep_size, pp_size, expt_dp_size],
106+
["tp", "ep", "pp", "dp"],
107+
rank_offset=rank_offset,
108+
)
109+
else:
110+
expert_grid = HyperCommGrid(
111+
[expt_tp_size, ep_size, expt_dp_size, pp_size],
112+
["tp", "ep", "dp", "pp"],
113+
rank_offset=rank_offset,
114+
)
115+
116+
decoder_pp_enum = decoder_grid.get_rank_enum("pp")
117+
expert_pp_enum = expert_grid.get_rank_enum("pp")
118+
assert decoder_pp_enum == expert_pp_enum, (
119+
f"PP groups must match between decoder and expert grids. "
120+
f"Decoder: {decoder_pp_enum}, Expert: {expert_pp_enum}"
121+
)
122+
123+
ep_group = expert_grid.create_pg("ep")
124+
expt_tp_group = expert_grid.create_pg("tp")
125+
expt_dp_group = expert_grid.create_pg("dp")
126+
tp_ep_group = expert_grid.create_pg(["tp", "ep"])
127+
tp_ep_pp_group = expert_grid.create_pg(["tp", "ep", "pp"])
128+
129+
embd_group = None
130+
pos_embd_group = None
131+
pp_rank_enum = decoder_grid.get_rank_enum("pp")
132+
for pp_ranks in pp_rank_enum:
133+
if len(pp_ranks) == 1:
134+
embd_ranks = [pp_ranks[0]]
135+
else:
136+
embd_ranks = [pp_ranks[0], pp_ranks[-1]]
137+
group = dist.new_group(ranks=embd_ranks)
138+
if rank in embd_ranks:
139+
embd_group = group
140+
pos_embd_ranks = [pp_ranks[0]]
141+
group = dist.new_group(ranks=pos_embd_ranks)
142+
if rank in pos_embd_ranks:
143+
pos_embd_group = group
144+
145+
return ProcessGroupCollection(
146+
tp=tp_group,
147+
cp=cp_group,
148+
pp=pp_group,
149+
ep=ep_group,
150+
embd=embd_group,
151+
pos_embd=pos_embd_group,
152+
dp=dp_group,
153+
tp_cp=tp_cp_group,
154+
mp=mp_group,
155+
expt_tp=expt_tp_group,
156+
expt_dp=expt_dp_group,
157+
tp_ep=tp_ep_group,
158+
tp_ep_pp=tp_ep_pp_group,
159+
dp_cp=dp_cp_group,
160+
tp_dp_cp=tp_dp_cp_group,
161+
)
162+
163+
164+
@dataclass
165+
class InferenceShard:
166+
"""One shard in a multi-shard inference layout: its identity, its rank
167+
window, and this rank's process groups for it.
168+
169+
Attributes:
170+
spec: This shard's :class:`~megatron.core.inference.shards_spec.InferenceShardSpec`
171+
(``tp``/``pp``/``ep``/``expt_tp``/``dp`` and optional
172+
``role`` = ``prefill``/``decode``).
173+
rank_offset: First global rank belonging to this shard.
174+
world_size: Number of ranks in this shard (tp*pp*dp).
175+
pg_collection: The shard's ProcessGroupCollection if the current rank
176+
is a member of this shard, else ``None`` -- the ``is not None`` check
177+
is how a rank finds its own shard. Every rank still participates in
178+
the collective process-group creation for every shard
179+
(``dist.new_group`` is world-collective); only members get a usable
180+
handle.
181+
"""
182+
183+
spec: InferenceShardSpec
184+
rank_offset: int
185+
world_size: int
186+
pg_collection: Optional[ProcessGroupCollection]
187+
188+
189+
def build_inference_pg_collections_for_shards(
190+
total_world_size: int,
191+
shards: Union[str, Sequence[InferenceShardSpec], Sequence[dict]],
192+
use_tp_pp_dp_mapping: bool = False,
193+
) -> List[InferenceShard]:
194+
"""Build one ProcessGroupCollection per heterogeneous inference shard.
195+
196+
Partitions global ranks into contiguous non-overlapping windows, one per
197+
shard. Shard ``i`` owns ranks
198+
``[offset_i, offset_i + tp_i*pp_i*dp_i)``.
199+
200+
Every rank must call this function so the collective ``dist.new_group``
201+
calls inside :func:`build_inference_pg_collection` succeed for every shard.
202+
The returned ``pg_collection`` is populated only on ranks belonging to
203+
that shard; others see ``None``.
204+
205+
Args:
206+
total_world_size: Full world size across training + all inference
207+
shards.
208+
shards: Shard layout in any form ``normalize_shard_specs`` accepts -- a
209+
spec string, a list of :class:`InferenceShardSpec`, or a list of
210+
raw dicts. Normalized internally to validated specs.
211+
use_tp_pp_dp_mapping: Passed through to ``build_inference_pg_collection``.
212+
213+
Returns:
214+
List of :class:`InferenceShard`, one per input spec.
215+
"""
216+
specs = normalize_shard_specs(shards, total_world_size)
217+
rank = dist.get_rank()
218+
results: List[InferenceShard] = []
219+
offset = 0
220+
for i, spec in enumerate(specs):
221+
tp, pp, ep, expt_tp, dp = spec.tp, spec.pp, spec.ep, spec.expt_tp, spec.dp
222+
shard_world = tp * pp * dp
223+
assert offset + shard_world <= total_world_size, (
224+
f"Shard {i} ({spec}) runs out of ranks: needs "
225+
f"[{offset}, {offset + shard_world}), total_world_size={total_world_size}."
226+
)
227+
pgc = build_inference_pg_collection(
228+
world_size=shard_world,
229+
tp_size=tp,
230+
pp_size=pp,
231+
# Inference shards don't context-parallelize; the spec validates cp == 1.
232+
cp_size=spec.cp,
233+
ep_size=ep,
234+
expt_tp_size=expt_tp,
235+
use_tp_pp_dp_mapping=use_tp_pp_dp_mapping,
236+
rank_offset=offset,
237+
)
238+
in_shard = offset <= rank < offset + shard_world
239+
results.append(
240+
InferenceShard(
241+
spec=spec,
242+
rank_offset=offset,
243+
world_size=shard_world,
244+
pg_collection=pgc if in_shard else None,
245+
)
246+
)
247+
offset += shard_world
248+
return results

0 commit comments

Comments
 (0)