Skip to content

Commit 3587238

Browse files
committed
Fix bugs for MSE
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 1d796f9 commit 3587238

5 files changed

Lines changed: 175 additions & 26 deletions

File tree

modelopt/torch/export/moe_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,25 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
110110
and w_quantizer._amax.dim() >= 1
111111
):
112112
amax = w_quantizer._amax
113+
# Static block-quant calibration (e.g. NVFP4 MSE FP8 sweep)
114+
# produces a per-block _amax with shape (num_blocks_total, ...)
115+
# where num_blocks_total = fused_total * blocks_per_row. That
116+
# shape collapses the row axis we want to slice on. Restore the
117+
# row dimension so the dim-0 slicing below splits gate / up
118+
# correctly. No-op when _amax is already aligned with fused_total.
119+
if amax.numel() != fused_total and amax.numel() % fused_total == 0:
120+
amax = amax.contiguous().view(fused_total, amax.numel() // fused_total)
113121
amax_dim0 = amax.shape[0]
114122
if fused_total % amax_dim0 == 0:
115123
slice_start = fused_start * amax_dim0 // fused_total
116124
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
117-
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
125+
sliced = amax[slice_start:slice_end].contiguous()
126+
# The amax setter refuses shape changes once `_amax` exists,
127+
# so drop the existing buffer before re-registering with the
128+
# sliced shape.
129+
if hasattr(w_quantizer, "_amax"):
130+
delattr(w_quantizer, "_amax")
131+
w_quantizer.amax = sliced
118132
else:
119133
warnings.warn(
120134
f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not "

modelopt/torch/export/unified_export_hf.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Code that export quantized Hugging Face models for deployment."""
1717

1818
import collections.abc
19+
import contextlib
1920
import json
2021
import re
2122
import tempfile
@@ -1134,6 +1135,30 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None:
11341135
mod.revert_weight_conversion = original
11351136

11361137

1138+
def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None:
1139+
"""Coerce ``model.generation_config`` so it passes transformers' strict validation.
1140+
1141+
Some upstream HF checkpoints ship a ``generation_config.json`` that mixes
1142+
``do_sample=False`` with sampling-only attrs (``top_p``, ``top_k``, ...).
1143+
Newer transformers raise ``ValueError("GenerationConfig is invalid: ...")``
1144+
inside ``save_pretrained``, blocking export. We try a strict validate and
1145+
on failure flip ``do_sample`` to ``True`` so the upstream sampling intent
1146+
is preserved (rather than silently dropping ``top_p`` etc.). Quietly does
1147+
nothing if the model has no generation_config or it's already valid.
1148+
"""
1149+
gc = getattr(model, "generation_config", None)
1150+
if gc is None or not hasattr(gc, "validate"):
1151+
return
1152+
try:
1153+
gc.validate(strict=True)
1154+
return
1155+
except Exception:
1156+
pass
1157+
if not getattr(gc, "do_sample", False):
1158+
with contextlib.suppress(Exception):
1159+
gc.do_sample = True
1160+
1161+
11371162
def export_speculative_decoding(
11381163
model: torch.nn.Module,
11391164
dtype: torch.dtype | None = None,
@@ -1228,6 +1253,12 @@ def export_hf_checkpoint(
12281253
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
12291254
_patches = _patch_revert_weight_conversion()
12301255

1256+
# Some upstream HF checkpoints ship a generation_config.json that fails
1257+
# transformers' strict validation on save (e.g. ``top_p`` set without
1258+
# ``do_sample=True`` — newer transformers raises). Flip ``do_sample`` to
1259+
# the sampling-attrs intent so save_pretrained can write the file.
1260+
_sanitize_generation_config_for_save(model)
1261+
12311262
try:
12321263
model.save_pretrained(
12331264
export_dir,

modelopt/torch/quantization/model_calib.py

Lines changed: 112 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,98 @@
6464
"max_calibrate",
6565
"smoothquant",
6666
"svdquant",
67+
"sync_grouped_weight_global_amax",
6768
]
6869

70+
71+
# Sibling weight-quantizer name groups whose ``global_amax`` should share an
72+
# FP8 scale-of-scales. All members of a group sit under the same parent module
73+
# (e.g. one self-attention or one MLP block) and either consume the same input
74+
# tensor or get fused at deployment, so a divergent global_amax across siblings
75+
# would split their FP8 grids and skew the round.
76+
_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = (
77+
# Standard self-attention (skipped for fused qkv_proj — single weight).
78+
("q_proj", "k_proj", "v_proj"),
79+
# Gated MLP, modern naming (Llama / Qwen / Mistral / etc.).
80+
("gate_proj", "up_proj"),
81+
# Gated MLP, older Mixtral-style naming.
82+
("w1", "w3"),
83+
)
84+
85+
86+
def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool:
87+
"""True for an NVFP4-static weight quantizer that ``max_calibrate`` already
88+
populated with a per-block ``_amax`` and that is currently enabled.
89+
"""
90+
return (
91+
isinstance(q, TensorQuantizer)
92+
and not q._disabled
93+
and q.is_nvfp4_static
94+
and hasattr(q, "_amax")
95+
and q._amax is not None
96+
)
97+
98+
99+
def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
100+
"""Find groups of Linear-like submodules whose NVFP4-static weight quantizers
101+
should share ``global_amax`` (Q/K/V under one attention parent; gate/up under
102+
one MLP parent).
103+
"""
104+
groups: list[list[nn.Module]] = []
105+
wq_attr = quantizer_attr_names("weight").weight_quantizer
106+
for parent in model.modules():
107+
for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS:
108+
members: list[nn.Module] = []
109+
for n in sibling_names:
110+
child = getattr(parent, n, None)
111+
if child is None:
112+
continue
113+
wq = getattr(child, wq_attr, None)
114+
if _is_calibrated_nvfp4_static_weight_quantizer(wq):
115+
members.append(child)
116+
if len(members) >= 2:
117+
groups.append(members)
118+
return groups
119+
120+
121+
@torch.no_grad()
122+
def sync_grouped_weight_global_amax(model: nn.Module) -> int:
123+
"""Sync ``global_amax`` across sibling NVFP4-static weight quantizers.
124+
125+
For each group of siblings (Q/K/V projections under one attention parent;
126+
gate/up — a.k.a. ``w1``/``w3`` — under one MLP parent) unifies the
127+
NVFP4 ``global_amax`` so the per-block FP8 round picks scales against a
128+
consistent FP8 grid across the group during MSE / local-Hessian search.
129+
130+
Reuses :func:`modelopt.torch.export.quant_utils.preprocess_linear_fusion`
131+
(whose ``NVFP4StaticQuantizer`` branch performs the same
132+
``max(stack(global_amax))`` unification at export time). To call it before
133+
MSE, this helper first promotes each grouped weight quantizer to
134+
:class:`NVFP4StaticQuantizer` with its local ``global_amax`` (=
135+
``reduce_amax(_amax)``); ``preprocess_linear_fusion`` then unifies in
136+
place.
137+
138+
Must be called after ``max_calibrate`` has populated each weight
139+
quantizer's ``_amax``. Idempotent. Returns the number of groups synced.
140+
"""
141+
from modelopt.torch.export.quant_utils import preprocess_linear_fusion
142+
143+
n_groups = 0
144+
for group in _collect_grouped_linears(model):
145+
# Promote each member's weight quantizer so `preprocess_linear_fusion`
146+
# sees post-conversion NVFP4StaticQuantizers (its NVFP4 branch reads
147+
# `global_amax`, which only exists post-promotion).
148+
wq_attr = quantizer_attr_names("weight").weight_quantizer
149+
for child in group:
150+
wq = getattr(child, wq_attr)
151+
if not isinstance(wq, NVFP4StaticQuantizer):
152+
local_global = reduce_amax(wq._amax, axis=None)
153+
NVFP4StaticQuantizer.from_tensor_quantizer(wq, global_amax=local_global)
154+
preprocess_linear_fusion(group)
155+
n_groups += 1
156+
return n_groups
157+
158+
69159
CalibratorFactory: TypeAlias = Callable[
70160
[torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator
71161
]
@@ -349,6 +439,13 @@ def mse_calibrate(
349439
# Step 1: First get initial amax using max calibration
350440
max_calibrate(model, forward_loop, distributed_sync)
351441

442+
# Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers
443+
# (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one
444+
# MLP block) so their FP8 scale-of-scales matches and the per-block FP8
445+
# round uses a consistent grid. No-op when there are no sibling groups
446+
# (e.g. fused QKV / fused gate_up_proj).
447+
sync_grouped_weight_global_amax(model)
448+
352449
# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
353450
# and identify weight quantizers
354451
weight_quantizers = []
@@ -360,19 +457,16 @@ def mse_calibrate(
360457
# Get the initial amax from max calibration
361458
initial_amax = module._amax.clone().detach()
362459

363-
is_nvfp4_static = (
364-
module.is_static_block_quant
365-
and module._num_bits == (2, 1)
366-
and module._block_sizes is not None
367-
and module._block_sizes.get("scale_bits") == (4, 3)
368-
)
460+
is_nvfp4_static = module.is_nvfp4_static
369461

370462
if is_nvfp4_static:
371-
# Compute and set global_amax
372-
global_amax = reduce_amax(initial_amax, axis=None)
373-
374-
# Convert to NVFP4StaticQuantizer in-place
375-
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
463+
# If sync_grouped_weight_global_amax already promoted this
464+
# quantizer (it's a sibling in a Q/K/V or gate/up group),
465+
# its global_amax has been unified across the group; just
466+
# leave it. Otherwise convert + set local global_amax.
467+
if not isinstance(module, NVFP4StaticQuantizer):
468+
global_amax = reduce_amax(initial_amax, axis=None)
469+
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
376470

377471
if fp8_scale_sweep:
378472
# Check if backend has a registered custom calibrator factory.
@@ -612,6 +706,11 @@ def forward(self, input, *args, **kwargs):
612706
print_rank_0("local_hessian: Running max calibration for all quantizers...")
613707
max_calibrate(model, forward_loop, distributed_sync)
614708

709+
# Sync global_amax across sibling NVFP4-static weight quantizers
710+
# (q/k/v_proj, gate/up_proj a.k.a. w1/w3) so the FP8 scale-of-scales
711+
# is consistent across the group. Idempotent; no-op when fused.
712+
sync_grouped_weight_global_amax(model)
713+
615714
# Setup helpers for all quantized linear modules
616715
name_to_module = dict(model.named_modules())
617716
weight_quantizers_info = []
@@ -666,14 +765,9 @@ def quant_func(x, amax, quantizer=weight_quantizer):
666765

667766
return xq
668767

669-
is_nvfp4_static = (
670-
weight_quantizer.is_static_block_quant
671-
and weight_quantizer._num_bits == (2, 1)
672-
and weight_quantizer._block_sizes is not None
673-
and weight_quantizer._block_sizes.get("scale_bits") == (4, 3)
674-
)
768+
is_nvfp4_static = weight_quantizer.is_nvfp4_static
675769

676-
if is_nvfp4_static:
770+
if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer):
677771
global_amax = reduce_amax(initial_amax, axis=None)
678772
NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax)
679773

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,22 @@ def is_mx_format(self):
514514
and self.block_sizes.get("scale_bits", None) == (8, 0)
515515
)
516516

517+
@property
518+
def is_nvfp4_static(self):
519+
"""Check if this quantizer is configured for NVFP4 static block quantization.
520+
521+
Format-only check (does not consider whether ``_amax`` has been
522+
populated by calibration). True when the quantizer holds E2M1 weights
523+
with E4M3 per-block scales in a static layout — i.e. the two-level
524+
scaling NVFP4 path consumed by :class:`NVFP4StaticQuantizer`.
525+
"""
526+
return (
527+
self.is_static_block_quant
528+
and self._num_bits == (2, 1)
529+
and self._block_sizes is not None
530+
and self._block_sizes.get("scale_bits") == (4, 3)
531+
)
532+
517533
def is_mxfp(self, bits):
518534
"""Check if is MXFP4/MXFP6/MXFP8."""
519535
if bits == 4:

modelopt/torch/quantization/utils/core_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -957,13 +957,7 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int:
957957
for _name, module in list(model.named_modules()):
958958
if isinstance(module, TensorQuantizer) and not module._disabled:
959959
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
960-
is_nvfp4_static = (
961-
module.is_static_block_quant
962-
and module._num_bits == (2, 1)
963-
and module._block_sizes is not None
964-
and module._block_sizes.get("scale_bits") == (4, 3)
965-
)
966-
if is_nvfp4_static:
960+
if module.is_nvfp4_static:
967961
initial_amax = module._amax.clone().detach()
968962
global_amax = reduce_amax(initial_amax, axis=None)
969963
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)

0 commit comments

Comments
 (0)