Skip to content

Commit b710449

Browse files
address reviews
1 parent bf4aa31 commit b710449

File tree

6 files changed

+84
-34
lines changed

6 files changed

+84
-34
lines changed

guides/awq_quantization_in_keras.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@
5151
parameter) causal language model.
5252
5353
"""
54+
from datasets import load_dataset
5455
import keras
5556
from keras_hub.models import Gemma3CausalLM
56-
from datasets import load_dataset
5757

5858

5959
prompt = "Keras is a"
@@ -93,9 +93,12 @@
9393
# Calibration slice (use a larger/representative set in practice)
9494
texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"]
9595

96-
calibration_dataset = [
97-
s + "." for text in texts for s in map(str.strip, text.split(".")) if s
98-
]
96+
calibration_dataset = []
97+
for text in texts:
98+
for s in text.split("."):
99+
s = s.strip()
100+
if s:
101+
calibration_dataset.append(s + ".")
99102

100103
awq_config = keras.quantizers.AWQConfig(
101104
dataset=calibration_dataset,
@@ -161,7 +164,7 @@
161164
| ------ | --- | ---- |
162165
| **Algorithm** | Grid search for activation-aware scales | Hessian-based second-order optimization |
163166
| **Quantization speed** | Faster (no Hessian computation) | Slower (requires Hessian estimation) |
164-
| **Bit-widths supported** | only 4-bit supported for now | 2/3/4/8-bit |
167+
| **Bit-widths supported** | 4-bit | 2/3/4/8-bit |
165168
| **Accuracy** | Competitive, especially on encoder models | Often slightly better on decoder LLMs |
166169
| **Memory during quantization** | Lower | Higher (Hessian storage) |
167170
| **Calibration sensitivity** | Less prone to overfitting | May overfit calibration set, affecting out-of-distribution performance |

guides/gptq_quantization_in_keras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
parameter) causal language model.
3232
3333
"""
34+
from datasets import load_dataset
3435
import keras
3536
from keras_hub.models import Gemma3CausalLM
36-
from datasets import load_dataset
3737

3838

3939
prompt = "Keras is a"
@@ -140,7 +140,7 @@
140140
| ------ | ---- | --- |
141141
| **Algorithm** | Hessian-based second-order optimization | Grid search for activation-aware scales |
142142
| **Quantization speed** | Slower (requires Hessian estimation) | Faster (no Hessian computation) |
143-
| **Bit-widths supported** | 2/3/4/8-bit | Only 4-bit supported for now |
143+
| **Bit-widths supported** | 2/3/4/8-bit | 4-bit |
144144
| **Accuracy** | Often slightly better on decoder LLMs | Competitive, especially on encoder models |
145145
| **Memory during quantization** | Higher (Hessian storage) | Lower |
146146
| **Calibration sensitivity** | May overfit calibration set, affecting out-of-distribution performance | Less prone to overfitting |

