Skip to content

Commit ff01478

Browse files
committed
[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE
Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE / Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers live in `nn.ModuleList`s (`gate_up_proj_weight_quantizers`, `down_proj_weight_quantizers`): 1. Add `_QuantFusedExperts.iter_weights_for_calibration` that yields per-expert (weight_slice, quantizer) pairs for both projections. The base impl uses singular `*_weight_quantizer` and silently skips fused-experts modules, so weight-only calibration paths never reach per-expert quantizers. 2. Refactor `mse_calibrate`: - Add `_bootstrap_uncalibrated_weight_quantizers` after `max_calibrate` to populate `_amax` on quantizers the forward pass didn't reach (dead MoE experts that received no calibration tokens). Runs the existing calibrator on the weight slice surfaced by `iter_weights_for_calibration`. - Replace the singular-only `weight_attr_names` discovery + `getattr`-by- name walk with an `iter_weights_for_calibration` walk done inside each parent module's `enable_weight_access_and_writeback` context, so MSE processes every per-expert quantizer (active and dead) and remains FSDP-safe. Without this, the export-time fallback in `_export_fused_experts` derived separate gate/up amaxes from each half of the fused weight, breaking the gate==up `weight_scale_2` invariant on dead experts. End-to-end check on Qwen3.5-122B-A10B with `nvfp4_experts_only_mse-fp8_cast_kv`: - Before: 1/12288 (layer 38 expert 69) gate \!= up; 0 weights MSE-calibrated - After: 0/12288 mismatches; 24576 weights MSE-calibrated; ~4.2 min Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent eaa953b commit ff01478

6 files changed

Lines changed: 335 additions & 144 deletions

File tree

modelopt/torch/export/moe_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,22 +110,16 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
110110
and w_quantizer._amax.dim() >= 1
111111
):
112112
amax = w_quantizer._amax
113-
# Static block-quant calibration (e.g. NVFP4 MSE FP8 sweep)
114-
# produces a per-block _amax with shape (num_blocks_total, ...)
115-
# where num_blocks_total = fused_total * blocks_per_row. That
116-
# shape collapses the row axis we want to slice on. Restore the
117-
# row dimension so the dim-0 slicing below splits gate / up
118-
# correctly. No-op when _amax is already aligned with fused_total.
113+
# Per-block _amax (NVFP4 static) collapses the row axis we want
114+
# to slice on; restore it so dim-0 slicing splits gate/up.
119115
if amax.numel() != fused_total and amax.numel() % fused_total == 0:
120116
amax = amax.contiguous().view(fused_total, amax.numel() // fused_total)
121117
amax_dim0 = amax.shape[0]
122118
if fused_total % amax_dim0 == 0:
123119
slice_start = fused_start * amax_dim0 // fused_total
124120
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
125121
sliced = amax[slice_start:slice_end].contiguous()
126-
# The amax setter refuses shape changes once `_amax` exists,
127-
# so drop the existing buffer before re-registering with the
128-
# sliced shape.
122+
# The amax setter refuses shape changes; drop _amax first.
129123
if hasattr(w_quantizer, "_amax"):
130124
delattr(w_quantizer, "_amax")
131125
w_quantizer.amax = sliced

modelopt/torch/export/unified_export_hf.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,10 +1241,6 @@ def export_hf_checkpoint(
12411241
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
12421242
_patches = _patch_revert_weight_conversion()
12431243

1244-
# Some upstream HF checkpoints ship a generation_config.json that fails
1245-
# transformers' strict validation on save (e.g. ``top_p`` set without
1246-
# ``do_sample=True`` — newer transformers raises). Flip ``do_sample`` to
1247-
# the sampling-attrs intent so save_pretrained can write the file.
12481244
_sanitize_generation_config_for_save(model)
12491245

12501246
try:

modelopt/torch/quantization/model_calib.py

Lines changed: 114 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
promote_nvfp4_static_quantizers,
5353
quantizer_attr_names,
5454
reduce_amax,
55-
weight_attr_names,
5655
)
5756
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper
5857

@@ -64,93 +63,104 @@
6463
"max_calibrate",
6564
"smoothquant",
6665
"svdquant",
67-
"sync_grouped_weight_global_amax",
6866
]
6967

7068

