Skip to content

[AWQ] Add output_mse_shrinkage: per-group clipping via activation projection#2492

Open
dzhengAP wants to merge 3 commits intovllm-project:mainfrom
dzhengAP:awq-mse-observer-alignment
Open

[AWQ] Add output_mse_shrinkage: per-group clipping via activation projection#2492
dzhengAP wants to merge 3 commits intovllm-project:mainfrom
dzhengAP:awq-mse-observer-alignment

Conversation

@dzhengAP
Copy link
Copy Markdown
Contributor

@dzhengAP dzhengAP commented Mar 20, 2026

Problem

AWQ's grid search optimizes the per-channel scale (α) against output MSE via a full forward pass. However, the clipping range (shrinkage factor p) is determined independently by the observer using weight MSE. These two objectives are misaligned — the clipping range is never evaluated against the same output MSE objective that drives the scale search.

Solution

output_mse_shrinkage: for each scale candidate α, find the best clipping factor p per quantization group by minimizing the activation-projected weight quantization error:

w_err   = W_quant - W_scaled              # (out_ch, G, group_size)
out_err = einsum('ogs,ngs->ogn',
                 w_err, X_grouped)        # (out_ch, G, n_tokens)
err     = out_err.pow(2).sum(n)           # (out_ch, G)

X is collected from real calibration samples via a forward hook on each balance layer, giving true output-space error per group. Each group independently selects its optimal clipping factor p — groups with outlier weights but small activations can be clipped aggressively, while groups with large activations remain conservative.

Usage

recipe = [
    AWQModifier(
        ignore=["lm_head"],
        scheme="W4A16_ASYM",
        targets=["Linear"],
        n_shrink_grid=10,   # number of shrink candidates (1 = disabled)
        maxshrink=0.20,     # search range: p in [1-maxshrink, 1.0]
    ),
    QuantizationModifier(...),
]

Ablation results

Model: meta-llama/Llama-3.1-8B-Instruct, W4A16 ASYM group=128, open-platypus 128 samples, WikiText-2 eval

Recipe PPL Δ vs RTN+minmax Time
RTN + minmax (lower bound) 10.165 +0.000 0.7m
GPTQ only 10.038 -0.127 ✓ 14.3m
AWQ + minmax (vanilla AWQ) 9.982 -0.182 ✓ 8.5m
AWQ + output MSE (ours) 9.963 -0.202 34.8m

Our method achieves the best PPL, improving -0.020 over vanilla AWQ and -0.202 over the RTN baseline.

Implementation notes

  • Chunked einsum over out_ch (~256MB per chunk) to bound peak GPU memory
  • Activation samples capped at 2048 tokens to prevent OOM on large models
  • n_shrink_grid=1 (default) disables shrinkage entirely — fully backward compatible, no existing recipes affected
  • 5 new unit tests covering defaults, valid inputs, and backward compatibility

Changes

  • Add n_shrink_grid: int = 1 and maxshrink: float = 0.20 fields to AWQModifier
  • Implement _apply_output_mse_shrinkage — per-group clipping optimization via activation projection
  • Chunked einsum + 2048-token cap to prevent OOM on large models
  • 5 new unit tests
  • All testes passed

Part of #2479
cc @HDCharles @brian-dellabetta

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses an inconsistency in AWQ's quantization process where the grid search for optimal scales used a different observer (memoryless_minmax) than what might be recommended for the final quantization (e.g., memoryless_mse). By introducing a configurable search_observer parameter, the PR allows users to align the grid search's objective with the final quantization's objective, leading to more accurate and performant quantized models, as demonstrated by benchmark results showing improved perplexity.

