|
| 1 | +""" |
| 2 | +Title: INT4 Quantization in Keras |
| 3 | +Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh) |
| 4 | +Date created: 2025/10/14 |
| 5 | +Last modified: 2025/10/14 |
| 6 | +Description: Complete guide to using INT4 quantization in Keras and KerasHub. |
| 7 | +Accelerator: GPU |
| 8 | +""" |
| 9 | + |
| 10 | +""" |
| 11 | +## What is INT4 quantization? |
| 12 | +
|
| 13 | +Quantization lowers the numerical precision of weights and activations to reduce memory use |
| 14 | +and often speed up inference, at the cost of a small accuracy drop. INT4 post-training |
| 15 | +quantization (PTQ) stores model weights in 4-bit signed integers and dynamically quantizes |
| 16 | +activations to 8-bit at runtime (a W4A8 scheme). Compared with FP32 this can shrink weight |
| 17 | +storage ~8x (2x vs INT8) while retaining acceptable accuracy for many encoder models and |
| 18 | +some decoder models. Compute still leverages widely available NVIDIA INT8 Tensor Cores. |
| 19 | +
|
| 20 | +4-bit is a more aggressive compression than 8-bit and may induce larger quality regressions, |
| 21 | +especially for large autoregressive language models. |
| 22 | +
|
| 23 | +## How it works |
| 24 | +
|
| 25 | +Quantization maps real values to 4-bit integers with a scale: |
| 26 | +
|
| 27 | +1. Per-output-channel scale computed for each weight matrix (symmetric abs-max). |
| 28 | +2. Weights are quantized to values in `[-8, 7]` (4 bits) and packed two-per-byte. |
| 29 | +3. At inference, activations are dynamically scaled and quantized to INT8. |
| 30 | +4. Packed INT4 weights are unpacked to an INT8 tensor (still with INT4-range values). |
| 31 | +5. INT8 x INT8 matmul accumulates in INT32. |
| 32 | +6. Result is dequantized using `(input_scale * per_channel_kernel_scale)`. |
| 33 | +
|
| 34 | +This mirrors the INT8 path described in the |
| 35 | +[INT8 guide](https://keras.io/guides/int8_quantization_in_keras) with some added unpack |
| 36 | +overhead for stronger compression. |
| 37 | +
|
| 38 | +## Benefits |
| 39 | +* Memory / bandwidth bound models: When the implementation spends most of its time on memory I/O, |
| 40 | + reducing the computation time does not reduce its overall runtime. INT4 reduces bytes |
| 41 | + moved by ~8x vs FP32, improving cache behavior and reducing memory stalls; |
| 42 | + this often helps more than increasing raw FLOPs. |
| 43 | +* Accuracy: Many architectures retain acceptable accuracy with INT4; encoder-only models |
| 44 | + often fare better than decoder LLMs. Always validate on your own dataset. |
| 45 | +* Compute bound layers on supported hardware: 4-bit kernels are unpacked to INT8 at inference, |
| 46 | + therefore, on NVIDIA GPUs, INT8 [Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/) |
| 47 | + speed up matmul/conv, boosting throughput on compute-limited layers. |
| 48 | +
|
| 49 | +### What Keras does in INT4 mode |
| 50 | +
|
| 51 | +* **Mapping**: Symmetric, linear quantization with INT4 plus a floating-point scale. |
| 52 | +* **Weights**: per-output-channel scales to preserve accuracy. |
| 53 | +* **Activations**: **dynamic AbsMax** scaling computed at runtime. |
| 54 | +* **Graph rewrite**: Quantization is applied after weights are trained and built; the graph |
| 55 | + is rewritten so you can run or save immediately. |
| 56 | +""" |
| 57 | + |
| 58 | +""" |
| 59 | +## Overview |
| 60 | +
|
| 61 | +This guide shows how to use 4-bit (W4A8) post-training quantization in Keras: |
| 62 | +
|
| 63 | +1. [Quantizing a minimal functional model](#quantizing-a-minimal-functional-model) |
| 64 | +2. [Saving and reloading a quantized model](#saving-and-reloading-a-quantized-model) |
| 65 | +3. [Quantizing a KerasHub model](#quantizing-a-kerashub-model) |
| 66 | +4. [When to use INT4 vs INT8](#when-should-i-use-int4-vs-int8) |
| 67 | +5. [Performance benchmarks](#performance--benchmarking) |
| 68 | +6. [Practical Tips](#practical-tips) |
| 69 | +7. [Limitations](#limitations) |
| 70 | +""" |
| 71 | + |
| 72 | +""" |
| 73 | +## Quantizing a Minimal Functional Model |
| 74 | +
|
| 75 | +Below we build a small functional model, capture a baseline output, quantize to INT4 |
| 76 | +in place, and compare outputs with an MSE metric. (For real evaluation use your |
| 77 | +validation metric.) |
| 78 | +""" |
| 79 | + |
| 80 | +import numpy as np |
| 81 | +import keras |
| 82 | +from keras import layers |
| 83 | + |
| 84 | +# Create a random number generator. |
| 85 | +rng = np.random.default_rng() |
| 86 | + |
| 87 | +# Create a simple functional model. |
| 88 | +inputs = keras.Input(shape=(10,)) |
| 89 | +x = layers.Dense(32, activation="relu")(inputs) |
| 90 | +outputs = layers.Dense(1, name="target")(x) |
| 91 | +model = keras.Model(inputs, outputs) |
| 92 | + |
| 93 | +# Baseline output with full-precision weights. |
| 94 | +x_eval = rng.random((32, 10)).astype("float32") |
| 95 | +y_fp32 = model(x_eval) |
| 96 | + |
| 97 | + |
| 98 | +# Quantize the model in-place to INT4 (W4A8). |
| 99 | +model.quantize("int4") |
| 100 | + |
| 101 | +# Compare outputs (MSE). |
| 102 | +y_int4 = model(x_eval) |
| 103 | +mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4)) |
| 104 | +print("Full-Precision vs INT4 MSE:", float(mse)) |
| 105 | + |
| 106 | +""" |
| 107 | +The INT4 quantized model usually produces outputs close enough for many downstream |
| 108 | +tasks. Expect larger deltas than INT8, so always validate on your own data. |
| 109 | +""" |
| 110 | + |
| 111 | +""" |
| 112 | +## Saving and Reloading a Quantized Model |
| 113 | +
|
| 114 | +You can use standard Keras saving / loading APIs. Quantization metadata (including |
| 115 | +scales and packed weights) is preserved. |
| 116 | +""" |
| 117 | + |
| 118 | +# Save the quantized model and reload to verify round-trip. |
| 119 | +model.save("int4.keras") |
| 120 | +int4_reloaded = keras.saving.load_model("int4.keras") |
| 121 | +y_int4_reloaded = int4_reloaded(x_eval) |
| 122 | + |
| 123 | +# Compare outputs (MSE). |
| 124 | +roundtrip_mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int4_reloaded)) |
| 125 | +print("MSE (INT4 vs reloaded INT4):", float(roundtrip_mse)) |
| 126 | + |
| 127 | +""" |
| 128 | +## Quantizing a KerasHub Model |
| 129 | +
|
| 130 | +All KerasHub models support the `.quantize(...)` API for post-training quantization, |
| 131 | +and follow the same workflow as above. |
| 132 | +
|
| 133 | +In this example, we will: |
| 134 | +
|
| 135 | +1. Load the [gemma3_1b](https://www.kaggle.com/models/keras/gemma3/keras/gemma3_1b) |
| 136 | + preset from KerasHub |
| 137 | +2. Generate text using both the full-precision and quantized models, and compare outputs. |
| 138 | +3. Save both models to disk and compute storage savings. |
| 139 | +4. Reload the INT4 model and verify output consistency with the original quantized model. |
| 140 | +""" |
| 141 | +import os |
| 142 | +from keras_hub.models import Gemma3CausalLM |
| 143 | + |
| 144 | +# Load a Gemma3 preset from KerasHub. |
| 145 | +gemma3 = Gemma3CausalLM.from_preset("gemma3_1b") |
| 146 | + |
| 147 | +# Generate with full-precision weights. |
| 148 | +fp_output = gemma3.generate("Keras is a", max_length=30) |
| 149 | +print("Full-precision output:", fp_output) |
| 150 | + |
| 151 | +# Save the full-precision model to a preset. |
| 152 | +gemma3.save_to_preset("gemma3_fp32") |
| 153 | + |
| 154 | +# Quantize to INT4. |
| 155 | +gemma3.quantize("int4") |
| 156 | + |
| 157 | +# Generate with INT4 weights. |
| 158 | +output = gemma3.generate("Keras is a", max_length=30) |
| 159 | +print("Quantized output:", output) |
| 160 | + |
| 161 | +# Save INT4 model to a new preset. |
| 162 | +gemma3.save_to_preset("gemma3_int4") |
| 163 | + |
| 164 | +# Reload and compare outputs |
| 165 | +gemma3_int4 = Gemma3CausalLM.from_preset("gemma3_int4") |
| 166 | + |
| 167 | +output = gemma3_int4.generate("Keras is a", max_length=30) |
| 168 | +print("Quantized reloaded output:", output) |
| 169 | + |
| 170 | + |
| 171 | +# Compute storage savings |
| 172 | +def bytes_to_mib(n): |
| 173 | + return n / (1024**2) |
| 174 | + |
| 175 | + |
| 176 | +gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5") |
| 177 | +gemma_int4_size = os.path.getsize("gemma3_int4/model.weights.h5") |
| 178 | + |
| 179 | +gemma_reduction = 100.0 * (1.0 - (gemma_int4_size / max(gemma_fp32_size, 1))) |
| 180 | +print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB") |
| 181 | +print(f"Gemma3: INT4 file size: {bytes_to_mib(gemma_int4_size):.2f} MiB") |
| 182 | +print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%") |
| 183 | + |
| 184 | +""" |
| 185 | +## Performance & Benchmarking |
| 186 | +
|
| 187 | +Micro-benchmarks collected on a single NVIDIA L4 (22.5 GB). Baselines are FP32. |
| 188 | +
|
| 189 | +### Text Classification (DistilBERT Base on SST-2) |
| 190 | +
|
| 191 | +<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/77e874187d6da3f8280c053192f78c06/int4-quantization-micro-benchmark-distilbert.ipynb) |
| 192 | +
|
| 193 | +| Metric | FP32 | INT4 | Change | |
| 194 | +| ------ | ---- | ---- | ------ | |
| 195 | +| Accuracy (↑) | 91.06% | 90.14% | -0.92pp | |
| 196 | +| Model Size (MB, ↓) | 255.86 | 159.49 | -37.67% | |
| 197 | +| Peak GPU Memory (MiB, ↓) | 1554.00 | 1243.26 | -20.00% | |
| 198 | +| Latency (ms/sample, ↓) | 6.43 | 5.73 | -10.83% | |
| 199 | +| Throughput (samples/s, ↑) | 155.60 | 174.50 | +12.15% | |
| 200 | +
|
| 201 | +**Analysis**: Accuracy drop is modest (<1pp) with notable speed and memory gains; |
| 202 | +encoder-only models tend to retain fidelity under heavier weight compression. |
| 203 | +
|
| 204 | +### Text Generation (Falcon 1B) |
| 205 | +
|
| 206 | +<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/19ab238e0f5b29ae24c0faf4128e7d7e/int4_quantization_micro_benchmark_falcon.ipynb) |
| 207 | +
|
| 208 | +| Metric | FP32 | INT4 | Change | |
| 209 | +| ------ | ---- | ---- | ------ | |
| 210 | +| Perplexity (↓) | 7.44 | 9.98 | +34.15% | |
| 211 | +| Model Size (GB, ↓) | 4.8884 | 0.9526 | -80.51% | |
| 212 | +| Peak GPU Memory (MiB, ↓) | 8021.12 | 5483.46 | -31.64% | |
| 213 | +| First Token Latency (ms, ↓) | 128.87 | 122.50 | -4.95% | |
| 214 | +| Sequence Latency (ms, ↓) | 338.29 | 181.93 | -46.22% | |
| 215 | +| Token Throughput (tokens/s, ↑) | 174.41 | 256.96 | +47.33% | |
| 216 | +
|
| 217 | +**Analysis**: INT4 gives large size (-80%) and memory (-32%) reductions. Perplexity |
| 218 | +increases (expected for aggressive compression) yet sequence latency drops and |
| 219 | +throughput rises ~50%. |
| 220 | +
|
| 221 | +### Text Generation (Gemma3 1B) |
| 222 | +
|
| 223 | +<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/9ca7813971868d5d1a16cd7998d0e352/int4_quantization_micro_benchmark_gemma3.ipynb) |
| 224 | +
|
| 225 | +| Metric | FP32 | INT4 | Change | |
| 226 | +| ------ | ---- | ---- | ------ | |
| 227 | +| Perplexity (↓) | 6.17 | 10.46 | +69.61% | |
| 228 | +| Model Size (GB, ↓) | 3.7303 | 1.4576 | -60.92% | |
| 229 | +| Peak GPU Memory (MiB, ↓) | 6844.67 | 5008.14 | -26.83% | |
| 230 | +| First Token Latency (ms, ↓) | 57.42 | 64.21 | +11.83% | |
| 231 | +| Sequence Latency (ms, ↓) | 239.78 | 161.18 | -32.78% | |
| 232 | +| Token Throughput (tokens/s, ↑) | 246.06 | 366.05 | +48.76% | |
| 233 | +
|
| 234 | +**Analysis**: INT4 gives large size (-61%) and memory (-27%) reductions. Perplexity |
| 235 | +increases (expected for aggressive compression) yet sequence latency drops and |
| 236 | +throughput rises ~50%. |
| 237 | +
|
| 238 | +### Text Generation (Llama 3.2 1B) |
| 239 | +
|
| 240 | +<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/310f50a0ca0eba3754de41c612b3b8ef/int4_quantization_micro_benchmark_llama3.ipynb) |
| 241 | +
|
| 242 | +| Metric | FP32 | INT4 | Change | |
| 243 | +| ------ | ---- | ---- | ------ | |
| 244 | +| Perplexity (↓) | 6.38 | 14.16 | +121.78% | |
| 245 | +| Model Size (GB, ↓) | 5.5890 | 2.4186 | -56.73% | |
| 246 | +| Peak GPU Memory (MiB, ↓) | 9509.49 | 6810.26 | -28.38% | |
| 247 | +| First Token Latency (ms, ↓) | 209.41 | 219.09 | +4.62% | |
| 248 | +| Sequence Latency (ms, ↓) | 322.33 | 262.15 | -18.67% | |
| 249 | +| Token Throughput (tokens/s, ↑) | 183.82 | 230.78 | +25.55% | |
| 250 | +
|
| 251 | +**Analysis**: INT4 gives large size (-57%) and memory (-28%) reductions. Perplexity |
| 252 | +increases (expected for aggressive compression) yet sequence latency drops and |
| 253 | +throughput rises ~25%. |
| 254 | +
|
| 255 | +### Text Generation (OPT 125M) |
| 256 | +
|
| 257 | +<img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> [**View in Colab**](https://colab.research.google.com/gist/JyotinderSingh/918fcdb8a1433dea12800f8ca4a240f5/int4_quantization_micro_benchmark_opt.ipynb) |
| 258 | +
|
| 259 | +| Metric | FP32 | INT4 | Change | |
| 260 | +| ------ | ---- | ---- | ------ | |
| 261 | +| Perplexity (↓) | 13.85 | 21.02 | +51.79% | |
| 262 | +| Model Size (MB, ↓) | 468.3 | 284.0 | -39.37% | |
| 263 | +| Peak GPU Memory (MiB, ↓) | 1007.23 | 659.28 | -34.54% | |
| 264 | +| First Token Latency (ms/sample, ↓) | 95.79 | 97.87 | +2.18% | |
| 265 | +| Sequence Latency (ms/sample, ↓) | 60.35 | 54.64 | -9.46% | |
| 266 | +| Throughput (samples/s, ↑) | 973.41 | 1075.15 | +10.45% | |
| 267 | +
|
| 268 | +**Analysis**: INT4 gives large size (-39%) and memory (-35%) reductions. Perplexity |
| 269 | +increases (expected for aggressive compression) yet sequence latency drops and |
| 270 | +throughput rises ~10%. |
| 271 | +""" |
| 272 | + |
| 273 | +""" |
| 274 | +## When should I use INT4 vs INT8? |
| 275 | +
|
| 276 | +| Goal / Constraint | Prefer INT8 | Prefer INT4 (W4A8) | |
| 277 | +| ----------------- | ----------- | ------------------ | |
| 278 | +| Minimal accuracy drop critical | ✔︎ | | |
| 279 | +| Maximum compression (disk / RAM) | | ✔︎ | |
| 280 | +| Bandwidth-bound inference | Possible | Often better | |
| 281 | +| Decoder LLM | ✔︎ | Try with eval first | |
| 282 | +| Encoder / classification models | ✔︎ | ✔︎ | |
| 283 | +| Available kernels / tooling maturity | ✔︎ | Emerging | |
| 284 | +
|
| 285 | +* Start with INT8; if memory or distribution size is still a bottleneck, evaluate INT4. |
| 286 | +* For LLMs, measure task-specific metrics (perplexity, exact match, etc.) after INT4. |
| 287 | +* Combine INT4 weights + LoRA adapters for efficient fine-tuning workflows. |
| 288 | +""" |
| 289 | + |
| 290 | +""" |
| 291 | +## Practical Tips |
| 292 | +
|
| 293 | +* Post-training quantization (PTQ) is a one-time operation; you cannot train a model |
| 294 | + after quantizing it to INT4. |
| 295 | +* Always materialize weights before quantization (e.g., `build()` or a forward pass). |
| 296 | +* Evaluate on a representative validation set; track task metrics, not just MSE. |
| 297 | +* Use LoRA for further fine-tuning. |
| 298 | +
|
| 299 | +## Limitations |
| 300 | +* Runtime unpack adds overhead (weights are decompressed layer-wise for each forward pass). |
| 301 | +* Large compression leads to accuracy drop (especially for decoder-only LLMs). |
| 302 | +* LoRA export path is lossy (dequantize -> add delta -> requantize). |
| 303 | +* Keras does not yet support native fused INT4 kernels; relies on unpack + INT8 matmul. |
| 304 | +""" |
0 commit comments