Skip to content

Commit 2ae1e37

Browse files
Introduces customizable quantization API
1 parent 9fc8185 commit 2ae1e37

15 files changed

+752
-135
lines changed

keras/src/layers/core/dense.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
from keras.src import quantizers
1010
from keras.src import regularizers
1111
from keras.src.api_export import keras_export
12+
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
1213
from keras.src.layers.input_spec import InputSpec
1314
from keras.src.layers.layer import Layer
15+
from keras.src.quantizers.quantization_config import QuantizationConfig
16+
from keras.src.quantizers.quantization_config import validate_and_resolve_config
1417
from keras.src.quantizers.quantizers import dequantize_with_sz_map
1518

1619

@@ -370,9 +373,9 @@ def variable_serialization_spec(self):
370373

371374
def quantized_build(self, kernel_shape, mode, config=None):
372375
if mode == "int8":
373-
self._int8_build(kernel_shape)
376+
self._int8_build(kernel_shape, config)
374377
elif mode == "int4":
375-
self._int4_build(kernel_shape)
378+
self._int4_build(kernel_shape, config)
376379
elif mode == "float8":
377380
self._float8_build()
378381
elif mode == "gptq":
@@ -381,8 +384,14 @@ def quantized_build(self, kernel_shape, mode, config=None):
381384
raise self._quantization_mode_error(mode)
382385
self._is_quantized = True
383386

