|
| 1 | +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. |
| 2 | + |
| 3 | +"""Framework-agnostic primitives for heterogeneous inference sharding: build a |
| 4 | +``ProcessGroupCollection`` per shard, each over a contiguous rank window at its |
| 5 | +own parallelism.""" |
| 6 | + |
| 7 | +from dataclasses import dataclass |
| 8 | +from typing import List, Optional, Sequence, Union |
| 9 | + |
| 10 | +import torch.distributed as dist |
| 11 | + |
| 12 | +from megatron.core import mpu |
| 13 | +from megatron.core.hyper_comm_grid import HyperCommGrid |
| 14 | +from megatron.core.inference.shards_spec import InferenceShardSpec, normalize_shard_specs |
| 15 | +from megatron.core.process_groups_config import ProcessGroupCollection |
| 16 | + |
| 17 | + |
| 18 | +def build_inference_pg_collection( |
| 19 | + world_size: int, |
| 20 | + tp_size: Optional[int] = None, |
| 21 | + pp_size: Optional[int] = None, |
| 22 | + cp_size: Optional[int] = None, |
| 23 | + ep_size: Optional[int] = None, |
| 24 | + expt_tp_size: Optional[int] = None, |
| 25 | + use_tp_pp_dp_mapping: bool = False, |
| 26 | + rank_offset: int = 0, |
| 27 | +) -> ProcessGroupCollection: |
| 28 | + """Build a ProcessGroupCollection for one inference model. |
| 29 | +
|
| 30 | + Uses two HyperCommGrids matching mpu: |
| 31 | + - decoder_grid for dense/attention layers (tp, cp, dp, pp) |
| 32 | + - expert_grid for MoE expert layers (expt_tp, ep, expt_dp, pp) |
| 33 | +
|
| 34 | + Args: |
| 35 | + world_size: Number of ranks in this inference window. |
| 36 | + tp_size: Tensor model parallel size. Defaults to training's TP size. |
| 37 | + pp_size: Pipeline parallel size. Defaults to training's PP size. |
| 38 | + cp_size: Context parallel size. Defaults to training's CP size. |
| 39 | + ep_size: Expert parallel size. Defaults to training's EP size. |
| 40 | + expt_tp_size: Expert tensor parallel size. Defaults to training's |
| 41 | + expert TP size. |
| 42 | + use_tp_pp_dp_mapping: If True, use 'tp-pp-dp' order; otherwise |
| 43 | + 'tp-dp-pp'. |
| 44 | + rank_offset: Starting global rank of the window. Use ``0`` for |
| 45 | + collocated inference (shares ranks with training); use a non-zero |
| 46 | + offset for non-collocated setups where inference ranks are disjoint |
| 47 | + from training ranks. |
| 48 | +
|
| 49 | + Returns: |
| 50 | + ProcessGroupCollection configured for the inference model. On ranks |
| 51 | + outside the ``[rank_offset, rank_offset + world_size)`` window every |
| 52 | + process-group field is a non-member sentinel returned by |
| 53 | + :func:`torch.distributed.new_subgroups_by_enumeration` — callers should |
| 54 | + not use that instance; see |
| 55 | + :func:`build_inference_pg_collections_for_shards` for the right way to |
| 56 | + filter. |
| 57 | + """ |
| 58 | + if tp_size is None: |
| 59 | + tp_size = mpu.get_tensor_model_parallel_world_size() |
| 60 | + if cp_size is None: |
| 61 | + cp_size = mpu.get_context_parallel_world_size() |
| 62 | + if pp_size is None: |
| 63 | + pp_size = mpu.get_pipeline_model_parallel_world_size() |
| 64 | + if ep_size is None: |
| 65 | + ep_size = mpu.get_expert_model_parallel_world_size() |
| 66 | + if expt_tp_size is None: |
| 67 | + expt_tp_size = mpu.get_expert_tensor_parallel_world_size() |
| 68 | + |
| 69 | + # Dense layer DP size (world = tp * cp * dp * pp) |
| 70 | + dp_size = world_size // (tp_size * cp_size * pp_size) |
| 71 | + assert dp_size >= 1 and (tp_size * cp_size * dp_size * pp_size) == world_size, ( |
| 72 | + f"World size ({world_size}) must be divisible by tp*cp*pp " |
| 73 | + f"({tp_size * cp_size * pp_size})" |
| 74 | + ) |
| 75 | + |
| 76 | + # Expert DP size (world = expt_tp * ep * expt_dp * pp) |
| 77 | + expt_dp_size = world_size // (expt_tp_size * ep_size * pp_size) |
| 78 | + assert expt_dp_size >= 1 and (expt_tp_size * ep_size * expt_dp_size * pp_size) == world_size, ( |
| 79 | + f"World size ({world_size}) must be divisible by expt_tp*ep*pp " |
| 80 | + f"({expt_tp_size * ep_size * pp_size})" |
| 81 | + ) |
| 82 | + |
| 83 | + rank = dist.get_rank() |
| 84 | + |
| 85 | + if use_tp_pp_dp_mapping: |
| 86 | + decoder_grid = HyperCommGrid( |
| 87 | + [tp_size, cp_size, pp_size, dp_size], ["tp", "cp", "pp", "dp"], rank_offset=rank_offset |
| 88 | + ) |
| 89 | + else: |
| 90 | + decoder_grid = HyperCommGrid( |
| 91 | + [tp_size, cp_size, dp_size, pp_size], ["tp", "cp", "dp", "pp"], rank_offset=rank_offset |
| 92 | + ) |
| 93 | + |
| 94 | + tp_group = decoder_grid.create_pg("tp") |
| 95 | + cp_group = decoder_grid.create_pg("cp") |
| 96 | + pp_group = decoder_grid.create_pg("pp") |
| 97 | + dp_group = decoder_grid.create_pg("dp") |
| 98 | + mp_group = decoder_grid.create_pg(["tp", "pp"]) |
| 99 | + tp_cp_group = decoder_grid.create_pg(["tp", "cp"]) |
| 100 | + dp_cp_group = decoder_grid.create_pg(["cp", "dp"]) |
| 101 | + tp_dp_cp_group = decoder_grid.create_pg(["tp", "cp", "dp"]) |
| 102 | + |
| 103 | + if use_tp_pp_dp_mapping: |
| 104 | + expert_grid = HyperCommGrid( |
| 105 | + [expt_tp_size, ep_size, pp_size, expt_dp_size], |
| 106 | + ["tp", "ep", "pp", "dp"], |
| 107 | + rank_offset=rank_offset, |
| 108 | + ) |
| 109 | + else: |
| 110 | + expert_grid = HyperCommGrid( |
| 111 | + [expt_tp_size, ep_size, expt_dp_size, pp_size], |
| 112 | + ["tp", "ep", "dp", "pp"], |
| 113 | + rank_offset=rank_offset, |
| 114 | + ) |
| 115 | + |
| 116 | + decoder_pp_enum = decoder_grid.get_rank_enum("pp") |
| 117 | + expert_pp_enum = expert_grid.get_rank_enum("pp") |
| 118 | + assert decoder_pp_enum == expert_pp_enum, ( |
| 119 | + f"PP groups must match between decoder and expert grids. " |
| 120 | + f"Decoder: {decoder_pp_enum}, Expert: {expert_pp_enum}" |
| 121 | + ) |
| 122 | + |
| 123 | + ep_group = expert_grid.create_pg("ep") |
| 124 | + expt_tp_group = expert_grid.create_pg("tp") |
| 125 | + expt_dp_group = expert_grid.create_pg("dp") |
| 126 | + tp_ep_group = expert_grid.create_pg(["tp", "ep"]) |
| 127 | + tp_ep_pp_group = expert_grid.create_pg(["tp", "ep", "pp"]) |
| 128 | + |
| 129 | + embd_group = None |
| 130 | + pos_embd_group = None |
| 131 | + pp_rank_enum = decoder_grid.get_rank_enum("pp") |
| 132 | + for pp_ranks in pp_rank_enum: |
| 133 | + if len(pp_ranks) == 1: |
| 134 | + embd_ranks = [pp_ranks[0]] |
| 135 | + else: |
| 136 | + embd_ranks = [pp_ranks[0], pp_ranks[-1]] |
| 137 | + group = dist.new_group(ranks=embd_ranks) |
| 138 | + if rank in embd_ranks: |
| 139 | + embd_group = group |
| 140 | + pos_embd_ranks = [pp_ranks[0]] |
| 141 | + group = dist.new_group(ranks=pos_embd_ranks) |
| 142 | + if rank in pos_embd_ranks: |
| 143 | + pos_embd_group = group |
| 144 | + |
| 145 | + return ProcessGroupCollection( |
| 146 | + tp=tp_group, |
| 147 | + cp=cp_group, |
| 148 | + pp=pp_group, |
| 149 | + ep=ep_group, |
| 150 | + embd=embd_group, |
| 151 | + pos_embd=pos_embd_group, |
| 152 | + dp=dp_group, |
| 153 | + tp_cp=tp_cp_group, |
| 154 | + mp=mp_group, |
| 155 | + expt_tp=expt_tp_group, |
| 156 | + expt_dp=expt_dp_group, |
| 157 | + tp_ep=tp_ep_group, |
| 158 | + tp_ep_pp=tp_ep_pp_group, |
| 159 | + dp_cp=dp_cp_group, |
| 160 | + tp_dp_cp=tp_dp_cp_group, |
| 161 | + ) |
| 162 | + |
| 163 | + |
| 164 | +@dataclass |
| 165 | +class InferenceShard: |
| 166 | + """One shard in a multi-shard inference layout: its identity, its rank |
| 167 | + window, and this rank's process groups for it. |
| 168 | +
|
| 169 | + Attributes: |
| 170 | + spec: This shard's :class:`~megatron.core.inference.shards_spec.InferenceShardSpec` |
| 171 | + (``tp``/``pp``/``ep``/``expt_tp``/``dp`` and optional |
| 172 | + ``role`` = ``prefill``/``decode``). |
| 173 | + rank_offset: First global rank belonging to this shard. |
| 174 | + world_size: Number of ranks in this shard (tp*pp*dp). |
| 175 | + pg_collection: The shard's ProcessGroupCollection if the current rank |
| 176 | + is a member of this shard, else ``None`` -- the ``is not None`` check |
| 177 | + is how a rank finds its own shard. Every rank still participates in |
| 178 | + the collective process-group creation for every shard |
| 179 | + (``dist.new_group`` is world-collective); only members get a usable |
| 180 | + handle. |
| 181 | + """ |
| 182 | + |
| 183 | + spec: InferenceShardSpec |
| 184 | + rank_offset: int |
| 185 | + world_size: int |
| 186 | + pg_collection: Optional[ProcessGroupCollection] |
| 187 | + |
| 188 | + |
| 189 | +def build_inference_pg_collections_for_shards( |
| 190 | + total_world_size: int, |
| 191 | + shards: Union[str, Sequence[InferenceShardSpec], Sequence[dict]], |
| 192 | + use_tp_pp_dp_mapping: bool = False, |
| 193 | +) -> List[InferenceShard]: |
| 194 | + """Build one ProcessGroupCollection per heterogeneous inference shard. |
| 195 | +
|
| 196 | + Partitions global ranks into contiguous non-overlapping windows, one per |
| 197 | + shard. Shard ``i`` owns ranks |
| 198 | + ``[offset_i, offset_i + tp_i*pp_i*dp_i)``. |
| 199 | +
|
| 200 | + Every rank must call this function so the collective ``dist.new_group`` |
| 201 | + calls inside :func:`build_inference_pg_collection` succeed for every shard. |
| 202 | + The returned ``pg_collection`` is populated only on ranks belonging to |
| 203 | + that shard; others see ``None``. |
| 204 | +
|
| 205 | + Args: |
| 206 | + total_world_size: Full world size across training + all inference |
| 207 | + shards. |
| 208 | + shards: Shard layout in any form ``normalize_shard_specs`` accepts -- a |
| 209 | + spec string, a list of :class:`InferenceShardSpec`, or a list of |
| 210 | + raw dicts. Normalized internally to validated specs. |
| 211 | + use_tp_pp_dp_mapping: Passed through to ``build_inference_pg_collection``. |
| 212 | +
|
| 213 | + Returns: |
| 214 | + List of :class:`InferenceShard`, one per input spec. |
| 215 | + """ |
| 216 | + specs = normalize_shard_specs(shards, total_world_size) |
| 217 | + rank = dist.get_rank() |
| 218 | + results: List[InferenceShard] = [] |
| 219 | + offset = 0 |
| 220 | + for i, spec in enumerate(specs): |
| 221 | + tp, pp, ep, expt_tp, dp = spec.tp, spec.pp, spec.ep, spec.expt_tp, spec.dp |
| 222 | + shard_world = tp * pp * dp |
| 223 | + assert offset + shard_world <= total_world_size, ( |
| 224 | + f"Shard {i} ({spec}) runs out of ranks: needs " |
| 225 | + f"[{offset}, {offset + shard_world}), total_world_size={total_world_size}." |
| 226 | + ) |
| 227 | + pgc = build_inference_pg_collection( |
| 228 | + world_size=shard_world, |
| 229 | + tp_size=tp, |
| 230 | + pp_size=pp, |
| 231 | + # Inference shards don't context-parallelize; the spec validates cp == 1. |
| 232 | + cp_size=spec.cp, |
| 233 | + ep_size=ep, |
| 234 | + expt_tp_size=expt_tp, |
| 235 | + use_tp_pp_dp_mapping=use_tp_pp_dp_mapping, |
| 236 | + rank_offset=offset, |
| 237 | + ) |
| 238 | + in_shard = offset <= rank < offset + shard_world |
| 239 | + results.append( |
| 240 | + InferenceShard( |
| 241 | + spec=spec, |
| 242 | + rank_offset=offset, |
| 243 | + world_size=shard_world, |
| 244 | + pg_collection=pgc if in_shard else None, |
| 245 | + ) |
| 246 | + ) |
| 247 | + offset += shard_world |
| 248 | + return results |
0 commit comments