|
| 1 | +""" |
| 2 | +Title: 8-bit Integer Quantization in Keras |
| 3 | +Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh) |
| 4 | +Date created: 2025/10/14 |
| 5 | +Last modified: 2025/10/14 |
| 6 | +Description: Complete guide to using INT8 quantization in Keras and KerasHub. |
| 7 | +Accelerator: GPU |
| 8 | +""" |
| 9 | + |
| 10 | +""" |
| 11 | +## What is INT8 quantization? |
| 12 | +
|
| 13 | +Quantization lowers the numerical precision of weights and activations to reduce memory use |
| 14 | +and often speed up inference, at the cost of a small accuracy drop. Moving from `float32` to |
| 15 | +`float16` halves the memory requirements; `float32` to INT8 is ~4x smaller (and ~2x vs |
| 16 | +`float16`). On hardware with low-precision kernels (e.g., NVIDIA Tensor Cores), this can also |
| 17 | +improve throughput and latency. Actual gains depend on your backend and device. |
| 18 | +
|
| 19 | +### How it works |
| 20 | +
|
| 21 | +Quantization maps real values to 8-bit integers with a scale: |
| 22 | +
|
| 23 | +* Integer domain: `[-128, 127]` (256 levels). |
| 24 | +* For a tensor (often per-output-channel for weights) with values `w`: |
| 25 | + * Compute `a_max = max(abs(w))`. |
| 26 | + * Set scale `s = (2 * a_max) / 256`. |
| 27 | + * Quantize: `q = clip(round(w / s), -128, 127)` (stored as INT8) and keep `s`. |
| 28 | +* Inference uses `q` and `s` to reconstruct effective weights on the fly |
| 29 | + (`w ≈ s · q`) or folds `s` into the matmul/conv for efficiency. |
| 30 | +
|
| 31 | +### Benefits |
| 32 | +
|
| 33 | +* Memory / bandwidth bound models: When implementation spends most of its time on memory I/O, |
| 34 | + reducing the computation time does not reduce their overall runtime. INT8 reduces bytes |
| 35 | + moved by ~4x vs `float32`, improving cache behavior and reducing memory stalls; |
| 36 | + this often helps more than increasing raw FLOPs. |
| 37 | +* Compute bound layers on supported hardware: On NVIDIA GPUs, INT8 |
| 38 | + [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) speed up matmul/conv, |
| 39 | + boosting throughput on compute-limited layers. |
| 40 | +* Accuracy: Many models retain near-FP accuracy with `float16`; INT8 may introduce a modest |
| 41 | + drop (often ~1-5% depending on task/model/data). Always validate on your own dataset. |
| 42 | +
|
| 43 | +### What Keras does in INT8 mode |
| 44 | +
|
| 45 | +* **Mapping**: Symmetric, linear quantization with INT8 plus a floating-point scale. |
| 46 | +* **Weights**: per-output-channel scales to preserve accuracy. |
| 47 | +* **Activations**: **dynamic AbsMax** scaling computed at runtime. |
| 48 | +* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph |
| 49 | + is rewritten so you can run or save immediately. |
| 50 | +""" |
| 51 | + |
| 52 | +""" |
| 53 | +## Overview |
| 54 | +
|
| 55 | +This guide shows how to use 8-bit integer post-training quantization (PTQ) in Keras: |
| 56 | +
|
| 57 | +1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model) |
| 58 | +2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model) |
| 59 | +3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model) |
| 60 | +
|
| 61 | +## Quantizing a minimal functional model. |
| 62 | +
|
| 63 | +We build a small functional model, capture a baseline output, quantize to INT8 in-place, |
| 64 | +and then compare outputs with an MSE metric. |
| 65 | +""" |
| 66 | + |
| 67 | +import os |
| 68 | +import numpy as np |
| 69 | +import keras |
| 70 | +from keras import layers |
| 71 | + |
| 72 | + |
| 73 | +# Create a random number generator. |
| 74 | +rng = np.random.default_rng() |
| 75 | + |
| 76 | +# Create a simple functional model. |
| 77 | +inputs = keras.Input(shape=(10,)) |
| 78 | +x = layers.Dense(32, activation="relu")(inputs) |
| 79 | +outputs = layers.Dense(1, name="target")(x) |
| 80 | +model = keras.Model(inputs, outputs) |
| 81 | + |
| 82 | +# Compile and train briefly to materialize meaningful weights. |
| 83 | +model.compile(optimizer="adam", loss="mse") |
| 84 | +x_train = rng.random((256, 10)).astype("float32") |
| 85 | +y_train = rng.random((256, 1)).astype("float32") |
| 86 | +model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0) |
| 87 | + |
| 88 | +# Sample inputs for evaluation. |
| 89 | +x_eval = rng.random((32, 10)).astype("float32") |
| 90 | + |
| 91 | +# Baseline (FP) outputs. |
| 92 | +y_fp32 = model(x_eval) |
| 93 | + |
| 94 | +# Quantize the model in-place to INT8. |
| 95 | +model.quantize("int8") |
| 96 | + |
| 97 | +# INT8 outputs after quantization. |
| 98 | +y_int8 = model(x_eval) |
| 99 | + |
| 100 | +# Compute a simple MSE between FP and INT8 outputs. |
| 101 | +mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8)) |
| 102 | +print("Full-Precision vs INT8 MSE:", float(mse)) |
| 103 | + |
| 104 | + |
| 105 | +""" |
| 106 | +It is evident that the INT8 quantized model produces outputs close to the original FP32 |
| 107 | +model, as indicated by the low MSE value. |
| 108 | +
|
| 109 | +## Saving and reloading a quantized model |
| 110 | +
|
| 111 | +You can use the standard Keras saving and loading APIs with quantized models. Quantization |
| 112 | +is preserved when saving to `.keras` and loading back. |
| 113 | +""" |
| 114 | + |
| 115 | +# Save the quantized model and reload to verify round-trip. |
| 116 | +model.save("int8.keras") |
| 117 | +int8_reloaded = keras.saving.load_model("int8.keras") |
| 118 | +y_int8_reloaded = int8_reloaded(x_eval) |
| 119 | +roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded)) |
| 120 | +print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse)) |
| 121 | + |
| 122 | +""" |
| 123 | +## Quantizing a KerasHub model |
| 124 | +
|
| 125 | +All KerasHub models support the `.quantize(...)` API for post-training quantization, |
| 126 | +and follow the same workflow as above. |
| 127 | +
|
| 128 | +In this example, we will: |
| 129 | +
|
| 130 | +1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b) |
| 131 | + preset from KerasHub |
| 132 | +2. Generate text using both the full-precision and quantized models, and compare outputs. |
| 133 | +3. Save both models to disk and compute storage savings. |
| 134 | +4. Reload the INT8 model and verify output consistency with the original quantized model. |
| 135 | +""" |
| 136 | + |
| 137 | +from keras_hub.models import Gemma3CausalLM |
| 138 | + |
| 139 | +# Load from Gemma3 preset |
| 140 | +gemma3 = Gemma3CausalLM.from_preset("gemma3_1b") |
| 141 | + |
| 142 | +# Generate text for a single prompt |
| 143 | +output = gemma3.generate("Keras is a", max_length=50) |
| 144 | +print("Full-precision output:", output) |
| 145 | + |
| 146 | +# Save FP32 Gemma3 model for size comparison. |
| 147 | +gemma3.save_to_preset("gemma3_fp32") |
| 148 | + |
| 149 | +# Quantize in-place to INT8 and generate again |
| 150 | +gemma3.quantize("int8") |
| 151 | + |
| 152 | +output = gemma3.generate("Keras is a", max_length=50) |
| 153 | +print("Quantized output:", output) |
| 154 | + |
| 155 | +# Save INT8 Gemma3 model |
| 156 | +gemma3.save_to_preset("gemma3_int8") |
| 157 | + |
| 158 | +# Reload and compare outputs |
| 159 | +gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8") |
| 160 | + |
| 161 | +output = gemma3_int8.generate("Keras is a", max_length=50) |
| 162 | +print("Quantized reloaded output:", output) |
| 163 | + |
| 164 | + |
| 165 | +# Compute storage savings |
| 166 | +def bytes_to_mib(n): |
| 167 | + return n / (1024**2) |
| 168 | + |
| 169 | + |
| 170 | +gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5") |
| 171 | +gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5") |
| 172 | + |
| 173 | +gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1))) |
| 174 | +print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB") |
| 175 | +print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB") |
| 176 | +print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%") |
| 177 | + |
| 178 | +""" |
| 179 | +## Practical tips |
| 180 | +
|
| 181 | +* Post-training quantization (PTQ) is a one-time operation; you cannot train a model |
| 182 | + after quantizing it to INT8. |
| 183 | +* Always materialize weights before quantization (e.g., `build()` or a forward pass). |
| 184 | +* Expect small numerical deltas; quantify with a metric like MSE on a validation batch. |
| 185 | +* Storage savings are immediate; speedups depend on backend/device kernels. |
| 186 | +
|
| 187 | +## References |
| 188 | +
|
| 189 | +* [Milvus: How does 8-bit quantization or float16 affect the accuracy and speed of Sentence Transformer embeddings and similarity calculations?](https://milvus.io/ai-quick-reference/how-does-quantization-such-as-int8-quantization-or-using-float16-affect-the-accuracy-and-speed-of-sentence-transformer-embeddings-and-similarity-calculations) |
| 190 | +* [NVIDIA Developer Blog: Achieving FP32 accuracy for INT8 inference using quantization-aware training with TensorRT](https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/) |
| 191 | +""" |
0 commit comments