Skip to content

Commit b577be3

Browse files
Meirtzclaude
andcommitted
feat(model): thread weight_dtype through HF export for plain-dtype output
Export has two consumers — online weight sync for RL rollout (export_hf_weights) and on-disk checkpoints (save_hf_pretrained). Each gains an optional weight_dtype that flows through WeightConversionTask into the export stream. Per review (HollowMan6): the plain-dtype cast is now generic, not DSv4-only. build_conversion_tasks stamps weight_dtype onto each task (no post-hoc dataclasses.replace except for caller-supplied tasks), and the cast lives in the shared stream path covering both the standard and grouped-export branches. The DSv4 hook simply skips requantization when weight_dtype is set and returns the converted weights unchanged, letting the generic path cast the dtype — keeping plain-dtype export identical across bridges. Adds --export-weight-dtype to the multi-gpu convert example. Validated end-to-end on 32x GB300: bf16 export = 35020 tensors / 0 scales; quantized export = 69187 / 34167. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Lingrui Mei <lmei@nvidia.com>
1 parent fc05b1d commit b577be3

4 files changed

Lines changed: 53 additions & 19 deletions

File tree

examples/conversion/convert_checkpoints_multi_gpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,10 @@ def main():
329329
choices=sorted(DTYPE_MAP),
330330
default=None,
331331
help=(
332-
"Emit plain weights in this dtype instead of re-creating the source repo's "
333-
"quantized weight/scale layout (currently honored by the DeepSeek-V4 bridge). "
334-
"Use for SFT products that need exact train/inference numerical parity."
332+
"Emit plain weights cast to this dtype. For bridges that recreate a quantized "
333+
"source layout on export (e.g. DeepSeek-V4) this also skips the requantization, "
334+
"so no *.scale companions are written. Use for SFT products that need exact "
335+
"train/inference numerical parity."
335336
),
336337
)
337338
args = parser.parse_args()

src/megatron/bridge/models/conversion/model_bridge.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,10 @@ class WeightConversionTask(Generic[MappingT]):
120120
sub-module that owns the parameter (required for loads).
121121
param_weight (Optional[torch.Tensor]): The actual parameter tensor that will
122122
receive the converted weight (required for loads).
123-
weight_dtype (Optional[torch.dtype]): When set, bridges that re-create a quantized
124-
source layout on export emit plain weights in this dtype instead.
123+
weight_dtype (Optional[torch.dtype]): Export only. When set, floating-point
124+
weights are cast to this dtype; bridges that recreate a quantized source
125+
layout (e.g. DeepSeek-V4) additionally skip requantization so no scale
126+
companions are emitted.
125127
126128
"""
127129

@@ -899,6 +901,23 @@ def maybe_modify_converted_hf_weight(
899901
"""
900902
return converted_weights_dict
901903

