|
| 1 | +""" |
| 2 | +Title: GPTQ Quantization in Keras |
| 3 | +Author: [Jyotinder Singh](https://x.com/Jyotinder_Singh) |
| 4 | +Date created: 2025/10/16 |
| 5 | +Last modified: 2025/10/16 |
| 6 | +Description: How to run weight-only GPTQ quantization for Keras & KerasHub models. |
| 7 | +Accelerator: GPU |
| 8 | +""" |
| 9 | + |
| 10 | +""" |
| 11 | +## What is GPTQ? |
| 12 | +
|
| 13 | +GPTQ ("Generative Pre-Training Quantization") is a post-training, weight-only |
| 14 | +quantization method that uses a second-order approximation of the loss (via a |
| 15 | +Hessian estimate) to minimize the error introduced when compressing weights to |
| 16 | +lower precision, typically 4-bit integers. |
| 17 | +
|
| 18 | +Unlike standard post-training techniques, GPTQ keeps activations in |
| 19 | +higher-precision and only quantizes the weights. This often preserves model |
| 20 | +quality in low bit-width settings while still providing large storage and |
| 21 | +memory savings. |
| 22 | +
|
| 23 | +Keras supports GPTQ quantization for KerasHub models via the |
| 24 | +`keras.quantizers.GPTQConfig` class. |
| 25 | +""" |
| 26 | + |
| 27 | +""" |
| 28 | +## Load a KerasHub model |
| 29 | +
|
| 30 | +This guide uses the `Gemma3CausalLM` model from KerasHub, a small (1B |
| 31 | +parameter) causal language model. |
| 32 | +
|
| 33 | +""" |
| 34 | +import keras |
| 35 | +from keras_hub.models import Gemma3CausalLM |
| 36 | +from datasets import load_dataset |
| 37 | + |
| 38 | + |
| 39 | +prompt = "Keras is a" |
| 40 | + |
| 41 | +model = Gemma3CausalLM.from_preset("gemma3_1b") |
| 42 | + |
| 43 | +outputs = model.generate(prompt, max_length=30) |
| 44 | +print(outputs) |
| 45 | + |
| 46 | +""" |
| 47 | +## Configure & run GPTQ quantization |
| 48 | +
|
| 49 | +You can configure GPTQ quantization via the `keras.quantizers.GPTQConfig` class. |
| 50 | +
|
| 51 | +The GPTQ configuration requires a calibration dataset and tokenizer, which it |
| 52 | +uses to estimate the Hessian and quantization error. Here, we use a small slice |
| 53 | +of the WikiText-2 dataset for calibration. |
| 54 | +
|
| 55 | +You can tune several parameters to trade off speed, memory, and accuracy. The |
| 56 | +most important of these are `weight_bits` (the bit-width to quantize weights to) |
| 57 | +and `group_size` (the number of weights to quantize together). The group size |
| 58 | +controls the granularity of quantization: smaller groups typically yield better |
| 59 | +accuracy but are slower to quantize and may use more memory. A good starting |
| 60 | +point is `group_size=128` for 4-bit quantization (`weight_bits=4`). |
| 61 | +
|
| 62 | +In this example, we first prepare a tiny calibration set, and then run GPTQ on |
| 63 | +the model using the `.quantize(...)` API. |
| 64 | +""" |
| 65 | + |
| 66 | +# Calibration slice (use a larger/representative set in practice) |
| 67 | +texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")["text"] |
| 68 | + |
| 69 | +calibration_dataset = [ |
| 70 | + s + "." for text in texts for s in map(str.strip, text.split(".")) if s |
| 71 | +] |
| 72 | + |
| 73 | +gptq_config = keras.quantizers.GPTQConfig( |
| 74 | + dataset=calibration_dataset, |
| 75 | + tokenizer=model.preprocessor.tokenizer, |
| 76 | + weight_bits=4, |
| 77 | + group_size=128, |
| 78 | + num_samples=256, |
| 79 | + sequence_length=256, |
| 80 | + hessian_damping=0.01, |
| 81 | + symmetric=False, |
| 82 | + activation_order=False, |
| 83 | +) |
| 84 | + |
| 85 | +model.quantize("gptq", config=gptq_config) |
| 86 | + |
| 87 | +outputs = model.generate(prompt, max_length=30) |
| 88 | +print(outputs) |
| 89 | + |
| 90 | +""" |
| 91 | +## Model Export |
| 92 | +
|
| 93 | +The GPTQ quantized model can be saved to a preset and reloaded elsewhere, just |
| 94 | +like any other KerasHub model. |
| 95 | +""" |
| 96 | + |
| 97 | +model.save_to_preset("gemma3_gptq_w4gs128_preset") |
| 98 | +model_from_preset = Gemma3CausalLM.from_preset("gemma3_gptq_w4gs128_preset") |
| 99 | +output = model_from_preset.generate(prompt, max_length=30) |
| 100 | +print(output) |
| 101 | + |
| 102 | +""" |
| 103 | +## Performance & Benchmarking |
| 104 | +
|
| 105 | +Micro-benchmarks collected on a single NVIDIA 4070 Ti Super (16 GB). |
| 106 | +Baselines are FP32. |
| 107 | +
|
| 108 | +Dataset: WikiText-2. |
| 109 | +
|
| 110 | +
|
| 111 | +| Model (preset) | Perplexity Increase % (↓ better) | Disk Storage Reduction Δ % (↓ better) | VRAM Reduction Δ % (↓ better) | First-token Latency Δ % (↓ better) | Throughput Δ % (↑ better) | |
| 112 | +| --------------------------------- | -------------------------------: | ------------------------------------: | ----------------------------: | ---------------------------------: | ------------------------: | |
| 113 | +| GPT2 (gpt2_base_en_cnn_dailymail) | 1.0% | -50.1% ↓ | -41.1% ↓ | +0.7% ↑ | +20.1% ↑ | |
| 114 | +| OPT (opt_125m_en) | 10.0% | -49.8% ↓ | -47.0% ↓ | +6.7% ↑ | -15.7% ↓ | |
| 115 | +| Bloom (bloom_1.1b_multi) | 7.0% | -47.0% ↓ | -54.0% ↓ | +1.8% ↑ | -15.7% ↓ | |
| 116 | +| Gemma3 (gemma3_1b) | 3.0% | -51.5% ↓ | -51.8% ↓ | +39.5% ↑ | +5.7% ↑ | |
| 117 | +
|
| 118 | +
|
| 119 | +Detailed benchmarking numbers and scripts are available |
| 120 | +[here](https://github.com/keras-team/keras/pull/21641). |
| 121 | +
|
| 122 | +### Analysis |
| 123 | +
|
| 124 | +There is notable reduction in disk space and VRAM usage across all models, with |
| 125 | +disk space savings around 50% and VRAM savings ranging from 41% to 54%. The |
| 126 | +reported disk savings understate the true weight compression because presets |
| 127 | +also include non-weight assets. |
| 128 | +
|
| 129 | +Perplexity increases only marginally, indicating model quality is largely |
| 130 | +preserved after quantization. |
| 131 | +""" |
| 132 | + |
| 133 | +""" |
| 134 | +## Practical tips |
| 135 | +
|
| 136 | +* GPTQ is a post-training technique; training after quantization is not supported. |
| 137 | +* Always use the model's own tokenizer for calibration. |
| 138 | +* Use a representative calibration set; small slices are only for demos. |
| 139 | +* Start with W4 group_size=128; tune per model/task. |
| 140 | +""" |
0 commit comments