Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions examples/models/nemotron_labs_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ The CPT checkpoint from Stage 1 is passed via `checkpoint.pretrained_checkpoint`
torchrun --nproc_per_node=8 examples/models/nemotron_labs_diffusion/ar_to_dlm.py \
--model-size 3b \
--hf-path mistralai/Ministral-3-3B-Base-2512 \
--data-paths /path/to/dclm/merged_tokenized_text_document \
checkpoint.finetune=true \
checkpoint.pretrained_checkpoint=/path/to/cpt_checkpoint
checkpoint.pretrained_checkpoint=/path/to/cpt_checkpoint \
--data-paths /path/to/dclm/merged_tokenized_text_document
```


Expand All @@ -61,25 +61,24 @@ The script [`inference_nemotron_labs_diffusion.py`](inference_nemotron_labs_diff

```bash
torchrun --nproc_per_node=4 examples/models/nemotron_labs_diffusion/inference_nemotron_labs_diffusion.py \
--megatron-path /path/to/checkpoints/ar_to_dlm_8b \
--hf-model mistralai/Ministral-3-8B-Base-2512 \
--megatron-path /path/to/checkpoints/ar_to_dlm_3b/iter_xxxxxxx \
--hf-model mistralai/Ministral-3-3B-Base-2512 \
--prompts "The capital of France is" \
--gen-length 256 --block-length 32 --steps-per-block 32 \
--tp 4
--gen-length 256 --block-length 32 --steps-per-block 32
```

### AR mode

```bash
python examples/models/nemotron_labs_diffusion/inference_nemotron_labs_diffusion.py \
--megatron-path /path/to/checkpoints/ar_to_dlm_3b \
--megatron-path /path/to/checkpoints/ar_to_dlm_3b/iter_xxxxxxx \
--hf-model mistralai/Ministral-3-3B-Base-2512 \
--mode ar \
--prompts "Once upon a time" \
--max-new-tokens 128
```

The `--tp` argument must match the tensor parallelism degree of the saved checkpoint (e.g. `--tp 4` for 8B checkpoints saved with TP=4). `--hf-model` is used for the tokenizer and model config only — weights are loaded from `--megatron-path`.
You can pass `--tp` argument, but it must match the tensor parallelism degree of the saved checkpoint (e.g. `--tp 4` for 8B checkpoints saved with TP=4). `--hf-model` is used for the tokenizer and model config only — weights are loaded from `--megatron-path`.

---

Expand All @@ -102,14 +101,6 @@ python examples/models/nemotron_labs_diffusion/convert_checkpoints.py import \
--torch-dtype bfloat16
```

For the 8B model (TP=4):
```bash
python examples/models/nemotron_labs_diffusion/convert_checkpoints.py import \
--hf-model nvidia/Nemotron-Labs-Diffusion-8B \
--megatron-path /path/to/checkpoints/hf_to_mb_8b \
--torch-dtype bfloat16
```

The Megatron checkpoint is written under `--megatron-path` (e.g. `.../hf_to_mb_3b/iter_0000000/`). Use the parent directory for training with `checkpoint.load`.

### Export: Megatron → HuggingFace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,25 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> "NemotronLabsDif
# Auto-detect checkpoint format: VLM configs nest text params under text_config
self._is_text_only = not hasattr(hf_config, "text_config")

# NemotronLabsDiffusionConfig (a trust_remote_code config) does not declare
# model-specific fields as dataclass fields. In transformers 5.x
# PretrainedConfig is a dataclass, so MLM's _convert_value_to_dict uses the
# dataclass-fields path and silently drops all model-specific attributes
# (hidden_size, rope_parameters, etc.). Adding to_cfg_dict to the class
# makes the serializer use PretrainedConfig.to_dict() which captures everything.
cfg_cls = type(hf_config)
if not hasattr(cfg_cls, "to_cfg_dict") and hasattr(hf_config, "to_dict"):

def _to_cfg_dict(self):
cls = self.__class__
return {
"_target_": f"{cls.__module__}.{cls.__qualname__}.from_dict",
"_call_": True,
"config_dict": self.to_dict(),
}

cfg_cls.to_cfg_dict = _to_cfg_dict

return NemotronLabsDiffusionModelProvider(
hidden_size=text_config.hidden_size,
ffn_hidden_size=text_config.intermediate_size,
Expand Down
5 changes: 5 additions & 0 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,11 @@ def load_megatron_model(
except ImportError:
raise ImportError("megatron.bridge.training is not available.")

if self.trust_remote_code:
from megatron.bridge.utils.instantiate_utils import register_allowed_target_prefix

register_allowed_target_prefix("transformers_modules.")

checkpoint_path = Path(path)

# Check for iter_* folders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,94 @@ def test_vlm_output_layer_mapping_uses_lm_head(self):
]
assert len(out_mappings) == 1
assert out_mappings[0].hf_param == "language_model.lm_head.weight"


class _MockConfig:
"""A mutable config class with to_dict, simulating a trust_remote_code PretrainedConfig."""

def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

def to_dict(self):
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}


class TestToCfgDictMonkeyPatch:
"""Tests for the to_cfg_dict monkey-patch in provider_bridge()."""

def _make_mock_hf_config(self):
text_cfg = _MockConfig(
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=8,
tie_word_embeddings=False,
rope_parameters={"rope_theta": 10000.0},
vocab_size=32000,
)
hf_cfg = _MockConfig(text_config=text_cfg)
return hf_cfg

