6464 "max_calibrate" ,
6565 "smoothquant" ,
6666 "svdquant" ,
67+ "sync_grouped_weight_global_amax" ,
6768]
6869
70+
71+ # Sibling weight-quantizer name groups whose ``global_amax`` should share an
72+ # FP8 scale-of-scales. All members of a group sit under the same parent module
73+ # (e.g. one self-attention or one MLP block) and either consume the same input
74+ # tensor or get fused at deployment, so a divergent global_amax across siblings
75+ # would split their FP8 grids and skew the round.
76+ _GROUPED_WEIGHT_QUANTIZER_PATTERNS : tuple [tuple [str , ...], ...] = (
77+ # Standard self-attention (skipped for fused qkv_proj — single weight).
78+ ("q_proj" , "k_proj" , "v_proj" ),
79+ # Gated MLP, modern naming (Llama / Qwen / Mistral / etc.).
80+ ("gate_proj" , "up_proj" ),
81+ # Gated MLP, older Mixtral-style naming.
82+ ("w1" , "w3" ),
83+ )
84+
85+
86+ def _is_calibrated_nvfp4_static_weight_quantizer (q ) -> bool :
87+ """True for an NVFP4-static weight quantizer that ``max_calibrate`` already
88+ populated with a per-block ``_amax`` and that is currently enabled.
89+ """
90+ return (
91+ isinstance (q , TensorQuantizer )
92+ and not q ._disabled
93+ and q .is_nvfp4_static
94+ and hasattr (q , "_amax" )
95+ and q ._amax is not None
96+ )
97+
98+
99+ def _collect_grouped_linears (model : nn .Module ) -> list [list [nn .Module ]]:
100+ """Find groups of Linear-like submodules whose NVFP4-static weight quantizers
101+ should share ``global_amax`` (Q/K/V under one attention parent; gate/up under
102+ one MLP parent).
103+ """
104+ groups : list [list [nn .Module ]] = []
105+ wq_attr = quantizer_attr_names ("weight" ).weight_quantizer
106+ for parent in model .modules ():
107+ for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS :
108+ members : list [nn .Module ] = []
109+ for n in sibling_names :
110+ child = getattr (parent , n , None )
111+ if child is None :
112+ continue
113+ wq = getattr (child , wq_attr , None )
114+ if _is_calibrated_nvfp4_static_weight_quantizer (wq ):
115+ members .append (child )
116+ if len (members ) >= 2 :
117+ groups .append (members )
118+ return groups
119+
120+
121+ @torch .no_grad ()
122+ def sync_grouped_weight_global_amax (model : nn .Module ) -> int :
123+ """Sync ``global_amax`` across sibling NVFP4-static weight quantizers.
124+
125+ For each group of siblings (Q/K/V projections under one attention parent;
126+ gate/up — a.k.a. ``w1``/``w3`` — under one MLP parent) unifies the
127+ NVFP4 ``global_amax`` so the per-block FP8 round picks scales against a
128+ consistent FP8 grid across the group during MSE / local-Hessian search.
129+
130+ Reuses :func:`modelopt.torch.export.quant_utils.preprocess_linear_fusion`
131+ (whose ``NVFP4StaticQuantizer`` branch performs the same
132+ ``max(stack(global_amax))`` unification at export time). To call it before
133+ MSE, this helper first promotes each grouped weight quantizer to
134+ :class:`NVFP4StaticQuantizer` with its local ``global_amax`` (=
135+ ``reduce_amax(_amax)``); ``preprocess_linear_fusion`` then unifies in
136+ place.
137+
138+ Must be called after ``max_calibrate`` has populated each weight
139+ quantizer's ``_amax``. Idempotent. Returns the number of groups synced.
140+ """
141+ from modelopt .torch .export .quant_utils import preprocess_linear_fusion
142+
143+ n_groups = 0
144+ for group in _collect_grouped_linears (model ):
145+ # Promote each member's weight quantizer so `preprocess_linear_fusion`
146+ # sees post-conversion NVFP4StaticQuantizers (its NVFP4 branch reads
147+ # `global_amax`, which only exists post-promotion).
148+ wq_attr = quantizer_attr_names ("weight" ).weight_quantizer
149+ for child in group :
150+ wq = getattr (child , wq_attr )
151+ if not isinstance (wq , NVFP4StaticQuantizer ):
152+ local_global = reduce_amax (wq ._amax , axis = None )
153+ NVFP4StaticQuantizer .from_tensor_quantizer (wq , global_amax = local_global )
154+ preprocess_linear_fusion (group )
155+ n_groups += 1
156+ return n_groups
157+
158+
69159CalibratorFactory : TypeAlias = Callable [
70160 [torch .Tensor , int | tuple | list | None , Callable [..., torch .Tensor ]], _Calibrator
71161]
@@ -349,6 +439,13 @@ def mse_calibrate(
349439 # Step 1: First get initial amax using max calibration
350440 max_calibrate (model , forward_loop , distributed_sync )
351441
442+ # Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers
443+ # (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one
444+ # MLP block) so their FP8 scale-of-scales matches and the per-block FP8
445+ # round uses a consistent grid. No-op when there are no sibling groups
446+ # (e.g. fused QKV / fused gate_up_proj).
447+ sync_grouped_weight_global_amax (model )
448+
352449 # Step 2: Replace calibrators with MseCalibrator for enabled quantizers
353450 # and identify weight quantizers
354451 weight_quantizers = []
@@ -360,19 +457,16 @@ def mse_calibrate(
360457 # Get the initial amax from max calibration
361458 initial_amax = module ._amax .clone ().detach ()
362459
363- is_nvfp4_static = (
364- module .is_static_block_quant
365- and module ._num_bits == (2 , 1 )
366- and module ._block_sizes is not None
367- and module ._block_sizes .get ("scale_bits" ) == (4 , 3 )
368- )
460+ is_nvfp4_static = module .is_nvfp4_static
369461
370462 if is_nvfp4_static :
371- # Compute and set global_amax
372- global_amax = reduce_amax (initial_amax , axis = None )
373-
374- # Convert to NVFP4StaticQuantizer in-place
375- NVFP4StaticQuantizer .from_tensor_quantizer (module , global_amax = global_amax )
463+ # If sync_grouped_weight_global_amax already promoted this
464+ # quantizer (it's a sibling in a Q/K/V or gate/up group),
465+ # its global_amax has been unified across the group; just
466+ # leave it. Otherwise convert + set local global_amax.
467+ if not isinstance (module , NVFP4StaticQuantizer ):
468+ global_amax = reduce_amax (initial_amax , axis = None )
469+ NVFP4StaticQuantizer .from_tensor_quantizer (module , global_amax = global_amax )
376470
377471 if fp8_scale_sweep :
378472 # Check if backend has a registered custom calibrator factory.
@@ -612,6 +706,11 @@ def forward(self, input, *args, **kwargs):
612706 print_rank_0 ("local_hessian: Running max calibration for all quantizers..." )
613707 max_calibrate (model , forward_loop , distributed_sync )
614708
709+ # Sync global_amax across sibling NVFP4-static weight quantizers
710+ # (q/k/v_proj, gate/up_proj a.k.a. w1/w3) so the FP8 scale-of-scales
711+ # is consistent across the group. Idempotent; no-op when fused.
712+ sync_grouped_weight_global_amax (model )
713+
615714 # Setup helpers for all quantized linear modules
616715 name_to_module = dict (model .named_modules ())
617716 weight_quantizers_info = []
@@ -666,14 +765,9 @@ def quant_func(x, amax, quantizer=weight_quantizer):
666765
667766 return xq
668767
669- is_nvfp4_static = (
670- weight_quantizer .is_static_block_quant
671- and weight_quantizer ._num_bits == (2 , 1 )
672- and weight_quantizer ._block_sizes is not None
673- and weight_quantizer ._block_sizes .get ("scale_bits" ) == (4 , 3 )
674- )
768+ is_nvfp4_static = weight_quantizer .is_nvfp4_static
675769
676- if is_nvfp4_static :
770+ if is_nvfp4_static and not isinstance ( weight_quantizer , NVFP4StaticQuantizer ) :
677771 global_amax = reduce_amax (initial_amax , axis = None )
678772 NVFP4StaticQuantizer .from_tensor_quantizer (weight_quantizer , global_amax = global_amax )
679773
0 commit comments