384-
def _int8_build(self, kernel_shape):
385-
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
387+
def _int8_build(self, kernel_shape, config=None):
388+
# Per-channel int8 quantizer for the last axis (features).
389+
self.inputs_quantizer = (
390+
QuantizationConfig.activation_quantizer_or_default(
391+
config, quantizers.AbsMaxQuantizer(axis=-1)
392+
)
393+
)
394+
386395
self._kernel = self.add_weight(
387396
name="kernel",
388397
shape=kernel_shape,
@@ -481,7 +490,7 @@ def _gptq_call(self, inputs, training=False):
481490
y = self.activation(y)
482491
return y
483492

484-
def _int4_build(self, kernel_shape):
493+
def _int4_build(self, kernel_shape, config=None):
485494
"""Build variables for int4 quantization.
486495
487496
`kernel_shape` is the *original* float32 kernel shape
@@ -490,8 +499,10 @@ def _int4_build(self, kernel_shape):
490499
int8 byte.
491500
"""
492501
# Per-channel int8 quantizer for the last axis (features).
493-
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
494-
axis=-1,
502+
self.inputs_quantizer = (
503+
QuantizationConfig.activation_quantizer_or_default(
504+
config, quantizers.AbsMaxQuantizer(axis=-1)
505+
)
495506
)
496507
input_dim, output_dim = kernel_shape
497508
packed_rows = (input_dim + 1) // 2 # ceil for odd dims
@@ -515,8 +526,6 @@ def _int4_build(self, kernel_shape):
515526
self._orig_input_dim = input_dim
516527

517528
def _float8_build(self):
518-
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
519-
520529
# If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set
521530
# `amax_history_length` to its default value.
522531
amax_history_length = getattr(
@@ -580,7 +589,15 @@ def grad_fn(*args, upstream=None):
580589
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
581590
return (inputs_grad, None, None)
582591

583-
inputs, inputs_scale = self.inputs_quantizer(inputs)
592+
if self.inputs_quantizer:
593+
inputs, inputs_scale = self.inputs_quantizer(inputs)
594+
else:
595+
# Weight-only quantization: inputs are not quantized
596+
# We still need inputs_scale for the formula:
597+
# x = x / (inputs_scale * kernel_scale)
598+
# If inputs are not quantized, inputs_scale should be 1.
599+
inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
600+
584601
x = ops.matmul(inputs, kernel)
585602
# De-scale outputs
586603
x = ops.cast(x, self.compute_dtype)
@@ -631,7 +648,10 @@ def grad_fn(*args, upstream=None):
631648
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
632649
return (inputs_grad, None, None)
633650

634-
inputs, inputs_scale = self.inputs_quantizer(inputs)
651+
if self.inputs_quantizer:
652+
inputs, inputs_scale = self.inputs_quantizer(inputs)
653+
else:
654+
inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
635655
x = ops.matmul(inputs, unpacked_kernel)
636656
x = ops.cast(x, self.compute_dtype)
637657
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
@@ -746,38 +766,53 @@ def grad(*args, upstream=None, variables=None):
746766
x = self.activation(x)
747767
return x
748768

749-
def quantize(self, mode, type_check=True, config=None):
769+
def quantize(self, mode=None, type_check=True, config=None):
750770
# Prevent quantization of the subclasses
751771
if type_check and (type(self) is not Dense):
752772
raise self._not_implemented_error(self.quantize)
753773

774+
config = validate_and_resolve_config(mode, config)
775+
mode = config.mode
776+
754777
kernel_shape = self._kernel.shape
755778
if mode == "int8":
756-
kernel_value, kernel_scale = quantizers.abs_max_quantize(
757-
self._kernel, axis=0, to_numpy=True
779+
# Handle weight quantization
780+
# Quantize `self._kernel` to int8 and compute corresponding scale
781+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
782+
config, quantizers.AbsMaxQuantizer(axis=0)
758783
)
759-
kernel_scale = ops.squeeze(kernel_scale, axis=0)
784+
kernel_value, kernel_scale = weight_quantizer(
785+
self._kernel, to_numpy=True
786+
)
787+
788+
if len(kernel_scale.shape) > 0:
789+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
790+
760791
del self._kernel
761792
# Build variables for int8 mode
762-
self.quantized_build(kernel_shape, mode)
793+
self.quantized_build(kernel_shape, mode, config)
763794
self._kernel.assign(kernel_value)
764795
self.kernel_scale.assign(kernel_scale)
765796
elif mode == "int4":
766797
# 1. Quantize to int4 values (still int8 dtype, range [-8,7])
767-
kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
768-
self._kernel,
769-
axis=0,
770-
value_range=(-8, 7),
771-
dtype="int8",
772-
to_numpy=True,
798+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
799+
config,
800+
quantizers.AbsMaxQuantizer(
801+
axis=0, value_range=(-8, 7), output_dtype="int8"
802+
),
773803
)
774-
kernel_scale = ops.squeeze(kernel_scale, axis=0)
804+
kernel_value_int4, kernel_scale = weight_quantizer(
805+
self._kernel, to_numpy=True
806+
)
807+
808+
if len(kernel_scale.shape) > 0:
809+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
775810
# 2. Pack two int4 values into a single int8 byte.
776811
packed_kernel_value, _, _ = quantizers.pack_int4(kernel_value_int4)
777812
del self._kernel
778813
# Build variables using the original kernel shape; _int4_build will
779814
# compute the packed shape internally.
780-
self.quantized_build(kernel_shape, mode)
815+
self.quantized_build(kernel_shape, mode, config)
781816
# Assign packed values.
782817
self._kernel.assign(packed_kernel_value)
783818
self.kernel_scale.assign(kernel_scale)

keras/src/layers/core/dense_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,67 @@
1717
from keras.src import testing
1818
from keras.src.backend.common import keras_tensor
1919
from keras.src.quantizers.gptq_config import GPTQConfig
20+
from keras.src.quantizers.quantization_config import Int4QuantizationConfig
21+
from keras.src.quantizers.quantization_config import Int8QuantizationConfig
22+
from keras.src.quantizers.quantizers import AbsMaxQuantizer
2023

2124

2225
class DenseTest(testing.TestCase):
26+
@parameterized.named_parameters(
27+
("int8", "int8", {"axis": 0}, {"axis": -1}),
28+
(
29+
"int4",
30+
"int4",
31+
{"axis": 0, "value_range": (-8, 7), "output_dtype": "int8"},
32+
{"axis": -1},
33+
),
34+
("int8_weight_only", "int8", {"axis": 0}, None),
35+
)
36+
def test_dense_quantize_config(
37+
self, mode, weight_quantizer_args, activation_quantizer_args
38+
):
39+
"""Test Dense quantization with QuantizationConfig."""
40+
layer = layers.Dense(units=32)
41+
layer.build((None, 8))
42+
43+
weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args)
44+
if activation_quantizer_args is not None:
45+
activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args)
46+
else:
47+
activation_quantizer = None
48+
49+
if mode == "int8":
50+
config = Int8QuantizationConfig(
51+
weight_quantizer=weight_quantizer,
52+
activation_quantizer=activation_quantizer,
53+
)
54+
elif mode == "int4":
55+
config = Int4QuantizationConfig(
56+
weight_quantizer=weight_quantizer,
57+
activation_quantizer=activation_quantizer,
58+
)
59+
60+
layer.quantize(mode, config=config)
61+
62+
if activation_quantizer_args is not None:
63+
# Verify inputs_quantizer is set correctly
64+
self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer)
65+
self.assertEqual(layer.inputs_quantizer.axis, (-1,))
66+
else:
67+
# Verify inputs_quantizer is None
68+
self.assertIsNone(layer.inputs_quantizer)
69+
70+
# Verify call works
71+
x = np.random.random((2, 8)).astype("float32")
72+
y = layer(x)
73+
self.assertEqual(y.shape, (2, 32))
74+
75+
if mode == "int4":
76+
# Verify kernel is int8 (packed int4)
77+
self.assertEqual(
78+
backend.standardize_dtype(layer._kernel.dtype), "int8"
79+
)
80+
2381
@pytest.mark.requires_trainable_backend
2482
def test_dense_basics(self):
2583
# 2D case, no bias.

