|
| 1 | +""" |
| 2 | +Title: AWQ Quantization in Keras |
| 3 | +Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh) |
| 4 | +Date created: 2025/01/15 |
| 5 | +Last modified: 2025/01/15 |
| 6 | +Description: How to run weight-only AWQ quantization for Keras & KerasHub models. |
| 7 | +Accelerator: GPU |
| 8 | +""" |
| 9 | + |
| 10 | +""" |
| 11 | +## What is AWQ? |
| 12 | +
|
| 13 | +AWQ (Activation-aware Weight Quantization) is a post-training, weight-only |
| 14 | +quantization method that uses activation statistics to identify and protect |
| 15 | +salient weights during quantization. |
| 16 | +
|
| 17 | +The key insight of AWQ is that not all weights are equally important: a small |
| 18 | +fraction of weights (typically <1%) are "salient" because they process |
| 19 | +channels with large activation magnitudes. By protecting these weights from |
| 20 | +quantization error, AWQ preserves model quality while achieving significant |
| 21 | +compression. |
| 22 | +
|
| 23 | +Unlike GPTQ which uses second-order (Hessian-based) optimization, AWQ uses a |
| 24 | +simpler grid search to find per-channel scales that minimize activation-weighted |
| 25 | +quantization error. This makes AWQ generally faster while achieving competitive |
| 26 | +accuracy. |
| 27 | +
|
| 28 | +### How it works |
| 29 | +
|
| 30 | +1. Run a small calibration set through the model to collect per-channel |
| 31 | + activation magnitudes. |
| 32 | +2. For each weight matrix, search for optimal per-channel scales that |
| 33 | + minimize activation-weighted quantization error. |
| 34 | +3. Multiply weights by the optimal scales before quantization |
| 35 | + (expanding salient weights). |
| 36 | +4. Quantize the scaled weights to 4-bit (or other supported bit-width) integers. |
| 37 | +5. During inference, dequantize weights and divide by scales to |
| 38 | + restore original magnitude. |
| 39 | +
|
| 40 | +The scale formula uses: `scales = activation_max^ratio` where ratio is |
| 41 | +searched over a grid from 0 to 1. |
| 42 | +
|
| 43 | +Keras supports AWQ quantization for KerasHub models via the |
| 44 | +`keras.quantizers.AWQConfig` class. |
| 45 | +""" |
| 46 | + |
| 47 | +""" |
| 48 | +## Load a KerasHub model |
| 49 | +
|
| 50 | +This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B |
| 51 | +parameter) causal language model. |
| 52 | +
|
| 53 | +""" |
| 54 | +import keras |
| 55 | +from keras_hub.models import Gemma3CausalLM |
| 56 | +from datasets import load_dataset |
| 57 | + |
| 58 | + |
| 59 | +prompt = "Keras is a" |
| 60 | + |
| 61 | +model = Gemma3CausalLM.from_preset("gemma3_1b") |
| 62 | + |
| 63 | +outputs = model.generate(prompt, max_length=30) |
| 64 | +print(outputs) |
| 65 | + |
| 66 | +""" |
| 67 | +## Configure & run AWQ quantization |
| 68 | +
|
| 69 | +You can configure AWQ quantization via the `keras.quantizers.AWQConfig` class. |
| 70 | +
|
| 71 | +The AWQ configuration requires a calibration dataset and tokenizer, which it |
| 72 | +uses to collect activation statistics and search for optimal scales. Here, we |
| 73 | +use a small slice of the WikiText-2 dataset for calibration. |
| 74 | +
|
| 75 | +Key parameters: |
| 76 | +
|
| 77 | +* `weight_bits`: The bit-width to quantize weights to. AWQ currently only |
| 78 | + supports 4-bit quantization. |
| 79 | +* `group_size`: The number of input features to quantize together. Smaller |
| 80 | + groups typically yield better accuracy but may use more memory. Use -1 for |
| 81 | + per-channel (no grouping). A good starting point is 128. |
| 82 | +* `num_grid_points`: The number of points to search over when finding optimal |
| 83 | + scales. More points give finer granularity but increase calibration time. |
| 84 | + Default is 20. |
| 85 | +* `num_samples`: Number of calibration samples to use for activation |
| 86 | + collection. |
| 87 | +* `sequence_length`: Maximum sequence length for calibration samples. |
| 88 | +
|
| 89 | +In this example, we first prepare a tiny calibration set, and then run AWQ on |
| 90 | +the model using the `.quantize(...)` API. |
| 91 | +""" |
| 92 | + |
| 93 | +# Calibration slice (use a larger/representative set in practice) |
| 94 | +texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"] |
| 95 | + |
| 96 | +calibration_dataset = [ |
| 97 | + s + "." for text in texts for s in map(str.strip, text.split(".")) if s |
| 98 | +] |
| 99 | + |
| 100 | +awq_config = keras.quantizers.AWQConfig( |
| 101 | + dataset=calibration_dataset, |
| 102 | + tokenizer=model.preprocessor.tokenizer, |
| 103 | + weight_bits=4, |
| 104 | + group_size=128, |
| 105 | + num_grid_points=20, |
| 106 | + num_samples=128, |
| 107 | + sequence_length=256, |
| 108 | +) |
| 109 | + |
| 110 | +model.quantize("awq", config=awq_config) |
| 111 | + |
| 112 | +outputs = model.generate(prompt, max_length=30) |
| 113 | +print(outputs) |
| 114 | + |
| 115 | +""" |
| 116 | +## Model Export |
| 117 | +
|
| 118 | +The AWQ quantized model can be saved to a preset and reloaded elsewhere, just |
| 119 | +like any other KerasHub model. |
| 120 | +""" |
| 121 | + |
| 122 | +model.save_to_preset("gemma3_awq_w4gs128_preset") |
| 123 | +model_from_preset = Gemma3CausalLM.from_preset("gemma3_awq_w4gs128_preset") |
| 124 | +output = model_from_preset.generate(prompt, max_length=30) |
| 125 | +print(output) |
| 126 | + |
| 127 | +""" |
| 128 | +## Performance & Benchmarking |
| 129 | +
|
| 130 | +Micro-benchmarks collected on a single RTX 4070 Ti Super (16 GB). |
| 131 | +Baselines are BF16 for Gemma3, and FP32 for Qwen3 and OPT. |
| 132 | +
|
| 133 | +Dataset: WikiText-2. |
| 134 | +
|
| 135 | +
|
| 136 | +| Model | Pre PPL | Post PPL | PPL Change | Disk Size Change | GPU Mem Change | Throughput Change | |
| 137 | +| ----- | ------: | -------: | ---------: | ---------------: | -------------: | ----------------: | |
| 138 | +| Qwen3 1.7B | 37.65 | 45.79 | +21.64% | -70.7% | -69.9% | -10.4% | |
| 139 | +| Gemma3 1B | 172.45 | 178.03 | +3.23% | -60.2% | -58.3% | -15.5% | |
| 140 | +| OPT 125M | 77.06 | 84.75 | +9.97% | -58.3% | -40.9% | -3.3% | |
| 141 | +
|
| 142 | +
|
| 143 | +### Analysis |
| 144 | +
|
| 145 | +* **Disk size reduction**: 58-71% across models due to 4-bit weight storage. |
| 146 | +* **GPU memory reduction**: 41-70% during inference. |
| 147 | +* **Perplexity degradation**: +3.2% (Gemma3 1B) to +21.6% (Qwen3 1.7B), model-dependent. |
| 148 | +* **Throughput**: -3% to -15% due to dequantization overhead. |
| 149 | +
|
| 150 | +AWQ provides substantial memory savings with modest quality degradation, |
| 151 | +making it ideal for deploying large models on memory-constrained devices. |
| 152 | +""" |
| 153 | + |
| 154 | +""" |
| 155 | +## AWQ vs GPTQ? |
| 156 | +
|
| 157 | +Both AWQ and GPTQ are weight-only quantization methods that require calibration |
| 158 | +data. Here's how to choose between them: |
| 159 | +
|
| 160 | +| Aspect | AWQ | GPTQ | |
| 161 | +| ------ | --- | ---- | |
| 162 | +| **Algorithm** | Grid search for activation-aware scales | Hessian-based second-order optimization | |
| 163 | +| **Quantization speed** | Faster (no Hessian computation) | Slower (requires Hessian estimation) | |
| 164 | +| **Bit-widths supported** | only 4-bit supported for now | 2/3/4/8-bit | |
| 165 | +| **Accuracy** | Competitive, especially on encoder models | Often slightly better on decoder LLMs | |
| 166 | +| **Memory during quantization** | Lower | Higher (Hessian storage) | |
| 167 | +| **Calibration sensitivity** | Less prone to overfitting | May overfit calibration set, affecting out-of-distribution performance | |
| 168 | +
|
| 169 | +**Choose AWQ when:** |
| 170 | +
|
| 171 | +* You need faster quantization (AWQ is typically 2-3x faster than GPTQ). |
| 172 | +* Memory during quantization is constrained. |
| 173 | +* 4-bit is sufficient for your use case. |
| 174 | +* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data). |
| 175 | +
|
| 176 | +**Choose GPTQ when:** |
| 177 | +
|
| 178 | +* You need bit-widths other than 4 (e.g., 2-bit or 8-bit). |
| 179 | +* Maximum accuracy is critical and you can afford longer quantization time. |
| 180 | +* You're working with decoder-only LLMs where GPTQ may have a slight edge. |
| 181 | +""" |
| 182 | + |
| 183 | +""" |
| 184 | +## Practical tips |
| 185 | +
|
| 186 | +* AWQ is a post-training technique; training after quantization is not supported. |
| 187 | +* Always use the model's own tokenizer for calibration. |
| 188 | +* Use a representative calibration set; small slices are only for demos. |
| 189 | +* Start with W4 group_size=128; tune per model/task. |
| 190 | +* AWQ only supports 4-bit quantization currently. |
| 191 | +* For best results, use calibration data that matches your inference domain. |
| 192 | +
|
| 193 | +## References |
| 194 | +
|
| 195 | +* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978) |
| 196 | +* [MIT HAN Lab AWQ Repository](https://github.com/mit-han-lab/llm-awq) |
| 197 | +""" |
0 commit comments