Skip to content

Commit 200ae6f

Browse files
committed
Inline local-Hessian activation capture; drop the QuantModule hook API
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent 7eaec3d commit 200ae6f

5 files changed

Lines changed: 121 additions & 154 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,10 @@ def _make_weight_mse_calibrator(
465465
)
466466
if backend is not None and backend_factory is not None:
467467
if error_func is not None:
468-
# Registered backends can't take a custom error_func; skip Hessian refinement.
468+
# Registered backend factories don't accept a custom error_func.
469469
warnings.warn(
470-
f"local_hessian: backend '{backend}' does not support a custom error "
471-
"function; skipping Hessian-weighted calibration for this quantizer."
470+
f"backend '{backend}' does not support a custom error function; skipping "
471+
"error-function-weighted MSE calibration for this quantizer."
472472
)
473473
return None
474474
return backend_factory(initial_amax, axis, quant_func)
@@ -670,6 +670,80 @@ def _warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, war
670670
_warn_if_block_size_mismatch(weight_quantizer, block_size, name)
671671

672672

673+
def _is_quant_fused_experts(module: nn.Module) -> bool:
674+
"""Whether ``module`` is a converted HF fused-MoE-experts wrapper with per-expert quantizers."""
675+
return hasattr(module, "_current_expert_idx") and hasattr(
676+
module, "gate_up_proj_weight_quantizers"
677+
)
678+
679+
680+
def _register_local_hessian_input_hooks(model, name_to_module, capture, block_size, warned):
681+
"""Register forward hooks feeding each weight's input activations to ``capture``.
682+
683+
Local-Hessian-specific (kept here rather than as a general ``QuantModule`` API): dense
684+
quantized linears hook the layer input; HF fused-MoE experts hook the shared input quantizers,
685+
keyed by the active expert (``_current_expert_idx``). Weights without a hook (conv,
686+
SequentialQuantizer, non-eager experts) fall back to plain MSE. Returns removable handles.
687+
"""
688+
handles: list = []
689+
690+
def _make_expert_hook(expert_module, weight_name, quantizers, enabled):
691+
def _expert_hook(_input_quantizer, args):
692+
if not args:
693+
return
694+
idx = expert_module._current_expert_idx
695+
if idx in enabled:
696+
# Read the weight fresh (valid under accelerate/FSDP re-materialization).
697+
capture(quantizers[idx], getattr(expert_module, weight_name)[idx], args[0])
698+
699+
return _expert_hook
700+
701+
for name, module in name_to_module.items():
702+
if is_quantized_linear(module) and isinstance(module.weight_quantizer, TensorQuantizer):
703+
with enable_weight_access_and_writeback(module, model, name_to_module):
704+
# ``weight`` may be absent (e.g. TE GroupedLinear exposes weight0..N, not weight);
705+
# such modules have no single 2-D weight to pair and fall back to plain MSE.
706+
weight = getattr(module, "weight", None)
707+
if weight is None or weight.dim() != 2 or not module.weight_quantizer.is_enabled:
708+
continue
709+
_warn_local_hessian_fallback(
710+
name, weight, module.weight_quantizer, block_size, warned
711+
)
712+
713+
def _dense_hook(linear, args):
714+
if args:
715+
capture(linear.weight_quantizer, linear.weight, args[0])
716+
717+
handles.append(module.register_forward_pre_hook(_dense_hook))
718+
elif _is_quant_fused_experts(module):
719+
with enable_weight_access_and_writeback(module, model, name_to_module):
720+
for weight_name, quantizers_name, input_q_name in (
721+
(
722+
"gate_up_proj",
723+
"gate_up_proj_weight_quantizers",
724+
"gate_up_proj_input_quantizer",
725+
),
726+
("down_proj", "down_proj_weight_quantizers", "down_proj_input_quantizer"),
727+
):
728+
weight = getattr(module, weight_name, None)
729+
quantizers = getattr(module, quantizers_name, None)
730+
input_quantizer = getattr(module, input_q_name, None)
731+
if weight is None or quantizers is None or input_quantizer is None:
732+
continue
733+
_warn_local_hessian_fallback(
734+
f"{name}.{weight_name}", weight[0], quantizers[0], block_size, warned
735+
)
736+
# Snapshot which experts are enabled now, before the caching forward silences
737+
# all weight quantizers — so we don't capture (and discard) disabled experts.
738+
enabled = {i for i, q in enumerate(quantizers) if q.is_enabled}
739+
handles.append(
740+
input_quantizer.register_forward_pre_hook(
741+
_make_expert_hook(module, weight_name, quantizers, enabled)
742+
)
743+
)
744+
return handles
745+
746+
673747
@torch.no_grad()
674748
def local_hessian_calibrate(
675749
model: nn.Module,
@@ -731,53 +805,19 @@ def capture(weight_quantizer, weight, input_tensor):
731805
accumulators[id(weight_quantizer)] = acc
732806
acc.accumulate(input_local)
733807

734-
# Phase 2: register capture hooks, disable weight fake-quant (input quantizers left as-is,
735-
# matching prior behavior), run one forward to accumulate Hessians. Hooks live only for it.
736-
handles: list = []
737-
silenced_weight_quantizers: list[TensorQuantizer] = []
808+
# Phase 2: capture each weight's input activations during a forward with weight fake-quant
809+
# disabled (so H = ΣXᵀX reflects full-precision weights); input quantizers are left as-is.
738810
warned: set = set()
739-
seen_modules: set[int] = set()
740-
for name, module in name_to_module.items():
741-
if not isinstance(module, QuantModule) or id(module) in seen_modules:
742-
continue
743-
seen_modules.add(id(module))
744-
with enable_weight_access_and_writeback(module, model, name_to_module):
745-
captures = module.register_calibration_input_hooks(capture)
746-
handles.extend(captures)
747-
for weight, weight_quantizer in module.iter_weights_for_calibration():
748-
# Silence weight fake-quant (incl. SequentialQuantizer leaves) so the capture
749-
# forward uses full-precision weights and downstream Hessians aren't corrupted.
750-
leaves = (
751-
list(weight_quantizer)
752-
if isinstance(weight_quantizer, SequentialQuantizer)
753-
else [weight_quantizer]
754-
)
755-
silenced_weight_quantizers.extend(
756-
q
757-
for q in leaves
758-
if isinstance(q, TensorQuantizer) and q.is_enabled and q._if_quant
759-
)
760-
# Only TensorQuantizer weights are refined (same as mse_calibrate); other types
761-
# (e.g. SequentialQuantizer) are unsupported and left at their max-cal scale.
762-
if not isinstance(weight_quantizer, TensorQuantizer):
763-
if weight_quantizer.is_enabled and "unsupported" not in warned:
764-
warned.add("unsupported")
765-
warn_rank_0(
766-
"local_hessian: only TensorQuantizer weights are calibrated; other "
767-
"types (e.g. SequentialQuantizer) stay at their max-calibrated scale."
768-
)
769-
continue
770-
if captures:
771-
_warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, warned)
772-
773-
for weight_quantizer in silenced_weight_quantizers:
774-
weight_quantizer.disable_quant()
811+
handles = _register_local_hessian_input_hooks(
812+
model, name_to_module, capture, block_size, warned
813+
)
775814
print_rank_0("local_hessian: Caching activations and computing local Hessian...")
776815
try:
777-
forward_loop(model)
816+
with set_quantizer_by_cfg_context(
817+
model, [{"quantizer_name": "*weight_quantizer", "enable": False}]
818+
):
819+
forward_loop(model)
778820
finally:
779-
for weight_quantizer in silenced_weight_quantizers:
780-
weight_quantizer.enable_quant()
781821
for handle in handles:
782822
handle.remove()
783823

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import contextlib
1919
import warnings
20-
from collections.abc import Callable
2120
from typing import Any
2221

