Skip to content

Commit ee63423

Browse files
Adds 8-bit Integer Quantization Documentation (#2193)
* Adds int8 quantization docs * address reviews * small fix
1 parent 62d2de3 commit ee63423

File tree

7 files changed

+716
-8
lines changed

7 files changed

+716
-8
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
Title: 8-bit Integer Quantization in Keras
3+
Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh)
4+
Date created: 2025/10/14
5+
Last modified: 2025/10/14
6+
Description: Complete guide to using INT8 quantization in Keras and KerasHub.
7+
Accelerator: GPU
8+
"""
9+
10+
"""
11+
## What is INT8 quantization?
12+
13+
Quantization lowers the numerical precision of weights and activations to reduce memory use
14+
and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to
15+
`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs
16+
`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also
17+
improve throughput and latency. Actual gains depend on your backend and device.
18+
19+
### How it works
20+
21+
Quantization maps real values to 8-bit integers with a scale:
22+
23+
* Integer domain: `[-128, 127]` (256 levels).
24+
* For a tensor (often per-output-channel for weights) with values `w`:
25+
* Compute `a_max = max(abs(w))`.
26+
* Set scale `s = (2 * a_max) / 256`.
27+
* Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`.
28+
* Inference uses `q` and `s` to reconstruct effective weights on the fly
29+
(`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency.
30+
31+
### Benefits
32+
33+
* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O,
34+
reducing the computation time does not reduce their overall runtime. INT8 reduces bytes
35+
moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls;
36+
this often helps more than increasing raw FLOPs.
37+
* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8
38+
[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv,
39+
boosting throughput on compute-limited layers.
40+
* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest
41+
drop (often ~1-5% depending on task/model/data). Always validate on your own dataset.
42+
43+
### What Keras does in INT8 mode
44+
45+
* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale.
46+
* **Weights**: per-output-channel scales to preserve accuracy.
47+
* **Activations**: **dynamic AbsMax** scaling computed at runtime.
48+
* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph
49+
is rewritten so you can run or save immediately.
50+
"""
51+
52+
"""
53+
## Overview
54+
55+
This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras:
56+
57+
1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model)
58+
2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model)
59+
3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model)
60+
61+
## Quantizing a minimal functional model.
62+
63+
We build a small functional model, capture a baseline output, quantize to INT8 in-place,
64+
and then compare outputs with an MSE metric.
65+
"""
66+
67+
import os
68+
import numpy as np
69+
import keras
70+
from keras import layers
71+
72+
73+
# Create a random number generator.
74+
rng = np.random.default_rng()
75+
76+
# Create a simple functional model.
77+
inputs = keras.Input(shape=(10,))
78+
x = layers.Dense(32, activation="relu")(inputs)
79+
outputs = layers.Dense(1, name="target")(x)
80+
model = keras.Model(inputs, outputs)
81+
82+
# Compile and train briefly to materialize meaningful weights.
83+
model.compile(optimizer="adam", loss="mse")
84+
x_train = rng.random((256, 10)).astype("float32")
85+
y_train = rng.random((256, 1)).astype("float32")
86+
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)
87+
88+
# Sample inputs for evaluation.
89+
x_eval = rng.random((32, 10)).astype("float32")
90+
91+
# Baseline (FP) outputs.
92+
y_fp32 = model(x_eval)
93+
94+
# Quantize the model in-place to INT8.
95+
model.quantize("int8")
96+
97+
# INT8 outputs after quantization.
98+
y_int8 = model(x_eval)
99+
100+
# Compute a simple MSE between FP and INT8 outputs.
101+
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8))
102+
print("Full-Precision vs INT8 MSE:", float(mse))
103+
104+
105+
"""
106+
It is evident that the INT8 quantized model produces outputs close to the original FP32
107+
model, as indicated by the low MSE value.
108+
109+
## Saving and reloading a quantized model
110+
111+
You can use the standard Keras saving and loading APIs with quantized models. Quantization
112+
is preserved when saving to `.keras` and loading back.
113+
"""
114+
115+
# Save the quantized model and reload to verify round-trip.
116+
model.save("int8.keras")
117+
int8_reloaded = keras.saving.load_model("int8.keras")
118+
y_int8_reloaded = int8_reloaded(x_eval)
119+
roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))
120+
print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))
121+
122+
"""
123+
## Quantizing a KerasHub model
124+
125+
All KerasHub models support the `.quantize(...)` API for post-training quantization,
126+
and follow the same workflow as above.
127+
128+
In this example, we will:
129+
130+
1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b)
131+
preset from KerasHub
132+
2. Generate text using both the full-precision and quantized models, and compare outputs.
133+
3. Save both models to disk and compute storage savings.
134+
4. Reload the INT8 model and verify output consistency with the original quantized model.
135+
"""
136+
137+
from keras_hub.models import Gemma3CausalLM
138+
139+
# Load from Gemma3 preset
140+
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
141+
142+
# Generate text for a single prompt
143+
output = gemma3.generate("Keras is a", max_length=50)
144+
print("Full-precision output:", output)
145+
146+
# Save FP32 Gemma3 model for size comparison.
147+
gemma3.save_to_preset("gemma3_fp32")
148+
149+
# Quantize in-place to INT8 and generate again
150+
gemma3.quantize("int8")
151+
152+
output = gemma3.generate("Keras is a", max_length=50)
153+
print("Quantized output:", output)
154+
155+
# Save INT8 Gemma3 model
156+
gemma3.save_to_preset("gemma3_int8")
157+
158+
# Reload and compare outputs
159+
gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")
160+
161+
output = gemma3_int8.generate("Keras is a", max_length=50)
162+
print("Quantized reloaded output:", output)
163+
164+
165+
# Compute storage savings
166+
def bytes_to_mib(n):
167+
return n / (1024**2)
168+
169+
170+
gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
171+
gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5")
172+
173+
gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))
174+
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
175+
print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB")
176+
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
177+
178+
"""
179+
## Practical tips
180+
181+
* Post-training quantization (PTQ) is a one-time operation; you cannot train a model
182+
after quantizing it to INT8.
183+
* Always materialize weights before quantization (e.g., `build()` or a forward pass).
184+
* Expect small numerical deltas; quantify with a metric like MSE on a validation batch.
185+
* Storage savings are immediate; speedups depend on backend/device kernels.
186+
187+
## References
188+
189+
* [Milvus: How does 8-bit quantization or float16 affect the accuracy and speed of Sentence Transformer embeddings and similarity calculations?](https://milvus.io/ai-quick-reference/how-does-quantization-such-as-int8-quantization-or-using-float16-affect-the-accuracy-and-speed-of-sentence-transformer-embeddings-and-similarity-calculations)
190+
* [NVIDIA Developer Blog: Achieving FP32 accuracy for INT8 inference using quantization-aware training with TensorRT](https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/)
191+
"""

0 commit comments

Comments
 (0)