Skip to content

Commit 1ad0b1c

Browse files
Merge branch 'banner' of https://github.com/buildwithsuhana/keras-io into banner
2 parents 02cfa1f + 2ee506a commit 1ad0b1c

22 files changed

+4377
-16
lines changed

.github/workflows/continuous_integration.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ jobs:
1212
black:
1313
runs-on: ubuntu-latest
1414
steps:
15-
- uses: actions/checkout@v2
16-
- uses: actions/setup-python@v1
15+
- uses: actions/checkout@v4
16+
- uses: actions/setup-python@v5
1717
with:
18-
python-version: 3.10.18
18+
python-version: '3.10'
1919
- name: Ensure files are formatted with black
2020
run: |
2121
pip install --upgrade pip
@@ -24,7 +24,7 @@ jobs:
2424
docker-image:
2525
runs-on: ubuntu-latest
2626
steps:
27-
- uses: actions/checkout@v2
27+
- uses: actions/checkout@v4
2828
- name: Ensure the docker image works and can start.
2929
run: |
3030
make container-test

examples/vision/ipynb/vivit.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ayush Thakur](https://twitter.com/ayushthakur0) (equal contribution)<br>\n",
1212
"**Date created:** 2022/01/12<br>\n",
13-
"**Last modified:** 2024/01/15<br>\n",
13+
"**Last modified:** 2025/10/16<br>\n",
1414
"**Description:** A Transformer-based architecture for video classification."
1515
]
1616
},

