Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
71821c2
retrieve PGCollection from legacy globals via parallel_state in setup
yaoyu-33 Oct 23, 2025
911ec14
Merge branch 'main' into m4/0_prepare
yaoyu-33 Oct 23, 2025
3f7ff31
fix setup
yaoyu-33 Oct 23, 2025
57c971a
pass pg_collection directly not leverage global state
yaoyu-33 Oct 29, 2025
b6a2b59
add unit test
yaoyu-33 Oct 30, 2025
70ae249
license
yaoyu-33 Oct 30, 2025
14607ba
lint
yaoyu-33 Oct 30, 2025
e5acfd9
Merge branch 'main' into m4/0_prepare
yaoyu-33 Oct 30, 2025
224e1a3
fix unit tests
yaoyu-33 Oct 30, 2025
7ca7dee
fix pretrain api
yaoyu-33 Oct 31, 2025
05939dc
remove parallel_state from train.py
yaoyu-33 Nov 3, 2025
bac52e2
update gpt_step and vlm_step to not rely on parallel_state
yaoyu-33 Nov 3, 2025
0a2e29f
add util to get pg collection from model
yaoyu-33 Nov 3, 2025
384488b
remove parallel state from train utils
yaoyu-33 Nov 3, 2025
aa82d5e
unit test update
yaoyu-33 Nov 3, 2025
1b54119
unit tests fixes
yaoyu-33 Nov 3, 2025
2f57a7b
Merge branch 'main' into m4/1_train_loops_and_steps
yaoyu-33 Nov 3, 2025
3df91bf
update get_pg_collection to use get_attr_wrapped_model
yaoyu-33 Nov 3, 2025
10acaad
Merge branch 'main' into m4/1_train_loops_and_steps
yaoyu-33 Nov 4, 2025
c8e6636
update model provider to m4
yaoyu-33 Nov 4, 2025
44c9bb4
update model providers for m4
yaoyu-33 Nov 4, 2025
6a5a16b
fix model provider unit tests
yaoyu-33 Nov 5, 2025
ca797f8
fix unit tests
yaoyu-33 Nov 5, 2025
71f43cb
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Nov 12, 2025
522acef
lint
yaoyu-33 Nov 12, 2025
a1cdf4c
fix unit test
yaoyu-33 Nov 21, 2025
a26232b
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Dec 5, 2025
5ea19fb
add pg_collection in model providers
yaoyu-33 Dec 5, 2025
5e6eedf
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Dec 5, 2025
8ca3d4c
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Dec 6, 2025
cbd5490
update mlm and provider
yaoyu-33 Dec 6, 2025
7012412
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Dec 15, 2025
2dc3e34
merge main
yaoyu-33 Dec 15, 2025
1a2ed53
update to use `_pg_collection`
yaoyu-33 Dec 15, 2025
a4c956e
update to use `_pg_collection`
yaoyu-33 Dec 15, 2025
ae9aa19
Revert "update to use `_pg_collection`"
yaoyu-33 Dec 15, 2025
69de3f4
fix unit test
yaoyu-33 Dec 16, 2025
766a397
fix tests
yaoyu-33 Dec 17, 2025
c416fe1
fix tests
yaoyu-33 Dec 17, 2025
c2b090c
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Dec 17, 2025
32ebce4
try fix
yaoyu-33 Dec 18, 2025
abb1dc1
Simplify process group removal
yaoyu-33 Dec 18, 2025
6d227ed
Merge remote-tracking branch 'origin/main' into m4/3_model_provider
yaoyu-33 Dec 31, 2025
9fe6ebc
remove pg_collection from transformer config when pass in to mcore
yaoyu-33 Dec 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/megatron/bridge/models/conversion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ def remove_non_pickleables(obj, max_depth: int = 3, current_depth: int = 0):
if obj is None:
return obj

# Explicitly drop process group objects without importing their classes directly.
cls = obj if isinstance(obj, type) else type(obj)
cls_module = getattr(cls, "__module__", "")
cls_name = getattr(cls, "__qualname__", getattr(cls, "__name__", ""))
if (cls_module, cls_name) in {
("megatron.core.process_groups_config", "ProcessGroupCollection"),
("torch._C._distributed_c10d", "ProcessGroup"),
}:
return None

# Check if object is a problematic callable
if callable(obj):
# Allow classes/types but remove function objects, methods, partials
Expand Down
20 changes: 15 additions & 5 deletions src/megatron/bridge/models/gemma/gemma2_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.pipeline_parallel.utils import (
is_pp_first_stage,
is_pp_last_stage,
is_vp_first_stage,
is_vp_last_stage,
)
from megatron.core.tensor_parallel import ColumnParallelLinear
from megatron.core.transformer import (
MegatronModule,
Expand Down Expand Up @@ -87,7 +93,7 @@ def __init__(
projection_size = self.config.kv_channels * self.config.num_attention_heads

# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
world_size = self.config.tensor_model_parallel_size
self.hidden_size_per_partition = divide(projection_size, world_size)
self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
Expand Down Expand Up @@ -172,8 +178,8 @@ def forward(
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0] * output_size[1], output_size[2], output_size[3]),
query.dtype,
"mpu",
dtype=query.dtype,
device=query.device,
)

