11import inspect
2+ import math
23from itertools import product
34from typing import Iterator , Literal
45
@@ -84,6 +85,10 @@ class AWQModifier(Modifier, QuantizationMixin):
8485 # This change is only useful for MoE models with parallel transformer blocks,
8586 # and one should use the default value (None) in most cases.
8687 ignore: ["lm_head"]
88+ # Set search_observer to match the observer used in the final QuantizationModifier
89+ # for best scale alignment. "memoryless_mse" is recommended when pairing with
90+ # MSE-based weight quantization. Defaults to "memoryless_minmax".
91+ search_observer: "memoryless_minmax"
8792 config_groups:
8893 group_0:
8994 targets:
@@ -151,6 +156,21 @@ class AWQModifier(Modifier, QuantizationMixin):
151156 this specifies how many grid points should be used. To decrease the runtime,
152157 at the possible cost of slightly worse scales, this can be decreased.
153158 Defaults to 20
159+ :param search_observer: name of the observer used to simulate quantization during
160+ the grid search. For best accuracy, this should match the observer used in the
161+ final QuantizationModifier (e.g. "memoryless_mse" when using MSE-based weight
162+ quantization). Defaults to "memoryless_minmax" for backward compatibility.
163+ Valid options: "memoryless_minmax", "memoryless_mse"
164+ :param n_shrink_grid: number of shrinkage grid points to jointly search over
165+ alongside the scale grid. When > 1, for each scale candidate (α) the grid
166+ search also sweeps over shrink factors p ∈ (1 - maxshrink, 1], selecting the
167+ (α, p) pair that minimizes output MSE jointly. This implements the joint
168+ scale + shrinkage optimization described in issue #2479.
169+ Defaults to 1 (disabled, shrinkage determined solely by the observer).
170+ Recommended value: 5-10 for a lightweight joint search.
171+ :param maxshrink: maximum shrinkage factor used when n_shrink_grid > 1.
172+ Shrink factors are swept from 1.0 down to (1 - maxshrink).
173+ Defaults to 0.20 (matching the memoryless_mse observer default).
154174 """
155175
156176 # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
@@ -162,6 +182,9 @@ class AWQModifier(Modifier, QuantizationMixin):
162182 offload_device : torch .device | None | Sentinel = Sentinel ("not_provided" )
163183 duo_scaling : bool | Literal ["both" ] = True
164184 n_grid : int = 20
185+ search_observer : str = "memoryless_minmax"
186+ n_shrink_grid : int = 1
187+ maxshrink : float = 0.20
165188
166189 # Private vars set during initialization, cleared during finalization
167190 _resolved_mappings : list [ResolvedMapping ] = PrivateAttr (default_factory = list )
@@ -673,8 +696,9 @@ def _compute_best_scale(
673696 n_grid = self .n_grid
674697 duo_scalings = [self .duo_scaling ]
675698
676- # Where appropriate, replace observers with memoryless_minmax
677- # for duration of grid search
699+ # Where appropriate, replace observers with search_observer
700+ # for duration of grid search. This should match the observer
701+ # used in the final QuantizationModifier for best scale alignment.
678702 balance_layers_to_patch = [
679703 balance_layer
680704 for balance_layer in mapping .balance_layers
@@ -686,7 +710,7 @@ def _compute_best_scale(
686710 "weight_observer" ,
687711 [
688712 Observer .load_from_registry (
689- "memoryless_minmax" ,
713+ self . search_observer ,
690714 base_name = "weight" ,
691715 args = balance_layer .quantization_scheme .weights ,
692716 module = balance_layer ,
@@ -717,65 +741,103 @@ def _compute_best_scale(
717741 scales [torch .isnan (scales )] = 1
718742 _scalesview = scales .view (1 , - 1 ).to (device )
719743
720- # Q(W * s)
721- for balance_layer in balance_layers_to_patch :
722- if not hasattr (balance_layer , "quantization_scheme" ) or not hasattr (
723- balance_layer .quantization_scheme , "weights"
724- ):
725- continue
726-
727- w_qscheme = balance_layer .quantization_scheme .weights
728- balance_layer .weight .data .copy_ (
729- orig_layer_weights [balance_layer ].to (_scalesview .device )
730- * _scalesview
731- )
744+ # Build shrink factor candidates:
745+ # n_shrink_grid=1 means no shrinkage search (p=1.0, minmax behaviour)
746+ # n_shrink_grid>1 sweeps p from 1.0 down to (1 - maxshrink)
747+ if self .n_shrink_grid > 1 :
748+ shrink_factors = [
749+ 1.0 - (self .maxshrink * i / (self .n_shrink_grid - 1 ))
750+ for i in range (self .n_shrink_grid )
751+ ]
752+ else :
753+ shrink_factors = [1.0 ]
754+
755+ for shrink_p in shrink_factors :
756+ # Q(W * s) with optional shrinkage applied to observer range
757+ for balance_layer in balance_layers_to_patch :
758+ if not hasattr (balance_layer , "quantization_scheme" ) or not hasattr (
759+ balance_layer .quantization_scheme , "weights"
760+ ):
761+ continue
762+
763+ w_qscheme = balance_layer .quantization_scheme .weights
764+ balance_layer .weight .data .copy_ (
765+ orig_layer_weights [balance_layer ].to (_scalesview .device )
766+ * _scalesview
767+ )
732768
733- should_calculate_gparam = (
734- w_qscheme .strategy == QuantizationStrategy .TENSOR_GROUP
735- )
736- call_observer (
737- balance_layer ,
738- "weight" ,
739- balance_layer .weight ,
740- should_calculate_gparam = should_calculate_gparam ,
741- )
742- balance_layer .weight .data = (
743- forward_quantize (
744- balance_layer ,
745- balance_layer .weight ,
746- "weight" ,
747- w_qscheme ,
769+ should_calculate_gparam = (
770+ w_qscheme .strategy == QuantizationStrategy .TENSOR_GROUP
748771 )
749- / _scalesview
750- ).to (balance_layer .weight .dtype )
751-
752- # Apply fused global scales for TENSOR_GROUP during grid search
753- # to match inference behavior
754- if balance_layers_to_patch and all (
755- getattr (layer .quantization_scheme .weights , "strategy" , None )
756- == QuantizationStrategy .TENSOR_GROUP
757- for layer in balance_layers_to_patch
758- ):
759- update_fused_layer_weight_global_scales (mapping .parent )
760772
761- # W * X
762- int_w_outputs = self ._run_samples (mapping .parent )
773+ if shrink_p < 1.0 :
774+ # Joint shrinkage: override observer min/max with shrunk range
775+ # using output MSE as the objective (not weight MSE)
776+ w = balance_layer .weight .data
777+ w_min = w .amin (dim = - 1 , keepdim = True ) * shrink_p
778+ w_max = w .amax (dim = - 1 , keepdim = True ) * shrink_p
779+ from compressed_tensors .quantization .utils import calculate_qparams
780+ scale , zp = calculate_qparams (
781+ min_vals = w_min .squeeze (- 1 ),
782+ max_vals = w_max .squeeze (- 1 ),
783+ quantization_args = w_qscheme ,
784+ )
785+ # store directly into the layer's weight_scale / weight_zero_point
786+ from compressed_tensors .utils import update_parameter_data
787+ update_parameter_data (balance_layer , scale , "weight_scale" )
788+ update_parameter_data (balance_layer , zp , "weight_zero_point" )
789+ else :
790+ call_observer (
791+ balance_layer ,
792+ "weight" ,
793+ balance_layer .weight ,
794+ should_calculate_gparam = should_calculate_gparam ,
795+ )
796+
797+ balance_layer .weight .data = (
798+ forward_quantize (
799+ balance_layer ,
800+ balance_layer .weight ,
801+ "weight" ,
802+ w_qscheme ,
803+ )
804+ / _scalesview
805+ ).to (balance_layer .weight .dtype )
806+
807+ # Apply fused global scales for TENSOR_GROUP during grid search
808+ # to match inference behavior
809+ if balance_layers_to_patch and all (
810+ getattr (layer .quantization_scheme .weights , "strategy" , None )
811+ == QuantizationStrategy .TENSOR_GROUP
812+ for layer in balance_layers_to_patch
813+ ):
814+ update_fused_layer_weight_global_scales (mapping .parent )
815+
816+ # W * X
817+ int_w_outputs = self ._run_samples (mapping .parent )
763818
764- # compute mean squared error (L2 norm)
765- loss = self ._compute_loss (fp16_outputs , int_w_outputs )
766- del int_w_outputs
819+ # compute mean squared error (L2 norm)
820+ loss = self ._compute_loss (fp16_outputs , int_w_outputs )
821+ del int_w_outputs
767822
768- if initial_error is None :
769- initial_error = loss
823+ # skip non-finite losses (can occur with aggressive MSE clipping)
824+ if not math .isfinite (loss ):
825+ history .append (
826+ {"ratio" : ratio , "shrink" : shrink_p , "duo_scaling" : use_duo_scaling , "error" : loss }
827+ )
828+ continue
770829
771- history .append (
772- {"ratio" : ratio , "duo_scaling" : use_duo_scaling , "error" : loss }
773- )
774- if loss < best_error :
775- best_error = loss
776- best_ratio = ratio
777- best_scales = scales .clone ()
778- pbar .set_postfix ({"best_error" : f"{ best_error :.3e} " })
830+ if initial_error is None :
831+ initial_error = loss
832+
833+ history .append (
834+ {"ratio" : ratio , "shrink" : shrink_p , "duo_scaling" : use_duo_scaling , "error" : loss }
835+ )
836+ if loss < best_error :
837+ best_error = loss
838+ best_ratio = ratio
839+ best_scales = scales .clone ()
840+ pbar .set_postfix ({"best_error" : f"{ best_error :.3e} " })
779841
780842 if best_ratio == - 1 :
781843 logger .debug (history )
@@ -985,6 +1047,19 @@ def validate_duo_scaling(cls, v):
9851047 raise ValueError (f"duo_scaling must be True, False, or 'both', got { v !r} " )
9861048 return v
9871049
1050+ @field_validator ("search_observer" )
1051+ @classmethod
1052+ def validate_search_observer (cls , v ):
1053+ """Validate that search_observer is a memoryless observer (stateless per call)"""
1054+ valid = {"memoryless_minmax" , "memoryless_mse" }
1055+ if v not in valid :
1056+ raise ValueError (
1057+ f"search_observer must be one of { valid } , got { v !r} . "
1058+ "Only memoryless observers are supported to avoid accumulating "
1059+ "statistics across grid search iterations."
1060+ )
1061+ return v
1062+
9881063
9891064def _check_layers_are_compatible (
9901065 smooth_layer , smooth_name , balance_layers , balance_names
@@ -1053,4 +1128,4 @@ def _allreduce_data_sum(data: list[torch.Tensor]) -> list[torch.Tensor]:
10531128 )
10541129 )
10551130 wait_for_comms (pending_comms )
1056- return data
1131+ return data
0 commit comments