guides/ipynb/awq_quantization_in_keras.ipynb

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@
7676
},
7777
"outputs": [],
7878
"source": [
79+
"from datasets import load_dataset\n",
7980
"import keras\n",
8081
"from keras_hub.models import Gemma3CausalLM\n",
81-
"from datasets import load_dataset\n",
8282
"\n",
8383
"\n",
8484
"prompt = \"Keras is a\"\n",
@@ -132,9 +132,12 @@
132132
"# Calibration slice (use a larger/representative set in practice)\n",
133133
"texts = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"train[:1%]\")[\"text\"]\n",
134134
"\n",
135-
"calibration_dataset = [\n",
136-
" s + \".\" for text in texts for s in map(str.strip, text.split(\".\")) if s\n",
137-
"]\n",
135+
"calibration_dataset = []\n",
136+
"for text in texts:\n",
137+
" for s in text.split(\".\"):\n",
138+
" s = s.strip()\n",
139+
" if s:\n",
140+
" calibration_dataset.append(s + \".\")\n",
138141
"\n",
139142
"awq_config = keras.quantizers.AWQConfig(\n",
140143
" dataset=calibration_dataset,\n",
@@ -225,7 +228,7 @@
225228
"| ------ | --- | ---- |\n",
226229
"| **Algorithm** | Grid search for activation-aware scales | Hessian-based second-order optimization |\n",
227230
"| **Quantization speed** | Faster (no Hessian computation) | Slower (requires Hessian estimation) |\n",
228-
"| **Bit-widths supported** | only 4-bit supported for now | 2/3/4/8-bit |\n",
231+
"| **Bit-widths supported** | 4-bit | 2/3/4/8-bit |\n",
229232
"| **Accuracy** | Competitive, especially on encoder models | Often slightly better on decoder LLMs |\n",
230233
"| **Memory during quantization** | Lower | Higher (Hessian storage) |\n",
231234
"| **Calibration sensitivity** | Less prone to overfitting | May overfit calibration set, affecting out-of-distribution performance |\n",

guides/ipynb/gptq_quantization_in_keras.ipynb

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,12 @@
167167
"Dataset: WikiText-2.\n",
168168
"\n",
169169
"\n",
170-
"| Model (preset) | Perplexity Increase % ( better) | Disk Storage Reduction Δ % ( better) | VRAM Reduction Δ % ( better) | First-token Latency Δ % ( better) | Throughput Δ % ( better) |\n",
171-
"| ------------------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |\n",
172-
"| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% | -41.1% | +0.7% | +20.1% |\n",
173-
"| OPT (opt_125m_en) | 10.0% | -49.8% | -47.0% | +6.7% | -15.7% |\n",
174-
"| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% | -54.0% | +1.8% | -15.7% |\n",
175-
"| Gemma3 (gemma3_1b) | 3.0% | -51.5% | -51.8% | +39.5% | +5.7% |\n",
170+
"| Model (preset) | Perplexity Increase % (\u2193 better) | Disk Storage Reduction \u0394 % (\u2193 better) | VRAM Reduction \u0394 % (\u2193 better) | First-token Latency \u0394 % (\u2193 better) | Throughput \u0394 % (\u2191 better) |\n",
171+
"| --------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |\n",
172+
"| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% \u2193 | -41.1% \u2193 | +0.7% \u2191 | +20.1% \u2191 |\n",
173+
"| OPT (opt_125m_en) | 10.0% | -49.8% \u2193 | -47.0% \u2193 | +6.7% \u2191 | -15.7% \u2193 |\n",
174+
"| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% \u2193 | -54.0% \u2193 | +1.8% \u2191 | -15.7% \u2193 |\n",
175+
"| Gemma3 (gemma3_1b) | 3.0% | -51.5% \u2193 | -51.8% \u2193 | +39.5% \u2191 | +5.7% \u2191 |\n",
176176
"\n",
177177
"\n",
178178
"Detailed benchmarking numbers and scripts are available\n",
@@ -191,8 +191,37 @@
191191
},
192192
{
193193
"cell_type": "markdown",
194-
"source": "## GPTQ vs AWQ?\n\nBoth GPTQ and AWQ are weight-only quantization methods that require calibration\ndata. Here's how to choose between them:\n\n| Aspect | GPTQ | AWQ |\n| ------ | ---- | --- |\n| **Algorithm** | Hessian-based second-order optimization | Grid search for activation-aware scales |\n| **Quantization speed** | Slower (requires Hessian estimation) | Faster (no Hessian computation) |\n| **Bit-widths supported** | 2/3/4/8-bit | Only 4-bit supported for now |\n| **Accuracy** | Often slightly better on decoder LLMs | Competitive, especially on encoder models |\n| **Memory during quantization** | Higher (Hessian storage) | Lower |\n| **Calibration sensitivity** | May overfit calibration set, affecting out-of-distribution performance | Less prone to overfitting |\n\n**Choose GPTQ when:**\n\n* You need bit-widths other than 4 (e.g., 2-bit or 8-bit).\n* Maximum accuracy is critical and you can afford longer quantization time.\n* You're working with decoder-only LLMs where GPTQ may have a slight edge.\n\n**Choose AWQ when:**\n\n* You need faster quantization (AWQ is typically 2-3x faster than GPTQ).\n* Memory during quantization is constrained.\n* 4-bit is sufficient for your use case.\n* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data).",
195-
"metadata": {}
194+
"metadata": {
195+
"colab_type": "text"
196+
},
197+
"source": [
198+
"## GPTQ vs AWQ?\n",
199+
"\n",
200+
"Both GPTQ and AWQ are weight-only quantization methods that require calibration\n",
201+
"data. Here's how to choose between them:\n",
202+
"\n",
203+
"| Aspect | GPTQ | AWQ |\n",
204+
"| ------ | ---- | --- |\n",
205+
"| **Algorithm** | Hessian-based second-order optimization | Grid search for activation-aware scales |\n",
206+
"| **Quantization speed** | Slower (requires Hessian estimation) | Faster (no Hessian computation) |\n",
207+
"| **Bit-widths supported** | 2/3/4/8-bit | 4-bit |\n",
208+
"| **Accuracy** | Often slightly better on decoder LLMs | Competitive, especially on encoder models |\n",
209+
"| **Memory during quantization** | Higher (Hessian storage) | Lower |\n",
210+
"| **Calibration sensitivity** | May overfit calibration set, affecting out-of-distribution performance | Less prone to overfitting |\n",
211+
"\n",
212+
"**Choose GPTQ when:**\n",
213+
"\n",
214+
"* You need bit-widths other than 4 (e.g., 2-bit or 8-bit).\n",
215+
"* Maximum accuracy is critical and you can afford longer quantization time.\n",
216+
"* You're working with decoder-only LLMs where GPTQ may have a slight edge.\n",
217+
"\n",
218+
"**Choose AWQ when:**\n",
219+
"\n",
220+
"* You need faster quantization (AWQ is typically 2-3x faster than GPTQ).\n",
221+
"* Memory during quantization is constrained.\n",
222+
"* 4-bit is sufficient for your use case.\n",
223+
"* Your model will be used on diverse/out-of-distribution data (AWQ is less prone to overfitting on calibration data)."
224+
]
196225
},
197226
{
198227
"cell_type": "markdown",
@@ -207,11 +236,6 @@
207236
"* Use a representative calibration set; small slices are only for demos.\n",
208237
"* Start with W4 group_size=128; tune per model/task."
209238
]
210-
},
211-
{
212-
"cell_type": "markdown",
213-
"metadata": {},
214-
"source": []
215239
}
216240
],
217241
"metadata": {

guides/md/awq_quantization_in_keras.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ parameter) causal language model.
5454

