Skip to content

Commit f0a48a6

Browse files
Disables implicit GPTQ quantization using dtype_policy setter (#21895)
* disables implicit gptq quantization using dtype_policy setter * Update layer.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update layer_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent d130816 commit f0a48a6

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

keras/src/layers/layer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,15 @@ def dtype_policy(self, value):
757757
self._dtype_policy = policy
758758
if policy.quantization_mode is not None:
759759
if self.built and not getattr(self, "_is_quantized", False):
760+
if policy.quantization_mode == "gptq":
761+
raise ValueError(
762+
"Implicitly enabling GPTQ quantization by setting "
763+
f"`dtype_policy` to '{value}' is not supported. "
764+
"GPTQ requires a calibration dataset and a "
765+
"`GPTQConfig` object.\n\n"
766+
"Please use the `.quantize('gptq', config=...)` method "
767+
"on the layer or model instead."
768+
)
760769
self.quantize(policy.quantization_mode)
761770

762771
@property

keras/src/layers/layer_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,16 @@ def test_quantized_layer_with_remat(self):
233233
self.assertLen(mock_remat.rematted_functions, 1)
234234
next(iter(mock_remat.rematted_functions.values())).assert_called()
235235

236+
def test_gptq_quantization_by_setting_dtype(self):
237+
"""Tests error being raised when dtype is set to GPTQ."""
238+
with self.assertRaisesRegex(
239+
ValueError,
240+
"Implicitly enabling GPTQ quantization.*is not supported",
241+
):
242+
layer = layers.Dense(3)
243+
layer.build((2, 4))
244+
layer.dtype_policy = "gptq/4/-1_from_float32"
245+
236246
def test_functional_model_with_remat(self):
237247
if backend.backend() in ("openvino", "numpy"):
238248
self.skipTest(

0 commit comments

Comments
 (0)