71-
# Sibling weight-quantizer name groups whose ``global_amax`` should share an
72-
# FP8 scale-of-scales. All members of a group sit under the same parent module
73-
# (e.g. one self-attention or one MLP block) and either consume the same input
74-
# tensor or get fused at deployment, so a divergent global_amax across siblings
75-
# would split their FP8 grids and skew the round.
76-
_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = (
77-
# Standard self-attention (skipped for fused qkv_proj — single weight).
78-
("q_proj", "k_proj", "v_proj"),
79-
# Gated MLP, modern naming (Llama / Qwen / Mistral / etc.).
80-
("gate_proj", "up_proj"),
81-
# Gated MLP, older Mixtral-style naming.
82-
("w1", "w3"),
83-
)
84-
85-
86-
def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool:
87-
"""True for an NVFP4-static weight quantizer that ``max_calibrate`` already
88-
populated with a per-block ``_amax`` and that is currently enabled.
89-
"""
69+
def _is_calibrated_nvfp4_static(q) -> bool:
70+
"""True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set."""
9071
return (
9172
isinstance(q, TensorQuantizer)
9273
and not q._disabled
9374
and q.is_nvfp4_static
94-
and hasattr(q, "_amax")
95-
and q._amax is not None
75+
and getattr(q, "_amax", None) is not None
9676
)
9777

9878

9979
def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
100-
"""Find groups of Linear-like submodules whose NVFP4-static weight quantizers
101-
should share ``global_amax`` (Q/K/V under one attention parent; gate/up under
102-
one MLP parent).
103-
"""
80+
"""Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers."""
81+
# Inline: layer_utils → quant_utils → model_calib cycle.
82+
from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS
83+
84+
# Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant
85+
# in export). Single source for the gate/up half avoids parallel lists.
86+
patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS)
10487
groups: list[list[nn.Module]] = []
10588
wq_attr = quantizer_attr_names("weight").weight_quantizer
10689
for parent in model.modules():
107-
for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS:
108-
members: list[nn.Module] = []
109-
for n in sibling_names:
110-
child = getattr(parent, n, None)
111-
if child is None:
112-
continue
113-
wq = getattr(child, wq_attr, None)
114-
if _is_calibrated_nvfp4_static_weight_quantizer(wq):
115-
members.append(child)
90+
for sibling_names in patterns:
91+
members = [
92+
child
93+
for child in (getattr(parent, n, None) for n in sibling_names)
94+
if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None))
95+
]
11696
if len(members) >= 2:
11797
groups.append(members)
11898
return groups
11999

120100

