diff --git a/guides/awq_quantization_in_keras.py b/guides/awq_quantization_in_keras.py new file mode 100644 index 0000000000..0450ba227d --- /dev/null +++ b/guides/awq_quantization_in_keras.py @@ -0,0 +1,200 @@ +""" +Title: AWQ Quantization in Keras +Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh) +Date created: 2025/01/15 +Last modified: 2025/01/15 +Description: How to run weight-only AWQ quantization for Keras & KerasHub models. +Accelerator: GPU +""" + +""" +## What is AWQ? + +AWQ (Activation-aware Weight Quantization) is a post-training, weight-only +quantization method that uses activation statistics to identify and protect +salient weights during quantization. + +The key insight of AWQ is that not all weights are equally important: a small +fraction of weights (typically <1%) are "salient" because they process +channels with large activation magnitudes. By protecting these weights from +quantization error, AWQ preserves model quality while achieving significant +compression. + +Unlike GPTQ which uses second-order (Hessian-based) optimization, AWQ uses a +simpler grid search to find per-channel scales that minimize activation-weighted +quantization error. This makes AWQ generally faster while achieving competitive +accuracy. + +### How it works + +1. Run a small calibration set through the model to collect per-channel + activation magnitudes. +2. For each weight matrix, search for optimal per-channel scales that + minimize activation-weighted quantization error. +3. Multiply weights by the optimal scales before quantization + (expanding salient weights). +4. Quantize the scaled weights to 4-bit (or other supported bit-width) integers. +5. During inference, dequantize weights and divide by scales to + restore original magnitude. + +The scale formula uses: `scales = activation_max^ratio` where ratio is +searched over a grid from 0 to 1. + +Keras supports AWQ quantization for KerasHub models via the +`keras.quantizers.AWQConfig` class. +""" + +""" +## Load a KerasHub model + +This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B +parameter) causal language model. + +""" +from datasets import load_dataset +import keras +from keras_hub.models import Gemma3CausalLM + + +prompt = "Keras is a" + +model = Gemma3CausalLM.from_preset("gemma3_1b") + +outputs = model.generate(prompt, max_length=30) +print(outputs) + +""" +## Configure & run AWQ quantization + +You can configure AWQ quantization via the `keras.quantizers.AWQConfig` class. + +The AWQ configuration requires a calibration dataset and tokenizer, which it +uses to collect activation statistics and search for optimal scales. Here, we +use a small slice of the WikiText-2 dataset for calibration. + +Key parameters: + +* `weight_bits`: The bit-width to quantize weights to. AWQ currently only + supports 4-bit quantization. +* `group_size`: The number of input features to quantize together. Smaller + groups typically yield better accuracy but may use more memory. Use -1 for + per-channel (no grouping). A good starting point is 128. +* `num_grid_points`: The number of points to search over when finding optimal + scales. More points give finer granularity but increase calibration time. + Default is 20. +* `num_samples`: Number of calibration samples to use for activation + collection. +* `sequence_length`: Maximum sequence length for calibration samples. + +In this example, we first prepare a tiny calibration set, and then run AWQ on +the model using the `.quantize(...)` API. +""" + +# Calibration slice (use a larger/representative set in practice) +texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"] + +calibration_dataset = [] +for text in texts: + for s in text.split("."): + s = s.strip() + if s: + calibration_dataset.append(s + ".") + +awq_config = keras.quantizers.AWQConfig( + dataset=calibration_dataset, + tokenizer=model.preprocessor.tokenizer, + weight_bits=4, + group_size=128, + num_grid_points=20, + num_samples=128, + sequence_length=256, +) + +model.quantize("awq", config=awq_config) + +outputs = model.generate(prompt, max_length=30) +print(outputs) + +""" +## Model Export + +The AWQ quantized model can be saved to a preset and reloaded elsewhere, just +like any other KerasHub model. +""" + +model.save_to_preset("gemma3_awq_w4gs128_preset") +model_from_preset = Gemma3CausalLM.from_preset("gemma3_awq_w4gs128_preset") +output = model_from_preset.generate(prompt, max_length=30) +print(output) + +""" +## Performance & Benchmarking + +Micro-benchmarks collected on a single RTX 4070 Ti Super (16 GB). +Baselines are BF16 for Gemma3, and FP32 for Qwen3 and OPT. + +Dataset: WikiText-2. + + +| Model | Pre PPL | Post PPL | PPL Change | Disk Size Change | GPU Mem Change | Throughput Change | +| ----- | ------: | -------: | ---------: | ---------------: | -------------: | ----------------: | +| Qwen3 1.7B | 37.65 | 45.79 | +21.64% | -70.7% | -69.9% | -10.4% | +| Gemma3 1B | 172.45 | 178.03 | +3.23% | -60.2% | -58.3% | -15.5% | +| OPT 125M | 77.06 | 84.75 | +9.97% | -58.3% | -40.9% | -3.3% | + + +### Analysis + +* **Disk size reduction**: 58-71% across models due to 4-bit weight storage. +* **GPU memory reduction**: 41-70% during inference. +* **Perplexity degradation**: +3.2% (Gemma3 1B) to +21.6% (Qwen3 1.7B), model-dependent. +* **Throughput**: -3% to -15% due to dequantization overhead. + +AWQ provides substantial memory savings with modest quality degradation, +making it ideal for deploying large models on memory-constrained devices. +""" + +""" +## AWQ vs GPTQ? + +Both AWQ and GPTQ are weight-only quantization methods that require calibration +data. Here's how to choose between them: + +| Aspect | AWQ | GPTQ | +| ------ | --- | ---- | +| **Algorithm** | Grid search for activation-aware scales | Hessian-based second-order optimization | +| **Quantization speed** | Faster (no Hessian computation) | Slower (requires Hessian estimation) | +| **Bit-widths supported** | 4-bit | 2/3/4/8-bit | +| **Accuracy** | Competitive, especially on encoder models | Often slightly better on decoder LLMs | +| **Memory during quantization** | Lower | Higher (Hessian storage) | +| **Calibration sensitivity** | Less prone to overfitting | May overfit calibration set, affecting out-of-distribution performance | + +**Choose AWQ when:** + +* You need faster quantization (AWQ is typically 2-3x faster than GPTQ). +* Memory during quantization is constrained. +* 4-bit is sufficient for your use case. +* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data). + +**Choose GPTQ when:** + +* You need bit-widths other than 4 (e.g., 2-bit or 8-bit). +* Maximum accuracy is critical and you can afford longer quantization time. +* You're working with decoder-only LLMs where GPTQ may have a slight edge. +""" + +""" +## Practical tips + +* AWQ is a post-training technique; training after quantization is not supported. +* Always use the model's own tokenizer for calibration. +* Use a representative calibration set; small slices are only for demos. +* Start with W4 group_size=128; tune per model/task. +* AWQ only supports 4-bit quantization currently. +* For best results, use calibration data that matches your inference domain. + +## References + +* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978) +* [MIT HAN Lab AWQ Repository](https://github.com/mit-han-lab/llm-awq) +""" diff --git a/guides/gptq_quantization_in_keras.py b/guides/gptq_quantization_in_keras.py index 4aaf040ca6..9f44209ad6 100644 --- a/guides/gptq_quantization_in_keras.py +++ b/guides/gptq_quantization_in_keras.py @@ -31,9 +31,9 @@ parameter) causal language model. """ +from datasets import load_dataset import keras from keras_hub.models import Gemma3CausalLM -from datasets import load_dataset prompt = "Keras is a" @@ -130,6 +130,35 @@ preserved after quantization. """ +""" +## GPTQ vs AWQ? + +Both GPTQ and AWQ are weight-only quantization methods that require calibration +data. Here's how to choose between them: + +| Aspect | GPTQ | AWQ | +| ------ | ---- | --- | +| **Algorithm** | Hessian-based second-order optimization | Grid search for activation-aware scales | +| **Quantization speed** | Slower (requires Hessian estimation) | Faster (no Hessian computation) | +| **Bit-widths supported** | 2/3/4/8-bit | 4-bit | +| **Accuracy** | Often slightly better on decoder LLMs | Competitive, especially on encoder models | +| **Memory during quantization** | Higher (Hessian storage) | Lower | +| **Calibration sensitivity** | May overfit calibration set, affecting out-of-distribution performance | Less prone to overfitting | + +**Choose GPTQ when:** + +* You need bit-widths other than 4 (e.g., 2-bit or 8-bit). +* Maximum accuracy is critical and you can afford longer quantization time. +* You're working with decoder-only LLMs where GPTQ may have a slight edge. + +**Choose AWQ when:** + +* You need faster quantization (AWQ is typically 2-3x faster than GPTQ). +* Memory during quantization is constrained. +* 4-bit is sufficient for your use case. +* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data). +""" + """ ## Practical tips diff --git a/guides/ipynb/awq_quantization_in_keras.ipynb b/guides/ipynb/awq_quantization_in_keras.ipynb new file mode 100644 index 0000000000..0ef2581d02 --- /dev/null +++ b/guides/ipynb/awq_quantization_in_keras.ipynb @@ -0,0 +1,301 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "# AWQ Quantization in Keras\n", + "\n", + "**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)
\n", + "**Date created:** 2025/01/15
\n", + "**Last modified:** 2025/01/15
\n", + "**Description:** How to run weight-only AWQ quantization for Keras & KerasHub models." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## What is AWQ?\n", + "\n", + "AWQ (Activation-aware Weight Quantization) is a post-training, weight-only\n", + "quantization method that uses activation statistics to identify and protect\n", + "salient weights during quantization.\n", + "\n", + "The key insight of AWQ is that not all weights are equally important: a small\n", + "fraction of weights (typically <1%) are \"salient\" because they process\n", + "channels with large activation magnitudes. By protecting these weights from\n", + "quantization error, AWQ preserves model quality while achieving significant\n", + "compression.\n", + "\n", + "Unlike GPTQ which uses second-order (Hessian-based) optimization, AWQ uses a\n", + "simpler grid search to find per-channel scales that minimize activation-weighted\n", + "quantization error. This makes AWQ generally faster while achieving competitive\n", + "accuracy.\n", + "\n", + "### How it works\n", + "\n", + "1. Run a small calibration set through the model to collect per-channel\n", + " activation magnitudes.\n", + "2. For each weight matrix, search for optimal per-channel scales that\n", + " minimize activation-weighted quantization error.\n", + "3. Multiply weights by the optimal scales before quantization\n", + " (expanding salient weights).\n", + "4. Quantize the scaled weights to 4-bit (or other supported bit-width) integers.\n", + "5. During inference, dequantize weights and divide by scales to\n", + " restore original magnitude.\n", + "\n", + "The scale formula uses: `scales = activation_max^ratio` where ratio is\n", + "searched over a grid from 0 to 1.\n", + "\n", + "Keras supports AWQ quantization for KerasHub models via the\n", + "`keras.quantizers.AWQConfig` class." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Load a KerasHub model\n", + "\n", + "This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B\n", + "parameter) causal language model." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "import keras\n", + "from keras_hub.models import Gemma3CausalLM\n", + "\n", + "\n", + "prompt = \"Keras is a\"\n", + "\n", + "model = Gemma3CausalLM.from_preset(\"gemma3_1b\")\n", + "\n", + "outputs = model.generate(prompt, max_length=30)\n", + "print(outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Configure & run AWQ quantization\n", + "\n", + "You can configure AWQ quantization via the `keras.quantizers.AWQConfig` class.\n", + "\n", + "The AWQ configuration requires a calibration dataset and tokenizer, which it\n", + "uses to collect activation statistics and search for optimal scales. Here, we\n", + "use a small slice of the WikiText-2 dataset for calibration.\n", + "\n", + "Key parameters:\n", + "\n", + "* `weight_bits`: The bit-width to quantize weights to. AWQ currently only\n", + " supports 4-bit quantization.\n", + "* `group_size`: The number of input features to quantize together. Smaller\n", + " groups typically yield better accuracy but may use more memory. Use -1 for\n", + " per-channel (no grouping). A good starting point is 128.\n", + "* `num_grid_points`: The number of points to search over when finding optimal\n", + " scales. More points give finer granularity but increase calibration time.\n", + " Default is 20.\n", + "* `num_samples`: Number of calibration samples to use for activation\n", + " collection.\n", + "* `sequence_length`: Maximum sequence length for calibration samples.\n", + "\n", + "In this example, we first prepare a tiny calibration set, and then run AWQ on\n", + "the model using the `.quantize(...)` API." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "# Calibration slice (use a larger/representative set in practice)\n", + "texts = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"train[:1%]\")[\"text\"]\n", + "\n", + "calibration_dataset = []\n", + "for text in texts:\n", + " for s in text.split(\".\"):\n", + " s = s.strip()\n", + " if s:\n", + " calibration_dataset.append(s + \".\")\n", + "\n", + "awq_config = keras.quantizers.AWQConfig(\n", + " dataset=calibration_dataset,\n", + " tokenizer=model.preprocessor.tokenizer,\n", + " weight_bits=4,\n", + " group_size=128,\n", + " num_grid_points=20,\n", + " num_samples=128,\n", + " sequence_length=256,\n", + ")\n", + "\n", + "model.quantize(\"awq\", config=awq_config)\n", + "\n", + "outputs = model.generate(prompt, max_length=30)\n", + "print(outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Model Export\n", + "\n", + "The AWQ quantized model can be saved to a preset and reloaded elsewhere, just\n", + "like any other KerasHub model." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "model.save_to_preset(\"gemma3_awq_w4gs128_preset\")\n", + "model_from_preset = Gemma3CausalLM.from_preset(\"gemma3_awq_w4gs128_preset\")\n", + "output = model_from_preset.generate(prompt, max_length=30)\n", + "print(output)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Performance & Benchmarking\n", + "\n", + "Micro-benchmarks collected on a single RTX 4070 Ti Super (16 GB).\n", + "Baselines are BF16 for Gemma3, and FP32 for Qwen3 and OPT.\n", + "\n", + "Dataset: WikiText-2.\n", + "\n", + "\n", + "| Model | Pre PPL | Post PPL | PPL Change | Disk Size Change | GPU Mem Change | Throughput Change |\n", + "| ----- | ------: | -------: | ---------: | ---------------: | -------------: | ----------------: |\n", + "| Qwen3 1.7B | 37.65 | 45.79 | +21.64% | -70.7% | -69.9% | -10.4% |\n", + "| Gemma3 1B | 172.45 | 178.03 | +3.23% | -60.2% | -58.3% | -15.5% |\n", + "| OPT 125M | 77.06 | 84.75 | +9.97% | -58.3% | -40.9% | -3.3% |\n", + "\n", + "\n", + "### Analysis\n", + "\n", + "* **Disk size reduction**: 58-71% across models due to 4-bit weight storage.\n", + "* **GPU memory reduction**: 41-70% during inference.\n", + "* **Perplexity degradation**: +3.2% (Gemma3 1B) to +21.6% (Qwen3 1.7B), model-dependent.\n", + "* **Throughput**: -3% to -15% due to dequantization overhead.\n", + "\n", + "AWQ provides substantial memory savings with modest quality degradation,\n", + "making it ideal for deploying large models on memory-constrained devices." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## AWQ vs GPTQ?\n", + "\n", + "Both AWQ and GPTQ are weight-only quantization methods that require calibration\n", + "data. Here's how to choose between them:\n", + "\n", + "| Aspect | AWQ | GPTQ |\n", + "| ------ | --- | ---- |\n", + "| **Algorithm** | Grid search for activation-aware scales | Hessian-based second-order optimization |\n", + "| **Quantization speed** | Faster (no Hessian computation) | Slower (requires Hessian estimation) |\n", + "| **Bit-widths supported** | 4-bit | 2/3/4/8-bit |\n", + "| **Accuracy** | Competitive, especially on encoder models | Often slightly better on decoder LLMs |\n", + "| **Memory during quantization** | Lower | Higher (Hessian storage) |\n", + "| **Calibration sensitivity** | Less prone to overfitting | May overfit calibration set, affecting out-of-distribution performance |\n", + "\n", + "**Choose AWQ when:**\n", + "\n", + "* You need faster quantization (AWQ is typically 2-3x faster than GPTQ).\n", + "* Memory during quantization is constrained.\n", + "* 4-bit is sufficient for your use case.\n", + "* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data).\n", + "\n", + "**Choose GPTQ when:**\n", + "\n", + "* You need bit-widths other than 4 (e.g., 2-bit or 8-bit).\n", + "* Maximum accuracy is critical and you can afford longer quantization time.\n", + "* You're working with decoder-only LLMs where GPTQ may have a slight edge." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Practical tips\n", + "\n", + "* AWQ is a post-training technique; training after quantization is not supported.\n", + "* Always use the model's own tokenizer for calibration.\n", + "* Use a representative calibration set; small slices are only for demos.\n", + "* Start with W4 group_size=128; tune per model/task.\n", + "* AWQ only supports 4-bit quantization currently.\n", + "* For best results, use calibration data that matches your inference domain.\n", + "\n", + "## References\n", + "\n", + "* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978)\n", + "* [MIT HAN Lab AWQ Repository](https://github.com/mit-han-lab/llm-awq)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "awq_quantization_in_keras", + "private_outputs": false, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/guides/ipynb/gptq_quantization_in_keras.ipynb b/guides/ipynb/gptq_quantization_in_keras.ipynb index bcbac78e29..ad3ff407e4 100644 --- a/guides/ipynb/gptq_quantization_in_keras.ipynb +++ b/guides/ipynb/gptq_quantization_in_keras.ipynb @@ -167,12 +167,12 @@ "Dataset: WikiText-2.\n", "\n", "\n", - "| Model (preset) | Perplexity Increase % (↓ better) | Disk Storage Reduction Δ % (↓ better) | VRAM Reduction Δ % (↓ better) | First-token Latency Δ % (↓ better) | Throughput Δ % (↑ better) |\n", - "| ------------------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |\n", - "| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% ↓ | -41.1% ↓ | +0.7% ↑ | +20.1% ↑ |\n", - "| OPT (opt_125m_en) | 10.0% | -49.8% ↓ | -47.0% ↓ | +6.7% ↑ | -15.7% ↓ |\n", - "| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% ↓ | -54.0% ↓ | +1.8% ↑ | -15.7% ↓ |\n", - "| Gemma3 (gemma3_1b) | 3.0% | -51.5% ↓ | -51.8% ↓ | +39.5% ↑ | +5.7% ↑ |\n", + "| Model (preset) | Perplexity Increase % (\u2193 better) | Disk Storage Reduction \u0394 % (\u2193 better) | VRAM Reduction \u0394 % (\u2193 better) | First-token Latency \u0394 % (\u2193 better) | Throughput \u0394 % (\u2191 better) |\n", + "| --------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |\n", + "| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% \u2193 | -41.1% \u2193 | +0.7% \u2191 | +20.1% \u2191 |\n", + "| OPT (opt_125m_en) | 10.0% | -49.8% \u2193 | -47.0% \u2193 | +6.7% \u2191 | -15.7% \u2193 |\n", + "| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% \u2193 | -54.0% \u2193 | +1.8% \u2191 | -15.7% \u2193 |\n", + "| Gemma3 (gemma3_1b) | 3.0% | -51.5% \u2193 | -51.8% \u2193 | +39.5% \u2191 | +5.7% \u2191 |\n", "\n", "\n", "Detailed benchmarking numbers and scripts are available\n", @@ -189,6 +189,40 @@ "preserved after quantization." ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## GPTQ vs AWQ?\n", + "\n", + "Both GPTQ and AWQ are weight-only quantization methods that require calibration\n", + "data. Here's how to choose between them:\n", + "\n", + "| Aspect | GPTQ | AWQ |\n", + "| ------ | ---- | --- |\n", + "| **Algorithm** | Hessian-based second-order optimization | Grid search for activation-aware scales |\n", + "| **Quantization speed** | Slower (requires Hessian estimation) | Faster (no Hessian computation) |\n", + "| **Bit-widths supported** | 2/3/4/8-bit | 4-bit |\n", + "| **Accuracy** | Often slightly better on decoder LLMs | Competitive, especially on encoder models |\n", + "| **Memory during quantization** | Higher (Hessian storage) | Lower |\n", + "| **Calibration sensitivity** | May overfit calibration set, affecting out-of-distribution performance | Less prone to overfitting |\n", + "\n", + "**Choose GPTQ when:**\n", + "\n", + "* You need bit-widths other than 4 (e.g., 2-bit or 8-bit).\n", + "* Maximum accuracy is critical and you can afford longer quantization time.\n", + "* You're working with decoder-only LLMs where GPTQ may have a slight edge.\n", + "\n", + "**Choose AWQ when:**\n", + "\n", + "* You need faster quantization (AWQ is typically 2-3x faster than GPTQ).\n", + "* Memory during quantization is constrained.\n", + "* 4-bit is sufficient for your use case.\n", + "* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data)." + ] + }, { "cell_type": "markdown", "metadata": { @@ -202,11 +236,6 @@ "* Use a representative calibration set; small slices are only for demos.\n", "* Start with W4 group_size=128; tune per model/task." ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": { @@ -238,4 +267,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/guides/md/awq_quantization_in_keras.md b/guides/md/awq_quantization_in_keras.md new file mode 100644 index 0000000000..f7d165dfb5 --- /dev/null +++ b/guides/md/awq_quantization_in_keras.md @@ -0,0 +1,230 @@ +# AWQ Quantization in Keras + +**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)
+**Date created:** 2025/01/15
+**Last modified:** 2025/01/15
+**Description:** How to run weight-only AWQ quantization for Keras & KerasHub models. + + + [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/awq_quantization_in_keras.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/awq_quantization_in_keras.py) + + + +--- +## What is AWQ? + +AWQ (Activation-aware Weight Quantization) is a post-training, weight-only +quantization method that uses activation statistics to identify and protect +salient weights during quantization. + +The key insight of AWQ is that not all weights are equally important: a small +fraction of weights (typically <1%) are "salient" because they process +channels with large activation magnitudes. By protecting these weights from +quantization error, AWQ preserves model quality while achieving significant +compression. + +Unlike GPTQ which uses second-order (Hessian-based) optimization, AWQ uses a +simpler grid search to find per-channel scales that minimize activation-weighted +quantization error. This makes AWQ generally faster while achieving competitive +accuracy. + +### How it works + +1. Run a small calibration set through the model to collect per-channel + activation magnitudes. +2. For each weight matrix, search for optimal per-channel scales that + minimize activation-weighted quantization error. +3. Multiply weights by the optimal scales before quantization + (expanding salient weights). +4. Quantize the scaled weights to 4-bit (or other supported bit-width) integers. +5. During inference, dequantize weights and divide by scales to + restore original magnitude. + +The scale formula uses: `scales = activation_max^ratio` where ratio is +searched over a grid from 0 to 1. + +Keras supports AWQ quantization for KerasHub models via the +`keras.quantizers.AWQConfig` class. + +--- +## Load a KerasHub model + +This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B +parameter) causal language model. + + +```python +from datasets import load_dataset +import keras +from keras_hub.models import Gemma3CausalLM + + +prompt = "Keras is a" + +model = Gemma3CausalLM.from_preset("gemma3_1b") + +outputs = model.generate(prompt, max_length=30) +print(outputs) +``` + +
+``` +Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning +``` +
+ +--- +## Configure & run AWQ quantization + +You can configure AWQ quantization via the `keras.quantizers.AWQConfig` class. + +The AWQ configuration requires a calibration dataset and tokenizer, which it +uses to collect activation statistics and search for optimal scales. Here, we +use a small slice of the WikiText-2 dataset for calibration. + +Key parameters: + +* `weight_bits`: The bit-width to quantize weights to. AWQ currently only + supports 4-bit quantization. +* `group_size`: The number of input features to quantize together. Smaller + groups typically yield better accuracy but may use more memory. Use -1 for + per-channel (no grouping). A good starting point is 128. +* `num_grid_points`: The number of points to search over when finding optimal + scales. More points give finer granularity but increase calibration time. + Default is 20. +* `num_samples`: Number of calibration samples to use for activation + collection. +* `sequence_length`: Maximum sequence length for calibration samples. + +In this example, we first prepare a tiny calibration set, and then run AWQ on +the model using the `.quantize(...)` API. + + +```python +# Calibration slice (use a larger/representative set in practice) +texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"] + +calibration_dataset = [] +for text in texts: + for s in text.split("."): + s = s.strip() + if s: + calibration_dataset.append(s + ".") + +awq_config = keras.quantizers.AWQConfig( + dataset=calibration_dataset, + tokenizer=model.preprocessor.tokenizer, + weight_bits=4, + group_size=128, + num_grid_points=20, + num_samples=128, + sequence_length=256, +) + +model.quantize("awq", config=awq_config) + +outputs = model.generate(prompt, max_length=30) +print(outputs) +``` + +
+``` +26/26 ━━━━━━━━━━━━━━━━━━━━ 239s 9s/step + +Keras is a Python library for deep learning. It is a high-level interface to the TensorFlow library. + +Keras is a great library +``` +
+ +--- +## Model Export + +The AWQ quantized model can be saved to a preset and reloaded elsewhere, just +like any other KerasHub model. + + +```python +model.save_to_preset("gemma3_awq_w4gs128_preset") +model_from_preset = Gemma3CausalLM.from_preset("gemma3_awq_w4gs128_preset") +output = model_from_preset.generate(prompt, max_length=30) +print(output) +``` + +
+``` +Keras is a Python library for deep learning. It is a high-level interface to the TensorFlow library. + +Keras is a great library +``` +
+ +--- +## Performance & Benchmarking + +Micro-benchmarks collected on a single RTX 4070 Ti Super (16 GB). +Baselines are BF16 for Gemma3, and FP32 for Qwen3 and OPT. + +Dataset: WikiText-2. + + +| Model | Pre PPL | Post PPL | PPL Change | Disk Size Change | GPU Mem Change | Throughput Change | +| ----- | ------: | -------: | ---------: | ---------------: | -------------: | ----------------: | +| Qwen3 1.7B | 37.65 | 45.79 | +21.64% | -70.7% | -69.9% | -10.4% | +| Gemma3 1B | 172.45 | 178.03 | +3.23% | -60.2% | -58.3% | -15.5% | +| OPT 125M | 77.06 | 84.75 | +9.97% | -58.3% | -40.9% | -3.3% | + + +### Analysis + +* **Disk size reduction**: 58-71% across models due to 4-bit weight storage. +* **GPU memory reduction**: 41-70% during inference. +* **Perplexity degradation**: +3.2% (Gemma3 1B) to +21.6% (Qwen3 1.7B), model-dependent. +* **Throughput**: -3% to -15% due to dequantization overhead. + +AWQ provides substantial memory savings with modest quality degradation, +making it ideal for deploying large models on memory-constrained devices. + +--- +## AWQ vs GPTQ? + +Both AWQ and GPTQ are weight-only quantization methods that require calibration +data. Here's how to choose between them: + +| Aspect | AWQ | GPTQ | +| ------ | --- | ---- | +| **Algorithm** | Grid search for activation-aware scales | Hessian-based second-order optimization | +| **Quantization speed** | Faster (no Hessian computation) | Slower (requires Hessian estimation) | +| **Bit-widths supported** | 4-bit | 2/3/4/8-bit | +| **Accuracy** | Competitive, especially on encoder models | Often slightly better on decoder LLMs | +| **Memory during quantization** | Lower | Higher (Hessian storage) | +| **Calibration sensitivity** | Less prone to overfitting | May overfit calibration set, affecting out-of-distribution performance | + +**Choose AWQ when:** + +* You need faster quantization (AWQ is typically 2-3x faster than GPTQ). +* Memory during quantization is constrained. +* 4-bit is sufficient for your use case. +* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data). + +**Choose GPTQ when:** + +* You need bit-widths other than 4 (e.g., 2-bit or 8-bit). +* Maximum accuracy is critical and you can afford longer quantization time. +* You're working with decoder-only LLMs where GPTQ may have a slight edge. + +--- +## Practical tips + +* AWQ is a post-training technique; training after quantization is not supported. +* Always use the model's own tokenizer for calibration. +* Use a representative calibration set; small slices are only for demos. +* Start with W4 group_size=128; tune per model/task. +* AWQ only supports 4-bit quantization currently. +* For best results, use calibration data that matches your inference domain. + +--- +## References + +* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978) +* [MIT HAN Lab AWQ Repository](https://github.com/mit-han-lab/llm-awq) diff --git a/guides/md/gptq_quantization_in_keras.md b/guides/md/gptq_quantization_in_keras.md index 76a7a4deda..6e1d81b01e 100644 --- a/guides/md/gptq_quantization_in_keras.md +++ b/guides/md/gptq_quantization_in_keras.md @@ -34,9 +34,9 @@ parameter) causal language model. ```python +from datasets import load_dataset import keras from keras_hub.models import Gemma3CausalLM -from datasets import load_dataset prompt = "Keras is a" @@ -119,7 +119,6 @@ model.save_to_preset("gemma3_gptq_w4gs128_preset") model_from_preset = Gemma3CausalLM.from_preset("gemma3_gptq_w4gs128_preset") output = model_from_preset.generate(prompt, max_length=30) print(output) - ```
@@ -160,6 +159,34 @@ also include non-weight assets. Perplexity increases only marginally, indicating model quality is largely preserved after quantization. +--- +## GPTQ vs AWQ? + +Both GPTQ and AWQ are weight-only quantization methods that require calibration +data. Here's how to choose between them: + +| Aspect | GPTQ | AWQ | +| ------ | ---- | --- | +| **Algorithm** | Hessian-based second-order optimization | Grid search for activation-aware scales | +| **Quantization speed** | Slower (requires Hessian estimation) | Faster (no Hessian computation) | +| **Bit-widths supported** | 2/3/4/8-bit | 4-bit | +| **Accuracy** | Often slightly better on decoder LLMs | Competitive, especially on encoder models | +| **Memory during quantization** | Higher (Hessian storage) | Lower | +| **Calibration sensitivity** | May overfit calibration set, affecting out-of-distribution performance | Less prone to overfitting | + +**Choose GPTQ when:** + +* You need bit-widths other than 4 (e.g., 2-bit or 8-bit). +* Maximum accuracy is critical and you can afford longer quantization time. +* You're working with decoder-only LLMs where GPTQ may have a slight edge. + +**Choose AWQ when:** + +* You need faster quantization (AWQ is typically 2-3x faster than GPTQ). +* Memory during quantization is constrained. +* 4-bit is sufficient for your use case. +* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data). + --- ## Practical tips diff --git a/scripts/guides_master.py b/scripts/guides_master.py index 6f9676e214..7fa461781b 100644 --- a/scripts/guides_master.py +++ b/scripts/guides_master.py @@ -139,6 +139,10 @@ "path": "gptq_quantization_in_keras", "title": "GPTQ quantization in Keras", }, + { + "path": "awq_quantization_in_keras", + "title": "AWQ quantization in Keras", + }, { "path": "writing_quantization_compatible_layers", "title": "Writing quantization-compatible layers in Keras",