Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ def forward(
inputs_embeds = inputs_embeds.transpose(1, 0)

if self.config.sequence_parallel:
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)
tp_group = self.config._pg_collection.tp if self.config._pg_collection is not None else None
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds, group=tp_group)

# Compute 3D MRoPE position IDs on ALL pipeline stages
# Each stage has input_ids and visual grid info from the data iterator
Expand Down
3 changes: 2 additions & 1 deletion src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def forward(
# bypassed when decoder_input is provided. Matches Megatron Core's LLaVA pattern
# (llava_model.py:747-750): CP slice first, then SP scatter → [S/(CP*TP), B, H].
if self.config.sequence_parallel and inputs_embeds is not None:
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)
tp_group = self.config._pg_collection.tp if self.config._pg_collection is not None else None
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds, group=tp_group)

outputs = self.language_model.forward(
input_ids=None,
Expand Down
3 changes: 2 additions & 1 deletion src/megatron/bridge/models/gemma_vl/modeling_gemma4_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def forward(

# SP scatter
if self.config.sequence_parallel and inputs_embeds is not None:
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)
tp_group = self.config._pg_collection.tp if self.config._pg_collection is not None else None
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds, group=tp_group)

outputs = self.language_model.forward(
input_ids=None,
Expand Down
3 changes: 2 additions & 1 deletion src/megatron/bridge/models/glm_vl/modeling_glm_45v.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def forward(
inputs_embeds = inputs_embeds.transpose(1, 0).contiguous()

if self.config.sequence_parallel:
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)
tp_group = self.config._pg_collection.tp if self.config._pg_collection is not None else None
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds, group=tp_group)

# Compute MRoPE position_ids on ALL pipeline stages
# Each stage has input_ids and visual grid info from the data iterator
Expand Down
3 changes: 2 additions & 1 deletion src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,8 @@ def forward(
inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # (B, T, D) -> (T, B, D)

if self.config.sequence_parallel:
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)
tp_group = self.config._pg_collection.tp if self.config._pg_collection is not None else None
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds, group=tp_group)

outputs = self.language_model.forward(
input_ids=None,
Expand Down
3 changes: 2 additions & 1 deletion src/megatron/bridge/models/ministral3/modeling_ministral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def forward(
# bypassed when decoder_input is provided. Matches Megatron Core's LLaVA pattern
# (llava_model.py:747-750): CP slice first, then SP scatter → [S/(CP*TP), B, H].
if self.config.sequence_parallel and inputs_embeds is not None:
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)
tp_group = self.config._pg_collection.tp if self.config._pg_collection is not None else None
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds, group=tp_group)

# Forward through Megatron language model
outputs = self.language_model.forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def forward(
sp_pad_len = (tp_size - seq_len % tp_size) % tp_size
if sp_pad_len > 0:
combined_embeddings = torch.nn.functional.pad(combined_embeddings, (0, 0, 0, 0, 0, sp_pad_len))
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings)
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
combined_embeddings, group=self.pg_collection.tp
)
combined_embeddings = combined_embeddings.contiguous()
else:
combined_embeddings = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ def forward(
sp_pad_len = (tp_size - seq_len % tp_size) % tp_size
if sp_pad_len > 0:
combined_embeddings = torch.nn.functional.pad(combined_embeddings, (0, 0, 0, 0, 0, sp_pad_len))
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings)
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
combined_embeddings, group=self.pg_collection.tp
)
combined_embeddings = combined_embeddings.contiguous()
else:
combined_embeddings = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,9 @@ def forward(
combined_embeddings = torch.nn.functional.pad(combined_embeddings, (0, 0, 0, 0, 0, sp_pad_len))
if visual_pos_masks is not None:
visual_pos_masks = torch.nn.functional.pad(visual_pos_masks, (0, sp_pad_len), value=False)
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings)
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
combined_embeddings, group=self.pg_collection.tp
)
combined_embeddings = combined_embeddings.contiguous()
else:
combined_embeddings = None
Expand Down
3 changes: 2 additions & 1 deletion src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def forward(
) # [b, decoder_seq_len, h_language] -> [decoder_seq_len, b, h_language]

if self.config.sequence_parallel:
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)
tp_group = self.config._pg_collection.tp if self.config._pg_collection is not None else None
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds, group=tp_group)

# Compute MRoPE position_ids on ALL pipeline stages
# Each stage has input_ids and visual grid info from the data iterator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,9 @@ def forward(
combined_embeddings = combined_embeddings_thd

