Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/conversion/convert_checkpoints_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def export_megatron_to_hf(
distributed_save: bool = False,
save_every_n_ranks: int = 1,
distributed_timeout_minutes: int | None = None,
export_weight_dtype: str | None = None,
) -> None:
"""Export a distributed Megatron checkpoint to HuggingFace format."""
_ensure_distributed_initialized(distributed_timeout_minutes)
Expand Down Expand Up @@ -244,6 +245,7 @@ def export_megatron_to_hf(
strict=strict,
distributed_save=distributed_save,
save_every_n_ranks=save_every_n_ranks,
weight_dtype=_parse_dtype(export_weight_dtype) if export_weight_dtype else None,
)
print_rank_0(f"Export complete: {hf_path}")

Expand Down Expand Up @@ -304,6 +306,16 @@ def main():
default=1,
help="Only every N-th rank writes files (reduces I/O, only with --distributed-save)",
)
export_parser.add_argument(
"--export-weight-dtype",
choices=sorted(DTYPE_MAP),
default=None,
help=(
"Emit plain weights in this dtype instead of re-creating the source repo's "
"quantized weight/scale layout (currently honored by the DeepSeek-V4 bridge). "
"Use for SFT products that need exact train/inference numerical parity."
),
)
args = parser.parse_args()

if not args.command:
Expand Down Expand Up @@ -338,6 +350,7 @@ def main():
distributed_save=args.distributed_save,
save_every_n_ranks=args.save_every_n_ranks,
distributed_timeout_minutes=args.distributed_timeout_minutes,
export_weight_dtype=args.export_weight_dtype,
)


Expand Down
6 changes: 6 additions & 0 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def export_hf_weights(
show_progress: bool = True,
conversion_tasks: Optional[List[WeightConversionTask]] = None,
merge_adapter_weights: bool = True,
weight_dtype: Optional[torch.dtype] = None,
) -> Iterable["HFWeightTuple"]:
"""
Export Megatron model weights to HuggingFace format.
Expand Down Expand Up @@ -563,6 +564,7 @@ def export_hf_weights(
show_progress=show_progress,
conversion_tasks=conversion_tasks,
merge_adapter_weights=merge_adapter_weights,
weight_dtype=weight_dtype,
)

def export_hf_weights_modelopt(
Expand Down Expand Up @@ -814,6 +816,7 @@ def save_hf_pretrained(
merge_adapter_weights: bool = True,
distributed_save: bool = False,
save_every_n_ranks: int = 1,
weight_dtype: Optional[torch.dtype] = None,
) -> None:
"""
Save a Megatron model in HuggingFace format.
Expand Down Expand Up @@ -926,6 +929,7 @@ def _save_artifacts():
merge_adapter_weights=merge_adapter_weights,
distributed_save=distributed_save,
save_every_n_ranks=save_every_n_ranks,
weight_dtype=weight_dtype,
)

