diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index c71af26a5f..1cf21f928c 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -114,6 +114,75 @@ def _saved_config_disables_mtp(path: str | Path) -> bool: return _config_disables_mtp(json.load(f)) +def _mtp_source_key_prefixes(source: Any, *configs: Any) -> tuple[str, ...]: + """Source-checkpoint key prefixes for MTP/nextn tensors that must be ignored + when exporting a model built without an MTP head. + + Different HF architectures name their Multi-Token-Prediction (nextn) tensors + differently: + + * DeepSeek-style: a dedicated ``mtp.*`` prefix. + * GLM-4.x ``glm4_moe_lite``: the nextn layer is stored as a regular decoder + layer at index ``num_hidden_layers`` (i.e. one past the last real layer), + so its tensors live under ``model.layers.{num_hidden_layers}.*``. + + When the megatron model is built without an MTP head the generator never + yields these tensors. If they remain in the source sharding map, the shards + that co-locate them with real (non-MTP) tensors can never be completed and + get dropped wholesale (taking real boundary params with them). Stripping + these prefixes from the expected source map lets those shards complete with + only their real keys. + + Precondition: only call this when the built model omits MTP (see + ``_model_omits_mtp``). The returned prefixes are always stripped, so calling + it for an MTP-enabled export would silently drop real nextn tensors. + + Returns the tuple of prefixes that exist in ``source`` and should be ignored. + """ + prefixes: list[str] = [] + + if source.has_glob("mtp.*"): + prefixes.append("mtp.") + + # GLM nextn layer at index == num_hidden_layers. HF decoder layers are + # 0-indexed, so layer ``num_hidden_layers`` is one past the last real layer. + # ``configs`` is ordered hf_config-before-model_config so the HF value wins + # on conflict (source keys are HF-shaped). + num_hidden_layers = _MISSING + for config in configs: + value = _get_config_field(config, "num_hidden_layers") + if value is _MISSING or value is None: + text_config = _get_config_field(config, "text_config") + if text_config is not _MISSING: + value = _get_config_field(text_config, "num_hidden_layers") + if value is not _MISSING and value is not None: + num_hidden_layers = int(value) + break + + if num_hidden_layers is not _MISSING: + nextn_prefix = f"model.layers.{num_hidden_layers}." + if source.has_glob(f"{nextn_prefix}*"): + prefixes.append(nextn_prefix) + + return tuple(prefixes) + + +def _model_omits_mtp(model_config: Any) -> bool: + """True when the *built* megatron model has no MTP head. + + Unlike :func:`_config_disables_mtp` (which treats ``None``/unset as + "unspecified"), the built model's provider carries an explicit + ``mtp_num_layers``; a falsy value (``None`` or ``0``) means no MTP head was + instantiated, so the export generator will not yield any MTP/nextn tensors. + """ + if model_config is None: + return False + value = _get_config_field(model_config, "mtp_num_layers") + if value is _MISSING: + return False + return not value + + # Preformatted display string for error/help messages SUPPORTED_HF_ARCHITECTURES_DISPLAY = " or ".join(f"'{s}'" for s in SUPPORTED_HF_ARCHITECTURES) @@ -1029,10 +1098,14 @@ def _filter_quant(gen): source = self.hf_pretrained.state.source model_config = getattr(model_instance, "config", None) hf_config = getattr(self.hf_pretrained, "config", self.hf_pretrained) - mtp_disabled = _saved_config_disables_mtp(path) or any( - _config_disables_mtp(config) for config in (hf_config, model_config) + mtp_disabled = ( + _saved_config_disables_mtp(path) + or any(_config_disables_mtp(config) for config in (hf_config, model_config)) + or _model_omits_mtp(model_config) ) - ignored_source_key_prefixes = ("mtp.",) if mtp_disabled and source.has_glob("mtp.*") else None + ignored_source_key_prefixes = ( + _mtp_source_key_prefixes(source, hf_config, model_config) if mtp_disabled else () + ) or None source.save_generator( generator, path, diff --git a/tests/unit_tests/models/test_auto_bridge.py b/tests/unit_tests/models/test_auto_bridge.py index a55ca3db59..9405632812 100644 --- a/tests/unit_tests/models/test_auto_bridge.py +++ b/tests/unit_tests/models/test_auto_bridge.py @@ -18,6 +18,7 @@ import json from pathlib import Path +from types import SimpleNamespace from unittest.mock import Mock, PropertyMock, patch import pytest @@ -29,10 +30,13 @@ AutoBridge, _config_disables_mtp, _drop_readonly_config_properties, + _model_omits_mtp, + _mtp_source_key_prefixes, _saved_config_disables_mtp, ) from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.hf_pretrained.state import SafeTensorsStateSource def create_mock_pretrained_causal_lm(): @@ -45,6 +49,26 @@ def __init__(self): return MockPreTrainedCausalLM() +def _make_fake_source(present): + """Build a ``SafeTensorsStateSource`` stand-in for ``save_hf_weights`` tests. + + Uses ``Mock(spec=...)`` so the ``isinstance(source, SafeTensorsStateSource)`` + gate in ``save_hf_weights`` stays satisfied without bypassing the real + ``__init__``. ``has_glob`` reports which source-key globs exist; the captured + ``save_generator`` kwargs are exposed on ``source.save_generator_kwargs`` for + assertions. + """ + source = Mock(spec=SafeTensorsStateSource) + source.save_generator_kwargs = None + source.has_glob.side_effect = lambda pattern: pattern in present + + def _capture_save_generator(generator, path, **kwargs): + source.save_generator_kwargs = kwargs + + source.save_generator.side_effect = _capture_save_generator + return source + + class TestAutoBridge: """Test cases for AutoBridge automatic selection and full bridge functionality.""" @@ -180,6 +204,106 @@ def test_mtp_disabled_helpers(self, tmp_path): assert _saved_config_disables_mtp(tmp_path) is True + def test_model_omits_mtp(self): + """A built model with a falsy mtp_num_layers has no MTP head.""" + assert _model_omits_mtp(None) is False + # Unset attribute -> unknown -> do not assume omitted. + assert _model_omits_mtp(SimpleNamespace()) is False + # SkyRL forces mtp_num_layers=None -> head omitted from export. + assert _model_omits_mtp(Mock(mtp_num_layers=None)) is True + assert _model_omits_mtp(Mock(mtp_num_layers=0)) is True + assert _model_omits_mtp(Mock(mtp_num_layers=1)) is False + + def test_mtp_source_key_prefixes(self): + """Resolve the MTP/nextn source-key prefixes to strip per architecture.""" + + def src(*present_globs): + present = set(present_globs) + return Mock(has_glob=lambda pattern: pattern in present) + + # DeepSeek-style: dedicated mtp.* prefix. + assert _mtp_source_key_prefixes(src("mtp.*"), {}) == ("mtp.",) + + # GLM glm4_moe_lite: nextn layer stored at index == num_hidden_layers. + glm_src = src("model.layers.47.*") + assert _mtp_source_key_prefixes(glm_src, {"num_hidden_layers": 47}) == ("model.layers.47.",) + + # Nested text_config carries num_hidden_layers. + assert _mtp_source_key_prefixes(glm_src, {"text_config": {"num_hidden_layers": 47}}) == ("model.layers.47.",) + + # No matching source keys -> nothing to strip. + assert _mtp_source_key_prefixes(src(), {"num_hidden_layers": 47}) == () + + # Both prefixes present. + both = src("mtp.*", "model.layers.47.*") + assert _mtp_source_key_prefixes(both, {"num_hidden_layers": 47}) == ("mtp.", "model.layers.47.") + + def test_save_hf_weights_strips_nextn_prefix_when_mtp_omitted(self, tmp_path): + """Regression: a model built without an MTP head must strip the GLM nextn + layer prefix from the source map before streaming save. + + This is the actual bug being fixed (45/48-shard checkpoint dropping + boundary shards on GLM-4.x glm4_moe_lite). Unlike the helper-level tests, + this asserts the orchestration in ``save_hf_weights`` wires the stripped + prefixes through to ``save_generator``. It fails if the + ``_model_omits_mtp(...)`` branch is removed, because the HF/saved configs + here do *not* explicitly disable MTP — the only signal is the built + model omitting the head. + """ + source = _make_fake_source(present={"model.layers.47.*", "model.layers.46.*"}) + # Built megatron model omits the MTP head (SkyRL forces mtp_num_layers=None). + self._run_save_hf_weights(source, tmp_path, mtp_num_layers=None) + + assert source.save_generator_kwargs is not None + assert source.save_generator_kwargs["ignored_source_key_prefixes"] == ("model.layers.47.",) + + def test_save_hf_weights_keeps_all_keys_when_mtp_enabled(self, tmp_path): + """Counterpart: when the model keeps its MTP head, nothing is stripped. + + Also guards the ``if mtp_disabled`` gate: if a future refactor drops the + gate and always calls ``_mtp_source_key_prefixes``, the helper would strip + the real ``model.layers.47.`` layer here and this assertion would fail. + """ + source = _make_fake_source(present={"model.layers.47.*"}) + self._run_save_hf_weights(source, tmp_path, mtp_num_layers=1) + + assert source.save_generator_kwargs["ignored_source_key_prefixes"] is None + + def _run_save_hf_weights(self, source, tmp_path, *, mtp_num_layers): + """Drive ``save_hf_weights`` with a stubbed bridge/model so the only + behavior under test is the MTP prefix-resolution wiring. + + ``num_hidden_layers=47`` with no MTP-disable field means the export + decision hinges purely on whether the *built* model omits the head + (``mtp_num_layers``). + """ + hf_pretrained = create_mock_pretrained_causal_lm() + # HF config carries layer count but does NOT set any MTP-disable field. + hf_pretrained.config = SimpleNamespace(num_hidden_layers=47) + model_instance = SimpleNamespace(config=SimpleNamespace(mtp_num_layers=mtp_num_layers)) + + bridge_obj = object.__new__(AutoBridge) + bridge_obj.hf_pretrained = hf_pretrained + + fake_model_bridge = Mock() + fake_model_bridge.stream_weights_megatron_to_hf.return_value = iter([]) + + with ( + # ``state`` is a read-only property on PreTrainedBase, so patch it + # rather than assigning to the instance. + patch.object( + type(hf_pretrained), + "state", + new_callable=PropertyMock, + return_value=SimpleNamespace(source=source), + ), + patch.object(AutoBridge, "_model_bridge", new_callable=PropertyMock) as mock_bridge, + patch.object(AutoBridge, "_get_model_instance", return_value=model_instance), + patch("megatron.bridge.models.conversion.auto_bridge.is_quantized", return_value=False), + ): + mock_bridge.return_value = fake_model_bridge + bridge_obj.save_hf_weights([Mock()], tmp_path, show_progress=False) + def test_can_handle_supported_model(self, llama_config_mock): """Test can_handle returns True for supported models.""" with patch(