def test_to_cfg_dict_added_when_config_has_to_dict(self):
"""provider_bridge adds to_cfg_dict to config classes that have to_dict."""
bridge = NemotronLabsDiffusionBridge()
hf_cfg = self._make_mock_hf_config()
hf = DummyHFPretrained(hf_cfg)

assert not hasattr(_MockConfig, "to_cfg_dict")
bridge.provider_bridge(hf)
assert hasattr(_MockConfig, "to_cfg_dict")

# Clean up monkey-patch so it doesn't leak to other tests
delattr(_MockConfig, "to_cfg_dict")

def test_to_cfg_dict_returns_correct_target(self):
"""to_cfg_dict must produce a _target_ using cls.__module__ and cls.__qualname__."""
bridge = NemotronLabsDiffusionBridge()
hf_cfg = self._make_mock_hf_config()
hf = DummyHFPretrained(hf_cfg)
bridge.provider_bridge(hf)

result = hf_cfg.to_cfg_dict()
expected_target = f"{_MockConfig.__module__}.{_MockConfig.__qualname__}.from_dict"
assert result["_target_"] == expected_target
assert result["_call_"] is True
assert "config_dict" in result

delattr(_MockConfig, "to_cfg_dict")

def test_to_cfg_dict_preserves_dynamic_attributes(self):
"""to_cfg_dict must capture dynamic attributes like rope_parameters via to_dict."""
bridge = NemotronLabsDiffusionBridge()
hf_cfg = self._make_mock_hf_config()
hf_cfg.llama_4_scaling_beta = 0.7 # dynamic attribute
hf = DummyHFPretrained(hf_cfg)
bridge.provider_bridge(hf)

result = hf_cfg.to_cfg_dict()
assert result["config_dict"]["llama_4_scaling_beta"] == 0.7

delattr(_MockConfig, "to_cfg_dict")

def test_to_cfg_dict_not_added_to_simplenamespace(self):
"""SimpleNamespace has no to_dict, so to_cfg_dict must not be added."""
bridge = NemotronLabsDiffusionBridge()
hf_cfg = _make_hf_config() # uses SimpleNamespace
hf = DummyHFPretrained(hf_cfg)
bridge.provider_bridge(hf)

assert not hasattr(types.SimpleNamespace, "to_cfg_dict")

def test_to_cfg_dict_not_added_twice(self):
"""If to_cfg_dict already exists, provider_bridge must not overwrite it."""
bridge = NemotronLabsDiffusionBridge()
hf_cfg = self._make_mock_hf_config()
hf = DummyHFPretrained(hf_cfg)

sentinel = lambda self: {"sentinel": True}
_MockConfig.to_cfg_dict = sentinel

bridge.provider_bridge(hf)
assert _MockConfig.to_cfg_dict is sentinel

delattr(_MockConfig, "to_cfg_dict")
26 changes: 26 additions & 0 deletions tests/unit_tests/models/test_auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,7 @@ def test_load_megatron_model_basic(self):

bridge = AutoBridge.__new__(AutoBridge)
bridge.hf_pretrained = mock_hf_model
bridge.trust_remote_code = False

with patch("megatron.bridge.training.model_load_save.load_megatron_model") as mock_load_megatron_model:
from pathlib import Path
Expand All @@ -1259,6 +1260,7 @@ def test_load_megatron_model_with_iter_folder(self):

bridge = AutoBridge.__new__(AutoBridge)
bridge.hf_pretrained = mock_hf_model
bridge.trust_remote_code = False

with patch("megatron.bridge.training.model_load_save.load_megatron_model") as mock_load_megatron_model:
from pathlib import Path
Expand Down Expand Up @@ -1298,6 +1300,7 @@ def test_load_megatron_model_with_mp_overrides(self):

bridge = AutoBridge.__new__(AutoBridge)
bridge.hf_pretrained = mock_hf_model
bridge.trust_remote_code = False

# Create model-parallel overrides
mp_overrides = {
Expand Down Expand Up @@ -1339,6 +1342,29 @@ def test_load_megatron_model_with_mp_overrides(self):
assert call_args.args[0] == "checkpoint_path" # path argument
assert "skip_temp_dist_context" in call_args.kwargs

def test_load_megatron_model_registers_prefix_when_trust_remote_code(self):
"""Test that load_megatron_model registers transformers_modules prefix when trust_remote_code=True."""
mock_hf_model = Mock(spec=PreTrainedCausalLM)
mock_config = Mock(spec=PretrainedConfig)
mock_config.architectures = ["LlamaForCausalLM"]
mock_hf_model.config = mock_config

bridge = AutoBridge.__new__(AutoBridge)
bridge.hf_pretrained = mock_hf_model
bridge.trust_remote_code = True

with patch("megatron.bridge.training.model_load_save.load_megatron_model") as mock_load_megatron_model:
with patch("megatron.bridge.utils.instantiate_utils.register_allowed_target_prefix") as mock_register:
from pathlib import Path

with patch.object(Path, "iterdir") as mock_iterdir:
mock_load_megatron_model.return_value = Mock()
mock_iterdir.return_value = []

bridge.load_megatron_model("./checkpoint_path")

mock_register.assert_called_once_with("transformers_modules.")

@patch("torch.distributed.is_available")
@patch("torch.distributed.is_initialized")
def test_save_hf_pretrained_uses_bridge_additional_file_patterns(self, mock_is_init, mock_is_avail):
Expand Down
Loading