Skip to content

Commit adee8b5

Browse files
committed
[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE
Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE / Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers live in `nn.ModuleList`s (`gate_up_proj_weight_quantizers`, `down_proj_weight_quantizers`): 1. Add `_QuantFusedExperts.iter_weights_for_calibration` that yields per-expert (weight_slice, quantizer) pairs for both projections. The base impl uses singular `*_weight_quantizer` and silently skips fused-experts modules, so weight-only calibration paths never reach per-expert quantizers. 2. Refactor `mse_calibrate`: - Add `_bootstrap_uncalibrated_weight_quantizers` after `max_calibrate` to populate `_amax` on quantizers the forward pass didn't reach (dead MoE experts that received no calibration tokens). Runs the existing calibrator on the weight slice surfaced by `iter_weights_for_calibration`. - Replace the singular-only `weight_attr_names` discovery + `getattr`-by- name walk with an `iter_weights_for_calibration` walk done inside each parent module's `enable_weight_access_and_writeback` context, so MSE processes every per-expert quantizer (active and dead) and remains FSDP-safe. Without this, the export-time fallback in `_export_fused_experts` derived separate gate/up amaxes from each half of the fused weight, breaking the gate==up `weight_scale_2` invariant on dead experts. End-to-end check on Qwen3.5-122B-A10B with `nvfp4_experts_only_mse-fp8_cast_kv`: - Before: 1/12288 (layer 38 expert 69) gate \!= up; 0 weights MSE-calibrated - After: 0/12288 mismatches; 24576 weights MSE-calibrated; ~4.2 min Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 3587238 commit adee8b5

5 files changed

Lines changed: 129 additions & 147 deletions

File tree

modelopt/torch/export/moe_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,22 +110,16 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
110110
and w_quantizer._amax.dim() >= 1
111111
):
112112
amax = w_quantizer._amax
113-
# Static block-quant calibration (e.g. NVFP4 MSE FP8 sweep)
114-
# produces a per-block _amax with shape (num_blocks_total, ...)
115-
# where num_blocks_total = fused_total * blocks_per_row. That
116-
# shape collapses the row axis we want to slice on. Restore the
117-
# row dimension so the dim-0 slicing below splits gate / up
118-
# correctly. No-op when _amax is already aligned with fused_total.
113+
# Per-block _amax (NVFP4 static) collapses the row axis we want
114+
# to slice on; restore it so dim-0 slicing splits gate/up.
119115
if amax.numel() != fused_total and amax.numel() % fused_total == 0:
120116
amax = amax.contiguous().view(fused_total, amax.numel() // fused_total)
121117
amax_dim0 = amax.shape[0]
122118
if fused_total % amax_dim0 == 0:
123119
slice_start = fused_start * amax_dim0 // fused_total
124120
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
125121
sliced = amax[slice_start:slice_end].contiguous()
126-
# The amax setter refuses shape changes once `_amax` exists,
127-
# so drop the existing buffer before re-registering with the
128-
# sliced shape.
122+
# The amax setter refuses shape changes; drop _amax first.
129123
if hasattr(w_quantizer, "_amax"):
130124
delattr(w_quantizer, "_amax")
131125
w_quantizer.amax = sliced

modelopt/torch/export/unified_export_hf.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,15 +1136,10 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None:
11361136

11371137

11381138
def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None:
1139-
"""Coerce ``model.generation_config`` so it passes transformers' strict validation.
1140-
1141-
Some upstream HF checkpoints ship a ``generation_config.json`` that mixes
1142-
``do_sample=False`` with sampling-only attrs (``top_p``, ``top_k``, ...).
1143-
Newer transformers raise ``ValueError("GenerationConfig is invalid: ...")``
1144-
inside ``save_pretrained``, blocking export. We try a strict validate and
1145-
on failure flip ``do_sample`` to ``True`` so the upstream sampling intent
1146-
is preserved (rather than silently dropping ``top_p`` etc.). Quietly does
1147-
nothing if the model has no generation_config or it's already valid.
1139+
"""Flip ``do_sample=True`` when generation_config mixes it with sampling attrs.
1140+
1141+
Some upstream HF checkpoints set ``top_p``/``top_k`` with ``do_sample=False``,
1142+
which newer transformers reject in ``save_pretrained``'s strict validate.
11481143
"""
11491144
gc = getattr(model, "generation_config", None)
11501145
if gc is None or not hasattr(gc, "validate"):
@@ -1253,10 +1248,6 @@ def export_hf_checkpoint(
12531248
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
12541249
_patches = _patch_revert_weight_conversion()
12551250

1256-
# Some upstream HF checkpoints ship a generation_config.json that fails
1257-
# transformers' strict validation on save (e.g. ``top_p`` set without
1258-
# ``do_sample=True`` — newer transformers raises). Flip ``do_sample`` to
1259-
# the sampling-attrs intent so save_pretrained can write the file.
12601251
_sanitize_generation_config_for_save(model)
12611252

12621253
try:

modelopt/torch/quantization/model_calib.py

Lines changed: 103 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
promote_nvfp4_static_quantizers,
5353
quantizer_attr_names,
5454
reduce_amax,
55-
weight_attr_names,
5655
)
5756
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper
5857

@@ -68,89 +67,91 @@
6867
]
6968

7069

71-
# Sibling weight-quantizer name groups whose ``global_amax`` should share an
72-
# FP8 scale-of-scales. All members of a group sit under the same parent module
73-
# (e.g. one self-attention or one MLP block) and either consume the same input
74-
# tensor or get fused at deployment, so a divergent global_amax across siblings
75-
# would split their FP8 grids and skew the round.
70+
# Sibling groups that share an FP8 scale-of-scales: members feed the same input
71+
# (Q/K/V) or get fused at deployment (gate/up), so divergent global_amax would
72+
# split their FP8 grids.
7673
_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = (
77-
# Standard self-attention (skipped for fused qkv_proj — single weight).
7874
("q_proj", "k_proj", "v_proj"),
79-
# Gated MLP, modern naming (Llama / Qwen / Mistral / etc.).
80-
("gate_proj", "up_proj"),
81-
# Gated MLP, older Mixtral-style naming.
82-
("w1", "w3"),
75+
("gate_proj", "up_proj"), # Llama/Qwen/Mistral
76+
("w1", "w3"), # Mixtral
8377
)
8478

8579

86-
def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool:
87-
"""True for an NVFP4-static weight quantizer that ``max_calibrate`` already
88-
populated with a per-block ``_amax`` and that is currently enabled.
89-
"""
90-
return (
91-
isinstance(q, TensorQuantizer)
92-
and not q._disabled
93-
and q.is_nvfp4_static
94-
and hasattr(q, "_amax")
95-
and q._amax is not None
96-
)
97-
98-
9980
def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
100-
"""Find groups of Linear-like submodules whose NVFP4-static weight quantizers
101-
should share ``global_amax`` (Q/K/V under one attention parent; gate/up under
102-
one MLP parent).
103-
"""
81+
"""Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers."""
10482
groups: list[list[nn.Module]] = []
10583
wq_attr = quantizer_attr_names("weight").weight_quantizer
10684
for parent in model.modules():
10785
for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS:
108-
members: list[nn.Module] = []
86+
members = []
10987
for n in sibling_names:
11088
child = getattr(parent, n, None)
111-
if child is None:
112-
continue
113-
wq = getattr(child, wq_attr, None)
114-
if _is_calibrated_nvfp4_static_weight_quantizer(wq):
89+
wq = getattr(child, wq_attr, None) if child is not None else None
90+
if (
91+
isinstance(wq, TensorQuantizer)
92+
and not wq._disabled
93+
and wq.is_nvfp4_static
94+
and getattr(wq, "_amax", None) is not None
95+
):
11596
members.append(child)
11697
if len(members) >= 2:
11798
groups.append(members)
11899
return groups
119100

120101

102+
@torch.no_grad()
103+
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
104+
"""Populate ``_amax`` from weights for quantizers the forward pass didn't reach.
105+
106+
Dead MoE experts that received no tokens are otherwise skipped by
107+
``mse_calibrate``, leaving export to derive separate per-half amax for
108+
gate/up and break the gate==up ``weight_scale_2`` invariant.
109+
"""
110+
n = 0
111+
for module in model.modules():
112+
if not isinstance(module, QuantModule):
113+
continue
114+
for weight, q in module.iter_weights_for_calibration():
115+
if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
116+
continue
117+
if q._calibrator is None:
118+
continue
119+
if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0):
120+
continue
121+
q.disable_quant()
122+
q.enable_calib()
123+
q(weight)
124+
if q._calibrator.compute_amax() is not None:
125+
q.load_calib_amax()
126+
q.enable_quant()
127+
q.disable_calib()
128+
if hasattr(q._calibrator, "reset"):
129+
q._calibrator.reset()
130+
n += 1
131+
return n
132+
133+
121134
@torch.no_grad()
122135
def sync_grouped_weight_global_amax(model: nn.Module) -> int:
123-
"""Sync ``global_amax`` across sibling NVFP4-static weight quantizers.
124-
125-
For each group of siblings (Q/K/V projections under one attention parent;
126-
gate/up — a.k.a. ``w1``/``w3`` — under one MLP parent) unifies the
127-
NVFP4 ``global_amax`` so the per-block FP8 round picks scales against a
128-
consistent FP8 grid across the group during MSE / local-Hessian search.
129-
130-
Reuses :func:`modelopt.torch.export.quant_utils.preprocess_linear_fusion`
131-
(whose ``NVFP4StaticQuantizer`` branch performs the same
132-
``max(stack(global_amax))`` unification at export time). To call it before
133-
MSE, this helper first promotes each grouped weight quantizer to
134-
:class:`NVFP4StaticQuantizer` with its local ``global_amax`` (=
135-
``reduce_amax(_amax)``); ``preprocess_linear_fusion`` then unifies in
136-
place.
137-
138-
Must be called after ``max_calibrate`` has populated each weight
139-
quantizer's ``_amax``. Idempotent. Returns the number of groups synced.
136+
"""Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers.
137+
138+
Reuses ``preprocess_linear_fusion`` (which performs the same unification at
139+
export time) to keep the FP8 scale-of-scales consistent across siblings
140+
during MSE / local-Hessian search. Must run after ``max_calibrate``.
140141
"""
142+
# Inline: quant_utils imports enable_stats_collection/finish_stats_collection/svd
143+
# from this module, so top-level would deadlock the cycle.
141144
from modelopt.torch.export.quant_utils import preprocess_linear_fusion
142145

146+
wq_attr = quantizer_attr_names("weight").weight_quantizer
143147
n_groups = 0
144148
for group in _collect_grouped_linears(model):
145-
# Promote each member's weight quantizer so `preprocess_linear_fusion`
146-
# sees post-conversion NVFP4StaticQuantizers (its NVFP4 branch reads
147-
# `global_amax`, which only exists post-promotion).
148-
wq_attr = quantizer_attr_names("weight").weight_quantizer
149149
for child in group:
150150
wq = getattr(child, wq_attr)
151151
if not isinstance(wq, NVFP4StaticQuantizer):
152-
local_global = reduce_amax(wq._amax, axis=None)
153-
NVFP4StaticQuantizer.from_tensor_quantizer(wq, global_amax=local_global)
152+
NVFP4StaticQuantizer.from_tensor_quantizer(
153+
wq, global_amax=reduce_amax(wq._amax, axis=None)
154+
)
154155
preprocess_linear_fusion(group)
155156
n_groups += 1
156157
return n_groups
@@ -436,37 +437,26 @@ def mse_calibrate(
436437
See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
437438
details on the remaining arguments.
438439
"""
439-
# Step 1: First get initial amax using max calibration
440+
# Step 1: max calibration; then populate _amax for dead experts so step 3
441+
# doesn't skip them, and unify NVFP4 global_amax across Q/K/V and gate/up
442+
# siblings so MSE searches against a consistent FP8 grid.
440443
max_calibrate(model, forward_loop, distributed_sync)
441-
442-
# Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers
443-
# (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one
444-
# MLP block) so their FP8 scale-of-scales matches and the per-block FP8
445-
# round uses a consistent grid. No-op when there are no sibling groups
446-
# (e.g. fused QKV / fused gate_up_proj).
444+
_bootstrap_uncalibrated_weight_quantizers(model)
447445
sync_grouped_weight_global_amax(model)
448446

449-
# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
450-
# and identify weight quantizers
451-
weight_quantizers = []
452-
seen_modules = set()
453-
447+
# Step 2: replace calibrators with MseCalibrator for enabled quantizers.
454448
for name, module in list(model.named_modules()):
455449
if isinstance(module, TensorQuantizer) and not module._disabled:
456450
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
457-
# Get the initial amax from max calibration
458451
initial_amax = module._amax.clone().detach()
459-
460452
is_nvfp4_static = module.is_nvfp4_static
461453

462-
if is_nvfp4_static:
463-
# If sync_grouped_weight_global_amax already promoted this
464-
# quantizer (it's a sibling in a Q/K/V or gate/up group),
465-
# its global_amax has been unified across the group; just
466-
# leave it. Otherwise convert + set local global_amax.
467-
if not isinstance(module, NVFP4StaticQuantizer):
468-
global_amax = reduce_amax(initial_amax, axis=None)
469-
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
454+
# sync_grouped_weight_global_amax may have already promoted +
455+
# unified global_amax across the sibling group; only promote
456+
# standalone (non-grouped) NVFP4-static quantizers here.
457+
if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer):
458+
global_amax = reduce_amax(initial_amax, axis=None)
459+
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
470460

471461
if fp8_scale_sweep:
472462
# Check if backend has a registered custom calibrator factory.
@@ -506,52 +496,50 @@ def mse_calibrate(
506496
quant_func=partial(_mse_quant_func, quantizer=module),
507497
)
508498

509-
# Identify weight quantizers by checking if they have corresponding weight parameters
499+
# Step 3: calibrate weight quantizers via iter_weights_for_calibration.
500+
# The fused-experts override yields one pair per expert per projection, so
501+
# every per-expert quantizer is MSE-calibrated (not just routed ones).
510502
name_to_module = dict(model.named_modules())
503+
seen_modules: set[int] = set()
504+
pbar = tqdm(desc="MSE weight calibration")
505+
n_calibrated = 0
511506
for parent_module in name_to_module.values():
512-
if parent_module in seen_modules:
507+
if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule):
513508
continue
514-
for weight_name in weight_attr_names(parent_module):
515-
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
516-
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
517-
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
518-
if getattr(weight_quantizer, "_calibrator", None) is not None:
519-
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
520-
seen_modules.add(parent_module)
521-
522-
# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
523-
# This prevents massive memory accumulation seen in large models
524-
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(
525-
tqdm(weight_quantizers, desc="MSE weight calibration")
526-
):
527-
# Enable calibration mode for the weight quantizer
528-
weight_quantizer.disable_quant()
529-
weight_quantizer.enable_calib()
509+
seen_modules.add(id(parent_module))
530510
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
531-
weight = getattr(parent_module, weight_name)
532-
weight_quantizer(weight)
511+
for weight, weight_quantizer in parent_module.iter_weights_for_calibration():
512+
if not (
513+
isinstance(weight_quantizer, TensorQuantizer)
514+
and weight_quantizer.is_enabled
515+
and getattr(weight_quantizer, "_calibrator", None) is not None
516+
):
517+
continue
518+
weight_quantizer.disable_quant()
519+
weight_quantizer.enable_calib()
520+
weight_quantizer(weight)
533521

534-
# IMMEDIATELY compute amax and reset calibrator to free memory
535-
cal = getattr(weight_quantizer, "_calibrator", None)
536-
if cal is not None and cal.compute_amax() is not None:
537-
weight_quantizer.load_calib_amax()
522+
cal = weight_quantizer._calibrator
523+
if cal.compute_amax() is not None:
524+
weight_quantizer.load_calib_amax()
538525

539-
weight_quantizer.enable_quant()
540-
weight_quantizer.disable_calib()
526+
weight_quantizer.enable_quant()
527+
weight_quantizer.disable_calib()
541528

542-
# Synchronize ALL CUDA devices before resetting to ensure all async operations complete
543-
# This is critical for multi-GPU setups where tensors may be on different devices
544-
if torch.cuda.is_available():
545-
for dev_id in range(torch.cuda.device_count()):
546-
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
529+
if torch.cuda.is_available():
530+
for dev_id in range(torch.cuda.device_count()):
531+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
547532

548-
if cal is not None and hasattr(cal, "reset"):
549-
cal.reset()
533+
if hasattr(cal, "reset"):
534+
cal.reset()
550535

551-
if (idx + 1) % 10 == 0 and torch.cuda.is_available():
552-
for dev_id in range(torch.cuda.device_count()):
553-
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
554-
torch.cuda.empty_cache()
536+
pbar.update(1)
537+
n_calibrated += 1
538+
if n_calibrated % 10 == 0 and torch.cuda.is_available():
539+
for dev_id in range(torch.cuda.device_count()):
540+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
541+
torch.cuda.empty_cache()
542+
pbar.close()
555543

556544
if torch.cuda.is_available():
557545
for dev_id in range(torch.cuda.device_count()):
@@ -706,9 +694,6 @@ def forward(self, input, *args, **kwargs):
706694
print_rank_0("local_hessian: Running max calibration for all quantizers...")
707695
max_calibrate(model, forward_loop, distributed_sync)
708696

709-
# Sync global_amax across sibling NVFP4-static weight quantizers
710-
# (q/k/v_proj, gate/up_proj a.k.a. w1/w3) so the FP8 scale-of-scales
711-
# is consistent across the group. Idempotent; no-op when fused.
712697
sync_grouped_weight_global_amax(model)
713698

714699
# Setup helpers for all quantized linear modules

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -516,13 +516,7 @@ def is_mx_format(self):
516516

517517
@property
518518
def is_nvfp4_static(self):
519-
"""Check if this quantizer is configured for NVFP4 static block quantization.
520-
521-
Format-only check (does not consider whether ``_amax`` has been
522-
populated by calibration). True when the quantizer holds E2M1 weights
523-
with E4M3 per-block scales in a static layout — i.e. the two-level
524-
scaling NVFP4 path consumed by :class:`NVFP4StaticQuantizer`.
525-
"""
519+
"""True for E2M1 weights + E4M3 per-block scales in static layout (format-only check)."""
526520
return (
527521
self.is_static_block_quant
528522
and self._num_bits == (2, 1)

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,24 @@ def forward(self, *args, **kwargs):
900900
self._down_proj_linear = False
901901
return super().forward(*args, **kwargs)
902902

903+
def iter_weights_for_calibration(self):
904+
"""Yield ``(weight_slice, quantizer)`` per-expert pairs.
905+
906+
The base impl uses singular ``*_weight_quantizer`` and skips fused-
907+
experts modules, so weight-only calibration never reaches per-expert
908+
quantizers without this override.
909+
"""
910+
for weight_name, quantizers_name in (
911+
("gate_up_proj", "gate_up_proj_weight_quantizers"),
912+
("down_proj", "down_proj_weight_quantizers"),
913+
):
914+
weight = getattr(self, weight_name, None)
915+
quantizers = getattr(self, quantizers_name, None)
916+
if weight is None or quantizers is None:
917+
continue
918+
for idx, q in enumerate(quantizers):
919+
yield weight[idx], q
920+
903921
def fold_weight(self, keep_attrs: bool = False):
904922
"""Fold per-expert weight quantizers into the fused 3-D weights.
905923

0 commit comments

Comments
 (0)