Skip to content

Commit bf4aa31

Browse files
Adds AWQ guide
1 parent b14d03b commit bf4aa31

File tree

7 files changed

+789
-1
lines changed

7 files changed

+789
-1
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""
2+
Title: AWQ Quantization in Keras
3+
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4+
Date created: 2025/01/15
5+
Last modified: 2025/01/15
6+
Description: How to run weight-only AWQ quantization for Keras & KerasHub models.
7+
Accelerator: GPU
8+
"""
9+
10+
"""
11+
## What is AWQ?
12+
13+
AWQ (Activation-aware Weight Quantization) is a post-training, weight-only
14+
quantization method that uses activation statistics to identify and protect
15+
salient weights during quantization.
16+
17+
The key insight of AWQ is that not all weights are equally important: a small
18+
fraction of weights (typically <1%) are "salient" because they process
19+
channels with large activation magnitudes. By protecting these weights from
20+
quantization error, AWQ preserves model quality while achieving significant
21+
compression.
22+
23+
Unlike GPTQ which uses second-order (Hessian-based) optimization, AWQ uses a
24+
simpler grid search to find per-channel scales that minimize activation-weighted
25+
quantization error. This makes AWQ generally faster while achieving competitive
26+
accuracy.
27+
28+
### How it works
29+
30+
1. Run a small calibration set through the model to collect per-channel
31+
activation magnitudes.
32+
2. For each weight matrix, search for optimal per-channel scales that
33+
minimize activation-weighted quantization error.
34+
3. Multiply weights by the optimal scales before quantization
35+
(expanding salient weights).
36+
4. Quantize the scaled weights to 4-bit (or other supported bit-width) integers.
37+
5. During inference, dequantize weights and divide by scales to
38+
restore original magnitude.
39+
40+
The scale formula uses: `scales = activation_max^ratio` where ratio is
41+
searched over a grid from 0 to 1.
42+
43+
Keras supports AWQ quantization for KerasHub models via the
44+
`keras.quantizers.AWQConfig` class.
45+
"""
46+
47+
"""
48+
## Load a KerasHub model
49+
50+
This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B
51+
parameter) causal language model.
52+
53+
"""
54+
import keras
55+
from keras_hub.models import Gemma3CausalLM
56+
from datasets import load_dataset
57+
58+
59+
prompt = "Keras is a"
60+
61+
model = Gemma3CausalLM.from_preset("gemma3_1b")
62+
63+
outputs = model.generate(prompt, max_length=30)
64+
print(outputs)
65+
66+
"""
67+
## Configure & run AWQ quantization
68+
69+
You can configure AWQ quantization via the `keras.quantizers.AWQConfig` class.
70+
71+
The AWQ configuration requires a calibration dataset and tokenizer, which it
72+
uses to collect activation statistics and search for optimal scales. Here, we
73+
use a small slice of the WikiText-2 dataset for calibration.
74+
75+
Key parameters:
76+
77+
* `weight_bits`: The bit-width to quantize weights to. AWQ currently only
78+
supports 4-bit quantization.
79+
* `group_size`: The number of input features to quantize together. Smaller
80+
groups typically yield better accuracy but may use more memory. Use -1 for
81+
per-channel (no grouping). A good starting point is 128.
82+
* `num_grid_points`: The number of points to search over when finding optimal
83+
scales. More points give finer granularity but increase calibration time.
84+
Default is 20.
85+
* `num_samples`: Number of calibration samples to use for activation
86+
collection.
87+
* `sequence_length`: Maximum sequence length for calibration samples.
88+
89+
In this example, we first prepare a tiny calibration set, and then run AWQ on
90+
the model using the `.quantize(...)` API.
91+
"""
92+
93+
# Calibration slice (use a larger/representative set in practice)
94+
texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"]
95+
96+
calibration_dataset = [
97+
s + "." for text in texts for s in map(str.strip, text.split(".")) if s
98+
]
99+
100+
awq_config = keras.quantizers.AWQConfig(
101+
dataset=calibration_dataset,
102+
tokenizer=model.preprocessor.tokenizer,
103+
weight_bits=4,
104+
group_size=128,
105+
num_grid_points=20,
106+
num_samples=128,
107+
sequence_length=256,
108+
)
109+
110+
model.quantize("awq", config=awq_config)
111+
112+
outputs = model.generate(prompt, max_length=30)
113+
print(outputs)
114+
115+
"""
116+
## Model Export
117+
118+
The AWQ quantized model can be saved to a preset and reloaded elsewhere, just
119+
like any other KerasHub model.
120+
"""
121+
122+
model.save_to_preset("gemma3_awq_w4gs128_preset")
123+
model_from_preset = Gemma3CausalLM.from_preset("gemma3_awq_w4gs128_preset")
124+
output = model_from_preset.generate(prompt, max_length=30)
125+
print(output)
126+
127+
"""
128+
## Performance & Benchmarking
129+
130+
Micro-benchmarks collected on a single RTX 4070 Ti Super (16 GB).
131+
Baselines are BF16 for Gemma3, and FP32 for Qwen3 and OPT.
132+
133+
Dataset: WikiText-2.
134+
135+
136+
| Model | Pre PPL | Post PPL | PPL Change | Disk Size Change | GPU Mem Change | Throughput Change |
137+
| ----- | ------: | -------: | ---------: | ---------------: | -------------: | ----------------: |
138+
| Qwen3 1.7B | 37.65 | 45.79 | +21.64% | -70.7% | -69.9% | -10.4% |
139+
| Gemma3 1B | 172.45 | 178.03 | +3.23% | -60.2% | -58.3% | -15.5% |
140+
| OPT 125M | 77.06 | 84.75 | +9.97% | -58.3% | -40.9% | -3.3% |
141+
142+
143+
### Analysis
144+
145+
* **Disk size reduction**: 58-71% across models due to 4-bit weight storage.
146+
* **GPU memory reduction**: 41-70% during inference.
147+
* **Perplexity degradation**: +3.2% (Gemma3 1B) to +21.6% (Qwen3 1.7B), model-dependent.
148+
* **Throughput**: -3% to -15% due to dequantization overhead.
149+
150+
AWQ provides substantial memory savings with modest quality degradation,
151+
making it ideal for deploying large models on memory-constrained devices.
152+
"""
153+
154+
"""
155+
## AWQ vs GPTQ?
156+
157+
Both AWQ and GPTQ are weight-only quantization methods that require calibration
158+
data. Here's how to choose between them:
159+
160+
| Aspect | AWQ | GPTQ |
161+
| ------ | --- | ---- |
162+
| **Algorithm** | Grid search for activation-aware scales | Hessian-based second-order optimization |
163+
| **Quantization speed** | Faster (no Hessian computation) | Slower (requires Hessian estimation) |
164+
| **Bit-widths supported** | only 4-bit supported for now | 2/3/4/8-bit |
165+
| **Accuracy** | Competitive, especially on encoder models | Often slightly better on decoder LLMs |
166+
| **Memory during quantization** | Lower | Higher (Hessian storage) |
167+
| **Calibration sensitivity** | Less prone to overfitting | May overfit calibration set, affecting out-of-distribution performance |
168+
169+
**Choose AWQ when:**
170+
171+
* You need faster quantization (AWQ is typically 2-3x faster than GPTQ).
172+
* Memory during quantization is constrained.
173+
* 4-bit is sufficient for your use case.
174+
* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data).
175+
176+
**Choose GPTQ when:**
177+
178+
* You need bit-widths other than 4 (e.g., 2-bit or 8-bit).
179+
* Maximum accuracy is critical and you can afford longer quantization time.
180+
* You're working with decoder-only LLMs where GPTQ may have a slight edge.
181+
"""
182+
183+
"""
184+
## Practical tips
185+
186+
* AWQ is a post-training technique; training after quantization is not supported.
187+
* Always use the model's own tokenizer for calibration.
188+
* Use a representative calibration set; small slices are only for demos.
189+
* Start with W4 group_size=128; tune per model/task.
190+
* AWQ only supports 4-bit quantization currently.
191+
* For best results, use calibration data that matches your inference domain.
192+
193+
## References
194+
195+
* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978)
196+
* [MIT HAN Lab AWQ Repository](https://github.com/mit-han-lab/llm-awq)
197+
"""

guides/gptq_quantization_in_keras.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,35 @@
130130
preserved after quantization.
131131
"""
132132

133+
"""
134+
## GPTQ vs AWQ?
135+
136+
Both GPTQ and AWQ are weight-only quantization methods that require calibration
137+
data. Here's how to choose between them:
138+
139+
| Aspect | GPTQ | AWQ |
140+
| ------ | ---- | --- |
141+
| **Algorithm** | Hessian-based second-order optimization | Grid search for activation-aware scales |
142+
| **Quantization speed** | Slower (requires Hessian estimation) | Faster (no Hessian computation) |
143+
| **Bit-widths supported** | 2/3/4/8-bit | Only 4-bit supported for now |
144+
| **Accuracy** | Often slightly better on decoder LLMs | Competitive, especially on encoder models |
145+
| **Memory during quantization** | Higher (Hessian storage) | Lower |
146+
| **Calibration sensitivity** | May overfit calibration set, affecting out-of-distribution performance | Less prone to overfitting |
147+
148+
**Choose GPTQ when:**
149+
150+
* You need bit-widths other than 4 (e.g., 2-bit or 8-bit).
151+
* Maximum accuracy is critical and you can afford longer quantization time.
152+
* You're working with decoder-only LLMs where GPTQ may have a slight edge.
153+
154+
**Choose AWQ when:**
155+
156+
* You need faster quantization (AWQ is typically 2-3x faster than GPTQ).
157+
* Memory during quantization is constrained.
158+
* 4-bit is sufficient for your use case.
159+
* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data).
160+
"""
161+
133162
"""
134163
## Practical tips
135164

0 commit comments

Comments
 (0)