904+
@staticmethod
905+
def _cast_export_weight_dtype(
906+
weights: Dict[str, torch.Tensor], weight_dtype: Optional[torch.dtype]
907+
) -> Dict[str, torch.Tensor]:
908+
"""Cast floating-point export weights to ``weight_dtype`` (no-op if None).
909+
910+
Integer tensors (e.g. packed/index buffers) are left untouched. This is the
911+
generic plain-dtype export path; bridges that recreate a quantized source
912+
layout opt out by leaving ``weight_dtype`` unset.
913+
"""
914+
if weight_dtype is None:
915+
return weights
916+
return {
917+
name: (weight.to(weight_dtype) if weight.is_floating_point() else weight)
918+
for name, weight in weights.items()
919+
}
920+
902921
def _accumulate_grouped_export(
903922
self,
904923
task: "WeightConversionTask",
@@ -1256,7 +1275,8 @@ def stream_weights_megatron_to_hf(
12561275
if conversion_tasks is None:
12571276
conversion_tasks = self.build_conversion_tasks(hf_pretrained, unwrapped_model_list)
12581277
if weight_dtype is not None:
1259-
# WeightConversionTask is frozen — rebuild the tasks with the dtype set
1278+
# Stamp the export dtype on the (frozen) tasks here rather than in
1279+
# build_conversion_tasks, which model bridges may override.
12601280
conversion_tasks = [replace(task, weight_dtype=weight_dtype) for task in conversion_tasks]
12611281

12621282
# Collect adapter conversion tasks when merge is requested
@@ -1312,6 +1332,7 @@ def stream_weights_megatron_to_hf(
13121332
task, converted_weights_dict, model_config, _grouped_buffers, hf_state_dict
13131333
)
13141334
if merged_result is not None:
1335+
merged_result = self._cast_export_weight_dtype(merged_result, task.weight_dtype)
13151336
for hf_name, tensor in merged_result.items():
13161337
yield HFWeightTuple(hf_name, tensor.cpu() if cpu else tensor)
13171338
continue
@@ -1336,6 +1357,8 @@ def stream_weights_megatron_to_hf(
13361357
adapter_weights,
13371358
)
13381359

1360+
converted_weights_dict = self._cast_export_weight_dtype(converted_weights_dict, task.weight_dtype)
1361+
13391362
for hf_name, tensor in converted_weights_dict.items():
13401363
final_tensor = tensor.cpu() if cpu else tensor
13411364

src/megatron/bridge/models/deepseek/deepseek_v4_bridge.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -913,14 +913,12 @@ def maybe_modify_converted_hf_weight(
913913
) -> Dict[str, torch.Tensor]:
914914
"""Recreate DSv4 quantized weight/scale pairs expected by the source shard index.
915915
916-
When ``task.weight_dtype`` is set, plain weights are emitted in that dtype
917-
instead (no ``*.scale`` companions).
916+
When ``task.weight_dtype`` is set the caller wants plain (non-quantized)
917+
weights, so skip requantization and let the generic export path cast the
918+
dtype — keeping plain-dtype export behavior identical across bridges.
918919
"""
919920
if task.weight_dtype is not None:
920-
return {
921-
name: (weight.to(task.weight_dtype) if weight.is_floating_point() else weight)
922-
for name, weight in converted_weights_dict.items()
923-
}
921+
return converted_weights_dict
924922
return quantization_utils.requantize_hf_weight_scale_pairs(
925923
converted_weights_dict,
926924
hf_state_dict,

tests/unit_tests/models/deepseek/test_deepseek_v4_bridge.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def test_provider_bridge_preserves_fused_defaults_without_cuda(self):
441441

442442

443443
class TestDeepSeekV4ExportWeightDtype:
444-
def test_weight_dtype_set_emits_plain_weights(self):
444+
def test_weight_dtype_set_skips_requantization(self, monkeypatch):
445445
from dataclasses import replace
446446
from unittest.mock import MagicMock
447447

@@ -451,18 +451,30 @@ def test_weight_dtype_set_emits_plain_weights(self):
451451
bridge = DeepSeekV4Bridge.__new__(DeepSeekV4Bridge)
452452
task = WeightConversionTask(param_name="w", global_param_name="w", mapping=MagicMock())
453453
task = replace(task, weight_dtype=torch.bfloat16) # frozen: must be settable via replace
454+
455+
def fail_requantize(*args, **kwargs):
456+
raise AssertionError("requantize must be skipped when weight_dtype is set")
457+
458+
monkeypatch.setattr(quantization_utils, "requantize_hf_weight_scale_pairs", fail_requantize)
454459
weight = torch.randn(4, 4, dtype=torch.float32)
455-
converted = {
456-
"model.layers.0.mlp.weight": weight,
457-
"model.layers.0.mlp.bias_idx": torch.ones(2, dtype=torch.int32),
458-
}
460+
converted = {"model.layers.0.mlp.weight": weight}
459461
hf_state = {"model.layers.0.mlp.weight": weight, "model.layers.0.mlp.scale": torch.ones(1)}
460462

461463
out = bridge.maybe_modify_converted_hf_weight(task, converted, hf_state)
462464

463-
assert set(out) == set(converted)
465+
assert out is converted # returned unchanged; generic path casts the dtype
466+
467+
def test_generic_export_cast_applies_plain_dtype(self):
468+
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
469+
470+
weights = {
471+
"model.layers.0.mlp.weight": torch.randn(4, 4, dtype=torch.float32),
472+
"model.layers.0.mlp.bias_idx": torch.ones(2, dtype=torch.int32),
473+
}
474+
out = MegatronModelBridge._cast_export_weight_dtype(weights, torch.bfloat16)
464475
assert out["model.layers.0.mlp.weight"].dtype == torch.bfloat16
465-
assert out["model.layers.0.mlp.bias_idx"].dtype == torch.int32
476+
assert out["model.layers.0.mlp.bias_idx"].dtype == torch.int32 # int preserved
477+
assert MegatronModelBridge._cast_export_weight_dtype(weights, None) is weights
466478

467479
def test_no_weight_dtype_requantizes_by_default(self, monkeypatch):
468480
from unittest.mock import MagicMock

0 commit comments

Comments
 (0)