examples/vision/md/vivit.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ayush Thakur](https://twitter.com/ayushthakur0) (equal contribution)<br>
44
**Date created:** 2022/01/12<br>
5-
**Last modified:** 2024/01/15<br>
5+
**Last modified:** 2025/10/16<br>
66
**Description:** A Transformer-based architecture for video classification.
77

88

@@ -370,7 +370,7 @@ def run_experiment():
370370

371371
model = run_experiment()
372372
```
373-
373+
<div class="k-default-codeblock">
374374
```
375375
Test accuracy: 76.72%
376376
Test top 5 accuracy: 97.54%

examples/vision/vivit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Title: Video Vision Transformer
33
Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ayush Thakur](https://twitter.com/ayushthakur0) (equal contribution)
44
Date created: 2022/01/12
5-
Last modified: 2024/01/15
5+
Last modified: 2025/10/16
66
Description: A Transformer-based architecture for video classification.
77
Accelerator: GPU
88
"""
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
Title: GPTQ Quantization in Keras
3+
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4+
Date created: 2025/10/16
5+
Last modified: 2025/10/16
6+
Description: How to run weight-only GPTQ quantization for Keras & KerasHub models.
7+
Accelerator: GPU
8+
"""
9+
10+
"""
11+
## What is GPTQ?
12+
13+
GPTQ ("Generative Pre-Training Quantization") is a post-training, weight-only
14+
quantization method that uses a second-order approximation of the loss (via a
15+
Hessian estimate) to minimize the error introduced when compressing weights to
16+
lower precision, typically 4-bit integers.
17+
18+
Unlike standard post-training techniques, GPTQ keeps activations in
19+
higher-precision and only quantizes the weights. This often preserves model
20+
quality in low bit-width settings while still providing large storage and
21+
memory savings.
22+
23+
Keras supports GPTQ quantization for KerasHub models via the
24+
`keras.quantizers.GPTQConfig` class.
25+
"""
26+
27+
"""
28+
## Load a KerasHub model
29+
30+
This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B
31+
parameter) causal language model.
32+
33+
"""
34+
import keras
35+
from keras_hub.models import Gemma3CausalLM
36+
from datasets import load_dataset
37+
38+
39+
prompt = "Keras is a"
40+
41+
model = Gemma3CausalLM.from_preset("gemma3_1b")
42+
43+
outputs = model.generate(prompt, max_length=30)
44+
print(outputs)
45+
46+
"""
47+
## Configure & run GPTQ quantization
48+
49+
You can configure GPTQ quantization via the `keras.quantizers.GPTQConfig` class.
50+
51+
The GPTQ configuration requires a calibration dataset and tokenizer, which it
52+
uses to estimate the Hessian and quantization error. Here, we use a small slice
53+
of the WikiText-2 dataset for calibration.
54+
55+
You can tune several parameters to trade off speed, memory, and accuracy. The
56+
most important of these are `weight_bits` (the bit-width to quantize weights to)
57+
and `group_size` (the number of weights to quantize together). The group size
58+
controls the granularity of quantization: smaller groups typically yield better
59+
accuracy but are slower to quantize and may use more memory. A good starting
60+
point is `group_size=128` for 4-bit quantization (`weight_bits=4`).
61+
62+
In this example, we first prepare a tiny calibration set, and then run GPTQ on
63+
the model using the `.quantize(...)` API.
64+
"""
65+
66+
# Calibration slice (use a larger/representative set in practice)
67+
texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"]
68+
69+
calibration_dataset = [
70+
s + "." for text in texts for s in map(str.strip, text.split(".")) if s
71+
]
72+
73+
gptq_config = keras.quantizers.GPTQConfig(
74+
dataset=calibration_dataset,
75+
tokenizer=model.preprocessor.tokenizer,
76+
weight_bits=4,
77+
group_size=128,
78+
num_samples=256,
79+
sequence_length=256,
80+
hessian_damping=0.01,
81+
symmetric=False,
82+
activation_order=False,
83+
)
84+
85+
model.quantize("gptq", config=gptq_config)
86+
87+
outputs = model.generate(prompt, max_length=30)
88+
print(outputs)
89+
90+
"""
91+
## Model Export
92+
93+
The GPTQ quantized model can be saved to a preset and reloaded elsewhere, just
94+
like any other KerasHub model.
95+
"""
96+
97+
model.save_to_preset("gemma3_gptq_w4gs128_preset")
98+
model_from_preset = Gemma3CausalLM.from_preset("gemma3_gptq_w4gs128_preset")
99+
output = model_from_preset.generate(prompt, max_length=30)
100+
print(output)
101+
102+
"""
103+
## Performance & Benchmarking
104+
105+
Micro-benchmarks collected on a single NVIDIA 4070 Ti Super (16 GB).
106+
Baselines are FP32.
107+
108+
Dataset: WikiText-2.
109+
110+
111+
| Model (preset) | Perplexity Increase % (↓ better) | Disk Storage Reduction Δ % (↓ better) | VRAM Reduction Δ % (↓ better) | First-token Latency Δ % (↓ better) | Throughput Δ % (↑ better) |
112+
| --------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: |
113+
| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% ↓ | -41.1% ↓ | +0.7% ↑ | +20.1% ↑ |
114+
| OPT (opt_125m_en) | 10.0% | -49.8% ↓ | -47.0% ↓ | +6.7% ↑ | -15.7% ↓ |
115+
| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% ↓ | -54.0% ↓ | +1.8% ↑ | -15.7% ↓ |
116+
| Gemma3 (gemma3_1b) | 3.0% | -51.5% ↓ | -51.8% ↓ | +39.5% ↑ | +5.7% ↑ |
117+
118+
119+
Detailed benchmarking numbers and scripts are available
120+
[here](https://github.com/keras-team/keras/pull/21641).
121+
122+
### Analysis
123+
124+
There is notable reduction in disk space and VRAM usage across all models, with
125+
disk space savings around 50% and VRAM savings ranging from 41% to 54%. The
126+
reported disk savings understate the true weight compression because presets
127+
also include non-weight assets.
128+
129+
Perplexity increases only marginally, indicating model quality is largely
130+
preserved after quantization.
131+
"""
132+
133+
"""
134+
## Practical tips
135+
136+
* GPTQ is a post-training technique; training after quantization is not supported.
137+
* Always use the model's own tokenizer for calibration.
138+
* Use a representative calibration set; small slices are only for demos.
139+
* Start with W4 group_size=128; tune per model/task.
140+
"""

0 commit comments

Comments
 (0)