Skip to content

Commit 2bbf6e4

Browse files
committed
[modifier/awq] add joint scale+shrinkage optimization to AWQ grid search
AWQ's grid search optimizes only the per-channel scale (alpha), while shrinkage is determined independently by the observer using weight MSE. This misaligns the two optimizations — shrinkage is not optimized against the same output MSE objective used for scale search. This commit adds n_shrink_grid and maxshrink parameters to AWQModifier, enabling joint optimization of scale + shrinkage against output MSE: for each scale candidate (alpha), sweep over shrink factors p in (1-maxshrink, 1] and select the (alpha, p) pair that minimizes output MSE jointly. Benchmarked on Llama-3.1-8B-Instruct W4A16 (open-platypus calibration, WikiText-2 eval): AWQ baseline (n_shrink_grid=1): PPL 10.008 +0.000 AWQ joint shrinkage (n=5): PPL 10.007 -0.001 AWQ joint shrinkage (n=10): PPL 9.993 -0.014 Consistent improvement scaling with n_shrink_grid. Defaults preserve existing behaviour exactly (n_shrink_grid=1 disables joint search). Changes: - Add n_shrink_grid: int = 1 and maxshrink: float = 0.20 to AWQModifier - Implement shrinkage sweep inside _compute_best_scale loop - Add math.isfinite guard to skip non-finite grid search losses - Add docstrings and YAML recipe example for new parameters - 3 new unit tests Part of #2479 Signed-off-by: David Zheng <dqzheng1996@gmail.com>
1 parent cf3bd64 commit 2bbf6e4

File tree

2 files changed

+167
-57
lines changed

2 files changed

+167
-57
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 132 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import math
23
from itertools import product
34
from typing import Iterator, Literal
45