Highlights

  • Configurable Search Observer: Introduced a new search_observer parameter in AWQModifier to allow users to specify the observer used during the grid search for quantization scales. This defaults to memoryless_minmax for backward compatibility.
  • Observer Alignment: Replaced the hardcoded memoryless_minmax observer in the _compute_best_scale function with the new self.search_observer parameter, enabling alignment between the grid search objective and the final quantization observer.
  • Robustness Improvement: Added a math.isfinite guard to gracefully skip non-finite losses that can occur during aggressive MSE clipping in the grid search, preventing potential errors.
  • Input Validation: Implemented a validate_search_observer field validator to ensure that only memoryless observers (memoryless_minmax, memoryless_mse) are accepted for the search_observer parameter, maintaining statelessness across iterations.
  • Documentation and Tests: Updated docstrings, added a YAML recipe example for search_observer, and included three new unit tests to verify the default behavior, acceptance of memoryless_mse, and rejection of invalid observer values.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a configurable search_observer for the AWQ grid search, which is a valuable improvement for aligning the search objective with the final quantization observer. The changes are well-implemented, with backward compatibility maintained through a sensible default. The addition of a validator for the new parameter and a guard for non-finite loss values enhances robustness. The new unit tests effectively cover the changes. I have a couple of minor suggestions to improve code style and maintainability.