5555

5656
```python
57+
from datasets import load_dataset
5758
import keras
5859
from keras_hub.models import Gemma3CausalLM
59-
from datasets import load_dataset
6060

6161

6262
prompt = "Keras is a"
@@ -104,9 +104,12 @@ the model using the `.quantize(...)` API.
104104
# Calibration slice (use a larger/representative set in practice)
105105
texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"]
106106

107-
calibration_dataset = [
108-
s + "." for text in texts for s in map(str.strip, text.split(".")) if s
109-
]
107+
calibration_dataset = []
108+
for text in texts:
109+
for s in text.split("."):
110+
s = s.strip()
111+
if s:
112+
calibration_dataset.append(s + ".")
110113

111114
awq_config = keras.quantizers.AWQConfig(
112115
dataset=calibration_dataset,
@@ -126,7 +129,7 @@ print(outputs)
126129

127130
<div class="k-default-codeblock">
128131
```
129-
26/26 ━━━━━━━━━━━━━━━━━━━━ 240s 9s/step
132+
26/26 ━━━━━━━━━━━━━━━━━━━━ 239s 9s/step
130133
131134
Keras is a Python library for deep learning. It is a high-level interface to the TensorFlow library.
132135
@@ -192,7 +195,7 @@ data. Here's how to choose between them:
192195
| ------ | --- | ---- |
193196
| **Algorithm** | Grid search for activation-aware scales | Hessian-based second-order optimization |
194197
| **Quantization speed** | Faster (no Hessian computation) | Slower (requires Hessian estimation) |
195-
| **Bit-widths supported** | only 4-bit supported for now | 2/3/4/8-bit |
198+
| **Bit-widths supported** | 4-bit | 2/3/4/8-bit |
196199
| **Accuracy** | Competitive, especially on encoder models | Often slightly better on decoder LLMs |
197200
| **Memory during quantization** | Lower | Higher (Hessian storage) |
198201
| **Calibration sensitivity** | Less prone to overfitting | May overfit calibration set, affecting out-of-distribution performance |