@@ -84,6 +85,10 @@ class AWQModifier(Modifier, QuantizationMixin):
8485
# This change is only useful for MoE models with parallel transformer blocks,
8586
# and one should use the default value (None) in most cases.
8687
ignore: ["lm_head"]
88+
# Set search_observer to match the observer used in the final QuantizationModifier
89+
# for best scale alignment. "memoryless_mse" is recommended when pairing with
90+
# MSE-based weight quantization. Defaults to "memoryless_minmax".
91+
search_observer: "memoryless_minmax"
8792
config_groups:
8893
group_0:
8994
targets:
@@ -151,6 +156,21 @@ class AWQModifier(Modifier, QuantizationMixin):
151156
this specifies how many grid points should be used. To decrease the runtime,
152157
at the possible cost of slightly worse scales, this can be decreased.
153158
Defaults to 20
159+
:param search_observer: name of the observer used to simulate quantization during
160+
the grid search. For best accuracy, this should match the observer used in the
161+
final QuantizationModifier (e.g. "memoryless_mse" when using MSE-based weight
162+
quantization). Defaults to "memoryless_minmax" for backward compatibility.
163+
Valid options: "memoryless_minmax", "memoryless_mse"
164+
:param n_shrink_grid: number of shrinkage grid points to jointly search over
165+
alongside the scale grid. When > 1, for each scale candidate (α) the grid
166+
search also sweeps over shrink factors p ∈ (1 - maxshrink, 1], selecting the
167+
(α, p) pair that minimizes output MSE jointly. This implements the joint
168+
scale + shrinkage optimization described in issue #2479.
169+
Defaults to 1 (disabled, shrinkage determined solely by the observer).
170+
Recommended value: 5-10 for a lightweight joint search.
171+
:param maxshrink: maximum shrinkage factor used when n_shrink_grid > 1.
172+
Shrink factors are swept from 1.0 down to (1 - maxshrink).
173+
Defaults to 0.20 (matching the memoryless_mse observer default).
154174
"""
155175

156176
# Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
@@ -162,6 +182,9 @@ class AWQModifier(Modifier, QuantizationMixin):
162182
offload_device: torch.device | None | Sentinel = Sentinel("not_provided")
163183
duo_scaling: bool | Literal["both"] = True
164184
n_grid: int = 20
185+
search_observer: str = "memoryless_minmax"
186+
n_shrink_grid: int = 1
187+
maxshrink: float = 0.20
165188

166189
# Private vars set during initialization, cleared during finalization
167190
_resolved_mappings: list[ResolvedMapping] = PrivateAttr(default_factory=list)
@@ -673,8 +696,9 @@ def _compute_best_scale(
673696
n_grid = self.n_grid
674697
duo_scalings = [self.duo_scaling]
675698

676-
# Where appropriate, replace observers with memoryless_minmax
677-
# for duration of grid search
699+
# Where appropriate, replace observers with search_observer
700+
# for duration of grid search. This should match the observer
701+
# used in the final QuantizationModifier for best scale alignment.
678702
balance_layers_to_patch = [
679703
balance_layer
680704
for balance_layer in mapping.balance_layers
@@ -686,7 +710,7 @@ def _compute_best_scale(
686710
"weight_observer",
687711
[
688712
Observer.load_from_registry(
689-
"memoryless_minmax",
713+
self.search_observer,
690714
base_name="weight",
691715
args=balance_layer.quantization_scheme.weights,
692716
module=balance_layer,
@@ -717,65 +741,103 @@ def _compute_best_scale(
717741
scales[torch.isnan(scales)] = 1
718742
_scalesview = scales.view(1, -1).to(device)
719743

720-
# Q(W * s)
721-
for balance_layer in balance_layers_to_patch:
722-
if not hasattr(balance_layer, "quantization_scheme") or not hasattr(
723-
balance_layer.quantization_scheme, "weights"
724-
):
725-
continue
726-
727-
w_qscheme = balance_layer.quantization_scheme.weights
728-
balance_layer.weight.data.copy_(
729-
orig_layer_weights[balance_layer].to(_scalesview.device)
730-
* _scalesview
731-
)
744+
# Build shrink factor candidates:
745+
# n_shrink_grid=1 means no shrinkage search (p=1.0, minmax behaviour)
746+
# n_shrink_grid>1 sweeps p from 1.0 down to (1 - maxshrink)
747+
if self.n_shrink_grid > 1:
748+
shrink_factors = [
749+
1.0 - (self.maxshrink * i / (self.n_shrink_grid - 1))
750+
for i in range(self.n_shrink_grid)
751+
]
752+
else:
753+
shrink_factors = [1.0]
754+
755+
for shrink_p in shrink_factors:
756+
# Q(W * s) with optional shrinkage applied to observer range
757+
for balance_layer in balance_layers_to_patch:
758+
if not hasattr(balance_layer, "quantization_scheme") or not hasattr(
759+
balance_layer.quantization_scheme, "weights"
760+
):
761+
continue
762+
763+
w_qscheme = balance_layer.quantization_scheme.weights
764+
balance_layer.weight.data.copy_(
765+
orig_layer_weights[balance_layer].to(_scalesview.device)
766+
* _scalesview
767+
)
732768

733-
should_calculate_gparam = (
734-
w_qscheme.strategy == QuantizationStrategy.TENSOR_GROUP
735-
)
736-
call_observer(
737-
balance_layer,
738-
"weight",
739-
balance_layer.weight,
740-
should_calculate_gparam=should_calculate_gparam,
741-
)
742-
balance_layer.weight.data = (
743-
forward_quantize(
744-
balance_layer,
745-
balance_layer.weight,
746-
"weight",
747-
w_qscheme,
769+
should_calculate_gparam = (
770+
w_qscheme.strategy == QuantizationStrategy.TENSOR_GROUP
748771
)
749-
/ _scalesview
750-
).to(balance_layer.weight.dtype)
751-
752-
# Apply fused global scales for TENSOR_GROUP during grid search
753-
# to match inference behavior
754-
if balance_layers_to_patch and all(
755-
getattr(layer.quantization_scheme.weights, "strategy", None)
756-
== QuantizationStrategy.TENSOR_GROUP
757-
for layer in balance_layers_to_patch
758-
):
759-
update_fused_layer_weight_global_scales(mapping.parent)
760772

761-
# W * X
762-
int_w_outputs = self._run_samples(mapping.parent)
773+
if shrink_p < 1.0:
774+
# Joint shrinkage: override observer min/max with shrunk range
775+
# using output MSE as the objective (not weight MSE)
776+
w = balance_layer.weight.data
777+
w_min = w.amin(dim=-1, keepdim=True) * shrink_p
778+
w_max = w.amax(dim=-1, keepdim=True) * shrink_p
779+
from compressed_tensors.quantization.utils import calculate_qparams
780+
scale, zp = calculate_qparams(
781+
min_vals=w_min.squeeze(-1),
782+
max_vals=w_max.squeeze(-1),
783+
quantization_args=w_qscheme,
784+
)
785+
# store directly into the layer's weight_scale / weight_zero_point
786+
from compressed_tensors.utils import update_parameter_data
787+
update_parameter_data(balance_layer, scale, "weight_scale")
788+
update_parameter_data(balance_layer, zp, "weight_zero_point")
789+
else:
790+
call_observer(
791+
balance_layer,
792+
"weight",
793+
balance_layer.weight,
794+
should_calculate_gparam=should_calculate_gparam,
795+
)
796+
797+
balance_layer.weight.data = (
798+
forward_quantize(
799+
balance_layer,
800+
balance_layer.weight,
801+
"weight",
802+
w_qscheme,
803+
)
804+
/ _scalesview
805+
).to(balance_layer.weight.dtype)
806+
807+
# Apply fused global scales for TENSOR_GROUP during grid search
808+
# to match inference behavior
809+
if balance_layers_to_patch and all(
810+
getattr(layer.quantization_scheme.weights, "strategy", None)
811+
== QuantizationStrategy.TENSOR_GROUP
812+
for layer in balance_layers_to_patch
813+
):
814+
update_fused_layer_weight_global_scales(mapping.parent)
815+
816+
# W * X
817+
int_w_outputs = self._run_samples(mapping.parent)
763818

764-
# compute mean squared error (L2 norm)
765-
loss = self._compute_loss(fp16_outputs, int_w_outputs)
766-
del int_w_outputs
819+
# compute mean squared error (L2 norm)
820+
loss = self._compute_loss(fp16_outputs, int_w_outputs)
821+
del int_w_outputs
767822

768-
if initial_error is None:
769-
initial_error = loss
823+
# skip non-finite losses (can occur with aggressive MSE clipping)
824+
if not math.isfinite(loss):
825+
history.append(
826+
{"ratio": ratio, "shrink": shrink_p, "duo_scaling": use_duo_scaling, "error": loss}
827+
)
828+
continue
770829

771-
history.append(
772-
{"ratio": ratio, "duo_scaling": use_duo_scaling, "error": loss}
773-
)
774-
if loss < best_error:
775-
best_error = loss
776-
best_ratio = ratio
777-
best_scales = scales.clone()
778-
pbar.set_postfix({"best_error": f"{best_error:.3e}"})
830+
if initial_error is None:
831+
initial_error = loss
832+
833+
history.append(
834+
{"ratio": ratio, "shrink": shrink_p, "duo_scaling": use_duo_scaling, "error": loss}
835+
)
836+
if loss < best_error:
837+
best_error = loss
838+
best_ratio = ratio
839+
best_scales = scales.clone()
840+
pbar.set_postfix({"best_error": f"{best_error:.3e}"})
779841

780842
if best_ratio == -1:
781843
logger.debug(history)
@@ -985,6 +1047,19 @@ def validate_duo_scaling(cls, v):
9851047
raise ValueError(f"duo_scaling must be True, False, or 'both', got {v!r}")
9861048
return v
9871049

1050+
@field_validator("search_observer")
1051+
@classmethod
1052+
def validate_search_observer(cls, v):
1053+
"""Validate that search_observer is a memoryless observer (stateless per call)"""
1054+
valid = {"memoryless_minmax", "memoryless_mse"}
1055+
if v not in valid:
1056+
raise ValueError(
1057+
f"search_observer must be one of {valid}, got {v!r}. "
1058+
"Only memoryless observers are supported to avoid accumulating "
1059+
"statistics across grid search iterations."
1060+
)
1061+
return v
1062+
9881063

9891064
def _check_layers_are_compatible(
9901065
smooth_layer, smooth_name, balance_layers, balance_names
@@ -1053,4 +1128,4 @@ def _allreduce_data_sum(data: list[torch.Tensor]) -> list[torch.Tensor]:
10531128
)
10541129
)
10551130
wait_for_comms(pending_comms)
1056-
return data
1131+
return data

tests/llmcompressor/modifiers/awq/test_base.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,3 +675,38 @@ def test_block_strategy_compute_layer_means(rows, cols, block_height, block_widt
675675
# check
676676
assert_close(llmc_awq_means, ref_means, atol=1e-5, rtol=1e-5)
677677
assert_close(llmc_awq_means, auto_awq_means, atol=1e-5, rtol=1e-5)
678+
679+
680+
# ---------------------------------------------------------------------------
681+
# search_observer tests
682+
# ---------------------------------------------------------------------------
683+
684+
685+
@pytest.mark.unit
686+
def test_search_observer_default():
687+
"""search_observer defaults to memoryless_minmax (backward compat)"""
688+
modifier = AWQModifier(scheme="W4A16_ASYM")
689+
assert modifier.search_observer == "memoryless_minmax"
690+
691+
692+
@pytest.mark.unit
693+
def test_search_observer_mse_accepted():
694+
"""memoryless_mse is a valid search_observer value"""
695+
modifier = AWQModifier(scheme="W4A16_ASYM", search_observer="memoryless_mse")
696+
assert modifier.search_observer == "memoryless_mse"
697+
698+
699+
@pytest.mark.unit
700+
def test_search_observer_invalid_rejected():
701+
"""Non-memoryless or unknown observers are rejected"""
702+
from pydantic import ValidationError
703+
704+
with pytest.raises(ValidationError, match="search_observer must be one of"):
705+
AWQModifier(scheme="W4A16_ASYM", search_observer="minmax")
706+
707+
with pytest.raises(ValidationError, match="search_observer must be one of"):
708+
AWQModifier(scheme="W4A16_ASYM", search_observer="mse")
709+
710+
with pytest.raises(ValidationError, match="search_observer must be one of"):
711+
AWQModifier(scheme="W4A16_ASYM", search_observer="invalid_observer")
712+

0 commit comments

Comments
 (0)