[AWQ] Add output_mse_shrinkage: per-group clipping via activation projection#2492
[AWQ] Add output_mse_shrinkage: per-group clipping via activation projection#2492dzhengAP wants to merge 3 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review. Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed. |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses an inconsistency in AWQ's quantization process where the grid search for optimal scales used a different observer ( Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a configurable search_observer for the AWQ grid search, which is a valuable improvement for aligning the search objective with the final quantization observer. The changes are well-implemented, with backward compatibility maintained through a sensible default. The addition of a validator for the new parameter and a guard for non-finite loss values enhances robustness. The new unit tests effectively cover the changes. I have a couple of minor suggestions to improve code style and maintainability.
b136155 to
a1cf160
Compare
a1cf160 to
5f99e89
Compare
| [ | ||
| Observer.load_from_registry( | ||
| "memoryless_minmax", | ||
| self.search_observer, |
There was a problem hiding this comment.
do we ever want this to be different than the observer defined in the quantization config? Wondering if we can save users another configuration option on AWQModifier, while also being more flexible to use different observers for heterogeneous quantization. Maybe something like this:
| self.search_observer, | |
| balance_layer.quantization_scheme.observer or "memoryless_minmax", |
cc @HDCharles , wdyt?
AWQ's grid search optimizes only the per-channel scale (alpha), while shrinkage is determined independently by the observer using weight MSE. This misaligns the two optimizations — shrinkage is not optimized against the same output MSE objective used for scale search. This commit adds n_shrink_grid and maxshrink parameters to AWQModifier, enabling joint optimization of scale + shrinkage against output MSE: for each scale candidate (alpha), sweep over shrink factors p in (1-maxshrink, 1] and select the (alpha, p) pair that minimizes output MSE jointly. Benchmarked on Llama-3.1-8B-Instruct W4A16 (open-platypus calibration, WikiText-2 eval): AWQ baseline (n_shrink_grid=1): PPL 10.008 +0.000 AWQ joint shrinkage (n=5): PPL 10.007 -0.001 AWQ joint shrinkage (n=10): PPL 9.993 -0.014 Consistent improvement scaling with n_shrink_grid. Defaults preserve existing behaviour exactly (n_shrink_grid=1 disables joint search). Changes: - Add n_shrink_grid: int = 1 and maxshrink: float = 0.20 to AWQModifier - Implement shrinkage sweep inside _compute_best_scale loop - Add math.isfinite guard to skip non-finite grid search losses - Add docstrings and YAML recipe example for new parameters - 3 new unit tests Part of vllm-project#2479 Signed-off-by: David Zheng <dqzheng1996@gmail.com>
5f99e89 to
2bbf6e4
Compare
|
hmm, pretty insignificant relative to AWQ itself, probably we can cross this one off our list, its a bunch of runtime/complexity for almost no gain. |
|
Oh i realize the problem: the shrinkage factor needs to be distinct for each different quantization group, right now its hardcoding the shrinkage for all groups. This does seem workable, with some finesseing normally MSE observer reshapes the weight to be num_groups x groupsize and then picks the best shrinkage for each group since we can calculate the MSE for each group separately doing this for the module output is possible too for per channel quantization but becomes a little more complicated for group/block quantization what we could do: take weight (out_channels x in_channels) and activations (n x in_channels), this would reduce the O(num_groups*num_shrinkage_factors) optimization problem to just O(num_shrinkage_factors) |
…ion via activation projection
AWQ's grid search optimizes scale (alpha) against output MSE, but the
quantization clipping range (shrinkage) is determined independently by
the observer using weight MSE. These objectives are misaligned.
This commit adds output_mse_shrinkage: per-group shrinkage optimization
using the same output MSE objective as the scale search. For each scale
candidate, the best clipping factor p is selected per quantization group
by minimizing the activation-projected weight quantization error:
w_err = W_quant - W_scaled (out_ch, G, group_size)
out_err = einsum('ogs,ngs->ogn',
w_err, X_grouped) (out_ch, G, n_tokens)
err = out_err^2.sum(n) (out_ch, G)
X is collected from real calibration samples via a forward hook on the
balance layer, so the error reflects actual token distributions rather
than a proxy. Each group independently selects its optimal p, allowing
aggressive clipping where activations are small and conservative clipping
where they are large.
New parameters:
n_shrink_grid: int = 1 (1 = disabled, backward compatible)
maxshrink: float = 0.20 (search range: p in [1-maxshrink, 1.0])
Benchmarked on Llama-3.1-8B-Instruct W4A16 ASYM group=128,
open-platypus calibration, WikiText-2 eval:
Baseline (n_shrink_grid=1): PPL 9.995
output_mse_shrinkage (n=10): PPL 9.890 (-0.105) ← best
output_mse_shrinkage (n=50): PPL 9.953 (-0.042)
output_mse_shrinkage (n=100): PPL 9.941 (-0.054)
All n values improve over baseline. n=10 gives best result; improvement
is not strictly monotonic, suggesting diminishing returns from finer
shrinkage resolution on calibration data.
Implementation notes:
- Chunked einsum over out_ch to bound peak memory (~256MB per chunk)
- Activation samples capped at 2048 tokens to prevent OOM on large models
- 5 new unit tests
Part of vllm-project#2479
Signed-off-by: David Zheng <dqzheng1996@gmail.com>
|
Hi @HDCharles, I am attaching some new results following our discussion, please check below. Key Takeaway is the new methods works, and it doesn't need to go monotonously, as n = 10 shows the best results (probably due to not overfitting). MotivationAWQ's grid search optimizes scale (alpha) against output MSE, but the quantization clipping range (shrinkage) is determined independently by the observer using weight MSE. These objectives are misaligned. Methods & AlgoThis commit adds output_mse_shrinkage: per-group shrinkage optimization using the same output MSE objective as the scale search. For each scale candidate, the best clipping factor p is selected per quantization group by minimizing the activation-projected weight quantization error: w_err = W_quant - W_scaled (out_ch, G, group_size) X is collected from real calibration samples via a forward hook on the balance layer, so the error reflects actual token distributions rather than a proxy. Each group independently selects its optimal p, allowing aggressive clipping where activations are small and conservative clipping where they are large. New parameters: ResultsBenchmarked on Llama-3.1-8B-Instruct W4A16 ASYM group=128, open-platypus calibration, WikiText-2 eval: Baseline (n_shrink_grid=1): PPL 9.995 All n values improve over baseline. n=10 gives best result; improvement is not strictly monotonic, suggesting diminishing returns from finer shrinkage resolution on calibration data. ##Notes
Part of #2479 |
[AWQ] Add output_mse_shrinkage: per-group clipping via activation projectionPart of #2479 ProblemAWQ's grid search optimizes the per-channel scale (α) against output MSE via a full forward pass. However, the clipping range (shrinkage factor p) is determined independently by the observer using weight MSE. These two objectives are misaligned — the clipping range is never evaluated against the same output MSE objective that drives the scale search. Solution
Usagerecipe = [
AWQModifier(
ignore=["lm_head"],
scheme="W4A16_ASYM",
targets=["Linear"],
n_shrink_grid=10, # number of shrink candidates (1 = disabled)
maxshrink=0.20, # search range: p in [1-maxshrink, 1.0]
),
QuantizationModifier(...),
]Ablation resultsModel:
Our method achieves the best PPL, improving -0.020 over vanilla AWQ and -0.202 over the RTN baseline. Implementation notes
Changes
|
Part of vllm-project#2479 Signed-off-by: David Zheng <dqzheng1996@gmail.com>
967b43e to
60ec28b
Compare
Problem
AWQ's grid search optimizes the per-channel scale (α) against output MSE via a full forward pass. However, the clipping range (shrinkage factor p) is determined independently by the observer using weight MSE. These two objectives are misaligned — the clipping range is never evaluated against the same output MSE objective that drives the scale search.
Solution
output_mse_shrinkage: for each scale candidate α, find the best clipping factorpper quantization group by minimizing the activation-projected weight quantization error:Xis collected from real calibration samples via a forward hook on each balance layer, giving true output-space error per group. Each group independently selects its optimal clipping factorp— groups with outlier weights but small activations can be clipped aggressively, while groups with large activations remain conservative.Usage
Ablation results
Model:
meta-llama/Llama-3.1-8B-Instruct, W4A16 ASYM group=128, open-platypus 128 samples, WikiText-2 evalOur method achieves the best PPL, improving -0.020 over vanilla AWQ and -0.202 over the RTN baseline.
Implementation notes
out_ch(~256MB per chunk) to bound peak GPU memoryn_shrink_grid=1(default) disables shrinkage entirely — fully backward compatible, no existing recipes affectedChanges
n_shrink_grid: int = 1andmaxshrink: float = 0.20fields toAWQModifier_apply_output_mse_shrinkage— per-group clipping optimization via activation projectionPart of #2479
cc @HDCharles @brian-dellabetta