def save_hf_weights(
Expand All @@ -937,6 +941,7 @@ def save_hf_weights(
merge_adapter_weights: bool = True,
distributed_save: bool = False,
save_every_n_ranks: int = 1,
weight_dtype: Optional[torch.dtype] = None,
) -> None:
"""
Save Megatron model weights in HuggingFace safetensors format.
Expand Down Expand Up @@ -989,6 +994,7 @@ def save_hf_weights(
cpu=True,
show_progress=show_progress,
merge_adapter_weights=merge_adapter_weights,
weight_dtype=weight_dtype,
)
model_instance = self._get_model_instance(model)
quant_tensors = None
Expand Down
10 changes: 9 additions & 1 deletion src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import math
import re
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -120,6 +120,8 @@ class WeightConversionTask(Generic[MappingT]):
sub-module that owns the parameter (required for loads).
param_weight (Optional[torch.Tensor]): The actual parameter tensor that will
receive the converted weight (required for loads).
weight_dtype (Optional[torch.dtype]): When set, bridges that re-create a quantized
source layout on export emit plain weights in this dtype instead.

"""

Expand All @@ -130,6 +132,7 @@ class WeightConversionTask(Generic[MappingT]):
vp_stage: Optional[int] = None
megatron_module: Optional[torch.nn.Module] = None
param_weight: Optional[torch.Tensor] = None
weight_dtype: Optional[torch.dtype] = None


class _HFNameSuffixMapping:
Expand Down Expand Up @@ -1192,6 +1195,7 @@ def stream_weights_megatron_to_hf(
show_progress: bool = True,
conversion_tasks: Optional[List[WeightConversionTask]] = None,
merge_adapter_weights: bool = True,
weight_dtype: Optional[torch.dtype] = None,
) -> Iterable[HFWeightTuple]:
"""Export Megatron weights to HuggingFace format.

Expand Down Expand Up @@ -1251,6 +1255,9 @@ def stream_weights_megatron_to_hf(
# Use provided conversion tasks or build them
if conversion_tasks is None:
conversion_tasks = self.build_conversion_tasks(hf_pretrained, unwrapped_model_list)
if weight_dtype is not None:
# WeightConversionTask is frozen — rebuild the tasks with the dtype set
conversion_tasks = [replace(task, weight_dtype=weight_dtype) for task in conversion_tasks]

# Collect adapter conversion tasks when merge is requested
adapter_tasks_by_base: Dict[str, List[AdapterWeightConversionTask]] = {}
Expand Down Expand Up @@ -1990,6 +1997,7 @@ def stream_weights_megatron_to_hf(
show_progress: bool = True,
conversion_tasks: Optional[List[WeightConversionTask]] = None,
merge_adapter_weights: bool = True,
weight_dtype: Optional[torch.dtype] = None,
) -> Iterable[HFWeightTuple]:
"""Bridge Megatron model state to HuggingFace format."""
...
Expand Down
12 changes: 10 additions & 2 deletions src/megatron/bridge/models/deepseek/deepseek_v4_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,16 @@ def maybe_modify_converted_hf_weight(
converted_weights_dict: Dict[str, torch.Tensor],
hf_state_dict: Mapping[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Recreate DSv4 quantized weight/scale pairs expected by the source shard index."""
del task
"""Recreate DSv4 quantized weight/scale pairs expected by the source shard index.

When ``task.weight_dtype`` is set, plain weights are emitted in that dtype
instead (no ``*.scale`` companions).
"""
if task.weight_dtype is not None:
return {
name: (weight.to(task.weight_dtype) if weight.is_floating_point() else weight)
for name, weight in converted_weights_dict.items()
}
return quantization_utils.requantize_hf_weight_scale_pairs(
converted_weights_dict,
hf_state_dict,
Expand Down
48 changes: 47 additions & 1 deletion tests/unit_tests/models/deepseek/test_deepseek_v4_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def _by_megatron(registry):


def _dummy_task():
return SimpleNamespace(param_name="", global_param_name="", mapping=None)
from megatron.bridge.models.conversion.model_bridge import WeightConversionTask

return WeightConversionTask(param_name="", global_param_name="", mapping=None)


class TestNativeDeepSeekV4ConfigTranslation:
Expand Down Expand Up @@ -394,3 +396,47 @@ def test_provider_bridge_forces_full_rotary_percent(self):
out = bridge.provider_bridge(hf_pretrained)

assert out.rotary_percent == 1.0


class TestDeepSeekV4ExportWeightDtype:
def test_weight_dtype_set_emits_plain_weights(self):
from dataclasses import replace
from unittest.mock import MagicMock

from megatron.bridge.models.conversion.model_bridge import WeightConversionTask
from megatron.bridge.models.deepseek.deepseek_v4_bridge import DeepSeekV4Bridge

bridge = DeepSeekV4Bridge.__new__(DeepSeekV4Bridge)
task = WeightConversionTask(param_name="w", global_param_name="w", mapping=MagicMock())
task = replace(task, weight_dtype=torch.bfloat16) # frozen: must be settable via replace
weight = torch.randn(4, 4, dtype=torch.float32)
converted = {
"model.layers.0.mlp.weight": weight,
"model.layers.0.mlp.bias_idx": torch.ones(2, dtype=torch.int32),
}
hf_state = {"model.layers.0.mlp.weight": weight, "model.layers.0.mlp.scale": torch.ones(1)}

out = bridge.maybe_modify_converted_hf_weight(task, converted, hf_state)

assert set(out) == set(converted)
assert out["model.layers.0.mlp.weight"].dtype == torch.bfloat16
assert out["model.layers.0.mlp.bias_idx"].dtype == torch.int32

def test_no_weight_dtype_requantizes_by_default(self, monkeypatch):
from unittest.mock import MagicMock

from megatron.bridge.models.conversion.model_bridge import WeightConversionTask
from megatron.bridge.models.deepseek.deepseek_v4_bridge import DeepSeekV4Bridge

bridge = DeepSeekV4Bridge.__new__(DeepSeekV4Bridge)
task = WeightConversionTask(param_name="w", global_param_name="w", mapping=MagicMock())
called = {}

def fake_requantize(converted, hf_state, *, use_mxfp4=None):
called["hit"] = True
return {"quantized": torch.zeros(1)}

monkeypatch.setattr(quantization_utils, "requantize_hf_weight_scale_pairs", fake_requantize)
out = bridge.maybe_modify_converted_hf_weight(task, {"a.weight": torch.ones(1)}, {})

assert called.get("hit") and "quantized" in out
5 changes: 5 additions & 0 deletions tests/unit_tests/models/test_auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def test_save_hf_pretrained(self, mock_is_init, mock_is_avail, mock_barrier, moc
merge_adapter_weights=True,
distributed_save=False,
save_every_n_ranks=1,
weight_dtype=None,
)

@patch("torch.distributed.is_initialized", return_value=False)
Expand Down Expand Up @@ -801,6 +802,7 @@ def test_save_hf_pretrained_non_zero_rank(
merge_adapter_weights=True,
distributed_save=False,
save_every_n_ranks=1,
weight_dtype=None,
)

def test_export_hf_weights(self):
Expand Down Expand Up @@ -841,6 +843,7 @@ def test_export_hf_weights(self):
show_progress=True,
conversion_tasks=None,
merge_adapter_weights=True,
weight_dtype=None,
)

def test_export_adapter_weights(self):
Expand Down Expand Up @@ -1474,6 +1477,7 @@ def fake_save_generator(gen, *args, **kwargs):
cpu=True,
show_progress=True,
merge_adapter_weights=True,
weight_dtype=None,
)

# The quantizer tensor should have been saved via torch.save sidecar
Expand Down Expand Up @@ -1530,6 +1534,7 @@ def test_save_hf_weights_no_sidecar_when_not_quantized(
cpu=True,
show_progress=True,
merge_adapter_weights=True,
weight_dtype=None,
)
mock_torch_save.assert_not_called()

Expand Down
Loading