2322
import torch
@@ -128,17 +127,6 @@ def iter_weights_for_calibration(self):
128127
weight_quantizer = getattr(self, quantizer_attr_names(weight_name).weight_quantizer)
129128
yield getattr(self, weight_name), weight_quantizer
130129

131-
def register_calibration_input_hooks(
132-
self, callback: Callable[[TensorQuantizer, torch.Tensor, torch.Tensor], None]
133-
) -> list:
134-
"""Register forward hooks calling ``callback(weight_quantizer, weight, input)`` per weight.
135-
136-
Activation-side counterpart to :meth:`iter_weights_for_calibration`, used by
137-
activation-aware calibration (e.g. local-Hessian). Returns removable handles; the base
138-
default is ``[]`` (no pairing available -> plain weight calibration). Override per module.
139-
"""
140-
return []
141-
142130
def fold_weight(self, keep_attrs: bool = False):
143131
"""Fold the weight for faster eval."""
144132
# Handle all attributes that end with _weight_quantizer
@@ -259,27 +247,6 @@ def _setup(self):
259247
self._register_temp_attribute("_enable_weight_quantization", False)
260248
self._register_dynamic_attribute("weight", self._get_quantized_weight)
261249

262-
def register_calibration_input_hooks(self, callback):
263-
"""Pair the weight quantizer with the forward input.
264-
265-
Only a 2-D weight with an enabled ``TensorQuantizer`` is hooked; conv (4-D) and
266-
``SequentialQuantizer`` weights are unsupported and fall back to plain calibration.
267-
"""
268-
weight = getattr(self, "weight", None)
269-
if (
270-
weight is None
271-
or weight.dim() != 2
272-
or not isinstance(self.weight_quantizer, TensorQuantizer)
273-
or not self.weight_quantizer.is_enabled
274-
):
275-
return []
276-
277-
def _pre_hook(module, args):
278-
if args:
279-
callback(module.weight_quantizer, module.weight, args[0])
280-
281-
return [self.register_forward_pre_hook(_pre_hook)]
282-
283250

