Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,75 @@ def _saved_config_disables_mtp(path: str | Path) -> bool:
return _config_disables_mtp(json.load(f))


def _mtp_source_key_prefixes(source: Any, *configs: Any) -> tuple[str, ...]:
"""Source-checkpoint key prefixes for MTP/nextn tensors that must be ignored
when exporting a model built without an MTP head.

Different HF architectures name their Multi-Token-Prediction (nextn) tensors
differently:

* DeepSeek-style: a dedicated ``mtp.*`` prefix.
* GLM-4.x ``glm4_moe_lite``: the nextn layer is stored as a regular decoder
layer at index ``num_hidden_layers`` (i.e. one past the last real layer),
so its tensors live under ``model.layers.{num_hidden_layers}.*``.

When the megatron model is built without an MTP head the generator never
yields these tensors. If they remain in the source sharding map, the shards
that co-locate them with real (non-MTP) tensors can never be completed and
get dropped wholesale (taking real boundary params with them). Stripping
these prefixes from the expected source map lets those shards complete with
only their real keys.

Precondition: only call this when the built model omits MTP (see
``_model_omits_mtp``). The returned prefixes are always stripped, so calling
it for an MTP-enabled export would silently drop real nextn tensors.

Returns the tuple of prefixes that exist in ``source`` and should be ignored.
"""
prefixes: list[str] = []

if source.has_glob("mtp.*"):
prefixes.append("mtp.")

# GLM nextn layer at index == num_hidden_layers. HF decoder layers are
# 0-indexed, so layer ``num_hidden_layers`` is one past the last real layer.
# ``configs`` is ordered hf_config-before-model_config so the HF value wins
# on conflict (source keys are HF-shaped).
num_hidden_layers = _MISSING
for config in configs:
value = _get_config_field(config, "num_hidden_layers")
if value is _MISSING or value is None:
text_config = _get_config_field(config, "text_config")
if text_config is not _MISSING:
value = _get_config_field(text_config, "num_hidden_layers")
if value is not _MISSING and value is not None:
num_hidden_layers = int(value)
break

if num_hidden_layers is not _MISSING:
nextn_prefix = f"model.layers.{num_hidden_layers}."
if source.has_glob(f"{nextn_prefix}*"):
prefixes.append(nextn_prefix)

return tuple(prefixes)


def _model_omits_mtp(model_config: Any) -> bool:
"""True when the *built* megatron model has no MTP head.

Unlike :func:`_config_disables_mtp` (which treats ``None``/unset as
"unspecified"), the built model's provider carries an explicit
``mtp_num_layers``; a falsy value (``None`` or ``0``) means no MTP head was
instantiated, so the export generator will not yield any MTP/nextn tensors.
"""
if model_config is None:
return False
value = _get_config_field(model_config, "mtp_num_layers")
if value is _MISSING:
return False
return not value


# Preformatted display string for error/help messages
SUPPORTED_HF_ARCHITECTURES_DISPLAY = " or ".join(f"'{s}'" for s in SUPPORTED_HF_ARCHITECTURES)

