Skip to content

Commit cfe4a4a

Browse files
committed
more reviwers feedback
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent b5e2c71 commit cfe4a4a

4 files changed

Lines changed: 81 additions & 21 deletions

File tree

modelopt/torch/export/moe_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,11 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
9090
and w_quantizer._amax is not None
9191
and w_quantizer._amax.dim() >= 1
9292
):
93-
amax = w_quantizer._amax # CPU float32
93+
amax = w_quantizer._amax
9494
amax_dim0 = amax.shape[0]
95-
if amax_dim0 % fused_total == 0:
95+
if fused_total % amax_dim0 == 0:
9696
slice_start = fused_start * amax_dim0 // fused_total
9797
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
98-
# Bypass amax.setter (which forbids shape changes); w_quantizer is a
99-
# deepcopy for gate/up so mutating it is safe.
10098
w_quantizer._amax = amax[slice_start:slice_end].contiguous()
10199
else:
102100
warnings.warn(
@@ -114,6 +112,7 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
114112
hasattr(w_quantizer, "_amax")
115113
and w_quantizer._amax is not None
116114
and w_quantizer._amax.numel() > 1
115+
and (getattr(w_quantizer, "block_sizes", None) or {}).get(-1) is not None
117116
):
118117
amax_cpu = w_quantizer._amax
119118
invalid_mask = ~(

modelopt/torch/quantization/model_calib.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,12 @@ def mse_calibrate(
410410
quant_func=partial(_mse_quant_func, quantizer=module),
411411
)
412412

413-
# Identify weight quantizers by checking if they have corresponding weight parameters
413+
# Collect weight quantizers (standard + fused-experts per-expert lists).
414+
try:
415+
from modelopt.torch.quantization.plugins.huggingface import _QuantFusedExperts as _qfe_cls
416+
except ImportError:
417+
_qfe_cls = None # type: ignore[misc]
418+
414419
name_to_module = dict(model.named_modules())
415420
for parent_module in name_to_module.values():
416421
if parent_module in seen_modules:
@@ -421,22 +426,56 @@ def mse_calibrate(
421426
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
422427
if getattr(weight_quantizer, "_calibrator", None) is not None:
423428
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
424-
# _QuantFusedExperts stores per-expert weight quantizers as nn.ModuleList named
425-
# {param_name}_weight_quantizers (plural). Detect this pattern and enqueue each
426-
# per-expert quantizer individually. The isinstance(qlist, nn.ModuleList) +
427-
# isinstance(wq, TensorQuantizer) check below guards against false positives on
428-
# unrelated modules that happen to have similarly-named attributes.
429-
for param_name, _ in parent_module.named_parameters(recurse=False):
430-
qlist = getattr(parent_module, f"{param_name}_weight_quantizers", None)
431-
if not isinstance(qlist, nn.ModuleList):
432-
continue
433-
for expert_idx, wq in enumerate(qlist):
434-
if isinstance(wq, TensorQuantizer) and wq.is_enabled:
435-
if getattr(wq, "_calibrator", None) is not None:
436-
weight_quantizers.append((parent_module, (param_name, expert_idx), wq))
429+
# Enqueue per-expert quantizers from {param}_weight_quantizers ModuleLists.
430+
if _qfe_cls is not None and isinstance(parent_module, _qfe_cls):
431+
for param_name, param in parent_module.named_parameters(recurse=False):
432+
qlist = getattr(parent_module, f"{param_name}_weight_quantizers", None)
433+
if not isinstance(qlist, nn.ModuleList):
434+
continue
435+
if len(qlist) != param.shape[0]:
436+
warnings.warn(
437+
f"Skipping {param_name}_weight_quantizers: list length {len(qlist)} "
438+
f"does not match parameter leading dimension {param.shape[0]}. "
439+
"This may indicate a misconfigured fused-experts module.",
440+
stacklevel=2,
441+
)
442+
continue
443+
for expert_idx, wq in enumerate(qlist):
444+
if isinstance(wq, TensorQuantizer) and wq.is_enabled:
445+
if getattr(wq, "_calibrator", None) is not None:
446+
weight_quantizers.append((parent_module, (param_name, expert_idx), wq))
437447

438448
seen_modules.add(parent_module)
439449

450+
# Warn about enabled weight quantizers that weren't scheduled for MSE calibration.
451+
picked_ids = {id(wq) for _, _, wq in weight_quantizers}
452+
453+
def _is_active_unpicked(q: Any) -> bool:
454+
return (
455+
isinstance(q, TensorQuantizer)
456+
and q.is_enabled
457+
and getattr(q, "_calibrator", None) is not None
458+
and id(q) not in picked_ids
459+
)
460+
461+
missed: list[str] = []
462+
for mod_name, module in name_to_module.items():
463+
for attr_name, attr in module._modules.items():
464+
if isinstance(attr, TensorQuantizer) and attr_name.endswith("weight_quantizer"):
465+
if _is_active_unpicked(attr):
466+
missed.append(f"{mod_name}.{attr_name}")
467+
elif isinstance(attr, nn.ModuleList) and attr_name.endswith("_weight_quantizers"):
468+
for i, wq in enumerate(attr):
469+
if _is_active_unpicked(wq):
470+
missed.append(f"{mod_name}.{attr_name}[{i}]")
471+
if missed:
472+
warnings.warn(
473+
f"MSE weight calibration: {len(missed)} weight quantizer(s) are enabled but were "
474+
f"not scheduled for calibration and will retain max-calibration amax values. "
475+
f"First {min(5, len(missed))}: {missed[:5]}",
476+
stacklevel=2,
477+
)
478+
440479
# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
441480
# This prevents massive memory accumulation seen in large models
442481
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(

tests/unit/torch/quantization/plugins/test_fused_experts.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def _quantize_with_block_sizes(self):
713713

714714
@pytest.mark.parametrize("zero_amax", [False, True])
715715
def test_fallback_warning_emitted(self, zero_amax):
716-
"""Fallback warning must fire for uncalibrated (_amax=None) and zero-amax experts."""
716+
"""Fallback warning must fire and produce valid per-block _amax + global_amax."""
717717
import warnings
718718
from unittest.mock import patch
719719

@@ -725,8 +725,16 @@ def test_fallback_warning_emitted(self, zero_amax):
725725
converted.gate_up_proj_weight_quantizers[idx]._amax = bad_amax
726726
converted.down_proj_weight_quantizers[idx]._amax = bad_amax
727727

728+
captured_wrappers = []
729+
730+
def _capture(wrapper, dtype):
731+
captured_wrappers.append(wrapper)
732+
728733
with (
729-
patch("modelopt.torch.export.unified_export_hf._export_quantized_weight"),
734+
patch(
735+
"modelopt.torch.export.unified_export_hf._export_quantized_weight",
736+
side_effect=_capture,
737+
),
730738
warnings.catch_warnings(record=True) as caught,
731739
):
732740
warnings.simplefilter("always")
@@ -735,4 +743,17 @@ def test_fallback_warning_emitted(self, zero_amax):
735743
assert any("weight-derived per-block amax" in str(w.message) for w in caught), (
736744
f"No fallback warning emitted for {'zero' if zero_amax else 'None'} amax — Bug 3 regression"
737745
)
746+
747+
# Every per-block weight quantizer must have a repaired per-block _amax and global_amax.
748+
for wrapper in captured_wrappers:
749+
wq = wrapper.weight_quantizer
750+
if not (getattr(wq, "block_sizes", None) or {}).get(-1):
751+
continue
752+
assert wq._amax is not None and wq._amax.numel() > 1, (
753+
"Fallback did not produce per-block _amax"
754+
)
755+
assert hasattr(wq, "global_amax") and wq.global_amax > 0, (
756+
"global_amax missing or zero after fallback"
757+
)
758+
738759
self._cleanup_registry(expert_type)

tests/unit/torch/quantization/test_nvfp4_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def test_no_zero_scales_for_tiny_weights(self):
2929
"""Tiny per-block amax (<<FP8 min) must not underflow to zero after FP8 cast."""
3030
block_size = 16
3131
tiny_weight = torch.full((4, block_size), 1e-10)
32-
wsf2 = torch.tensor(1e-10 / (6.0 * 448.0))
32+
# wsf2=1.0 → per_block_scale = amax/(6*wsf2) ≈ 1.7e-11 << 2^-9, exercises FP8-min clamp
33+
wsf2 = torch.tensor(1.0)
3334

3435
per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(tiny_weight, block_size, wsf2)
3536
per_block_scale_f32 = per_block_scale.float()

0 commit comments

Comments
 (0)