Skip to content

Commit 570920b

Browse files
authored
[NVbug 6142360] Share fused gate_up amax fallback to keep weight_scale_2 consistent (#1411)
## Summary - Pre-fill the source `gate_up_proj_weight_quantizers[idx]._amax` from the **fused** `gate_up[idx]` tensor when it is uncalibrated, **before** the per-projection deepcopy in `_export_fused_experts`. Both gate's clone and up's clone then inherit the same scalar amax, so `weight_scale_2 = amax / (6 · 448)` matches across the W1/W3 fusion that vLLM expects at load time. - Add a unit regression test that builds a fused-experts module with intentionally mismatched gate vs. up weight magnitudes, leaves every expert uncalibrated, and asserts the gate and up wrappers carry the same amax into the FP4 quantization step. Fails on `main`, passes with this fix. ## Why NVbug 6142360: Qwen3.5-MoE / Qwen3-Next NVFP4 checkpoints produced by ModelOpt yielded garbled output under vLLM (`1\n1\n1\n…` style degenerate token loops). The vLLM log showed `w1_weight_scale_2 must match w3_weight_scale_2. Accuracy may be affected.`. Root cause: in `_export_fused_experts`, when an expert receives no calibration tokens (common for low-frequency experts even with `--calib_size 512` on 128-expert MoEs), the per-projection fallback computed `amax` independently from each split slice: ```python w_quantizer.amax = weight_slice.abs().amax().to(torch.float32) ``` `weight_slice` is the gate-only or up-only 2-D half of the fused `gate_up_proj`. Since gate and up have different magnitudes, gate and up end up with different scalar amax values — and therefore different `weight_scale_2`. vLLM fuses W1 and W3 into a single weight at load time and asserts a single shared scale; the half whose scale was discarded is now off by the gate/up magnitude ratio (~10× was typical), which catastrophically corrupts the MoE output. The fix derives the fallback amax once from the whole `gate_up[idx]` tensor before the deepcopies, so gate and up share the same amax — exactly what calibration would have produced if any token had hit the expert. Calibrated experts are unaffected (the new code path is gated on `_amax` being missing or zero). `down_proj` keeps its existing per-projection fallback because it has its own quantizer with no fusion partner. ## End-to-end vLLM verification Used `vllm/vllm-openai:nightly` Docker on RTX 6000 Ada, Marlin NVFP4 backend, with a real Qwen3-30B-A3B-Instruct NVFP4 checkpoint. The bug requires the export-time scale skew, so I A/B-tested by mutating `gate_proj.weight_scale_2` directly: | Variant | gate vs up `weight_scale_2` | Output for "Write an article about AI." | |---|---|---| | Baseline (original checkpoint) | matched (calibrated) | coherent: "The Rise of Artificial Intelligence: Transforming the World One Algorithm at a Time…" | | Bug-simulated (gate = up / 10 for all 6143 expert pairs) | mismatched ~10× | `__':\n__':\n__':\n…` (degenerate loop, same shape as user's `1\n1\n1\n…`) | | Fix-simulated (gate restored to up) | matched | coherent: same article as baseline | The 10× skew matches what the bug actually emits for uncalibrated experts (e.g. `gate=2.77e-5, up=2.30e-4` from a synthetic repro), and reproduces the same garbled-loop pattern the user reported. ## Test plan - [x] New unit test `tests/unit/torch/quantization/plugins/test_fused_experts.py::TestExportFusedExperts::test_uncalibrated_expert_gate_up_share_amax` passes. - [x] Full `tests/unit/torch/quantization/plugins/test_fused_experts.py`: 28/28 pass. - [x] Broader `tests/unit/torch/export/` + `tests/unit/torch/quantization/`: 609 passed, 9 skipped. - [x] vLLM end-to-end A/B (table above): bug reproduces with mismatched scales, fix produces coherent output identical to baseline. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> - Did you get Claude approval on this PR?: ✅ / ❌ / N/A <!--- Run `/claude review`. NVIDIA org members can self-trigger for complex changes; orthogonal to CodeRabbit. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Enhanced handling of uncalibrated mixture-of-experts weight quantizers during export with improved consistency checks and informative warnings. * **Tests** * Added regression test for uncalibrated expert quantization behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: weimingc <weimingc@nvidia.com> Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 1f9c0bf commit 570920b

2 files changed

Lines changed: 111 additions & 0 deletions

File tree

modelopt/torch/export/moe_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,29 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
6262
for idx in range(n):
6363
expert = nn.Module()
6464

65+
# If the gate_up source quantizer was never calibrated (rare expert
66+
# that received no calibration tokens), derive its amax once from the
67+
# FUSED tensor so gate and up share the same weight_scale_2 below.
68+
# Why: vLLM fuses W1 (gate) and W3 (up) at load time and asserts a
69+
# single per-tensor scale across the fusion. The per-projection
70+
# fallback further down would otherwise compute amax independently from
71+
# each half — gate's max and up's max generally differ — producing
72+
# mismatched weight_scale_2 and garbled MoE output at inference.
73+
gate_up_q = module.gate_up_proj_weight_quantizers[idx]
74+
if getattr(gate_up_q, "is_enabled", False) and (
75+
not hasattr(gate_up_q, "_amax")
76+
or gate_up_q._amax is None
77+
or torch.all(gate_up_q._amax == 0)
78+
):
79+
gate_up_q.amax = gate_up[idx].abs().amax().to(torch.float32)
80+
warnings.warn(
81+
f"Expert {idx} gate_up_proj weight quantizer was not calibrated "
82+
f"(amax missing or zero). Using fused-tensor amax as fallback "
83+
f"(shared by gate and up so weight_scale_2 stays consistent). "
84+
f"Consider increasing calibration size to activate all experts.",
85+
stacklevel=2,
86+
)
87+
6588
projections = [
6689
("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0, True),
6790
("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0, True),

tests/unit/torch/quantization/plugins/test_fused_experts.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,94 @@ def test_export_creates_per_expert_submodules(self):
300300
if QuantModuleRegistry.get(expert_type) is not None:
301301
QuantModuleRegistry.unregister(expert_type)
302302

303+
def test_uncalibrated_expert_gate_up_share_amax(self, monkeypatch):
304+
"""gate_proj and up_proj must share weight_scale_2 even when an expert
305+
was never routed during calibration.
306+
307+
Regression for the bug where ``_export_fused_experts``'s per-projection
308+
fallback computed amax independently from the gate and up halves of the
309+
fused tensor — producing mismatched ``weight_scale_2`` values for any
310+
uncalibrated expert. vLLM fuses W1 (gate) and W3 (up) at load time and
311+
asserts a single shared scale; mismatched scales corrupted MoE output.
312+
The fix derives the fallback amax once from the fused ``gate_up[idx]``
313+
tensor before the deepcopies, so gate's clone and up's clone start with
314+
the same amax.
315+
"""
316+
from modelopt.torch.export.moe_utils import _export_fused_experts
317+
318+
# Build experts where gate and up have very different magnitudes —
319+
# any per-half fallback would clearly produce different amaxes.
320+
experts = _SyntheticFusedExperts()
321+
gate = torch.randn(NUM_EXPERTS, INTERMEDIATE_DIM, HIDDEN_DIM) * 0.02
322+
up = torch.randn(NUM_EXPERTS, INTERMEDIATE_DIM, HIDDEN_DIM) * 0.20
323+
with torch.no_grad():
324+
experts.gate_up_proj.copy_(torch.cat([gate, up], dim=1))
325+
326+
expert_type = type(experts)
327+
if QuantModuleRegistry.get(expert_type) is None:
328+
QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})(
329+
_QuantFusedExperts
330+
)
331+
try:
332+
converted = QuantModuleRegistry.convert(experts)
333+
334+
# Leave every expert weight quantizer uncalibrated (no _amax).
335+
# Mark them enabled to exercise the export-time fallback path.
336+
for q in converted.gate_up_proj_weight_quantizers:
337+
q._disabled = False
338+
for q in converted.down_proj_weight_quantizers:
339+
q._disabled = False
340+
341+
# Capture the amax each per-projection wrapper carries into the
342+
# FP4 quantization step. Patching here avoids needing CUDA / FP4.
343+
seen = {} # (expert_idx, proj_name) -> amax tensor
344+
345+
def _spy_export(wrapper, dtype):
346+
# Identify which expert/projection this wrapper belongs to by
347+
# matching the weight tensor against the fused parameters.
348+
w = wrapper.weight.data
349+
# gate_up_proj is (N, 2*INTER, HIDDEN); split halves are
350+
# contiguous .data views or .contiguous() copies — we can match
351+
# by shape and value identity for this synthetic case.
352+
amax = wrapper.weight_quantizer._amax.detach().clone()
353+
# Identify by matching against gate vs. up slices of each expert.
354+
for idx in range(NUM_EXPERTS):
355+
g_slice = converted.gate_up_proj.data[idx, :INTERMEDIATE_DIM, :]
356+
u_slice = converted.gate_up_proj.data[idx, INTERMEDIATE_DIM:, :]
357+
d_slice = converted.down_proj.data[idx]
358+
if w.shape == g_slice.shape and torch.equal(w, g_slice):
359+
seen[(idx, "gate_proj")] = amax
360+
return
361+
if w.shape == u_slice.shape and torch.equal(w, u_slice):
362+
seen[(idx, "up_proj")] = amax
363+
return
364+
if w.shape == d_slice.shape and torch.equal(w, d_slice):
365+
seen[(idx, "down_proj")] = amax
366+
return
367+
368+
monkeypatch.setattr(
369+
"modelopt.torch.export.unified_export_hf._export_quantized_weight",
370+
_spy_export,
371+
)
372+
373+
_export_fused_experts(converted, torch.float16)
374+
375+
# Assert: for every expert, gate's amax matches up's amax.
376+
for idx in range(NUM_EXPERTS):
377+
g_amax = seen.get((idx, "gate_proj"))
378+
u_amax = seen.get((idx, "up_proj"))
379+
assert g_amax is not None and u_amax is not None, (
380+
f"Expert {idx}: missing recorded amax (gate={g_amax}, up={u_amax})"
381+
)
382+
assert torch.allclose(g_amax, u_amax), (
383+
f"Expert {idx}: gate amax {g_amax.item()} != up amax {u_amax.item()}. "
384+
f"Uncalibrated fused experts must share gate/up amax so that "
385+
f"weight_scale_2 stays consistent across the fusion."
386+
)
387+
finally:
388+
if QuantModuleRegistry.get(expert_type) is not None:
389+
QuantModuleRegistry.unregister(expert_type)
390+
303391

304392
# ---------------------------------------------------------------------------
305393
# Tests for force_eager_experts_impl_on_the_fly

0 commit comments

Comments
 (0)