Expand Down Expand Up @@ -1029,10 +1098,14 @@ def _filter_quant(gen):
source = self.hf_pretrained.state.source
model_config = getattr(model_instance, "config", None)
hf_config = getattr(self.hf_pretrained, "config", self.hf_pretrained)
mtp_disabled = _saved_config_disables_mtp(path) or any(
_config_disables_mtp(config) for config in (hf_config, model_config)
mtp_disabled = (
_saved_config_disables_mtp(path)
or any(_config_disables_mtp(config) for config in (hf_config, model_config))
or _model_omits_mtp(model_config)
)
ignored_source_key_prefixes = ("mtp.",) if mtp_disabled and source.has_glob("mtp.*") else None
ignored_source_key_prefixes = (
_mtp_source_key_prefixes(source, hf_config, model_config) if mtp_disabled else ()
) or None
source.save_generator(
generator,
path,
Expand Down
124 changes: 124 additions & 0 deletions tests/unit_tests/models/test_auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import json
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import Mock, PropertyMock, patch

import pytest
Expand All @@ -29,10 +30,13 @@
AutoBridge,
_config_disables_mtp,
_drop_readonly_config_properties,
_model_omits_mtp,
_mtp_source_key_prefixes,
_saved_config_disables_mtp,
)
from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource


def create_mock_pretrained_causal_lm():
Expand All @@ -45,6 +49,26 @@ def __init__(self):
return MockPreTrainedCausalLM()


def _make_fake_source(present):
"""Build a ``SafeTensorsStateSource`` stand-in for ``save_hf_weights`` tests.

Uses ``Mock(spec=...)`` so the ``isinstance(source, SafeTensorsStateSource)``
gate in ``save_hf_weights`` stays satisfied without bypassing the real
``__init__``. ``has_glob`` reports which source-key globs exist; the captured
``save_generator`` kwargs are exposed on ``source.save_generator_kwargs`` for
assertions.
"""
source = Mock(spec=SafeTensorsStateSource)
source.save_generator_kwargs = None
source.has_glob.side_effect = lambda pattern: pattern in present

def _capture_save_generator(generator, path, **kwargs):
source.save_generator_kwargs = kwargs

source.save_generator.side_effect = _capture_save_generator
return source


class TestAutoBridge:
"""Test cases for AutoBridge automatic selection and full bridge functionality."""

Expand Down Expand Up @@ -180,6 +204,106 @@ def test_mtp_disabled_helpers(self, tmp_path):

assert _saved_config_disables_mtp(tmp_path) is True

def test_model_omits_mtp(self):
"""A built model with a falsy mtp_num_layers has no MTP head."""
assert _model_omits_mtp(None) is False
# Unset attribute -> unknown -> do not assume omitted.
assert _model_omits_mtp(SimpleNamespace()) is False
# SkyRL forces mtp_num_layers=None -> head omitted from export.
assert _model_omits_mtp(Mock(mtp_num_layers=None)) is True
assert _model_omits_mtp(Mock(mtp_num_layers=0)) is True
assert _model_omits_mtp(Mock(mtp_num_layers=1)) is False

def test_mtp_source_key_prefixes(self):
"""Resolve the MTP/nextn source-key prefixes to strip per architecture."""

def src(*present_globs):
present = set(present_globs)
return Mock(has_glob=lambda pattern: pattern in present)

# DeepSeek-style: dedicated mtp.* prefix.
assert _mtp_source_key_prefixes(src("mtp.*"), {}) == ("mtp.",)

# GLM glm4_moe_lite: nextn layer stored at index == num_hidden_layers.
glm_src = src("model.layers.47.*")
assert _mtp_source_key_prefixes(glm_src, {"num_hidden_layers": 47}) == ("model.layers.47.",)

# Nested text_config carries num_hidden_layers.
assert _mtp_source_key_prefixes(glm_src, {"text_config": {"num_hidden_layers": 47}}) == ("model.layers.47.",)

# No matching source keys -> nothing to strip.
assert _mtp_source_key_prefixes(src(), {"num_hidden_layers": 47}) == ()

# Both prefixes present.
both = src("mtp.*", "model.layers.47.*")
assert _mtp_source_key_prefixes(both, {"num_hidden_layers": 47}) == ("mtp.", "model.layers.47.")

def test_save_hf_weights_strips_nextn_prefix_when_mtp_omitted(self, tmp_path):
"""Regression: a model built without an MTP head must strip the GLM nextn
layer prefix from the source map before streaming save.

This is the actual bug being fixed (45/48-shard checkpoint dropping
boundary shards on GLM-4.x glm4_moe_lite). Unlike the helper-level tests,
this asserts the orchestration in ``save_hf_weights`` wires the stripped
prefixes through to ``save_generator``. It fails if the
``_model_omits_mtp(...)`` branch is removed, because the HF/saved configs
here do *not* explicitly disable MTP — the only signal is the built
model omitting the head.
"""
source = _make_fake_source(present={"model.layers.47.*", "model.layers.46.*"})
# Built megatron model omits the MTP head (SkyRL forces mtp_num_layers=None).
self._run_save_hf_weights(source, tmp_path, mtp_num_layers=None)

assert source.save_generator_kwargs is not None
assert source.save_generator_kwargs["ignored_source_key_prefixes"] == ("model.layers.47.",)

def test_save_hf_weights_keeps_all_keys_when_mtp_enabled(self, tmp_path):
"""Counterpart: when the model keeps its MTP head, nothing is stripped.

Also guards the ``if mtp_disabled`` gate: if a future refactor drops the
gate and always calls ``_mtp_source_key_prefixes``, the helper would strip
the real ``model.layers.47.`` layer here and this assertion would fail.
"""
source = _make_fake_source(present={"model.layers.47.*"})
self._run_save_hf_weights(source, tmp_path, mtp_num_layers=1)

assert source.save_generator_kwargs["ignored_source_key_prefixes"] is None

def _run_save_hf_weights(self, source, tmp_path, *, mtp_num_layers):
"""Drive ``save_hf_weights`` with a stubbed bridge/model so the only
behavior under test is the MTP prefix-resolution wiring.

``num_hidden_layers=47`` with no MTP-disable field means the export
decision hinges purely on whether the *built* model omits the head
(``mtp_num_layers``).
"""
hf_pretrained = create_mock_pretrained_causal_lm()
# HF config carries layer count but does NOT set any MTP-disable field.
hf_pretrained.config = SimpleNamespace(num_hidden_layers=47)
model_instance = SimpleNamespace(config=SimpleNamespace(mtp_num_layers=mtp_num_layers))

bridge_obj = object.__new__(AutoBridge)
bridge_obj.hf_pretrained = hf_pretrained

fake_model_bridge = Mock()
fake_model_bridge.stream_weights_megatron_to_hf.return_value = iter([])

with (
# ``state`` is a read-only property on PreTrainedBase, so patch it
# rather than assigning to the instance.
patch.object(
type(hf_pretrained),
"state",
new_callable=PropertyMock,
return_value=SimpleNamespace(source=source),
),
patch.object(AutoBridge, "_model_bridge", new_callable=PropertyMock) as mock_bridge,
patch.object(AutoBridge, "_get_model_instance", return_value=model_instance),
patch("megatron.bridge.models.conversion.auto_bridge.is_quantized", return_value=False),
):
mock_bridge.return_value = fake_model_bridge
bridge_obj.save_hf_weights([Mock()], tmp_path, show_progress=False)

def test_can_handle_supported_model(self, llama_config_mock):
"""Test can_handle returns True for supported models."""
with patch(
Expand Down