284251
class _LegacyQuantInputBaseMixin:
285252
"""A mixin to support legacy quantized modules which needs to have an __init__ method."""

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -918,36 +918,6 @@ def iter_weights_for_calibration(self):
918918
for idx, q in enumerate(quantizers):
919919
yield weight[idx], q
920920

921-
def register_calibration_input_hooks(self, callback):
922-
"""Pair each per-expert weight quantizer with its routed input activation.
923-
924-
Hooks the shared input quantizers, which the eager ``F.linear`` path calls per expert
925-
while ``_current_expert_idx`` is set. Batched/grouped kernels never call them, so those
926-
experts get no capture (fall back to plain weight calibration).
927-
"""
928-
handles = []
929-
for weight_name, quantizers_name, input_quantizer_name in (
930-
("gate_up_proj", "gate_up_proj_weight_quantizers", "gate_up_proj_input_quantizer"),
931-
("down_proj", "down_proj_weight_quantizers", "down_proj_input_quantizer"),
932-
):
933-
weight = getattr(self, weight_name, None)
934-
quantizers = getattr(self, quantizers_name, None)
935-
input_quantizer = getattr(self, input_quantizer_name, None)
936-
if weight is None or quantizers is None or input_quantizer is None:
937-
continue
938-
939-
def _pre_hook(_iq, args, _weight_name=weight_name, _quantizers=quantizers):
940-
if not args:
941-
return
942-
idx = self._current_expert_idx
943-
weight_quantizer = _quantizers[idx]
944-
if weight_quantizer.is_enabled:
945-
# Read the weight fresh (valid under accelerate/FSDP re-materialization).
946-
callback(weight_quantizer, getattr(self, _weight_name)[idx], args[0])
947-
948-
handles.append(input_quantizer.register_forward_pre_hook(_pre_hook))
949-
return handles
950-
951921
def fold_weight(self, keep_attrs: bool = False):
952922
"""Fold per-expert weight quantizers into the fused 3-D weights.
953923

tests/unit/torch/quantization/plugins/test_fused_experts.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,8 @@ def forward_loop(m):
652652

653653
self._cleanup_registry(expert_type)
654654

655-
def test_local_hessian_per_expert_capture_and_refinement(self):
656-
"""The plugin's extension point pairs each per-expert weight quantizer with its routed
657-
input, and local_hessian uses that to refine every expert's weight amax."""
655+
def test_local_hessian_refines_per_expert_weights(self):
656+
"""local_hessian captures each expert's routed activations and refines its weight amax."""
658657
model = _TinyMoEModel()
659658
expert_type = type(model.moe.experts)
660659
self._cleanup_registry(expert_type)
@@ -679,28 +678,25 @@ def forward_loop(m):
679678
expert_quantizers = list(experts.gate_up_proj_weight_quantizers) + list(
680679
experts.down_proj_weight_quantizers
681680
)
682-
683-
# Extension point captures per-expert (weight_quantizer, weight_slice, cin).
684-
captured = []
685-
handles = experts.register_calibration_input_hooks(
686-
lambda wq, w, x: captured.append((id(wq), tuple(w.shape), x.shape[-1]))
687-
)
688-
assert len(handles) == 2 # one pre-hook per shared input quantizer (gate_up, down)
689-
with torch.no_grad():
690-
model(torch.randn(1, 8, HIDDEN_DIM))
691-
for h in handles:
692-
h.remove()
693-
valid_ids = {id(q) for q in expert_quantizers}
694-
shapes = {(2 * INTERMEDIATE_DIM, HIDDEN_DIM), (HIDDEN_DIM, INTERMEDIATE_DIM)}
695-
assert captured and all(
696-
wq_id in valid_ids and shape in shapes and cin == shape[1]
697-
for wq_id, shape, cin in captured
698-
)
699-
700-
# End-to-end: local_hessian refines per-expert weight amax via that capture.
701681
max_amax = {id(q): q.amax.clone() for q in expert_quantizers if q.amax is not None}
682+
# Expected (cout, cin) keyed by quantizer id, to verify each Hessian pairs with its
683+
# own expert's weight slice (catches gate_up/down swaps and stale-index mis-pairing).
684+
expected_shape = {}
685+
for quantizers, weight in (
686+
(experts.gate_up_proj_weight_quantizers, experts.gate_up_proj),
687+
(experts.down_proj_weight_quantizers, experts.down_proj),
688+
):
689+
for i, q in enumerate(quantizers):
690+
expected_shape[id(q)] = (weight[i].shape[0], weight[i].shape[1])
691+
702692
local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False, debug=True)
703-
assert any(a.num_samples > 0 for a in model._local_hessian_accumulators.values())
693+
694+
# Each captured Hessian is keyed to a real per-expert quantizer with the matching weight
695+
# shape, spans multiple distinct experts, and the refinement moved at least one amax.
696+
routed = {qid: a for qid, a in model._local_hessian_accumulators.items() if a.num_samples}
697+
assert len(routed) >= 2, "expected multiple distinct experts to capture Hessians"
698+
for qid, acc in routed.items():
699+
assert (acc.cout, acc.cin) == expected_shape[qid]
704700
assert all(q.amax is not None and torch.isfinite(q.amax).all() for q in expert_quantizers)
705701
assert any(
706702
id(q) in max_amax and not torch.allclose(q.amax, max_amax[id(q)])

tests/unit/torch/quantization/test_local_hessian.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -158,36 +158,30 @@ def test_no_forward_loop_is_skipped(self):
158158
assert all(torch.equal(before[n], a) for n, a in _weight_amaxes(model).items())
159159

160160

161-
class TestActivationCaptureExtensionPoint:
162-
"""The extension point that decouples local-Hessian capture from module type."""
161+
class TestLocalHessianFallbacks:
162+
"""Weights local-Hessian can't pair with an input fall back to plain MSE (no Hessian)."""
163163

164-
def test_dense_captures_and_conv_falls_back(self):
164+
def test_conv_weight_falls_back_without_crash(self):
165165
torch.manual_seed(0)
166-
model = SimpleLinear()
167-
mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop())
168-
captured = []
169-
handles = model.net[0].register_calibration_input_hooks(
170-
lambda wq, w, x: captured.append((tuple(w.shape), x.shape[-1]))
171-
)
172-
assert len(handles) == 1
173-
with torch.no_grad():
174-
model(torch.randn(2, 16))
175-
for h in handles:
176-
h.remove()
177-
assert captured and captured[0] == ((32, 16), 16) # cin from activation matches weight
178-
179-
conv = SimpleConv()
180-
mtq.quantize(conv, INT8_WEIGHT_CFG, forward_loop=lambda m: m(SimpleConv.get_input()))
181-
assert conv.net[0].register_calibration_input_hooks(lambda *a: None) == [] # 4-D weight
182-
183-
def test_sequential_quantizer_weight_not_hooked(self):
166+
model = SimpleConv() # 4-D conv weights — no single 2-D weight to pair
167+
forward_loop = lambda m: m(SimpleConv.get_input()) # noqa: E731
168+
mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=forward_loop)
169+
local_hessian_calibrate(model, forward_loop, fp8_scale_sweep=False, debug=True)
170+
conv = model.net[0]
171+
assert id(conv.weight_quantizer) not in model._local_hessian_accumulators
172+
assert conv.weight_quantizer.amax is not None # still calibrated via plain MSE
173+
174+
def test_sequential_quantizer_weight_falls_back_without_crash(self):
184175
torch.manual_seed(0)
185176
model = SimpleLinear()
186177
mtq.quantize(model, INT8_WEIGHT_CFG, forward_loop=_make_forward_loop())
187178
linear = model.net[0]
188179
linear.weight_quantizer = SequentialQuantizer(TensorQuantizer(), TensorQuantizer())
189-
assert linear.register_calibration_input_hooks(lambda *a: None) == [] # unsupported
180+
local_hessian_calibrate(model, _make_forward_loop(), fp8_scale_sweep=False, debug=True)
181+
assert id(linear.weight_quantizer) not in model._local_hessian_accumulators
182+
190183

184+
class TestBlockSizeMismatchWarning:
191185
def test_block_size_mismatch_warns_only_on_mismatch(self):
192186
def q(block):
193187
return TensorQuantizer(

0 commit comments

Comments
 (0)