121101
@torch.no_grad()
122-
def sync_grouped_weight_global_amax(model: nn.Module) -> int:
123-
"""Sync ``global_amax`` across sibling NVFP4-static weight quantizers.
124-
125-
For each group of siblings (Q/K/V projections under one attention parent;
126-
gate/up — a.k.a. ``w1``/``w3`` — under one MLP parent) unifies the
127-
NVFP4 ``global_amax`` so the per-block FP8 round picks scales against a
128-
consistent FP8 grid across the group during MSE / local-Hessian search.
129-
130-
Reuses :func:`modelopt.torch.export.quant_utils.preprocess_linear_fusion`
131-
(whose ``NVFP4StaticQuantizer`` branch performs the same
132-
``max(stack(global_amax))`` unification at export time). To call it before
133-
MSE, this helper first promotes each grouped weight quantizer to
134-
:class:`NVFP4StaticQuantizer` with its local ``global_amax`` (=
135-
``reduce_amax(_amax)``); ``preprocess_linear_fusion`` then unifies in
136-
place.
137-
138-
Must be called after ``max_calibrate`` has populated each weight
139-
quantizer's ``_amax``. Idempotent. Returns the number of groups synced.
102+
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
103+
"""Re-run weight calibration on the weight tensor for quantizers missing ``_amax``.
104+
105+
Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE
106+
doesn't drop them and break the gate==up ``weight_scale_2`` export invariant.
107+
Activation quantizers on those modules remain uncalibrated; emits a warning.
108+
"""
109+
name_to_module = dict(model.named_modules())
110+
n = 0
111+
for module in name_to_module.values():
112+
if not isinstance(module, QuantModule):
113+
continue
114+
with enable_weight_access_and_writeback(module, model, name_to_module):
115+
for weight, q in module.iter_weights_for_calibration():
116+
if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
117+
continue
118+
if q._calibrator is None:
119+
continue
120+
if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0):
121+
continue
122+
q.disable_quant()
123+
q.enable_calib()
124+
q(weight)
125+
if q._calibrator.compute_amax() is not None:
126+
q.load_calib_amax()
127+
q.enable_quant()
128+
q.disable_calib()
129+
if hasattr(q._calibrator, "reset"):
130+
q._calibrator.reset()
131+
n += 1
132+
if n > 0:
133+
warnings.warn(
134+
f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; "
135+
f"their activation quantizers (if any) remain uncalibrated. "
136+
f"Increase calib size/seq len to activate all experts.",
137+
stacklevel=2,
138+
)
139+
return n
140+
141+
142+
@torch.no_grad()
143+
def _sync_grouped_weight_global_amax(model: nn.Module) -> int:
144+
"""Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers.
145+
146+
Run after ``max_calibrate``. Sibling discovery is name-based via
147+
``_collect_grouped_linears``; non-matching architectures (wqkv, fused
148+
qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently
149+
fall back to per-module global_amax. Fused-experts containers already
150+
share a single quantizer across gate/up halves and need no sync.
140151
"""
152+
# quant_utils imports back from this module; top-level would cycle.
141153
from modelopt.torch.export.quant_utils import preprocess_linear_fusion
142154

155+
wq_attr = quantizer_attr_names("weight").weight_quantizer
143156
n_groups = 0
144157
for group in _collect_grouped_linears(model):
145-
# Promote each member's weight quantizer so `preprocess_linear_fusion`
146-
# sees post-conversion NVFP4StaticQuantizers (its NVFP4 branch reads
147-
# `global_amax`, which only exists post-promotion).
148-
wq_attr = quantizer_attr_names("weight").weight_quantizer
149158
for child in group:
150159
wq = getattr(child, wq_attr)
151160
if not isinstance(wq, NVFP4StaticQuantizer):
152-
local_global = reduce_amax(wq._amax, axis=None)
153-
NVFP4StaticQuantizer.from_tensor_quantizer(wq, global_amax=local_global)
161+
NVFP4StaticQuantizer.from_tensor_quantizer(
162+
wq, global_amax=reduce_amax(wq._amax, axis=None)
163+
)
154164
preprocess_linear_fusion(group)
155165
n_groups += 1
156166
return n_groups
@@ -436,37 +446,24 @@ def mse_calibrate(
436446
See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
437447
details on the remaining arguments.
438448
"""
439-
# Step 1: First get initial amax using max calibration
449+
# Step 1: max calibrate, bootstrap dead-expert weight quantizers,
450+
# unify grouped NVFP4 global_amax so MSE sees a consistent FP8 grid.
440451
max_calibrate(model, forward_loop, distributed_sync)
452+
_bootstrap_uncalibrated_weight_quantizers(model)
453+
_sync_grouped_weight_global_amax(model)
441454

442-
# Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers
443-
# (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one
444-
# MLP block) so their FP8 scale-of-scales matches and the per-block FP8
445-
# round uses a consistent grid. No-op when there are no sibling groups
446-
# (e.g. fused QKV / fused gate_up_proj).
447-
sync_grouped_weight_global_amax(model)
448-
449-
# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
450-
# and identify weight quantizers
451-
weight_quantizers = []
452-
seen_modules = set()
453-
455+
# Step 2: replace calibrators with MseCalibrator for enabled quantizers.
454456
for name, module in list(model.named_modules()):
455457
if isinstance(module, TensorQuantizer) and not module._disabled:
456458
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
457-
# Get the initial amax from max calibration
458459
initial_amax = module._amax.clone().detach()
459-
460460
is_nvfp4_static = module.is_nvfp4_static
461461

462-
if is_nvfp4_static:
463-
# If sync_grouped_weight_global_amax already promoted this
464-
# quantizer (it's a sibling in a Q/K/V or gate/up group),
465-
# its global_amax has been unified across the group; just
466-
# leave it. Otherwise convert + set local global_amax.
467-
if not isinstance(module, NVFP4StaticQuantizer):
468-
global_amax = reduce_amax(initial_amax, axis=None)
469-
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
462+
# Promote standalone NVFP4-static quantizers; grouped siblings
463+
# already promoted by _sync_grouped_weight_global_amax above.
464+
if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer):
465+
global_amax = reduce_amax(initial_amax, axis=None)
466+
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
470467

471468
if fp8_scale_sweep:
472469
# Check if backend has a registered custom calibrator factory.
@@ -506,52 +503,48 @@ def mse_calibrate(
506503
quant_func=partial(_mse_quant_func, quantizer=module),
507504
)
508505

509-
# Identify weight quantizers by checking if they have corresponding weight parameters
506+
# Step 3: calibrate weight quantizers via iter_weights_for_calibration.
510507
name_to_module = dict(model.named_modules())
508+
seen_modules: set[int] = set()
509+
pbar = tqdm(desc="MSE weight calibration")
510+
n_calibrated = 0
511511
for parent_module in name_to_module.values():
512-
if parent_module in seen_modules:
512+
if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule):
513513
continue
514-
for weight_name in weight_attr_names(parent_module):
515-
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
516-
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
517-
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
518-
if getattr(weight_quantizer, "_calibrator", None) is not None:
519-
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
520-
seen_modules.add(parent_module)
521-
522-
# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
523-
# This prevents massive memory accumulation seen in large models
524-
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(
525-
tqdm(weight_quantizers, desc="MSE weight calibration")
526-
):
527-
# Enable calibration mode for the weight quantizer
528-
weight_quantizer.disable_quant()
529-
weight_quantizer.enable_calib()
514+
seen_modules.add(id(parent_module))
530515
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
531-
weight = getattr(parent_module, weight_name)
532-
weight_quantizer(weight)
516+
for weight, weight_quantizer in parent_module.iter_weights_for_calibration():
517+
if not (
518+
isinstance(weight_quantizer, TensorQuantizer)
519+
and weight_quantizer.is_enabled
520+
and getattr(weight_quantizer, "_calibrator", None) is not None
521+
):
522+
continue
523+
weight_quantizer.disable_quant()
524+
weight_quantizer.enable_calib()
525+
weight_quantizer(weight)
533526