# Raw attention scores. [b * np, sq, sk]
Expand Down Expand Up @@ -377,11 +383,15 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> "MCoreG
model = super().provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage)

# Apply Embedding Scaling for Gemma2: sqrt(hidden_size)
if parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage):
if is_vp_first_stage(
vp_stage=vp_stage, vp_size=self.virtual_pipeline_model_parallel_size
) and is_pp_first_stage(self._pg_collection.pp):
extend_instance(model.embedding, EmbeddingScalingMixin)

# Prevents final logits from growing excessively by scaling them to a fixed range
if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage):
if is_vp_last_stage(vp_stage=vp_stage, vp_size=self.virtual_pipeline_model_parallel_size) and is_pp_last_stage(
self._pg_collection.pp
):
extend_instance(model.output_layer, Gemma2OutputLayer)

return model
Expand Down
9 changes: 4 additions & 5 deletions src/megatron/bridge/models/gemma/gemma_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from typing import Callable

import torch
from megatron.core import parallel_state
from megatron.core.activations import fast_gelu
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_vp_first_stage
from megatron.core.transformer.enums import AttnBackend

from megatron.bridge.models.gpt_provider import GPTModelProvider
Expand Down Expand Up @@ -70,10 +70,9 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> "MCoreG
model = super().provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage)

# Apply Embedding Scaling for Gemma: sqrt(hidden_size)
if parallel_state.is_pipeline_first_stage(
ignore_virtual=False,
vp_stage=vp_stage,
):
if is_vp_first_stage(
vp_stage=vp_stage, vp_size=self.virtual_pipeline_model_parallel_size
) and is_pp_first_stage(self._pg_collection.pp):
from megatron.bridge.models.gemma.modules import EmbeddingScalingMixin, extend_instance