if self.config.sequence_parallel:
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings)
combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
combined_embeddings, group=self.pg_collection.tp
)
combined_embeddings = combined_embeddings.contiguous()

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def forward(

def _sp_scatter_embedding(input_ids, position_ids):
out = _original_embedding(input_ids=input_ids, position_ids=position_ids)
return tensor_parallel.scatter_to_sequence_parallel_region(out)
return tensor_parallel.scatter_to_sequence_parallel_region(out, group=self.pg_collection.tp)

_sp_scatter_embedding.word_embeddings = _original_embedding.word_embeddings
self.__dict__["embedding"] = _sp_scatter_embedding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,11 @@ def test_image_forward_sequence_parallel_path(self, thinker_config, monkeypatch)
model = self._build_model(thinker_config)
model.thinker.config.sequence_parallel = True

calls = {"scatter": 0, "split": 0}
calls = {"scatter": 0, "scatter_group": None, "split": 0}

def _identity_scatter(x):
def _identity_scatter(x, *, group=None):
calls["scatter"] += 1
calls["scatter_group"] = group
return x

def _identity_split(visual_pos_masks, deepstack_visual_embeds, **kwargs):
Expand Down Expand Up @@ -381,6 +382,7 @@ def _identity_split(visual_pos_masks, deepstack_visual_embeds, **kwargs):

assert output is not None
assert calls["scatter"] == 1
assert calls["scatter_group"] is model.thinker.pg_collection.tp
assert calls["split"] == 1

def test_audio_forward(self, thinker_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Unit tests for Qwen3VL text model forward behavior."""

from types import SimpleNamespace

import torch

from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import Qwen3VLGPTModel
Expand Down Expand Up @@ -74,3 +76,51 @@ def test_forward_accepts_extra_preprocess_output():
assert dummy.decoder.called_with["sequence_len_offset"] is preproc[4]
assert not any(value is preproc[5] for value in dummy.decoder.called_with.values())
assert dummy.postprocess_args["decoder_input"] is preproc[0]


def test_mtp_sequence_parallel_embedding_scatter_uses_tp_group(monkeypatch):
"""The MTP embedding wrapper must not fall back to global tensor-parallel state."""
expected_group = object()
calls = {"group": None}

def _identity_scatter(x, *, group=None):
calls["group"] = group
return x

monkeypatch.setattr(
"megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model.tensor_parallel.scatter_to_sequence_parallel_region",
_identity_scatter,
)

class _DummyEmbedding:
word_embeddings = object()

def __call__(self, *, input_ids, position_ids): # noqa: ARG002
return torch.ones(1, 1, 1)

class _DummyMTPModel(_DummyModel):
def __init__(self):
super().__init__()
self.config = SimpleNamespace(sequence_parallel=True)
self.embedding = _DummyEmbedding()
self.mtp_process = True
self.pg_collection = SimpleNamespace(tp=expected_group)

def _postprocess(self, **kwargs):
self.embedding(input_ids=kwargs["input_ids"], position_ids=kwargs["position_ids"])
return "ok"

dummy = _DummyMTPModel()
input_ids = torch.zeros((1, 4), dtype=torch.long)
position_ids = torch.zeros((1, 4), dtype=torch.long)
attention_mask = torch.ones((1, 4), dtype=torch.long)

output = Qwen3VLGPTModel.forward(
dummy,
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
)

assert output == "ok"
assert calls["group"] is expected_group
71 changes: 71 additions & 0 deletions tests/unit_tests/models/test_sequence_parallel_scatter_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Regression tests for explicit process-group sequence-parallel scatters."""

import ast
from pathlib import Path

import pytest


_ROOT = Path(__file__).parents[3]
_MODELS_ROOT = _ROOT / "src/megatron/bridge/models"

pytestmark = pytest.mark.unit

# Intentionally-unfixed bare scatter sites go here as
# "src/megatron/bridge/models/...py:<line>": "short rationale".
# Keep this empty unless a model has no explicit TP group available at the call site.
_BARE_SCATTER_EXCLUSIONS: dict[str, str] = {}


def _modeling_sources() -> list[Path]:
return sorted(
path
for path in _MODELS_ROOT.rglob("*.py")
if any(part.startswith(("modeling", "modelling")) for part in path.relative_to(_MODELS_ROOT).parts)
)


def _scatter_calls(tree: ast.AST) -> list[ast.Call]:
calls = []
for node in ast.walk(tree):
if not isinstance(node, ast.Call):
continue
func = node.func
if isinstance(func, ast.Attribute) and func.attr == "scatter_to_sequence_parallel_region":
calls.append(node)
elif isinstance(func, ast.Name) and func.id == "scatter_to_sequence_parallel_region":
calls.append(node)
return calls


def test_explicit_process_group_scatter_sites_pass_group():
missing_group = []
seen_exclusions = set()

for path in _modeling_sources():
relative_path = path.relative_to(_ROOT).as_posix()
tree = ast.parse(path.read_text(), filename=str(path))
for call in _scatter_calls(tree):
if not any(keyword.arg == "group" for keyword in call.keywords):
location = f"{relative_path}:{call.lineno}"
if location in _BARE_SCATTER_EXCLUSIONS:
seen_exclusions.add(location)
else:
missing_group.append(location)

assert missing_group == []
assert set(_BARE_SCATTER_EXCLUSIONS) == seen_exclusions
Loading