@@ -465,10 +465,10 @@ def _make_weight_mse_calibrator(
465465 )
466466 if backend is not None and backend_factory is not None :
467467 if error_func is not None :
468- # Registered backends can 't take a custom error_func; skip Hessian refinement .
468+ # Registered backend factories don 't accept a custom error_func.
469469 warnings .warn (
470- f"local_hessian: backend '{ backend } ' does not support a custom error "
471- "function; skipping Hessian -weighted calibration for this quantizer."
470+ f"backend '{ backend } ' does not support a custom error function; skipping "
471+ "error- function-weighted MSE calibration for this quantizer."
472472 )
473473 return None
474474 return backend_factory (initial_amax , axis , quant_func )
@@ -670,6 +670,80 @@ def _warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, war
670670 _warn_if_block_size_mismatch (weight_quantizer , block_size , name )
671671
672672
673+ def _is_quant_fused_experts (module : nn .Module ) -> bool :
674+ """Whether ``module`` is a converted HF fused-MoE-experts wrapper with per-expert quantizers."""
675+ return hasattr (module , "_current_expert_idx" ) and hasattr (
676+ module , "gate_up_proj_weight_quantizers"
677+ )
678+
679+
680+ def _register_local_hessian_input_hooks (model , name_to_module , capture , block_size , warned ):
681+ """Register forward hooks feeding each weight's input activations to ``capture``.
682+
683+ Local-Hessian-specific (kept here rather than as a general ``QuantModule`` API): dense
684+ quantized linears hook the layer input; HF fused-MoE experts hook the shared input quantizers,
685+ keyed by the active expert (``_current_expert_idx``). Weights without a hook (conv,
686+ SequentialQuantizer, non-eager experts) fall back to plain MSE. Returns removable handles.
687+ """
688+ handles : list = []
689+
690+ def _make_expert_hook (expert_module , weight_name , quantizers , enabled ):
691+ def _expert_hook (_input_quantizer , args ):
692+ if not args :
693+ return
694+ idx = expert_module ._current_expert_idx
695+ if idx in enabled :
696+ # Read the weight fresh (valid under accelerate/FSDP re-materialization).
697+ capture (quantizers [idx ], getattr (expert_module , weight_name )[idx ], args [0 ])
698+
699+ return _expert_hook
700+
701+ for name , module in name_to_module .items ():
702+ if is_quantized_linear (module ) and isinstance (module .weight_quantizer , TensorQuantizer ):
703+ with enable_weight_access_and_writeback (module , model , name_to_module ):
704+ # ``weight`` may be absent (e.g. TE GroupedLinear exposes weight0..N, not weight);
705+ # such modules have no single 2-D weight to pair and fall back to plain MSE.
706+ weight = getattr (module , "weight" , None )
707+ if weight is None or weight .dim () != 2 or not module .weight_quantizer .is_enabled :
708+ continue
709+ _warn_local_hessian_fallback (
710+ name , weight , module .weight_quantizer , block_size , warned
711+ )
712+
713+ def _dense_hook (linear , args ):
714+ if args :
715+ capture (linear .weight_quantizer , linear .weight , args [0 ])
716+
717+ handles .append (module .register_forward_pre_hook (_dense_hook ))
718+ elif _is_quant_fused_experts (module ):
719+ with enable_weight_access_and_writeback (module , model , name_to_module ):
720+ for weight_name , quantizers_name , input_q_name in (
721+ (
722+ "gate_up_proj" ,
723+ "gate_up_proj_weight_quantizers" ,
724+ "gate_up_proj_input_quantizer" ,
725+ ),
726+ ("down_proj" , "down_proj_weight_quantizers" , "down_proj_input_quantizer" ),
727+ ):
728+ weight = getattr (module , weight_name , None )
729+ quantizers = getattr (module , quantizers_name , None )
730+ input_quantizer = getattr (module , input_q_name , None )
731+ if weight is None or quantizers is None or input_quantizer is None :
732+ continue
733+ _warn_local_hessian_fallback (
734+ f"{ name } .{ weight_name } " , weight [0 ], quantizers [0 ], block_size , warned
735+ )
736+ # Snapshot which experts are enabled now, before the caching forward silences
737+ # all weight quantizers — so we don't capture (and discard) disabled experts.
738+ enabled = {i for i , q in enumerate (quantizers ) if q .is_enabled }
739+ handles .append (
740+ input_quantizer .register_forward_pre_hook (
741+ _make_expert_hook (module , weight_name , quantizers , enabled )
742+ )
743+ )
744+ return handles
745+
746+
673747@torch .no_grad ()
674748def local_hessian_calibrate (
675749 model : nn .Module ,
@@ -731,53 +805,19 @@ def capture(weight_quantizer, weight, input_tensor):
731805 accumulators [id (weight_quantizer )] = acc
732806 acc .accumulate (input_local )
733807
734- # Phase 2: register capture hooks, disable weight fake-quant (input quantizers left as-is,
735- # matching prior behavior), run one forward to accumulate Hessians. Hooks live only for it.
736- handles : list = []
737- silenced_weight_quantizers : list [TensorQuantizer ] = []
808+ # Phase 2: capture each weight's input activations during a forward with weight fake-quant
809+ # disabled (so H = ΣXᵀX reflects full-precision weights); input quantizers are left as-is.
738810 warned : set = set ()
739- seen_modules : set [int ] = set ()
740- for name , module in name_to_module .items ():
741- if not isinstance (module , QuantModule ) or id (module ) in seen_modules :
742- continue
743- seen_modules .add (id (module ))
744- with enable_weight_access_and_writeback (module , model , name_to_module ):
745- captures = module .register_calibration_input_hooks (capture )
746- handles .extend (captures )
747- for weight , weight_quantizer in module .iter_weights_for_calibration ():
748- # Silence weight fake-quant (incl. SequentialQuantizer leaves) so the capture
749- # forward uses full-precision weights and downstream Hessians aren't corrupted.
750- leaves = (
751- list (weight_quantizer )
752- if isinstance (weight_quantizer , SequentialQuantizer )
753- else [weight_quantizer ]
754- )
755- silenced_weight_quantizers .extend (
756- q
757- for q in leaves
758- if isinstance (q , TensorQuantizer ) and q .is_enabled and q ._if_quant
759- )
760- # Only TensorQuantizer weights are refined (same as mse_calibrate); other types
761- # (e.g. SequentialQuantizer) are unsupported and left at their max-cal scale.
762- if not isinstance (weight_quantizer , TensorQuantizer ):
763- if weight_quantizer .is_enabled and "unsupported" not in warned :
764- warned .add ("unsupported" )
765- warn_rank_0 (
766- "local_hessian: only TensorQuantizer weights are calibrated; other "
767- "types (e.g. SequentialQuantizer) stay at their max-calibrated scale."
768- )
769- continue
770- if captures :
771- _warn_local_hessian_fallback (name , weight , weight_quantizer , block_size , warned )
772-
773- for weight_quantizer in silenced_weight_quantizers :
774- weight_quantizer .disable_quant ()
811+ handles = _register_local_hessian_input_hooks (
812+ model , name_to_module , capture , block_size , warned
813+ )
775814 print_rank_0 ("local_hessian: Caching activations and computing local Hessian..." )
776815 try :
777- forward_loop (model )
816+ with set_quantizer_by_cfg_context (
817+ model , [{"quantizer_name" : "*weight_quantizer" , "enable" : False }]
818+ ):
819+ forward_loop (model )
778820 finally :
779- for weight_quantizer in silenced_weight_quantizers :
780- weight_quantizer .enable_quant ()
781821 for handle in handles :
782822 handle .remove ()
783823
0 commit comments