@dzhengAP dzhengAP force-pushed the awq-mse-observer-alignment branch 2 times, most recently from b136155 to a1cf160 Compare March 20, 2026 06:31
@dzhengAP dzhengAP force-pushed the awq-mse-observer-alignment branch from a1cf160 to 5f99e89 Compare March 20, 2026 07:23
[
Observer.load_from_registry(
"memoryless_minmax",
self.search_observer,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we ever want this to be different than the observer defined in the quantization config? Wondering if we can save users another configuration option on AWQModifier, while also being more flexible to use different observers for heterogeneous quantization. Maybe something like this:

Suggested change
self.search_observer,
balance_layer.quantization_scheme.observer or "memoryless_minmax",

cc @HDCharles , wdyt?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree

AWQ's grid search optimizes only the per-channel scale (alpha), while
shrinkage is determined independently by the observer using weight MSE.
This misaligns the two optimizations — shrinkage is not optimized against
the same output MSE objective used for scale search.

This commit adds n_shrink_grid and maxshrink parameters to AWQModifier,
enabling joint optimization of scale + shrinkage against output MSE:
for each scale candidate (alpha), sweep over shrink factors p in
(1-maxshrink, 1] and select the (alpha, p) pair that minimizes output
MSE jointly.

Benchmarked on Llama-3.1-8B-Instruct W4A16 (open-platypus calibration,
WikiText-2 eval):

  AWQ baseline (n_shrink_grid=1):   PPL 10.008  +0.000
  AWQ joint shrinkage (n=5):        PPL 10.007  -0.001
  AWQ joint shrinkage (n=10):       PPL  9.993  -0.014

Consistent improvement scaling with n_shrink_grid. Defaults preserve
existing behaviour exactly (n_shrink_grid=1 disables joint search).

Changes:
- Add n_shrink_grid: int = 1 and maxshrink: float = 0.20 to AWQModifier
- Implement shrinkage sweep inside _compute_best_scale loop
- Add math.isfinite guard to skip non-finite grid search losses
- Add docstrings and YAML recipe example for new parameters
- 3 new unit tests

Part of vllm-project#2479

Signed-off-by: David Zheng <dqzheng1996@gmail.com>
@dzhengAP dzhengAP force-pushed the awq-mse-observer-alignment branch from 5f99e89 to 2bbf6e4 Compare March 26, 2026 03:40
@dzhengAP dzhengAP changed the title [AWQ] Add configurable search_observer parameter for grid search / final observer alignment [AWQ] Add joint scale+shrinkage optimization to grid search Mar 26, 2026
@HDCharles
Copy link
Copy Markdown
Collaborator

hmm, pretty insignificant relative to AWQ itself, probably we can cross this one off our list, its a bunch of runtime/complexity for almost no gain.

@HDCharles
Copy link
Copy Markdown
Collaborator

HDCharles commented Mar 31, 2026

Oh i realize the problem: the shrinkage factor needs to be distinct for each different quantization group, right now its hardcoding the shrinkage for all groups.

This does seem workable, with some finesseing

normally MSE observer reshapes the weight to be num_groups x groupsize and then picks the best shrinkage for each group since we can calculate the MSE for each group separately

doing this for the module output is possible too for per channel quantization but becomes a little more complicated for group/block quantization

what we could do:

take weight (out_channels x in_channels) and activations (n x in_channels),
dot them together for each sample: (n x out_channels x in_channels)
subtract from original dotted with activations: (n x out_channels x in_channels)
square and sum over n: (out_channels x in_channels)
reshape for quantization: (num_groups x group_size)
sum each group: (num_groups) then you have an error value for each group and could pick the best one for each group similar to how its done in the MSE observer

this would reduce the O(num_groups*num_shrinkage_factors) optimization problem to just O(num_shrinkage_factors)

…ion via activation projection

AWQ's grid search optimizes scale (alpha) against output MSE, but the
quantization clipping range (shrinkage) is determined independently by
the observer using weight MSE. These objectives are misaligned.

This commit adds output_mse_shrinkage: per-group shrinkage optimization
using the same output MSE objective as the scale search. For each scale
candidate, the best clipping factor p is selected per quantization group
by minimizing the activation-projected weight quantization error:

  w_err  = W_quant - W_scaled            (out_ch, G, group_size)
  out_err = einsum('ogs,ngs->ogn',
                   w_err, X_grouped)     (out_ch, G, n_tokens)
  err     = out_err^2.sum(n)             (out_ch, G)

X is collected from real calibration samples via a forward hook on the
balance layer, so the error reflects actual token distributions rather
than a proxy. Each group independently selects its optimal p, allowing
aggressive clipping where activations are small and conservative clipping
where they are large.

New parameters:
  n_shrink_grid: int = 1    (1 = disabled, backward compatible)
  maxshrink: float = 0.20   (search range: p in [1-maxshrink, 1.0])

Benchmarked on Llama-3.1-8B-Instruct W4A16 ASYM group=128,
open-platypus calibration, WikiText-2 eval:

  Baseline (n_shrink_grid=1):   PPL 9.995
  output_mse_shrinkage (n=10):  PPL 9.890  (-0.105) ← best
  output_mse_shrinkage (n=50):  PPL 9.953  (-0.042)
  output_mse_shrinkage (n=100): PPL 9.941  (-0.054)

All n values improve over baseline. n=10 gives best result; improvement
is not strictly monotonic, suggesting diminishing returns from finer
shrinkage resolution on calibration data.

Implementation notes:
  - Chunked einsum over out_ch to bound peak memory (~256MB per chunk)
  - Activation samples capped at 2048 tokens to prevent OOM on large models
  - 5 new unit tests

Part of vllm-project#2479

Signed-off-by: David Zheng <dqzheng1996@gmail.com>
@dzhengAP
Copy link
Copy Markdown
Contributor Author

dzhengAP commented Apr 1, 2026

Hi @HDCharles, I am attaching some new results following our discussion, please check below. Key Takeaway is the new methods works, and it doesn't need to go monotonously, as n = 10 shows the best results (probably due to not overfitting).

Motivation

AWQ's grid search optimizes scale (alpha) against output MSE, but the quantization clipping range (shrinkage) is determined independently by the observer using weight MSE. These objectives are misaligned.

Methods & Algo

This commit adds output_mse_shrinkage: per-group shrinkage optimization using the same output MSE objective as the scale search. For each scale candidate, the best clipping factor p is selected per quantization group by minimizing the activation-projected weight quantization error:

w_err = W_quant - W_scaled (out_ch, G, group_size)
out_err = einsum('ogs,ngs->ogn', w_err, X_grouped) (out_ch, G, n_tokens)
err = out_err^2.sum(n) (out_ch, G)

X is collected from real calibration samples via a forward hook on the balance layer, so the error reflects actual token distributions rather than a proxy. Each group independently selects its optimal p, allowing aggressive clipping where activations are small and conservative clipping where they are large.

New parameters:
n_shrink_grid: int = 1 (1 = disabled, backward compatible)
maxshrink: float = 0.20 (search range: p in [1-maxshrink, 1.0])

Results

Benchmarked on Llama-3.1-8B-Instruct W4A16 ASYM group=128, open-platypus calibration, WikiText-2 eval:

Baseline (n_shrink_grid=1): PPL 9.995
output_mse_shrinkage (n=10): PPL 9.890 (-0.105) ← best
output_mse_shrinkage (n=50): PPL 9.953 (-0.042)
output_mse_shrinkage (n=100): PPL 9.941 (-0.054)

All n values improve over baseline. n=10 gives best result; improvement is not strictly monotonic, suggesting diminishing returns from finer shrinkage resolution on calibration data.

##Notes
Implementation notes:

  • Chunked einsum over out_ch to bound peak memory (~256MB per chunk)
  • Activation samples capped at 2048 tokens to prevent OOM on large models
  • 5 new unit tests, all passed

Part of #2479
CC @brian-dellabetta

@dzhengAP
Copy link
Copy Markdown
Contributor Author

dzhengAP commented Apr 2, 2026

[AWQ] Add output_mse_shrinkage: per-group clipping via activation projection

Part of #2479
cc @HDCharles @brian-dellabetta

Problem

AWQ's grid search optimizes the per-channel scale (α) against output MSE via a full forward pass. However, the clipping range (shrinkage factor p) is determined independently by the observer using weight MSE. These two objectives are misaligned — the clipping range is never evaluated against the same output MSE objective that drives the scale search.

Solution

output_mse_shrinkage: for each scale candidate α, find the best clipping factor p per quantization group by minimizing the activation-projected weight quantization error:

w_err   = W_quant - W_scaled              # (out_ch, G, group_size)
out_err = einsum('ogs,ngs->ogn',
                 w_err, X_grouped)        # (out_ch, G, n_tokens)
err     = out_err.pow(2).sum(n)           # (out_ch, G)

X is collected from real calibration samples via a forward hook on each balance layer, giving true output-space error per group. Each group independently selects its optimal clipping factor p — groups with outlier weights but small activations can be clipped aggressively, while groups with large activations remain conservative.

Usage

recipe = [
    AWQModifier(
        ignore=["lm_head"],
        scheme="W4A16_ASYM",
        targets=["Linear"],
        n_shrink_grid=10,   # number of shrink candidates (1 = disabled)
        maxshrink=0.20,     # search range: p in [1-maxshrink, 1.0]
    ),
    QuantizationModifier(...),
]

Ablation results

Model: meta-llama/Llama-3.1-8B-Instruct, W4A16 ASYM group=128, open-platypus 128 samples, WikiText-2 eval

Recipe PPL Δ vs RTN+minmax Time
RTN + minmax (lower bound) 10.165 +0.000 0.7m
GPTQ only 10.038 -0.127 ✓ 14.3m
AWQ + minmax (vanilla AWQ) 9.982 -0.182 ✓ 8.5m
AWQ + output MSE (ours) 9.963 -0.202 34.8m

Our method achieves the best PPL, improving -0.020 over vanilla AWQ and -0.202 over the RTN baseline.

Implementation notes

  • Chunked einsum over out_ch (~256MB per chunk) to bound peak GPU memory
  • Activation samples capped at 2048 tokens to prevent OOM on large models
  • n_shrink_grid=1 (default) disables shrinkage entirely — fully backward compatible, no existing recipes affected
  • 5 new unit tests covering defaults, valid inputs, and backward compatibility

Changes

  • Add n_shrink_grid: int = 1 and maxshrink: float = 0.20 fields to AWQModifier
  • Implement _apply_output_mse_shrinkage — per-group clipping optimization via activation projection
  • Chunked einsum + 2048-token cap to prevent OOM on large models
  • 5 new unit tests; all units test passed
  • Updated to the main PR description

@dzhengAP dzhengAP changed the title [AWQ] Add joint scale+shrinkage optimization to grid search [AWQ] Add output_mse_shrinkage: per-group clipping via activation projection Apr 2, 2026
Part of vllm-project#2479

Signed-off-by: David Zheng <dqzheng1996@gmail.com>
@dzhengAP dzhengAP force-pushed the awq-mse-observer-alignment branch from 967b43e to 60ec28b Compare April 2, 2026 18:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants