diff --git a/guides/awq_quantization_in_keras.py b/guides/awq_quantization_in_keras.py
new file mode 100644
index 0000000000..0450ba227d
--- /dev/null
+++ b/guides/awq_quantization_in_keras.py
@@ -0,0 +1,200 @@
+"""
+Title: AWQ Quantization in Keras
+Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
+Date created: 2025/01/15
+Last modified: 2025/01/15
+Description: How to run weight-only AWQ quantization for Keras & KerasHub models.
+Accelerator: GPU
+"""
+
+"""
+## What is AWQ?
+
+AWQ (Activation-aware Weight Quantization) is a post-training, weight-only
+quantization method that uses activation statistics to identify and protect
+salient weights during quantization.
+
+The key insight of AWQ is that not all weights are equally important: a small
+fraction of weights (typically <1%) are "salient" because they process
+channels with large activation magnitudes. By protecting these weights from
+quantization error, AWQ preserves model quality while achieving significant
+compression.
+
+Unlike GPTQ which uses second-order (Hessian-based) optimization, AWQ uses a
+simpler grid search to find per-channel scales that minimize activation-weighted
+quantization error. This makes AWQ generally faster while achieving competitive
+accuracy.
+
+### How it works
+
+1. Run a small calibration set through the model to collect per-channel
+ activation magnitudes.
+2. For each weight matrix, search for optimal per-channel scales that
+ minimize activation-weighted quantization error.
+3. Multiply weights by the optimal scales before quantization
+ (expanding salient weights).
+4. Quantize the scaled weights to 4-bit (or other supported bit-width) integers.
+5. During inference, dequantize weights and divide by scales to
+ restore original magnitude.
+
+The scale formula uses: `scales = activation_max^ratio` where ratio is
+searched over a grid from 0 to 1.
+
+Keras supports AWQ quantization for KerasHub models via the
+`keras.quantizers.AWQConfig` class.
+"""
+
+"""
+## Load a KerasHub model
+
+This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B
+parameter) causal language model.
+
+"""
+from datasets import load_dataset
+import keras
+from keras_hub.models import Gemma3CausalLM
+
+
+prompt = "Keras is a"
+
+model = Gemma3CausalLM.from_preset("gemma3_1b")
+
+outputs = model.generate(prompt, max_length=30)
+print(outputs)
+
+"""
+## Configure & run AWQ quantization
+
+You can configure AWQ quantization via the `keras.quantizers.AWQConfig` class.
+
+The AWQ configuration requires a calibration dataset and tokenizer, which it
+uses to collect activation statistics and search for optimal scales. Here, we
+use a small slice of the WikiText-2 dataset for calibration.
+
+Key parameters:
+
+* `weight_bits`: The bit-width to quantize weights to. AWQ currently only
+ supports 4-bit quantization.
+* `group_size`: The number of input features to quantize together. Smaller
+ groups typically yield better accuracy but may use more memory. Use -1 for
+ per-channel (no grouping). A good starting point is 128.
+* `num_grid_points`: The number of points to search over when finding optimal
+ scales. More points give finer granularity but increase calibration time.
+ Default is 20.
+* `num_samples`: Number of calibration samples to use for activation
+ collection.
+* `sequence_length`: Maximum sequence length for calibration samples.
+
+In this example, we first prepare a tiny calibration set, and then run AWQ on
+the model using the `.quantize(...)` API.
+"""
+
+# Calibration slice (use a larger/representative set in practice)
+texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"]
+
+calibration_dataset = []
+for text in texts:
+ for s in text.split("."):
+ s = s.strip()
+ if s:
+ calibration_dataset.append(s + ".")
+
+awq_config = keras.quantizers.AWQConfig(
+ dataset=calibration_dataset,
+ tokenizer=model.preprocessor.tokenizer,
+ weight_bits=4,
+ group_size=128,
+ num_grid_points=20,
+ num_samples=128,
+ sequence_length=256,
+)
+
+model.quantize("awq", config=awq_config)
+
+outputs = model.generate(prompt, max_length=30)
+print(outputs)
+
+"""
+## Model Export
+
+The AWQ quantized model can be saved to a preset and reloaded elsewhere, just
+like any other KerasHub model.
+"""
+
+model.save_to_preset("gemma3_awq_w4gs128_preset")
+model_from_preset = Gemma3CausalLM.from_preset("gemma3_awq_w4gs128_preset")
+output = model_from_preset.generate(prompt, max_length=30)
+print(output)
+
+"""
+## Performance & Benchmarking
+
+Micro-benchmarks collected on a single RTX 4070 Ti Super (16 GB).
+Baselines are BF16 for Gemma3, and FP32 for Qwen3 and OPT.
+
+Dataset: WikiText-2.
+
+
+| Model | Pre PPL | Post PPL | PPL Change | Disk Size Change | GPU Mem Change | Throughput Change |
+| ----- | ------: | -------: | ---------: | ---------------: | -------------: | ----------------: |
+| Qwen3 1.7B | 37.65 | 45.79 | +21.64% | -70.7% | -69.9% | -10.4% |
+| Gemma3 1B | 172.45 | 178.03 | +3.23% | -60.2% | -58.3% | -15.5% |
+| OPT 125M | 77.06 | 84.75 | +9.97% | -58.3% | -40.9% | -3.3% |
+
+
+### Analysis
+
+* **Disk size reduction**: 58-71% across models due to 4-bit weight storage.
+* **GPU memory reduction**: 41-70% during inference.
+* **Perplexity degradation**: +3.2% (Gemma3 1B) to +21.6% (Qwen3 1.7B), model-dependent.
+* **Throughput**: -3% to -15% due to dequantization overhead.
+
+AWQ provides substantial memory savings with modest quality degradation,
+making it ideal for deploying large models on memory-constrained devices.
+"""
+
+"""
+## AWQ vs GPTQ?
+
+Both AWQ and GPTQ are weight-only quantization methods that require calibration
+data. Here's how to choose between them:
+
+| Aspect | AWQ | GPTQ |
+| ------ | --- | ---- |
+| **Algorithm** | Grid search for activation-aware scales | Hessian-based second-order optimization |
+| **Quantization speed** | Faster (no Hessian computation) | Slower (requires Hessian estimation) |
+| **Bit-widths supported** | 4-bit | 2/3/4/8-bit |
+| **Accuracy** | Competitive, especially on encoder models | Often slightly better on decoder LLMs |
+| **Memory during quantization** | Lower | Higher (Hessian storage) |
+| **Calibration sensitivity** | Less prone to overfitting | May overfit calibration set, affecting out-of-distribution performance |
+
+**Choose AWQ when:**
+
+* You need faster quantization (AWQ is typically 2-3x faster than GPTQ).
+* Memory during quantization is constrained.
+* 4-bit is sufficient for your use case.
+* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data).
+
+**Choose GPTQ when:**
+
+* You need bit-widths other than 4 (e.g., 2-bit or 8-bit).
+* Maximum accuracy is critical and you can afford longer quantization time.
+* You're working with decoder-only LLMs where GPTQ may have a slight edge.
+"""
+
+"""
+## Practical tips
+
+* AWQ is a post-training technique; training after quantization is not supported.
+* Always use the model's own tokenizer for calibration.
+* Use a representative calibration set; small slices are only for demos.
+* Start with W4 group_size=128; tune per model/task.
+* AWQ only supports 4-bit quantization currently.
+* For best results, use calibration data that matches your inference domain.
+
+## References
+
+* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978)
+* [MIT HAN Lab AWQ Repository](https://github.com/mit-han-lab/llm-awq)
+"""
diff --git a/guides/gptq_quantization_in_keras.py b/guides/gptq_quantization_in_keras.py
index 4aaf040ca6..9f44209ad6 100644
--- a/guides/gptq_quantization_in_keras.py
+++ b/guides/gptq_quantization_in_keras.py
@@ -31,9 +31,9 @@
parameter) causal language model.
"""
+from datasets import load_dataset
import keras
from keras_hub.models import Gemma3CausalLM
-from datasets import load_dataset
prompt = "Keras is a"
@@ -130,6 +130,35 @@
preserved after quantization.
"""
+"""
+## GPTQ vs AWQ?
+
+Both GPTQ and AWQ are weight-only quantization methods that require calibration
+data. Here's how to choose between them:
+
+| Aspect | GPTQ | AWQ |
+| ------ | ---- | --- |
+| **Algorithm** | Hessian-based second-order optimization | Grid search for activation-aware scales |
+| **Quantization speed** | Slower (requires Hessian estimation) | Faster (no Hessian computation) |
+| **Bit-widths supported** | 2/3/4/8-bit | 4-bit |
+| **Accuracy** | Often slightly better on decoder LLMs | Competitive, especially on encoder models |
+| **Memory during quantization** | Higher (Hessian storage) | Lower |
+| **Calibration sensitivity** | May overfit calibration set, affecting out-of-distribution performance | Less prone to overfitting |
+
+**Choose GPTQ when:**
+
+* You need bit-widths other than 4 (e.g., 2-bit or 8-bit).
+* Maximum accuracy is critical and you can afford longer quantization time.
+* You're working with decoder-only LLMs where GPTQ may have a slight edge.
+
+**Choose AWQ when:**
+
+* You need faster quantization (AWQ is typically 2-3x faster than GPTQ).
+* Memory during quantization is constrained.
+* 4-bit is sufficient for your use case.
+* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data).
+"""
+
"""
## Practical tips
diff --git a/guides/ipynb/awq_quantization_in_keras.ipynb b/guides/ipynb/awq_quantization_in_keras.ipynb
new file mode 100644
index 0000000000..0ef2581d02
--- /dev/null
+++ b/guides/ipynb/awq_quantization_in_keras.ipynb
@@ -0,0 +1,301 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# AWQ Quantization in Keras\n",
+ "\n",
+ "**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)
\n",
+ "**Date created:** 2025/01/15
\n",
+ "**Last modified:** 2025/01/15
\n",
+ "**Description:** How to run weight-only AWQ quantization for Keras & KerasHub models."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## What is AWQ?\n",
+ "\n",
+ "AWQ (Activation-aware Weight Quantization) is a post-training, weight-only\n",
+ "quantization method that uses activation statistics to identify and protect\n",
+ "salient weights during quantization.\n",
+ "\n",
+ "The key insight of AWQ is that not all weights are equally important: a small\n",
+ "fraction of weights (typically <1%) are \"salient\" because they process\n",
+ "channels with large activation magnitudes. By protecting these weights from\n",
+ "quantization error, AWQ preserves model quality while achieving significant\n",
+ "compression.\n",
+ "\n",
+ "Unlike GPTQ which uses second-order (Hessian-based) optimization, AWQ uses a\n",
+ "simpler grid search to find per-channel scales that minimize activation-weighted\n",
+ "quantization error. This makes AWQ generally faster while achieving competitive\n",
+ "accuracy.\n",
+ "\n",
+ "### How it works\n",
+ "\n",
+ "1. Run a small calibration set through the model to collect per-channel\n",
+ " activation magnitudes.\n",
+ "2. For each weight matrix, search for optimal per-channel scales that\n",
+ " minimize activation-weighted quantization error.\n",
+ "3. Multiply weights by the optimal scales before quantization\n",
+ " (expanding salient weights).\n",
+ "4. Quantize the scaled weights to 4-bit (or other supported bit-width) integers.\n",
+ "5. During inference, dequantize weights and divide by scales to\n",
+ " restore original magnitude.\n",
+ "\n",
+ "The scale formula uses: `scales = activation_max^ratio` where ratio is\n",
+ "searched over a grid from 0 to 1.\n",
+ "\n",
+ "Keras supports AWQ quantization for KerasHub models via the\n",
+ "`keras.quantizers.AWQConfig` class."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Load a KerasHub model\n",
+ "\n",
+ "This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B\n",
+ "parameter) causal language model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "from datasets import load_dataset\n",
+ "import keras\n",
+ "from keras_hub.models import Gemma3CausalLM\n",
+ "\n",
+ "\n",
+ "prompt = \"Keras is a\"\n",
+ "\n",
+ "model = Gemma3CausalLM.from_preset(\"gemma3_1b\")\n",
+ "\n",
+ "outputs = model.generate(prompt, max_length=30)\n",
+ "print(outputs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Configure & run AWQ quantization\n",
+ "\n",
+ "You can configure AWQ quantization via the `keras.quantizers.AWQConfig` class.\n",
+ "\n",
+ "The AWQ configuration requires a calibration dataset and tokenizer, which it\n",
+ "uses to collect activation statistics and search for optimal scales. Here, we\n",
+ "use a small slice of the WikiText-2 dataset for calibration.\n",
+ "\n",
+ "Key parameters:\n",
+ "\n",
+ "* `weight_bits`: The bit-width to quantize weights to. AWQ currently only\n",
+ " supports 4-bit quantization.\n",
+ "* `group_size`: The number of input features to quantize together. Smaller\n",
+ " groups typically yield better accuracy but may use more memory. Use -1 for\n",
+ " per-channel (no grouping). A good starting point is 128.\n",
+ "* `num_grid_points`: The number of points to search over when finding optimal\n",
+ " scales. More points give finer granularity but increase calibration time.\n",
+ " Default is 20.\n",
+ "* `num_samples`: Number of calibration samples to use for activation\n",
+ " collection.\n",
+ "* `sequence_length`: Maximum sequence length for calibration samples.\n",
+ "\n",
+ "In this example, we first prepare a tiny calibration set, and then run AWQ on\n",
+ "the model using the `.quantize(...)` API."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "# Calibration slice (use a larger/representative set in practice)\n",
+ "texts = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"train[:1%]\")[\"text\"]\n",
+ "\n",
+ "calibration_dataset = []\n",
+ "for text in texts:\n",
+ " for s in text.split(\".\"):\n",
+ " s = s.strip()\n",
+ " if s:\n",
+ " calibration_dataset.append(s + \".\")\n",
+ "\n",
+ "awq_config = keras.quantizers.AWQConfig(\n",
+ " dataset=calibration_dataset,\n",
+ " tokenizer=model.preprocessor.tokenizer,\n",
+ " weight_bits=4,\n",
+ " group_size=128,\n",
+ " num_grid_points=20,\n",
+ " num_samples=128,\n",
+ " sequence_length=256,\n",
+ ")\n",
+ "\n",
+ "model.quantize(\"awq\", config=awq_config)\n",
+ "\n",
+ "outputs = model.generate(prompt, max_length=30)\n",
+ "print(outputs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Model Export\n",
+ "\n",
+ "The AWQ quantized model can be saved to a preset and reloaded elsewhere, just\n",
+ "like any other KerasHub model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "model.save_to_preset(\"gemma3_awq_w4gs128_preset\")\n",
+ "model_from_preset = Gemma3CausalLM.from_preset(\"gemma3_awq_w4gs128_preset\")\n",
+ "output = model_from_preset.generate(prompt, max_length=30)\n",
+ "print(output)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Performance & Benchmarking\n",
+ "\n",
+ "Micro-benchmarks collected on a single RTX 4070 Ti Super (16 GB).\n",
+ "Baselines are BF16 for Gemma3, and FP32 for Qwen3 and OPT.\n",
+ "\n",
+ "Dataset: WikiText-2.\n",
+ "\n",
+ "\n",
+ "| Model | Pre PPL | Post PPL | PPL Change | Disk Size Change | GPU Mem Change | Throughput Change |\n",
+ "| ----- | ------: | -------: | ---------: | ---------------: | -------------: | ----------------: |\n",
+ "| Qwen3 1.7B | 37.65 | 45.79 | +21.64% | -70.7% | -69.9% | -10.4% |\n",
+ "| Gemma3 1B | 172.45 | 178.03 | +3.23% | -60.2% | -58.3% | -15.5% |\n",
+ "| OPT 125M | 77.06 | 84.75 | +9.97% | -58.3% | -40.9% | -3.3% |\n",
+ "\n",
+ "\n",
+ "### Analysis\n",
+ "\n",
+ "* **Disk size reduction**: 58-71% across models due to 4-bit weight storage.\n",
+ "* **GPU memory reduction**: 41-70% during inference.\n",
+ "* **Perplexity degradation**: +3.2% (Gemma3 1B) to +21.6% (Qwen3 1.7B), model-dependent.\n",
+ "* **Throughput**: -3% to -15% due to dequantization overhead.\n",
+ "\n",
+ "AWQ provides substantial memory savings with modest quality degradation,\n",
+ "making it ideal for deploying large models on memory-constrained devices."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## AWQ vs GPTQ?\n",
+ "\n",
+ "Both AWQ and GPTQ are weight-only quantization methods that require calibration\n",
+ "data. Here's how to choose between them:\n",
+ "\n",
+ "| Aspect | AWQ | GPTQ |\n",
+ "| ------ | --- | ---- |\n",
+ "| **Algorithm** | Grid search for activation-aware scales | Hessian-based second-order optimization |\n",
+ "| **Quantization speed** | Faster (no Hessian computation) | Slower (requires Hessian estimation) |\n",
+ "| **Bit-widths supported** | 4-bit | 2/3/4/8-bit |\n",
+ "| **Accuracy** | Competitive, especially on encoder models | Often slightly better on decoder LLMs |\n",
+ "| **Memory during quantization** | Lower | Higher (Hessian storage) |\n",
+ "| **Calibration sensitivity** | Less prone to overfitting | May overfit calibration set, affecting out-of-distribution performance |\n",
+ "\n",
+ "**Choose AWQ when:**\n",
+ "\n",
+ "* You need faster quantization (AWQ is typically 2-3x faster than GPTQ).\n",
+ "* Memory during quantization is constrained.\n",
+ "* 4-bit is sufficient for your use case.\n",
+ "* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data).\n",
+ "\n",
+ "**Choose GPTQ when:**\n",
+ "\n",
+ "* You need bit-widths other than 4 (e.g., 2-bit or 8-bit).\n",
+ "* Maximum accuracy is critical and you can afford longer quantization time.\n",
+ "* You're working with decoder-only LLMs where GPTQ may have a slight edge."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Practical tips\n",
+ "\n",
+ "* AWQ is a post-training technique; training after quantization is not supported.\n",
+ "* Always use the model's own tokenizer for calibration.\n",
+ "* Use a representative calibration set; small slices are only for demos.\n",
+ "* Start with W4 group_size=128; tune per model/task.\n",
+ "* AWQ only supports 4-bit quantization currently.\n",
+ "* For best results, use calibration data that matches your inference domain.\n",
+ "\n",
+ "## References\n",
+ "\n",
+ "* [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/abs/2306.00978)\n",
+ "* [MIT HAN Lab AWQ Repository](https://github.com/mit-han-lab/llm-awq)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "awq_quantization_in_keras",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/guides/ipynb/gptq_quantization_in_keras.ipynb b/guides/ipynb/gptq_quantization_in_keras.ipynb
index bcbac78e29..ad3ff407e4 100644
--- a/guides/ipynb/gptq_quantization_in_keras.ipynb
+++ b/guides/ipynb/gptq_quantization_in_keras.ipynb
@@ -167,12 +167,12 @@
"Dataset: WikiText-2.\n",
"\n",
"\n",
- "| Model (preset) | Perplexity Increase % (↓ better) | Disk Storage Reduction Δ % (↓ better) | VRAM Reduction Δ % (↓ better) | First-token Latency Δ % (↓ better) | Throughput Δ % (↑ better) |\n",
- "| ------------------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |\n",
- "| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% ↓ | -41.1% ↓ | +0.7% ↑ | +20.1% ↑ |\n",
- "| OPT (opt_125m_en) | 10.0% | -49.8% ↓ | -47.0% ↓ | +6.7% ↑ | -15.7% ↓ |\n",
- "| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% ↓ | -54.0% ↓ | +1.8% ↑ | -15.7% ↓ |\n",
- "| Gemma3 (gemma3_1b) | 3.0% | -51.5% ↓ | -51.8% ↓ | +39.5% ↑ | +5.7% ↑ |\n",
+ "| Model (preset) | Perplexity Increase % (\u2193 better) | Disk Storage Reduction \u0394 % (\u2193 better) | VRAM Reduction \u0394 % (\u2193 better) | First-token Latency \u0394 % (\u2193 better) | Throughput \u0394 % (\u2191 better) |\n",
+ "| --------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |\n",
+ "| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% \u2193 | -41.1% \u2193 | +0.7% \u2191 | +20.1% \u2191 |\n",
+ "| OPT (opt_125m_en) | 10.0% | -49.8% \u2193 | -47.0% \u2193 | +6.7% \u2191 | -15.7% \u2193 |\n",
+ "| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% \u2193 | -54.0% \u2193 | +1.8% \u2191 | -15.7% \u2193 |\n",
+ "| Gemma3 (gemma3_1b) | 3.0% | -51.5% \u2193 | -51.8% \u2193 | +39.5% \u2191 | +5.7% \u2191 |\n",
"\n",
"\n",
"Detailed benchmarking numbers and scripts are available\n",
@@ -189,6 +189,40 @@
"preserved after quantization."
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## GPTQ vs AWQ?\n",
+ "\n",
+ "Both GPTQ and AWQ are weight-only quantization methods that require calibration\n",
+ "data. Here's how to choose between them:\n",
+ "\n",
+ "| Aspect | GPTQ | AWQ |\n",
+ "| ------ | ---- | --- |\n",
+ "| **Algorithm** | Hessian-based second-order optimization | Grid search for activation-aware scales |\n",
+ "| **Quantization speed** | Slower (requires Hessian estimation) | Faster (no Hessian computation) |\n",
+ "| **Bit-widths supported** | 2/3/4/8-bit | 4-bit |\n",
+ "| **Accuracy** | Often slightly better on decoder LLMs | Competitive, especially on encoder models |\n",
+ "| **Memory during quantization** | Higher (Hessian storage) | Lower |\n",
+ "| **Calibration sensitivity** | May overfit calibration set, affecting out-of-distribution performance | Less prone to overfitting |\n",
+ "\n",
+ "**Choose GPTQ when:**\n",
+ "\n",
+ "* You need bit-widths other than 4 (e.g., 2-bit or 8-bit).\n",
+ "* Maximum accuracy is critical and you can afford longer quantization time.\n",
+ "* You're working with decoder-only LLMs where GPTQ may have a slight edge.\n",
+ "\n",
+ "**Choose AWQ when:**\n",
+ "\n",
+ "* You need faster quantization (AWQ is typically 2-3x faster than GPTQ).\n",
+ "* Memory during quantization is constrained.\n",
+ "* 4-bit is sufficient for your use case.\n",
+ "* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data)."
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {
@@ -202,11 +236,6 @@
"* Use a representative calibration set; small slices are only for demos.\n",
"* Start with W4 group_size=128; tune per model/task."
]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": []
}
],
"metadata": {
@@ -238,4 +267,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
+}
\ No newline at end of file
diff --git a/guides/md/awq_quantization_in_keras.md b/guides/md/awq_quantization_in_keras.md
new file mode 100644
index 0000000000..f7d165dfb5
--- /dev/null
+++ b/guides/md/awq_quantization_in_keras.md
@@ -0,0 +1,230 @@
+# AWQ Quantization in Keras
+
+**Author:** [Jyotinder Singh](https://x.com/Jyotinder_Singh)
+**Date created:** 2025/01/15
+**Last modified:** 2025/01/15
+**Description:** How to run weight-only AWQ quantization for Keras & KerasHub models.
+
+
+ [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/awq_quantization_in_keras.ipynb) •
[**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/awq_quantization_in_keras.py)
+
+
+
+---
+## What is AWQ?
+
+AWQ (Activation-aware Weight Quantization) is a post-training, weight-only
+quantization method that uses activation statistics to identify and protect
+salient weights during quantization.
+
+The key insight of AWQ is that not all weights are equally important: a small
+fraction of weights (typically <1%) are "salient" because they process
+channels with large activation magnitudes. By protecting these weights from
+quantization error, AWQ preserves model quality while achieving significant
+compression.
+
+Unlike GPTQ which uses second-order (Hessian-based) optimization, AWQ uses a
+simpler grid search to find per-channel scales that minimize activation-weighted
+quantization error. This makes AWQ generally faster while achieving competitive
+accuracy.
+
+### How it works
+
+1. Run a small calibration set through the model to collect per-channel
+ activation magnitudes.
+2. For each weight matrix, search for optimal per-channel scales that
+ minimize activation-weighted quantization error.
+3. Multiply weights by the optimal scales before quantization
+ (expanding salient weights).
+4. Quantize the scaled weights to 4-bit (or other supported bit-width) integers.
+5. During inference, dequantize weights and divide by scales to
+ restore original magnitude.
+
+The scale formula uses: `scales = activation_max^ratio` where ratio is
+searched over a grid from 0 to 1.
+
+Keras supports AWQ quantization for KerasHub models via the
+`keras.quantizers.AWQConfig` class.
+
+---
+## Load a KerasHub model
+
+This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B
+parameter) causal language model.
+
+
+```python
+from datasets import load_dataset
+import keras
+from keras_hub.models import Gemma3CausalLM
+
+
+prompt = "Keras is a"
+
+model = Gemma3CausalLM.from_preset("gemma3_1b")
+
+outputs = model.generate(prompt, max_length=30)
+print(outputs)
+```
+
+