guides/md/gptq_quantization_in_keras.md

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ parameter) causal language model.
3434

3535

3636
```python
37+
from datasets import load_dataset
3738
import keras
3839
from keras_hub.models import Gemma3CausalLM
39-
from datasets import load_dataset
4040

4141

4242
prompt = "Keras is a"
@@ -101,6 +101,24 @@ print(outputs)
101101

102102
<div class="k-default-codeblock">
103103
```
104+
/home/jyotinder/anaconda3/envs/keras-io/lib/python3.12/site-packages/keras/src/models/model.py:547: UserWarning: Layer InputLayer does not have a `quantize` method implemented.
105+
warnings.warn(str(e))
106+
/home/jyotinder/anaconda3/envs/keras-io/lib/python3.12/site-packages/keras/src/models/model.py:547: UserWarning: Layer RMSNormalization does not have a `quantize` method implemented.
107+
warnings.warn(str(e))
108+
/home/jyotinder/anaconda3/envs/keras-io/lib/python3.12/site-packages/keras/src/models/model.py:547: UserWarning: Layer RotaryEmbedding does not have a `quantize` method implemented.
109+
warnings.warn(str(e))
110+
/home/jyotinder/anaconda3/envs/keras-io/lib/python3.12/site-packages/keras/src/models/model.py:547: UserWarning: Layer Softmax does not have a `quantize` method implemented.
111+
warnings.warn(str(e))
112+
/home/jyotinder/anaconda3/envs/keras-io/lib/python3.12/site-packages/keras/src/models/model.py:547: UserWarning: Layer Dropout does not have a `quantize` method implemented.
113+
warnings.warn(str(e))
114+
115+
/home/jyotinder/anaconda3/envs/keras-io/lib/python3.12/site-packages/keras/src/models/model.py:547: UserWarning: Invalid quantization mode. Expected one of ('int8', 'int4'). Received: quantization_mode=gptq
116+
warnings.warn(str(e))
117+
118+
I0000 00:00:1769055417.299190 54325 cuda_solvers.cc:175] Creating GpuSolver handles for stream 0x1f32b460
119+
120+
26/26 ━━━━━━━━━━━━━━━━━━━━ 1235s 47s/step
121+
104122
Keras is a Python library for deep learning. It is a high-level interface to the TensorFlow library.
105123
106124
Keras is a great library
@@ -119,7 +137,6 @@ model.save_to_preset("gemma3_gptq_w4gs128_preset")
119137
model_from_preset = Gemma3CausalLM.from_preset("gemma3_gptq_w4gs128_preset")
120138
output = model_from_preset.generate(prompt, max_length=30)
121139
print(output)
122-
123140
```
124141

125142
<div class="k-default-codeblock">
@@ -170,7 +187,7 @@ data. Here's how to choose between them:
170187
| ------ | ---- | --- |
171188
| **Algorithm** | Hessian-based second-order optimization | Grid search for activation-aware scales |
172189
| **Quantization speed** | Slower (requires Hessian estimation) | Faster (no Hessian computation) |
173-
| **Bit-widths supported** | 2/3/4/8-bit | Only 4-bit supported for now |
190+
| **Bit-widths supported** | 2/3/4/8-bit | 4-bit |
174191
| **Accuracy** | Often slightly better on decoder LLMs | Competitive, especially on encoder models |
175192
| **Memory during quantization** | Higher (Hessian storage) | Lower |
176193
| **Calibration sensitivity** | May overfit calibration set, affecting out-of-distribution performance | Less prone to overfitting |

0 commit comments

Comments
 (0)