Skip to content

Commit ade0c30

Browse files
Adds 4-bit Integer Quantization Documentation (#2194)
* Adds INT4 quantization docs * Update int4_quantization_in_keras.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix grammar * address reviews --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ee63423 commit ade0c30

File tree

4 files changed

+1056
-0
lines changed

4 files changed

+1056
-0
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
"""
2+
Title: INT4 Quantization in Keras
3+
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4+
Date created: 2025/10/14
5+
Last modified: 2025/10/14
6+
Description: Complete guide to using INT4 quantization in Keras and KerasHub.
7+
Accelerator: GPU
8+
"""
9+
10+
"""
11+
## What is INT4 quantization?
12+
13+
Quantization lowers the numerical precision of weights and activations to reduce memory use
14+
and often speed up inference, at the cost of a small accuracy drop. INT4 post-training
15+
quantization (PTQ) stores model weights in 4-bit signed integers and dynamically quantizes
16+
activations to 8-bit at runtime (a W4A8 scheme). Compared with FP32 this can shrink weight
17+
storage ~8x (2x vs INT8) while retaining acceptable accuracy for many encoder models and
18+
some decoder models. Compute still leverages widely available NVIDIA INT8 Tensor Cores.
19+
20+
4-bit is a more aggressive compression than 8-bit and may induce larger quality regressions,
21+
especially for large autoregressive language models.
22+
23+
## How it works
24+
25+
Quantization maps real values to 4-bit integers with a scale:
26+
27+
1. Per-output-channel scale computed for each weight matrix (symmetric abs-max).
28+
2. Weights are quantized to values in `[-8, 7]` (4 bits) and packed two-per-byte.
29+
3. At inference, activations are dynamically scaled and quantized to INT8.
30+
4. Packed INT4 weights are unpacked to an INT8 tensor (still with INT4-range values).
31+
5. INT8 x INT8 matmul accumulates in INT32.
32+
6. Result is dequantized using `(input_scale * per_channel_kernel_scale)`.
33+
34+
This mirrors the INT8 path described in the
35+
[INT8 guide](https://keras.io/guides/int8_quantization_in_keras) with some added unpack
36+
overhead for stronger compression.
37+
38+
## Benefits
39+
* Memory / bandwidth bound models: When the implementation spends most of its time on memory I/O,
40+
reducing the computation time does not reduce its overall runtime. INT4 reduces bytes
41+
moved by ~8x vs FP32, improving cache behavior and reducing memory stalls;
42+
this often helps more than increasing raw FLOPs.
43+
* Accuracy: Many architectures retain acceptable accuracy with INT4; encoder-only models
44+
often fare better than decoder LLMs. Always validate on your own dataset.
45+
* Compute bound layers on supported hardware: 4-bit kernels are unpacked to INT8 at inference,
46+
therefore, on NVIDIA GPUs, INT8 [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)
47+
speed up matmul/conv, boosting throughput on compute-limited layers.
48+
49+
### What Keras does in INT4 mode
50+
51+
* **Mapping**: Symmetric, linear quantization with INT4 plus a floating-point scale.
52+
* **Weights**: per-output-channel scales to preserve accuracy.
53+
* **Activations**: **dynamic AbsMax** scaling computed at runtime.
54+
* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph
55+
is rewritten so you can run or save immediately.
56+
"""
57+
58+
"""
59+
## Overview
60+
61+
This guide shows how to use 4-bit (W4A8) post-training quantization in Keras:
62+
63+
1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model)
64+
2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model)
65+
3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model)
66+
4. [When to use INT4 vs INT8](#when-should-i-use-int4-vs-int8)
67+
5. [Performance benchmarks](#performance--benchmarking)
68+
6. [Practical Tips](#practical-tips)
69+
7. [Limitations](#limitations)
70+
"""
71+
72+
"""
73+
## Quantizing a Minimal Functional Model
74+
75+
Below we build a small functional model, capture a baseline output, quantize to INT4
76+
in place, and compare outputs with an MSE metric. (For real evaluation use your
77+
validation metric.)
78+
"""
79+
80+
import numpy as np
81+
import keras
82+
from keras import layers
83+
84+
# Create a random number generator.
85+
rng = np.random.default_rng()
86+
87+
# Create a simple functional model.
88+
inputs = keras.Input(shape=(10,))
89+
x = layers.Dense(32, activation="relu")(inputs)
90+
outputs = layers.Dense(1, name="target")(x)
91+
model = keras.Model(inputs, outputs)
92+
93+
# Baseline output with full-precision weights.
94+
x_eval = rng.random((32, 10)).astype("float32")
95+
y_fp32 = model(x_eval)
96+
97+
98+
# Quantize the model in-place to INT4 (W4A8).
99+
model.quantize("int4")
100+
101+
# Compare outputs (MSE).
102+
y_int4 = model(x_eval)
103+
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4))
104+
print("Full-Precision vs INT4 MSE:", float(mse))
105+
106+
"""
107+
The INT4 quantized model usually produces outputs close enough for many downstream
108+
tasks. Expect larger deltas than INT8, so always validate on your own data.
109+
"""
110+
111+
"""
112+
## Saving and Reloading a Quantized Model
113+
114+
You can use standard Keras saving / loading APIs. Quantization metadata (including
115+
scales and packed weights) is preserved.
116+
"""
117+
118+
# Save the quantized model and reload to verify round-trip.
119+
model.save("int4.keras")
120+
int4_reloaded = keras.saving.load_model("int4.keras")
121+
y_int4_reloaded = int4_reloaded(x_eval)
122+
123+
# Compare outputs (MSE).
124+
roundtrip_mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4_reloaded))
125+
print("MSE (INT4 vs reloaded INT4):", float(roundtrip_mse))
126+
127+
"""
128+
## Quantizing a KerasHub Model
129+
130+
All KerasHub models support the `.quantize(...)` API for post-training quantization,
131+
and follow the same workflow as above.
132+
133+
In this example, we will:
134+
135+
1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)
136+
preset from KerasHub
137+
2. Generate text using both the full-precision and quantized models, and compare outputs.
138+
3. Save both models to disk and compute storage savings.
139+
4. Reload the INT4 model and verify output consistency with the original quantized model.
140+
"""
141+
import os
142+
from keras_hub.models import Gemma3CausalLM
143+
144+
# Load a Gemma3 preset from KerasHub.
145+
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
146+
147+
# Generate with full-precision weights.
148+
fp_output = gemma3.generate("Keras is a", max_length=30)
149+
print("Full-precision output:", fp_output)
150+
151+
# Save the full-precision model to a preset.
152+
gemma3.save_to_preset("gemma3_fp32")
153+
154+
# Quantize to INT4.
155+
gemma3.quantize("int4")
156+
157+
# Generate with INT4 weights.
158+
output = gemma3.generate("Keras is a", max_length=30)
159+
print("Quantized output:", output)
160+
161+
# Save INT4 model to a new preset.
162+
gemma3.save_to_preset("gemma3_int4")
163+
164+
# Reload and compare outputs
165+
gemma3_int4 = Gemma3CausalLM.from_preset("gemma3_int4")
166+
167+
output = gemma3_int4.generate("Keras is a", max_length=30)
168+
print("Quantized reloaded output:", output)
169+
170+
171+
# Compute storage savings
172+
def bytes_to_mib(n):
173+
return n / (1024**2)
174+
175+
176+
gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
177+
gemma_int4_size = os.path.getsize("gemma3_int4/model.weights.h5")
178+
179+
gemma_reduction = 100.0 * (1.0 - (gemma_int4_size / max(gemma_fp32_size, 1)))
180+
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
181+
print(f"Gemma3: INT4 file size: {bytes_to_mib(gemma_int4_size):.2f} MiB")
182+
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
183+
184+
"""
185+
## Performance & Benchmarking
186+
187+
Micro-benchmarks collected on a single NVIDIA L4 (22.5 GB). Baselines are FP32.
188+
189+
### Text Classification (DistilBERT Base on SST-2)
190+
191+
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/77e874187d6da3f8280c053192f78c06/int4-quantization-micro-benchmark-distilbert.ipynb)
192+
193+
| Metric | FP32 | INT4 | Change |
194+
| ------ | ---- | ---- | ------ |
195+
| Accuracy (↑) | 91.06% | 90.14% | -0.92pp |
196+
| Model Size (MB, ↓) | 255.86 | 159.49 | -37.67% |
197+
| Peak GPU Memory (MiB, ↓) | 1554.00 | 1243.26 | -20.00% |
198+
| Latency (ms/sample, ↓) | 6.43 | 5.73 | -10.83% |
199+
| Throughput (samples/s, ↑) | 155.60 | 174.50 | +12.15% |
200+
201+
**Analysis**: Accuracy drop is modest (<1pp) with notable speed and memory gains;
202+
encoder-only models tend to retain fidelity under heavier weight compression.
203+
204+
### Text Generation (Falcon 1B)
205+
206+
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/19ab238e0f5b29ae24c0faf4128e7d7e/int4_quantization_micro_benchmark_falcon.ipynb)
207+
208+
| Metric | FP32 | INT4 | Change |
209+
| ------ | ---- | ---- | ------ |
210+
| Perplexity (↓) | 7.44 | 9.98 | +34.15% |
211+
| Model Size (GB, ↓) | 4.8884 | 0.9526 | -80.51% |
212+
| Peak GPU Memory (MiB, ↓) | 8021.12 | 5483.46 | -31.64% |
213+
| First Token Latency (ms, ↓) | 128.87 | 122.50 | -4.95% |
214+
| Sequence Latency (ms, ↓) | 338.29 | 181.93 | -46.22% |
215+
| Token Throughput (tokens/s, ↑) | 174.41 | 256.96 | +47.33% |
216+
217+
**Analysis**: INT4 gives large size (-80%) and memory (-32%) reductions. Perplexity
218+
increases (expected for aggressive compression) yet sequence latency drops and
219+
throughput rises ~50%.
220+
221+
### Text Generation (Gemma3 1B)
222+
223+
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/9ca7813971868d5d1a16cd7998d0e352/int4_quantization_micro_benchmark_gemma3.ipynb)
224+
225+
| Metric | FP32 | INT4 | Change |
226+
| ------ | ---- | ---- | ------ |
227+
| Perplexity (↓) | 6.17 | 10.46 | +69.61% |
228+
| Model Size (GB, ↓) | 3.7303 | 1.4576 | -60.92% |
229+
| Peak GPU Memory (MiB, ↓) | 6844.67 | 5008.14 | -26.83% |
230+
| First Token Latency (ms, ↓) | 57.42 | 64.21 | +11.83% |
231+
| Sequence Latency (ms, ↓) | 239.78 | 161.18 | -32.78% |
232+
| Token Throughput (tokens/s, ↑) | 246.06 | 366.05 | +48.76% |
233+
234+
**Analysis**: INT4 gives large size (-61%) and memory (-27%) reductions. Perplexity
235+
increases (expected for aggressive compression) yet sequence latency drops and
236+
throughput rises ~50%.
237+
238+
### Text Generation (Llama 3.2 1B)
239+
240+
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/310f50a0ca0eba3754de41c612b3b8ef/int4_quantization_micro_benchmark_llama3.ipynb)
241+
242+
| Metric | FP32 | INT4 | Change |
243+
| ------ | ---- | ---- | ------ |
244+
| Perplexity (↓) | 6.38 | 14.16 | +121.78% |
245+
| Model Size (GB, ↓) | 5.5890 | 2.4186 | -56.73% |
246+
| Peak GPU Memory (MiB, ↓) | 9509.49 | 6810.26 | -28.38% |
247+
| First Token Latency (ms, ↓) | 209.41 | 219.09 | +4.62% |
248+
| Sequence Latency (ms, ↓) | 322.33 | 262.15 | -18.67% |
249+
| Token Throughput (tokens/s, ↑) | 183.82 | 230.78 | +25.55% |
250+
251+
**Analysis**: INT4 gives large size (-57%) and memory (-28%) reductions. Perplexity
252+
increases (expected for aggressive compression) yet sequence latency drops and
253+
throughput rises ~25%.
254+
255+
### Text Generation (OPT 125M)
256+
257+
<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/918fcdb8a1433dea12800f8ca4a240f5/int4_quantization_micro_benchmark_opt.ipynb)
258+
259+
| Metric | FP32 | INT4 | Change |
260+
| ------ | ---- | ---- | ------ |
261+
| Perplexity (↓) | 13.85 | 21.02 | +51.79% |
262+
| Model Size (MB, ↓) | 468.3 | 284.0 | -39.37% |
263+
| Peak GPU Memory (MiB, ↓) | 1007.23 | 659.28 | -34.54% |
264+
| First Token Latency (ms/sample, ↓) | 95.79 | 97.87 | +2.18% |
265+
| Sequence Latency (ms/sample, ↓) | 60.35 | 54.64 | -9.46% |
266+
| Throughput (samples/s, ↑) | 973.41 | 1075.15 | +10.45% |
267+
268+
**Analysis**: INT4 gives large size (-39%) and memory (-35%) reductions. Perplexity
269+
increases (expected for aggressive compression) yet sequence latency drops and
270+
throughput rises ~10%.
271+
"""
272+
273+
"""
274+
## When should I use INT4 vs INT8?
275+
276+
| Goal / Constraint | Prefer INT8 | Prefer INT4 (W4A8) |
277+
| ----------------- | ----------- | ------------------ |
278+
| Minimal accuracy drop critical | ✔︎ | |
279+
| Maximum compression (disk / RAM) | | ✔︎ |
280+
| Bandwidth-bound inference | Possible | Often better |
281+
| Decoder LLM | ✔︎ | Try with eval first |
282+
| Encoder / classification models | ✔︎ | ✔︎ |
283+
| Available kernels / tooling maturity | ✔︎ | Emerging |
284+
285+
* Start with INT8; if memory or distribution size is still a bottleneck, evaluate INT4.
286+
* For LLMs, measure task-specific metrics (perplexity, exact match, etc.) after INT4.
287+
* Combine INT4 weights + LoRA adapters for efficient fine-tuning workflows.
288+
"""
289+
290+
"""
291+
## Practical Tips
292+
293+
* Post-training quantization (PTQ) is a one-time operation; you cannot train a model
294+
after quantizing it to INT4.
295+
* Always materialize weights before quantization (e.g., `build()` or a forward pass).
296+
* Evaluate on a representative validation set; track task metrics, not just MSE.
297+
* Use LoRA for further fine-tuning.
298+
299+
## Limitations
300+
* Runtime unpack adds overhead (weights are decompressed layer-wise for each forward pass).
301+
* Large compression leads to accuracy drop (especially for decoder-only LLMs).
302+
* LoRA export path is lossy (dequantize -> add delta -> requantize).
303+
* Keras does not yet support native fused INT4 kernels; relies on unpack + INT8 matmul.
304+
"""

0 commit comments

Comments
 (0)