keras/src/layers/core/einsum_dense.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
from keras.src import quantizers
1414
from keras.src import regularizers
1515
from keras.src.api_export import keras_export
16+
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
1617
from keras.src.layers.input_spec import InputSpec
1718
from keras.src.layers.layer import Layer
19+
from keras.src.quantizers.quantization_config import QuantizationConfig
1820
from keras.src.quantizers.quantizers import dequantize_with_sz_map
1921

2022

@@ -444,9 +446,9 @@ def variable_serialization_spec(self):
444446

445447
def quantized_build(self, kernel_shape, mode, config=None):
446448
if mode == "int8":
447-
self._int8_build(kernel_shape)
449+
self._int8_build(kernel_shape, config)
448450
elif mode == "int4":
449-
self._int4_build(kernel_shape)
451+
self._int4_build(kernel_shape, config)
450452
elif mode == "float8":
451453
self._float8_build()
452454
elif mode == "gptq":
@@ -455,10 +457,13 @@ def quantized_build(self, kernel_shape, mode, config=None):
455457
raise self._quantization_mode_error(mode)
456458
self._is_quantized = True
457459

458-
def _int8_build(self, kernel_shape):
460+
def _int8_build(self, kernel_shape, config=None):
459461
self._set_quantization_info()
460-
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
461-
axis=self._input_reduced_axes
462+
self.inputs_quantizer = (
463+
QuantizationConfig.activation_quantizer_or_default(
464+
config,
465+
quantizers.AbsMaxQuantizer(axis=self._input_reduced_axes),
466+
)
462467
)
463468
self._kernel = self.add_weight(
464469
name="kernel",
@@ -591,7 +596,7 @@ def _gptq_call(self, inputs, training=False):
591596
y = self.activation(y)
592597
return y
593598

594-
def _int4_build(self, kernel_shape):
599+
def _int4_build(self, kernel_shape, config=None):
595600
"""Build variables for int4 quantization.
596601
597602
The packed int4 kernel stores two int4 values within a single int8
@@ -603,8 +608,11 @@ def _int4_build(self, kernel_shape):
603608
self._set_quantization_info()
604609

605610
# Quantizer for the inputs (per the reduced axes)
606-
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
607-
axis=self._input_reduced_axes
611+
self.inputs_quantizer = (
612+
QuantizationConfig.activation_quantizer_or_default(
613+
config,
614+
quantizers.AbsMaxQuantizer(axis=self._input_reduced_axes),
615+
)
608616
)
609617

610618
# Choose the axis to perform int4 packing - use the first reduced axis
@@ -643,8 +651,6 @@ def _int4_build(self, kernel_shape):
643651
)
644652

645653
def _float8_build(self):
646-
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
647-
648654
# If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set
649655
# `amax_history_length` to its default value.
650656
amax_history_length = getattr(
@@ -727,10 +733,15 @@ def grad_fn(*args, upstream=None):
727733
)
728734
return (inputs_grad, None, None)
729735

730-
inputs, inputs_scale = self.inputs_quantizer(inputs)
736+
if self.inputs_quantizer:
737+
inputs, inputs_scale = self.inputs_quantizer(inputs)
738+
# Deal with `inputs_scale`
739+
inputs_scale = self._adjust_scale_for_quant(
740+
inputs_scale, "input"
741+
)
742+
else:
743+
inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
731744
x = ops.einsum(self.equation, inputs, kernel)
732-
# Deal with `inputs_scale`
733-
inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input")
734745
# De-scale outputs
735746
x = ops.cast(x, self.compute_dtype)
736747
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
@@ -803,14 +814,20 @@ def grad_fn(*args, upstream=None):
803814
return (inputs_grad, None, None)
804815

805816
# Quantize inputs per `self.inputs_quantizer`.
806-
inputs_q, inputs_scale = self.inputs_quantizer(inputs)
817+
if self.inputs_quantizer:
818+
inputs_q, inputs_scale = self.inputs_quantizer(inputs)
819+
# Align `inputs_scale` axes with the output
820+
# for correct broadcasting
821+
inputs_scale = self._adjust_scale_for_quant(
822+
inputs_scale, "input"
823+
)
824+
else:
825+
inputs_q = inputs
826+
inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
807827

808828
# Compute einsum on quantized inputs and unpacked int4 kernel.
809829
x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
810830

811-
# Align `inputs_scale` axes with the output for correct broadcasting
812-
inputs_scale = self._adjust_scale_for_quant(inputs_scale, "input")
813-
814831
# De-scale outputs.
815832
x = ops.cast(x, self.compute_dtype)
816833
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
@@ -933,24 +950,33 @@ def quantize(self, mode, type_check=True, config=None):
933950
raise self._not_implemented_error(self.quantize)
934951

935952
kernel_shape = self._kernel.shape
953+
936954
if mode in ("int8", "int4", "gptq"):
937955
self._set_quantization_info()
938956

939957
if mode == "int8":
940958
# Quantize `self._kernel` to int8 and compute corresponding scale
941-
kernel_value, kernel_scale = quantizers.abs_max_quantize(
942-
self._kernel, axis=self._kernel_reduced_axes, to_numpy=True
959+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
960+
config,
961+
quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes),
962+
)
963+
kernel_value, kernel_scale = weight_quantizer(
964+
self._kernel, to_numpy=True
943965
)
944966
kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
945967
del self._kernel
946968
elif mode == "int4":
947969
# Quantize to int4 values (stored in int8 dtype, range [-8, 7])
948-
kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
949-
self._kernel,
950-
axis=self._kernel_reduced_axes,
951-
value_range=(-8, 7),
952-
dtype="int8",
953-
to_numpy=True,
970+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
971+
config,
972+
quantizers.AbsMaxQuantizer(
973+
axis=self._kernel_reduced_axes,
974+
value_range=(-8, 7),
975+
output_dtype="int8",
976+
),
977+
)
978+
kernel_value_int4, kernel_scale = weight_quantizer(
979+
self._kernel, to_numpy=True
954980
)
955981
kernel_scale = self._adjust_scale_for_quant(kernel_scale, "kernel")
956982

0 commit comments

Comments
 (0)