diff --git a/keras/api/_tf_keras/keras/quantizers/__init__.py b/keras/api/_tf_keras/keras/quantizers/__init__.py index 299e467ac1bb..205183264c03 100644 --- a/keras/api/_tf_keras/keras/quantizers/__init__.py +++ b/keras/api/_tf_keras/keras/quantizers/__init__.py @@ -8,6 +8,18 @@ from keras.src.quantizers import get as get from keras.src.quantizers import serialize as serialize from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig +from keras.src.quantizers.quantization_config import ( + Float8QuantizationConfig as Float8QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + Int4QuantizationConfig as Int4QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + Int8QuantizationConfig as Int8QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + QuantizationConfig as QuantizationConfig, +) from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize diff --git a/keras/api/quantizers/__init__.py b/keras/api/quantizers/__init__.py index 299e467ac1bb..205183264c03 100644 --- a/keras/api/quantizers/__init__.py +++ b/keras/api/quantizers/__init__.py @@ -8,6 +8,18 @@ from keras.src.quantizers import get as get from keras.src.quantizers import serialize as serialize from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig +from keras.src.quantizers.quantization_config import ( + Float8QuantizationConfig as Float8QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + Int4QuantizationConfig as Int4QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + Int8QuantizationConfig as Int8QuantizationConfig, +) +from keras.src.quantizers.quantization_config import ( + QuantizationConfig as QuantizationConfig, +) from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer from keras.src.quantizers.quantizers import Quantizer as Quantizer from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 8af0120c5101..48883c9f0d37 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -11,6 +11,8 @@ from keras.src.api_export import keras_export from keras.src.layers.input_spec import InputSpec from keras.src.layers.layer import Layer +from keras.src.quantizers.quantization_config import QuantizationConfig +from keras.src.quantizers.quantization_config import validate_and_resolve_config from keras.src.quantizers.quantizers import dequantize_with_sz_map @@ -372,9 +374,9 @@ def variable_serialization_spec(self): def quantized_build(self, kernel_shape, mode, config=None): if mode == "int8": - self._int8_build(kernel_shape) + self._int8_build(kernel_shape, config) elif mode == "int4": - self._int4_build(kernel_shape) + self._int4_build(kernel_shape, config) elif mode == "float8": self._float8_build() elif mode == "gptq": @@ -383,8 +385,13 @@ def quantized_build(self, kernel_shape, mode, config=None): raise self._quantization_mode_error(mode) self._is_quantized = True - def _int8_build(self, kernel_shape): - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + def _int8_build(self, kernel_shape, config=None): + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=-1) + ) + ) + self._kernel = self.add_weight( name="kernel", shape=kernel_shape, @@ -483,7 +490,7 @@ def _gptq_call(self, inputs, training=False): y = self.activation(y) return y - def _int4_build(self, kernel_shape): + def _int4_build(self, kernel_shape, config=None): """Build variables for int4 quantization. `kernel_shape` is the *original* float32 kernel shape @@ -492,8 +499,10 @@ def _int4_build(self, kernel_shape): int8 byte. """ # Per-channel int8 quantizer for the last axis (features). - self.inputs_quantizer = quantizers.AbsMaxQuantizer( - axis=-1, + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=-1) + ) ) input_dim, output_dim = kernel_shape packed_rows = (input_dim + 1) // 2 # ceil for odd dims @@ -582,11 +591,15 @@ def grad_fn(*args, upstream=None): inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel)) return (inputs_grad, None, None) - inputs, inputs_scale = self.inputs_quantizer(inputs) + output_scale = kernel_scale + if self.inputs_quantizer: + inputs, inputs_scale = self.inputs_quantizer(inputs) + output_scale = ops.multiply(output_scale, inputs_scale) + x = ops.matmul(inputs, kernel) # De-scale outputs x = ops.cast(x, self.compute_dtype) - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + x = ops.divide(x, output_scale) return x, grad_fn x = matmul_with_inputs_gradient( @@ -633,10 +646,15 @@ def grad_fn(*args, upstream=None): inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel)) return (inputs_grad, None, None) - inputs, inputs_scale = self.inputs_quantizer(inputs) + output_scale = kernel_scale + + if self.inputs_quantizer: + inputs, inputs_scale = self.inputs_quantizer(inputs) + output_scale = ops.multiply(output_scale, inputs_scale) + x = ops.matmul(inputs, unpacked_kernel) x = ops.cast(x, self.compute_dtype) - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + x = ops.divide(x, output_scale) return x, grad_fn x = matmul_with_inputs_gradient( @@ -753,25 +771,33 @@ def quantize(self, mode, type_check=True, config=None): if type_check and (type(self) is not Dense): raise self._not_implemented_error(self.quantize) + config = validate_and_resolve_config(mode, config) + mode = config.mode + kernel_shape = self._kernel.shape if mode == "int8": - kernel_value, kernel_scale = quantizers.abs_max_quantize( - self._kernel, axis=0, to_numpy=True + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=0) + ) + kernel_value, kernel_scale = weight_quantizer( + self._kernel, to_numpy=True ) kernel_scale = ops.squeeze(kernel_scale, axis=0) del self._kernel # Build variables for int8 mode - self.quantized_build(kernel_shape, mode) + self.quantized_build(kernel_shape, mode, config) self._kernel.assign(kernel_value) self.kernel_scale.assign(kernel_scale) elif mode == "int4": # 1. Quantize to int4 values (still int8 dtype, range [-8,7]) - kernel_value_int4, kernel_scale = quantizers.abs_max_quantize( - self._kernel, - axis=0, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer( + axis=0, value_range=(-8, 7), output_dtype="int8" + ), + ) + kernel_value_int4, kernel_scale = weight_quantizer( + self._kernel, to_numpy=True ) kernel_scale = ops.squeeze(kernel_scale, axis=0) # 2. Pack two int4 values into a single int8 byte. @@ -779,7 +805,7 @@ def quantize(self, mode, type_check=True, config=None): del self._kernel # Build variables using the original kernel shape; _int4_build will # compute the packed shape internally. - self.quantized_build(kernel_shape, mode) + self.quantized_build(kernel_shape, mode, config) # Assign packed values. self._kernel.assign(packed_kernel_value) self.kernel_scale.assign(kernel_scale) diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py index 802ca10a1d41..f7acc28c0ae3 100644 --- a/keras/src/layers/core/dense_test.py +++ b/keras/src/layers/core/dense_test.py @@ -17,9 +17,67 @@ from keras.src import testing from keras.src.backend.common import keras_tensor from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantizers import AbsMaxQuantizer class DenseTest(testing.TestCase): + @parameterized.named_parameters( + ("int8", "int8", {"axis": 0}, {"axis": -1}), + ( + "int4", + "int4", + {"axis": 0, "value_range": (-8, 7), "output_dtype": "int8"}, + {"axis": -1}, + ), + ("int8_weight_only", "int8", {"axis": 0}, None), + ) + def test_dense_quantize_config( + self, mode, weight_quantizer_args, activation_quantizer_args + ): + """Test Dense quantization with QuantizationConfig.""" + layer = layers.Dense(units=32) + layer.build((None, 8)) + + weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args) + if activation_quantizer_args is not None: + activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args) + else: + activation_quantizer = None + + if mode == "int8": + config = Int8QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + elif mode == "int4": + config = Int4QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + + layer.quantize(mode, config=config) + + if activation_quantizer_args is not None: + # Verify inputs_quantizer is set correctly + self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer) + self.assertEqual(layer.inputs_quantizer.axis, (-1,)) + else: + # Verify inputs_quantizer is None + self.assertIsNone(layer.inputs_quantizer) + + # Verify call works + x = np.random.random((2, 8)).astype("float32") + y = layer(x) + self.assertEqual(y.shape, (2, 32)) + + if mode == "int4": + # Verify kernel is int8 (packed int4) + self.assertEqual( + backend.standardize_dtype(layer._kernel.dtype), "int8" + ) + @pytest.mark.requires_trainable_backend def test_dense_basics(self): # 2D case, no bias. diff --git a/keras/src/layers/core/einsum_dense.py b/keras/src/layers/core/einsum_dense.py index 23d98fe3ec04..110b96efa096 100644 --- a/keras/src/layers/core/einsum_dense.py +++ b/keras/src/layers/core/einsum_dense.py @@ -6,6 +6,7 @@ import numpy as np from keras.src import activations +from keras.src import backend from keras.src import constraints from keras.src import dtype_policies from keras.src import initializers @@ -15,6 +16,7 @@ from keras.src.api_export import keras_export from keras.src.layers.input_spec import InputSpec from keras.src.layers.layer import Layer +from keras.src.quantizers.quantization_config import QuantizationConfig from keras.src.quantizers.quantizers import dequantize_with_sz_map @@ -444,9 +446,9 @@ def variable_serialization_spec(self): def quantized_build(self, kernel_shape, mode, config=None): if mode == "int8": - self._int8_build(kernel_shape) + self._int8_build(kernel_shape, config) elif mode == "int4": - self._int4_build(kernel_shape) + self._int4_build(kernel_shape, config) elif mode == "float8": self._float8_build() elif mode == "gptq": @@ -455,10 +457,13 @@ def quantized_build(self, kernel_shape, mode, config=None): raise self._quantization_mode_error(mode) self._is_quantized = True - def _int8_build(self, kernel_shape): + def _int8_build(self, kernel_shape, config=None): self._set_quantization_info() - self.inputs_quantizer = quantizers.AbsMaxQuantizer( - axis=self._input_reduced_axes + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer(axis=self._input_reduced_axes), + ) ) self._kernel = self.add_weight( name="kernel", @@ -591,7 +596,7 @@ def _gptq_call(self, inputs, training=False): y = self.activation(y) return y - def _int4_build(self, kernel_shape): + def _int4_build(self, kernel_shape, config=None): """Build variables for int4 quantization. The packed int4 kernel stores two int4 values within a single int8 @@ -603,8 +608,11 @@ def _int4_build(self, kernel_shape): self._set_quantization_info() # Quantizer for the inputs (per the reduced axes) - self.inputs_quantizer = quantizers.AbsMaxQuantizer( - axis=self._input_reduced_axes + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer(axis=self._input_reduced_axes), + ) ) # Choose the axis to perform int4 packing - use the first reduced axis @@ -727,13 +735,34 @@ def grad_fn(*args, upstream=None): ) return (inputs_grad, None, None) - inputs, inputs_scale = self.inputs_quantizer(inputs) - x = ops.einsum(self.equation, inputs, kernel) - # Deal with `inputs_scale` - inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input") - # De-scale outputs - x = ops.cast(x, self.compute_dtype) - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + if self.inputs_quantizer: + inputs, inputs_scale = self.inputs_quantizer(inputs) + # Align `inputs_scale` axes with the output + # for correct broadcasting + inputs_scale = self._adjust_scale_for_quant( + inputs_scale, "input" + ) + x = ops.einsum(self.equation, inputs, kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + else: + # Weight-only quantization: dequantize kernel and use float + # einsum. This is a workaround for PyTorch's einsum which + # doesn't support mixed-precision inputs (float input, + # int8 kernel). + if backend.backend() == "torch": + kernel_scale = self._adjust_scale_for_dequant(kernel_scale) + float_kernel = ops.divide( + ops.cast(kernel, dtype=self.compute_dtype), + kernel_scale, + ) + x = ops.einsum(self.equation, inputs, float_kernel) + else: + x = ops.einsum(self.equation, inputs, kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, kernel_scale) return x, grad_fn x = einsum_with_inputs_gradient( @@ -803,17 +832,36 @@ def grad_fn(*args, upstream=None): return (inputs_grad, None, None) # Quantize inputs per `self.inputs_quantizer`. - inputs_q, inputs_scale = self.inputs_quantizer(inputs) - - # Compute einsum on quantized inputs and unpacked int4 kernel. - x = ops.einsum(self.equation, inputs_q, unpacked_kernel) - - # Align `inputs_scale` axes with the output for correct broadcasting - inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input") - - # De-scale outputs. - x = ops.cast(x, self.compute_dtype) - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + if self.inputs_quantizer: + inputs_q, inputs_scale = self.inputs_quantizer(inputs) + # Align `inputs_scale` axes with the output + # for correct broadcasting + inputs_scale = self._adjust_scale_for_quant( + inputs_scale, "input" + ) + x = ops.einsum(self.equation, inputs_q, unpacked_kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + else: + # Weight-only quantization: dequantize kernel and use float + # einsum. This is a workaround for PyTorch's einsum which + # doesn't support mixed-precision inputs (float input, + # int4 kernel). + if backend.backend() == "torch": + # Align `kernel_scale` to the same layout as + # `unpacked_kernel`. + kernel_scale = self._adjust_scale_for_dequant(kernel_scale) + float_kernel = ops.divide( + ops.cast(unpacked_kernel, dtype=self.compute_dtype), + kernel_scale, + ) + x = ops.einsum(self.equation, inputs, float_kernel) + else: + x = ops.einsum(self.equation, inputs, unpacked_kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, kernel_scale) return x, grad_fn x = einsum_with_inputs_gradient( @@ -938,19 +986,27 @@ def quantize(self, mode, type_check=True, config=None): if mode == "int8": # Quantize `self._kernel` to int8 and compute corresponding scale - kernel_value, kernel_scale = quantizers.abs_max_quantize( - self._kernel, axis=self._kernel_reduced_axes, to_numpy=True + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes), + ) + kernel_value, kernel_scale = weight_quantizer( + self._kernel, to_numpy=True ) kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel") del self._kernel elif mode == "int4": # Quantize to int4 values (stored in int8 dtype, range [-8, 7]) - kernel_value_int4, kernel_scale = quantizers.abs_max_quantize( - self._kernel, - axis=self._kernel_reduced_axes, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer( + axis=self._kernel_reduced_axes, + value_range=(-8, 7), + output_dtype="int8", + ), + ) + kernel_value_int4, kernel_scale = weight_quantizer( + self._kernel, to_numpy=True ) kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel") diff --git a/keras/src/layers/core/einsum_dense_test.py b/keras/src/layers/core/einsum_dense_test.py index 92496f5f9d7a..4f7dfef9fd5b 100644 --- a/keras/src/layers/core/einsum_dense_test.py +++ b/keras/src/layers/core/einsum_dense_test.py @@ -16,9 +16,77 @@ from keras.src import saving from keras.src import testing from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantizers import AbsMaxQuantizer class EinsumDenseTest(testing.TestCase): + @parameterized.named_parameters( + ("int8", "int8", {"axis": 0}, {"axis": -1}), + ( + "int4", + "int4", + {"axis": 0, "value_range": (-8, 7), "output_dtype": "int8"}, + {"axis": -1}, + ), + ("int8_weight_only", "int8", {"axis": 0}, None), + ( + "int4_weight_only", + "int4", + {"axis": 0, "value_range": (-8, 7), "output_dtype": "int8"}, + None, + ), + ) + def test_einsum_dense_quantize( + self, mode, weight_quantizer_args, activation_quantizer_args + ): + """Test EinsumDense quantization with QuantizationConfig.""" + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + + weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args) + if activation_quantizer_args is not None: + activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args) + else: + activation_quantizer = None + + if mode == "int8": + config = Int8QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + elif mode == "int4": + config = Int4QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + + layer.quantize(mode, config=config) + + if activation_quantizer_args is not None: + # Verify inputs_quantizer is set correctly + self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer) + self.assertEqual(layer.inputs_quantizer.axis, (-1,)) + else: + # Verify inputs_quantizer is None + self.assertIsNone(layer.inputs_quantizer) + + # Verify call works + x = np.random.random((2, 3)).astype("float32") + y = layer(x) + self.assertEqual(y.shape, (2, 8, 32)) + + if mode == "int4": + # Verify kernel is int8 (packed int4) + self.assertEqual( + backend.standardize_dtype(layer._kernel.dtype), "int8" + ) + @parameterized.named_parameters( { "testcase_name": "_1d_end_weight", diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py index c1cb3b6b0117..f4ee2d4e52d9 100644 --- a/keras/src/layers/core/embedding.py +++ b/keras/src/layers/core/embedding.py @@ -10,6 +10,8 @@ from keras.src.api_export import keras_export from keras.src.backend import KerasTensor from keras.src.layers.layer import Layer +from keras.src.quantizers.quantization_config import QuantizationConfig +from keras.src.quantizers.quantization_config import validate_and_resolve_config @keras_export("keras.layers.Embedding") @@ -315,16 +317,16 @@ def variable_serialization_spec(self): ], } - def quantized_build(self, embeddings_shape, mode): + def quantized_build(self, embeddings_shape, mode, config=None): if mode == "int8": - self._int8_build(embeddings_shape) + self._int8_build(embeddings_shape, config) elif mode == "int4": - self._int4_build(embeddings_shape) + self._int4_build(embeddings_shape, config) else: raise self._quantization_mode_error(mode) self._is_quantized = True - def _int8_build(self, embeddings_shape): + def _int8_build(self, embeddings_shape, config=None): self._embeddings = self.add_weight( name="embeddings", shape=embeddings_shape, @@ -342,7 +344,7 @@ def _int8_build(self, embeddings_shape): trainable=False, ) - def _int4_build(self, embeddings_shape): + def _int4_build(self, embeddings_shape, config=None): input_dim, output_dim = embeddings_shape packed_rows = (output_dim + 1) // 2 # ceil for odd dims @@ -412,26 +414,37 @@ def quantize(self, mode, type_check=True, config=None): if type_check and (type(self) is not Embedding): raise self._not_implemented_error(self.quantize) + config = validate_and_resolve_config(mode, config) + mode = config.mode + embeddings_shape = (self.input_dim, self.output_dim) if mode == "int8": # Quantize `self._embeddings` to int8 and compute corresponding # scale. - embeddings_value, embeddings_scale = quantizers.abs_max_quantize( - self._embeddings, axis=-1, to_numpy=True + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer(axis=-1), + ) + embeddings_value, embeddings_scale = weight_quantizer( + self._embeddings, to_numpy=True ) embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) del self._embeddings - self.quantized_build(embeddings_shape, mode) + self.quantized_build(embeddings_shape, mode, config) self._embeddings.assign(embeddings_value) self.embeddings_scale.assign(embeddings_scale) elif mode == "int4": # Quantize to int4 values (stored in int8 dtype, range [-8, 7]). - embeddings_value, embeddings_scale = quantizers.abs_max_quantize( - self._embeddings, - axis=-1, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer( + axis=-1, + value_range=(-8, 7), + output_dtype="int8", + ), + ) + embeddings_value, embeddings_scale = weight_quantizer( + self._embeddings, to_numpy=True ) embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) # 2. Pack two int4 values into a single int8 byte. @@ -439,7 +452,7 @@ def quantize(self, mode, type_check=True, config=None): embeddings_value, axis=-1 ) del self._embeddings - self.quantized_build(embeddings_shape, mode) + self.quantized_build(embeddings_shape, mode, config) self._embeddings.assign(packed_embeddings_value) self.embeddings_scale.assign(embeddings_scale) else: diff --git a/keras/src/layers/core/embedding_test.py b/keras/src/layers/core/embedding_test.py index 68b4ca1d9c15..4337fda7cd6a 100644 --- a/keras/src/layers/core/embedding_test.py +++ b/keras/src/layers/core/embedding_test.py @@ -12,10 +12,50 @@ from keras.src import ops from keras.src import quantizers from keras.src import saving +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantizers import AbsMaxQuantizer from keras.src.testing import test_case class EmbeddingTest(test_case.TestCase): + @parameterized.named_parameters( + ("int8", "int8", {"axis": -1}), + ( + "int4", + "int4", + {"axis": -1, "value_range": (-8, 7), "output_dtype": "int8"}, + ), + ("int8_custom", "int8", {"axis": -1}), + ) + def test_embedding_quantize_config(self, mode, weight_quantizer_args): + """Test Embedding quantization with QuantizationConfig.""" + layer = layers.Embedding(input_dim=10, output_dim=6) + layer.build((None,)) + + weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args) + if mode == "int8": + config = Int8QuantizationConfig( + weight_quantizer=weight_quantizer, activation_quantizer=None + ) + elif mode == "int4": + config = Int4QuantizationConfig( + weight_quantizer=weight_quantizer, activation_quantizer=None + ) + + layer.quantize(mode, config=config) + + # Verify weights are quantized + self.assertEqual( + backend.standardize_dtype(layer._embeddings.dtype), "int8" + ) + self.assertTrue(hasattr(layer, "embeddings_scale")) + + # Verify call works + x = np.random.randint(0, 10, size=(2, 3)) + y = layer(x) + self.assertEqual(y.shape, (2, 3, 6)) + @pytest.mark.requires_trainable_backend def test_embedding_basics(self): self.run_layer_test( diff --git a/keras/src/layers/core/reversible_embedding.py b/keras/src/layers/core/reversible_embedding.py index ae8ea8f4c4f7..41b0f88b0aea 100644 --- a/keras/src/layers/core/reversible_embedding.py +++ b/keras/src/layers/core/reversible_embedding.py @@ -6,6 +6,8 @@ from keras.src import quantizers from keras.src.api_export import keras_export from keras.src.backend import KerasTensor +from keras.src.quantizers.quantization_config import QuantizationConfig +from keras.src.quantizers.quantization_config import validate_and_resolve_config @keras_export("keras.layers.ReversibleEmbedding") @@ -172,20 +174,25 @@ def variable_serialization_spec(self): variable_spec.append("reverse_embeddings_scale") return _spec - def quantized_build(self, embeddings_shape, mode): + def quantized_build(self, embeddings_shape, mode, config=None): if mode == "int8": - self._int8_build(embeddings_shape) + self._int8_build(embeddings_shape, config) elif mode == "int4": - self._int4_build(embeddings_shape) + self._int4_build(embeddings_shape, config) else: raise self._quantization_mode_error(mode) self._is_quantized = True - def _int8_build(self, embeddings_shape): + def _int8_build(self, embeddings_shape, config=None): if embeddings_shape is None: embeddings_shape = (self.input_dim, self.output_dim) super()._int8_build(embeddings_shape=embeddings_shape) - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=-1) + ) + ) if not self.tie_weights: self.reverse_embeddings = self.add_weight( name="reverse_embeddings", @@ -201,11 +208,16 @@ def _int8_build(self, embeddings_shape): trainable=False, ) - def _int4_build(self, embeddings_shape): + def _int4_build(self, embeddings_shape, config=None): if embeddings_shape is None: embeddings_shape = (self.input_dim, self.output_dim) - super()._int4_build(embeddings_shape=embeddings_shape) - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + super()._int4_build(embeddings_shape=embeddings_shape, config=config) + + self.inputs_quantizer = ( + QuantizationConfig.activation_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=-1) + ) + ) if not self.tie_weights: packed_rows = (self.output_dim + 1) // 2 # ceil for odd dims self.reverse_embeddings = self.add_weight( @@ -232,7 +244,10 @@ def _int8_call(self, inputs, reverse=False): else: kernel = self.reverse_embeddings scale = self.reverse_embeddings_scale - inputs, inputs_scale = self.inputs_quantizer(inputs) + if self.inputs_quantizer: + inputs, inputs_scale = self.inputs_quantizer(inputs) + else: + inputs_scale = ops.ones((1,), dtype=self.compute_dtype) logits = ops.matmul(inputs, kernel) # De-scale outputs logits = ops.cast(logits, self.compute_dtype) @@ -258,7 +273,10 @@ def _int4_call(self, inputs, reverse=False): unpacked_embeddings = quantizers.unpack_int4( embeddings, self.output_dim, axis=0 ) - inputs, inputs_scale = self.inputs_quantizer(inputs) + if self.inputs_quantizer: + inputs, inputs_scale = self.inputs_quantizer(inputs) + else: + inputs_scale = ops.ones((1,), dtype=self.compute_dtype) logits = ops.matmul(inputs, unpacked_embeddings) # De-scale outputs logits = ops.cast(logits, self.compute_dtype) @@ -272,30 +290,40 @@ def _int4_call(self, inputs, reverse=False): return logits def quantize(self, mode, type_check=True, config=None): - del config if type_check and type(self) is not ReversibleEmbedding: raise self._not_implemented_error(self.quantize) + config = validate_and_resolve_config(mode, config) + mode = config.mode + embeddings_shape = (self.input_dim, self.output_dim) if mode == "int8": # Quantize `self._embeddings` to int8 and compute corresponding # scale. - embeddings_value, embeddings_scale = quantizers.abs_max_quantize( - self._embeddings, axis=-1, to_numpy=True + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=-1) + ) + embeddings_value, embeddings_scale = weight_quantizer( + self._embeddings, to_numpy=True ) embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) del self._embeddings if not self.tie_weights: + reverse_weight_quantizer = ( + QuantizationConfig.weight_quantizer_or_default( + config, quantizers.AbsMaxQuantizer(axis=0) + ) + ) reverse_embeddings_value, reverse_embeddings_scale = ( - quantizers.abs_max_quantize( - self.reverse_embeddings, axis=0, to_numpy=True + reverse_weight_quantizer( + self.reverse_embeddings, to_numpy=True ) ) reverse_embeddings_scale = ops.squeeze( reverse_embeddings_scale, axis=0 ) del self.reverse_embeddings - self.quantized_build(embeddings_shape, mode) + self.quantized_build(embeddings_shape, mode, config) self._embeddings.assign(embeddings_value) self.embeddings_scale.assign(embeddings_scale) if not self.tie_weights: @@ -303,12 +331,16 @@ def quantize(self, mode, type_check=True, config=None): self.reverse_embeddings_scale.assign(reverse_embeddings_scale) elif mode == "int4": # Quantize to int4 values (stored in int8 dtype, range [-8, 7]). - embeddings_value, embeddings_scale = quantizers.abs_max_quantize( - self._embeddings, - axis=-1, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + weight_quantizer = QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer( + axis=-1, + value_range=(-8, 7), + output_dtype="int8", + ), + ) + embeddings_value, embeddings_scale = weight_quantizer( + self._embeddings, to_numpy=True ) embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) # 2. Pack two int4 values into a single int8 byte. @@ -317,13 +349,19 @@ def quantize(self, mode, type_check=True, config=None): ) del self._embeddings if not self.tie_weights: + reverse_weight_quantizer = ( + QuantizationConfig.weight_quantizer_or_default( + config, + quantizers.AbsMaxQuantizer( + axis=0, + value_range=(-8, 7), + output_dtype="int8", + ), + ) + ) reverse_embeddings_value, reverse_embeddings_scale = ( - quantizers.abs_max_quantize( - self.reverse_embeddings, - axis=0, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + reverse_weight_quantizer( + self.reverse_embeddings, to_numpy=True ) ) reverse_embeddings_scale = ops.squeeze( @@ -334,7 +372,7 @@ def quantize(self, mode, type_check=True, config=None): reverse_embeddings_value, axis=0 ) del self.reverse_embeddings - self.quantized_build(embeddings_shape, mode) + self.quantized_build(embeddings_shape, mode, config) self._embeddings.assign(packed_embeddings_value) self.embeddings_scale.assign(embeddings_scale) if not self.tie_weights: diff --git a/keras/src/layers/core/reversible_embedding_test.py b/keras/src/layers/core/reversible_embedding_test.py index 043c734aea01..95822ea45a2d 100644 --- a/keras/src/layers/core/reversible_embedding_test.py +++ b/keras/src/layers/core/reversible_embedding_test.py @@ -9,11 +9,64 @@ from keras.src import models from keras.src import ops from keras.src import saving +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantizers import AbsMaxQuantizer from keras.src.testing import test_case from keras.src.testing.test_utils import named_product class ReversibleEmbeddingTest(test_case.TestCase): + @parameterized.named_parameters( + ("int8", "int8", {"axis": -1}, {"axis": -1}), + ( + "int4", + "int4", + {"axis": -1, "value_range": (-8, 7), "output_dtype": "int8"}, + {"axis": -1}, + ), + ("int8_weight_only", "int8", {"axis": -1}, None), + ) + def test_reversible_embedding_quantize( + self, mode, weight_quantizer_args, activation_quantizer_args + ): + """Test ReversibleEmbedding quantization with QuantizationConfig.""" + layer = layers.ReversibleEmbedding( + input_dim=10, output_dim=6, tie_weights=True + ) + layer.build((None,)) + + weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args) + if activation_quantizer_args is not None: + activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args) + else: + activation_quantizer = None + + if mode == "int8": + config = Int8QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + elif mode == "int4": + config = Int4QuantizationConfig( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + + layer.quantize(mode, config=config) + + if activation_quantizer_args is not None: + # Verify inputs_quantizer is set correctly + self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer) + else: + # Verify inputs_quantizer is None + self.assertIsNone(layer.inputs_quantizer) + + # Verify reverse call works + x = np.random.random((2, 6)).astype("float32") + y = layer(x, reverse=True) + self.assertEqual(y.shape, (2, 10)) + @parameterized.named_parameters( ("tie_weights", True), ("untie_weights", False), diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 37f4b3bef7ef..5b671401ee5f 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -7,10 +7,11 @@ from keras.src import backend from keras.src import utils from keras.src.api_export import keras_export +from keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES from keras.src.layers.layer import Layer from keras.src.models.variable_mapping import map_saveable_variables -from keras.src.quantizers.gptq_config import GPTQConfig from keras.src.quantizers.gptq_core import gptq_quantize +from keras.src.quantizers.quantization_config import validate_and_resolve_config from keras.src.quantizers.utils import should_quantize_layer from keras.src.saving import saving_api from keras.src.trainers import trainer as base_trainer @@ -424,7 +425,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): **kwargs, ) - def get_quantization_layer_structure(self, mode): + def get_quantization_layer_structure(self, mode=None): """Returns the quantization structure for the model. This method is intended to be overridden by model authors to provide @@ -464,8 +465,6 @@ def quantize(self, mode, config=None, filters=None, **kwargs): layers which match the filter conditions will be quantized. """ - from keras.src.dtype_policies import QUANTIZATION_MODES - # Validate inputs. type_check = kwargs.pop("type_check", True) if kwargs: @@ -488,18 +487,8 @@ def quantize(self, mode, config=None, filters=None, **kwargs): f"{type(filters)}" ) - if mode == "gptq": - if not isinstance(config, GPTQConfig): - raise ValueError( - "Mode 'gptq' requires a valid `config` argument of type " - f"`GPTQConfig`. Received: {type(config)}" - ) - elif config is not None: - # All other modes must not receive a config - raise ValueError( - f"The `config` argument is only supported for 'gptq' mode, " - f"but received mode='{mode}' and a non-None config." - ) + config = validate_and_resolve_config(mode, config) + mode = config.mode graph_modified = False for layer in self._flatten_layers(): diff --git a/keras/src/quantizers/__init__.py b/keras/src/quantizers/__init__.py index 586530204588..1e80a9cb7dc3 100644 --- a/keras/src/quantizers/__init__.py +++ b/keras/src/quantizers/__init__.py @@ -1,6 +1,10 @@ import inspect from keras.src.api_export import keras_export +from keras.src.quantizers.quantization_config import Float8QuantizationConfig +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantization_config import QuantizationConfig from keras.src.quantizers.quantizers import AbsMaxQuantizer from keras.src.quantizers.quantizers import Quantizer from keras.src.quantizers.quantizers import abs_max_quantize @@ -13,7 +17,14 @@ from keras.src.saving import serialization_lib from keras.src.utils.naming import to_snake_case -ALL_OBJECTS = {Quantizer, AbsMaxQuantizer} +ALL_OBJECTS = { + Quantizer, + AbsMaxQuantizer, + QuantizationConfig, + Int8QuantizationConfig, + Int4QuantizationConfig, + Float8QuantizationConfig, +} ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} ALL_OBJECTS_DICT.update( {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS} diff --git a/keras/src/quantizers/gptq_config.py b/keras/src/quantizers/gptq_config.py index edcb465ce4c2..0ea159b56548 100644 --- a/keras/src/quantizers/gptq_config.py +++ b/keras/src/quantizers/gptq_config.py @@ -1,8 +1,9 @@ from keras.src.api_export import keras_export +from keras.src.quantizers.quantization_config import QuantizationConfig @keras_export("keras.quantizers.GPTQConfig") -class GPTQConfig: +class GPTQConfig(QuantizationConfig): """Configuration class for the GPTQ (Gradient-based Post-Training Quantization) algorithm. @@ -154,6 +155,7 @@ def __init__( activation_order: bool = False, quantization_layer_structure: dict = None, ): + super().__init__() if weight_bits not in [2, 3, 4, 8]: raise ValueError( f"Unsupported weight_bits {weight_bits}. " @@ -183,6 +185,10 @@ def __init__( self.activation_order = activation_order self.quantization_layer_structure = quantization_layer_structure + @property + def mode(self): + return "gptq" + def dtype_policy_string(self): """Returns the dtype policy string for this configuration. diff --git a/keras/src/quantizers/gptq_test.py b/keras/src/quantizers/gptq_test.py index d6fe0048ac3f..a2af07c27155 100644 --- a/keras/src/quantizers/gptq_test.py +++ b/keras/src/quantizers/gptq_test.py @@ -14,6 +14,7 @@ from keras.src.quantizers.gptq import _stable_permutation from keras.src.quantizers.gptq import gptq_quantize_matrix from keras.src.quantizers.gptq_config import GPTQConfig +from keras.src.quantizers.quantization_config import QuantizationConfig from keras.src.quantizers.quantizers import dequantize_with_sz_map from keras.src.quantizers.quantizers import dequantize_with_zero_point from keras.src.quantizers.quantizers import quantize_with_zero_point @@ -621,18 +622,27 @@ def test_quantize_gptq_combinations(self, dataset, config): @parameterized.named_parameters( { - "testcase_name": "gptq_with_invalid_config", + "testcase_name": "gptq_with_invalid_config_type", "mode": "gptq", "config": {"weight_bits": 4}, "expected_exception": ValueError, - "error_msg": "Mode 'gptq' requires a valid `config`", + "error_msg": "Argument `config` must be an instance of " + "`QuantizationConfig`", }, { - "testcase_name": "non_gptq_with_unsupported_config", - "mode": "int8", - "config": GPTQConfig(dataset=["a"], tokenizer=lambda x: x), + "testcase_name": "gptq_with_none_config", + "mode": "gptq", + "config": None, "expected_exception": ValueError, - "error_msg": "only supported for 'gptq'", + "error_msg": "For GPTQ, you must pass a GPTQConfig " + "object explicitly.", + }, + { + "testcase_name": "gptq_with_base_quantization_config", + "mode": "gptq", + "config": QuantizationConfig(), + "expected_exception": NotImplementedError, + "error_msg": "Do not instantiate QuantizationConfig directly.", }, { "testcase_name": "gptq_missing_structure", diff --git a/keras/src/quantizers/quantization_config.py b/keras/src/quantizers/quantization_config.py new file mode 100644 index 000000000000..362bc866e67d --- /dev/null +++ b/keras/src/quantizers/quantization_config.py @@ -0,0 +1,213 @@ +from keras.src.api_export import keras_export +from keras.src.dtype_policies import QUANTIZATION_MODES +from keras.src.saving import serialization_lib + + +@keras_export("keras.quantizers.QuantizationConfig") +class QuantizationConfig: + """Base class for quantization configs. + + Subclasses must implement the `mode` property and the `get_config` and + `from_config` class methods. + + Args: + weight_quantizer: Quantizer for weights. + activation_quantizer: Quantizer for activations. + """ + + def __init__(self, weight_quantizer=None, activation_quantizer=None): + self.weight_quantizer = weight_quantizer + self.activation_quantizer = activation_quantizer + + @property + def mode(self): + raise NotImplementedError( + "Subclasses must implement this property. Do not instantiate " + "QuantizationConfig directly." + ) + + def get_config(self): + return { + "weight_quantizer": serialization_lib.serialize_keras_object( + self.weight_quantizer + ), + "activation_quantizer": serialization_lib.serialize_keras_object( + self.activation_quantizer + ), + } + + @classmethod + def from_config(cls, config): + weight_quantizer = serialization_lib.deserialize_keras_object( + config.get("weight_quantizer") + ) + activation_quantizer = serialization_lib.deserialize_keras_object( + config.get("activation_quantizer") + ) + return cls( + weight_quantizer=weight_quantizer, + activation_quantizer=activation_quantizer, + ) + + @staticmethod + def weight_quantizer_or_default(config, default): + if config and config.weight_quantizer: + return config.weight_quantizer + return default + + @staticmethod + def activation_quantizer_or_default(config, default): + if config and config.activation_quantizer: + return config.activation_quantizer + elif config and config.activation_quantizer is None: + return None + return default + + +@keras_export("keras.quantizers.Int8QuantizationConfig") +class Int8QuantizationConfig(QuantizationConfig): + """Int8 quantization config. + + Args: + weight_quantizer: Quantizer for weights. + activation_quantizer: Quantizer for activations. If "default", uses + AbsMaxQuantizer with axis=-1. + """ + + def __init__(self, weight_quantizer=None, activation_quantizer="default"): + from keras.src.quantizers.quantizers import AbsMaxQuantizer + + if activation_quantizer == "default": + activation_quantizer = AbsMaxQuantizer(axis=-1) + super().__init__(weight_quantizer, activation_quantizer) + if self.weight_quantizer: + if hasattr(self.weight_quantizer, "value_range"): + if self.weight_quantizer.value_range != (-127, 127): + raise ValueError( + "Int8QuantizationConfig requires a weight_quantizer " + "with value_range=(-127, 127). Received: " + f"value_range={self.weight_quantizer.value_range}" + ) + + @property + def mode(self): + return "int8" + + +@keras_export("keras.quantizers.Int4QuantizationConfig") +class Int4QuantizationConfig(QuantizationConfig): + """Int4 quantization config. + + Args: + weight_quantizer: Quantizer for weights. + activation_quantizer: Quantizer for activations. If "default", uses + AbsMaxQuantizer with axis=-1. + """ + + def __init__(self, weight_quantizer=None, activation_quantizer="default"): + from keras.src.quantizers.quantizers import AbsMaxQuantizer + + if activation_quantizer == "default": + activation_quantizer = AbsMaxQuantizer(axis=-1) + super().__init__(weight_quantizer, activation_quantizer) + if self.weight_quantizer: + if hasattr(self.weight_quantizer, "value_range"): + if self.weight_quantizer.value_range != (-8, 7): + raise ValueError( + "Int4QuantizationConfig requires a weight_quantizer " + "with value_range=(-8, 7). Received: " + f"value_range={self.weight_quantizer.value_range}" + ) + + @property + def mode(self): + return "int4" + + +@keras_export("keras.quantizers.Float8QuantizationConfig") +class Float8QuantizationConfig(QuantizationConfig): + """FP8 quantization config. + + FP8 mixed-precision training does not support user defined quantizers. + This config is only used to indicate that FP8 mixed-precision training + should be used. + """ + + def __init__(self): + super().__init__(None, None) + + @property + def mode(self): + return "float8" + + +def validate_and_resolve_config(mode, config, name=None): + """Validate and resolve quantization config. + + This function validates the quantization config and resolves the mode. + If mode is not provided, it is inferred from the config. + If config is not provided, a default config is inferred from the mode. + + Args: + mode: Quantization mode. + config: Quantization config. + name: Name of the quantization config. + """ + # 1. Backwards Compatibility: Handle string shortcuts + if isinstance(config, str): + mode = config + config = None + + # 2. Resolve "mode" into a Config object + if config is None: + if mode == "int8": + config = Int8QuantizationConfig() + elif mode == "int4": + config = Int4QuantizationConfig() + elif mode == "float8": + config = Float8QuantizationConfig() + elif mode == "gptq": + raise ValueError( + "For GPTQ, you must pass a GPTQConfig object explicitly." + ) + else: + if mode is not None: + raise ValueError( + f"Invalid quantization mode. Received: mode={mode}" + ) + raise ValueError( + "You must provide either `mode` or `config` to `quantize`." + ) + else: + if not isinstance(config, QuantizationConfig): + raise ValueError( + "Argument `config` must be an instance of " + "`QuantizationConfig`. " + f"Received: config={config} (of type {type(config)})" + ) + + # 3. Validation: Prevent contradictions + if mode is not None and config.mode != mode: + raise ValueError( + f"Contradictory arguments: mode='{mode}' but " + f"config.mode='{config.mode}'" + ) + + # 4. Execution + mode = config.mode # Ensure mode is consistent + if mode not in QUANTIZATION_MODES: + raise ValueError( + "Invalid quantization mode. " + f"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}" + ) + + if mode == "gptq": + from keras.src.quantizers.gptq_config import GPTQConfig + + if not isinstance(config, GPTQConfig): + raise ValueError( + "Mode 'gptq' requires a valid `config` argument of type " + f"`GPTQConfig`. Received: {type(config)}" + ) + + return config diff --git a/keras/src/quantizers/quantization_config_test.py b/keras/src/quantizers/quantization_config_test.py new file mode 100644 index 000000000000..f7c94c6cc2e4 --- /dev/null +++ b/keras/src/quantizers/quantization_config_test.py @@ -0,0 +1,106 @@ +from keras.src import testing +from keras.src.quantizers.quantization_config import Int4QuantizationConfig +from keras.src.quantizers.quantization_config import Int8QuantizationConfig +from keras.src.quantizers.quantization_config import QuantizationConfig +from keras.src.quantizers.quantization_config import validate_and_resolve_config +from keras.src.quantizers.quantizers import AbsMaxQuantizer + + +class QuantizationConfigTest(testing.TestCase): + def test_base_quantization_config(self): + config = QuantizationConfig() + with self.assertRaises(NotImplementedError): + _ = config.mode + + def test_int8_quantization_config_valid(self): + config = Int8QuantizationConfig() + self.assertEqual(config.mode, "int8") + self.assertIsNone(config.weight_quantizer) + + # Valid weight quantizer + q = AbsMaxQuantizer(axis=0, value_range=(-127, 127)) + config = Int8QuantizationConfig(weight_quantizer=q) + self.assertEqual(config.weight_quantizer, q) + + def test_int8_quantization_config_invalid(self): + # Invalid value_range + q = AbsMaxQuantizer(axis=0, value_range=(-8, 7)) + with self.assertRaisesRegex(ValueError, "value_range"): + Int8QuantizationConfig(weight_quantizer=q) + + def test_int4_quantization_config_valid(self): + config = Int4QuantizationConfig() + self.assertEqual(config.mode, "int4") + self.assertIsNone(config.weight_quantizer) + + # Valid weight quantizer + q = AbsMaxQuantizer(axis=0, value_range=(-8, 7)) + config = Int4QuantizationConfig(weight_quantizer=q) + self.assertEqual(config.weight_quantizer, q) + + def test_int4_quantization_config_invalid(self): + # Invalid value_range + q = AbsMaxQuantizer(axis=0, value_range=(-127, 127)) + with self.assertRaisesRegex(ValueError, "value_range"): + Int4QuantizationConfig(weight_quantizer=q) + + def test_quantization_config_serialization(self): + config = Int8QuantizationConfig( + weight_quantizer=AbsMaxQuantizer(axis=0), + activation_quantizer=AbsMaxQuantizer(axis=-1), + ) + serialized = config.get_config() + deserialized = Int8QuantizationConfig.from_config(serialized) + self.assertIsInstance(deserialized, Int8QuantizationConfig) + self.assertIsInstance(deserialized.weight_quantizer, AbsMaxQuantizer) + self.assertIsInstance( + deserialized.activation_quantizer, AbsMaxQuantizer + ) + self.assertEqual(deserialized.weight_quantizer.axis, (0,)) + self.assertEqual(deserialized.activation_quantizer.axis, (-1,)) + + def test_validate_and_resolve_config(self): + # 1. String mode + config = validate_and_resolve_config("int8", None) + self.assertIsInstance(config, Int8QuantizationConfig) + self.assertEqual(config.mode, "int8") + + config = validate_and_resolve_config("int4", None) + self.assertIsInstance(config, Int4QuantizationConfig) + self.assertEqual(config.mode, "int4") + + # 2. Config object + config_in = Int8QuantizationConfig() + config_out = validate_and_resolve_config(None, config_in) + self.assertIs(config_out, config_in) + + # 3. Mode + Config (matching) + config_in = Int8QuantizationConfig() + config_out = validate_and_resolve_config("int8", config_in) + self.assertIs(config_out, config_in) + + # 4. Mode + Config (mismatch) + config_in = Int8QuantizationConfig() + with self.assertRaisesRegex(ValueError, "Contradictory arguments"): + validate_and_resolve_config("int4", config_in) + + # 5. Invalid mode + with self.assertRaisesRegex(ValueError, "Invalid quantization mode"): + validate_and_resolve_config("invalid_mode", None) + + # 6. GPTQ without config + with self.assertRaisesRegex(ValueError, "must pass a GPTQConfig"): + validate_and_resolve_config("gptq", None) + + # 7. Contradictory config + with self.assertRaisesRegex(ValueError, "Contradictory arguments"): + validate_and_resolve_config("gptq", Int8QuantizationConfig()) + + # 8. GPTQ with invalid config type (but correct mode) + class FakeGPTQConfig(QuantizationConfig): + @property + def mode(self): + return "gptq" + + with self.assertRaisesRegex(ValueError, "requires a valid `config`"): + validate_and_resolve_config("gptq", FakeGPTQConfig()) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index d9ef671b6fc9..708a143504c9 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -117,9 +117,14 @@ def __init__( self.value_range = value_range self.epsilon = epsilon - def __call__(self, x): + def __call__(self, x, to_numpy=False): quantized_x, scale = abs_max_quantize( - x, self.axis, self.value_range, self.output_dtype, self.epsilon + x, + self.axis, + self.value_range, + self.output_dtype, + self.epsilon, + to_numpy, ) return quantized_x, scale