extend_instance(model.embedding, EmbeddingScalingMixin)
Expand Down
21 changes: 11 additions & 10 deletions src/megatron/bridge/models/gpt_full_te_layer_autocast_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import packaging
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core import tensor_parallel
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.transformer.cuda_graphs import CudaGraphManager
from megatron.core.transformer.module import MegatronModule
Expand Down Expand Up @@ -186,7 +186,7 @@ def __init__(self, config, layer_number=1, hidden_dropout=None, **kwargs):
"attention_dropout": config.attention_dropout,
"layer_number": layer_number + self._get_layer_offset(),
"kv_channels": config.kv_channels,
"tp_size": parallel_state.get_tensor_model_parallel_world_size(),
"tp_size": config.tensor_model_parallel_size,
"params_dtype": config.params_dtype,
"get_rng_state_tracker": tensor_parallel.random.get_cuda_rng_tracker,
"fuse_wgrad_accumulation": config.gradient_accumulation_fusion,
Expand Down Expand Up @@ -265,15 +265,16 @@ def forward(
return hidden_states, context

def _get_layer_offset(self):
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
# Derive pipeline/virtual pipeline indices from provided pg_collection/config
pp_group = getattr(self.config, "_pg_collection", None).pp if hasattr(self.config, "_pg_collection") else None
pipeline_rank = pp_group.rank() if pp_group is not None else 0

num_layers_per_pipeline_rank = (
self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
num_layers_per_pipeline_rank = self.config.num_layers // self.config.pipeline_model_parallel_size

if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
vp_size = getattr(self.config, "virtual_pipeline_model_parallel_size", None)
vp_rank = getattr(self.config, "_vp_stage", None)
if vp_size is not None:
assert vp_rank is not None, "_vp_stage must be set on config when using virtual pipeline parallelism"

total_num_layers = self.config.num_layers
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
Expand All @@ -282,7 +283,7 @@ def _get_layer_offset(self):

else:
# Each stage gets a contiguous set of layers.
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.config.pipeline_model_parallel_size > 1:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = 0
Expand Down
51 changes: 37 additions & 14 deletions src/megatron/bridge/models/gpt_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import contextlib
import copy
import inspect
import logging
from dataclasses import dataclass, field
Expand All @@ -22,13 +23,19 @@
import modelopt.torch.distill as mtd
import modelopt.torch.distill.plugins.megatron as mtd_mcore
import torch
from megatron.core import parallel_state
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.pipeline_parallel.utils import (
is_pp_first_stage,
is_pp_last_stage,
is_vp_first_stage,
is_vp_last_stage,
)
from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.dot_product_attention import DotProductAttention as MCoreDotProductAttention
from megatron.core.transformer.enums import AttnBackend
Expand Down Expand Up @@ -95,6 +102,8 @@ def local_layer_spec(config: "GPTModelProvider") -> ModuleSpec:

def quantization_layer_spec(config: "GPTModelProvider") -> ModuleSpec:
"""Layer specification for quantization with ModelOpt."""
from megatron.core import parallel_state

use_arbitrary_attention_mask = parallel_state.get_context_parallel_world_size() == 1
# arbitrary attention mask is used for speculative decoding training
# When context parallel > 1, only causal mask type is supported
Expand Down Expand Up @@ -184,6 +193,8 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[MCoreGPTModel]):
# When resuming modelopt_state, we also change the transformer_layer_spec to `megatron.core.post_training.modelopt.gpt.model_specs` which is a combination of local spec + TEDotProductAttention.
restore_modelopt_state: bool = False

_pg_collection: Optional[ProcessGroupCollection] = None

def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel:
"""Configure and instantiate a Megatron Core GPT model based on this configuration.

Expand Down Expand Up @@ -250,9 +261,24 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGP
if self.attention_backend == AttnBackend.local:
if hasattr(transformer_layer_spec, "submodules"):
transformer_layer_spec.submodules.self_attention.submodules.core_attention = MCoreDotProductAttention
# Determine pre/post flags if not provided using vp + pp stage
if pre_process is None:
pre_process = is_vp_first_stage(vp_stage=vp_stage, vp_size=vp_size) and is_pp_first_stage(
self._pg_collection.pp
)
if post_process is None:
post_process = is_vp_last_stage(vp_stage=vp_stage, vp_size=vp_size) and is_pp_last_stage(
self._pg_collection.pp
)
# Expose vp stage on config for downstream modules (e.g., TE layers)
# so they can compute correct offsets without legacy globals.
self._vp_stage = vp_stage
config_for_model = copy.copy(self)
config_for_model._pg_collection = None
config_for_model._vp_stage = vp_stage
with model_init_device_context():
model = MCoreGPTModel(
self,
config_for_model,
transformer_layer_spec=transformer_layer_spec,
vocab_size=padded_vocab_size,
max_sequence_length=self.seq_length,
Expand All @@ -263,11 +289,10 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGP
rotary_percent=self.rotary_percent,
rotary_base=self.rotary_base,
seq_len_interpolation_factor=self.seq_len_interpolation_factor,
pre_process=pre_process
or parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage),
post_process=post_process
or parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage),
pre_process=pre_process,
post_process=post_process,
scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel,
pg_collection=self._pg_collection,
vp_stage=vp_stage,
**kwargs,
)
Expand All @@ -279,25 +304,23 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGP
if self.use_transformer_engine_full_layer_spec:
# Copied from:
# https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py
if parallel_state.get_tensor_model_parallel_world_size() > 1:
if self._pg_collection.tp.size() > 1:
for index, child in enumerate(model.modules()):
if index == 0:
continue
if hasattr(child, "set_tensor_parallel_group"):
tp_group = parallel_state.get_tensor_model_parallel_group()
tp_group = self._pg_collection.tp
child.set_tensor_parallel_group(tp_group)

if parallel_state.get_context_parallel_world_size() > 1:
if self._pg_collection.cp.size() > 1:
cp_stream = torch.cuda.Stream()
for index, child in enumerate(model.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(
parallel_state.get_context_parallel_group(),
parallel_state.get_context_parallel_global_ranks(),
cp_stream,
)
cp_group = self._pg_collection.cp
cp_global_ranks = torch.distributed.get_process_group_ranks(cp_group)
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)

return model

Expand Down
15 changes: 11 additions & 4 deletions src/megatron/bridge/models/mamba/mamba_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import warnings
from dataclasses import dataclass
from typing import Callable, Literal, Optional, Union

import torch
from megatron.core import parallel_state
from megatron.core.models.mamba import MambaModel as MCoreMambaModel
from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec as default_mamba_stack_spec
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage
from megatron.core.post_training.modelopt.mamba.model_specs import get_mamba_stack_modelopt_spec
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer import ModuleSpec
from megatron.core.transformer.enums import AttnBackend

Expand Down Expand Up @@ -125,6 +127,7 @@ class MambaModelProvider(TransformerConfig, ModelProviderMixin[MCoreMambaModel])
vocab_size: Optional[int] = None
should_pad_vocab: bool = False
hf_model_id: Optional[str] = None
_pg_collection: Optional[ProcessGroupCollection] = None
"""Optional HuggingFace model identifier associated with this provider."""

# If True, restore the modelopt_state that contains quantization, sparsity, speculative decoding transformation state.
Expand Down Expand Up @@ -165,8 +168,11 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreMa
else:
padded_vocab_size = self.vocab_size

model_config = copy.copy(self)
model_config._pg_collection = None

return MCoreMambaModel(
self,
model_config,
mamba_stack_spec=mamba_stack_spec,
vocab_size=padded_vocab_size,
max_sequence_length=self.seq_length,
Expand All @@ -180,8 +186,9 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreMa
rotary_percent=self.rotary_percent,
rotary_base=self.rotary_base,
seq_len_interpolation_factor=self.seq_len_interpolation_factor,
pre_process=pre_process or parallel_state.is_pipeline_first_stage(),
post_process=post_process or parallel_state.is_pipeline_last_stage(),
pre_process=pre_process or is_pp_first_stage(self._pg_collection.pp),
post_process=post_process or is_pp_last_stage(self._pg_collection.pp),
pg_collection=self._pg_collection,
)


Expand Down
Loading
Loading