534-
# IMMEDIATELY compute amax and reset calibrator to free memory
535-
cal = getattr(weight_quantizer, "_calibrator", None)
536-
if cal is not None and cal.compute_amax() is not None:
537-
weight_quantizer.load_calib_amax()
527+
cal = weight_quantizer._calibrator
528+
if cal.compute_amax() is not None:
529+
weight_quantizer.load_calib_amax()
538530

539-
weight_quantizer.enable_quant()
540-
weight_quantizer.disable_calib()
531+
weight_quantizer.enable_quant()
532+
weight_quantizer.disable_calib()
541533

542-
# Synchronize ALL CUDA devices before resetting to ensure all async operations complete
543-
# This is critical for multi-GPU setups where tensors may be on different devices
544-
if torch.cuda.is_available():
545-
for dev_id in range(torch.cuda.device_count()):
546-
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
534+
if torch.cuda.is_available():
535+
for dev_id in range(torch.cuda.device_count()):
536+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
547537

548-
if cal is not None and hasattr(cal, "reset"):
549-
cal.reset()
538+
if hasattr(cal, "reset"):
539+
cal.reset()
550540

551-
if (idx + 1) % 10 == 0 and torch.cuda.is_available():
552-
for dev_id in range(torch.cuda.device_count()):
553-
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
554-
torch.cuda.empty_cache()
541+
pbar.update(1)
542+
n_calibrated += 1
543+
if n_calibrated % 10 == 0 and torch.cuda.is_available():
544+
for dev_id in range(torch.cuda.device_count()):
545+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
546+
torch.cuda.empty_cache()
547+
pbar.close()
555548

556549
if torch.cuda.is_available():
557550
for dev_id in range(torch.cuda.device_count()):
@@ -706,10 +699,7 @@ def forward(self, input, *args, **kwargs):
706699
print_rank_0("local_hessian: Running max calibration for all quantizers...")
707700
max_calibrate(model, forward_loop, distributed_sync)
708701

709-
# Sync global_amax across sibling NVFP4-static weight quantizers
710-
# (q/k/v_proj, gate/up_proj a.k.a. w1/w3) so the FP8 scale-of-scales
711-
# is consistent across the group. Idempotent; no-op when fused.
712-
sync_grouped_weight_global_amax(model)
702+
_sync_grouped_weight_global_amax(model)
713703

714704
# Setup helpers for all quantized linear modules
715705
name_to_module = dict(model.named_modules())

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -516,13 +516,7 @@ def is_mx_format(self):
516516

517517
@property
518518
def is_nvfp4_static(self):
519-
"""Check if this quantizer is configured for NVFP4 static block quantization.
520-
521-
Format-only check (does not consider whether ``_amax`` has been
522-
populated by calibration). True when the quantizer holds E2M1 weights
523-
with E4M3 per-block scales in a static layout — i.e. the two-level
524-
scaling NVFP4 path consumed by :class:`NVFP4StaticQuantizer`.
525-
"""
519+
"""True for E2M1 weights + E4M3 per-block scales in static layout (format-only check)."""
526520
return (
527521
self.is_static_block_quant
528522
and self._num_bits == (2, 1)

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,24 @@ def forward(self, *args, **kwargs):
900900
self._down_proj_linear = False
901901
return super().forward(*args, **kwargs)
902902

903+
def iter_weights_for_calibration(self):
904+
"""Yield ``(weight_slice, quantizer)`` per-expert pairs.
905+
906+
The base impl uses singular ``*_weight_quantizer`` and skips fused-
907+
experts modules, so weight-only calibration never reaches per-expert
908+
quantizers without this override.
909+
"""
910+
for weight_name, quantizers_name in (
911+
("gate_up_proj", "gate_up_proj_weight_quantizers"),
912+
("down_proj", "down_proj_weight_quantizers"),
913+
):
914+
weight = getattr(self, weight_name, None)
915+
quantizers = getattr(self, quantizers_name, None)
916+
if weight is None or quantizers is None:
917+
continue
918+
for idx, q in enumerate(quantizers):
919+
yield weight[idx], q
920+
903921
def fold_weight(self, keep_attrs: bool = False):
904922
"""Fold per-expert weight quantizers into the fused 3-D weights.
905923

0 commit comments

Comments
 (0)