Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras/api/_tf_keras/keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
from keras.src.quantizers import get as get
from keras.src.quantizers import serialize as serialize
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
from keras.src.quantizers.quantization_config import (
Float8QuantizationConfig as Float8QuantizationConfig,
)
from keras.src.quantizers.quantization_config import (
Int4QuantizationConfig as Int4QuantizationConfig,
)
from keras.src.quantizers.quantization_config import (
Int8QuantizationConfig as Int8QuantizationConfig,
)
from keras.src.quantizers.quantization_config import (
QuantizationConfig as QuantizationConfig,
)
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
from keras.src.quantizers.quantizers import Quantizer as Quantizer
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
Expand Down
12 changes: 12 additions & 0 deletions keras/api/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
from keras.src.quantizers import get as get
from keras.src.quantizers import serialize as serialize
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
from keras.src.quantizers.quantization_config import (
Float8QuantizationConfig as Float8QuantizationConfig,
)
from keras.src.quantizers.quantization_config import (
Int4QuantizationConfig as Int4QuantizationConfig,
)
from keras.src.quantizers.quantization_config import (
Int8QuantizationConfig as Int8QuantizationConfig,
)
from keras.src.quantizers.quantization_config import (
QuantizationConfig as QuantizationConfig,
)
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
from keras.src.quantizers.quantizers import Quantizer as Quantizer
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
Expand Down
68 changes: 47 additions & 21 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from keras.src.api_export import keras_export
from keras.src.layers.input_spec import InputSpec
from keras.src.layers.layer import Layer
from keras.src.quantizers.quantization_config import QuantizationConfig
from keras.src.quantizers.quantization_config import validate_and_resolve_config
from keras.src.quantizers.quantizers import dequantize_with_sz_map


Expand Down Expand Up @@ -372,9 +374,9 @@ def variable_serialization_spec(self):

def quantized_build(self, kernel_shape, mode, config=None):
if mode == "int8":
self._int8_build(kernel_shape)
self._int8_build(kernel_shape, config)
elif mode == "int4":
self._int4_build(kernel_shape)
self._int4_build(kernel_shape, config)
elif mode == "float8":
self._float8_build()
elif mode == "gptq":
Expand All @@ -383,8 +385,13 @@ def quantized_build(self, kernel_shape, mode, config=None):
raise self._quantization_mode_error(mode)
self._is_quantized = True

def _int8_build(self, kernel_shape):
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
def _int8_build(self, kernel_shape, config=None):
self.inputs_quantizer = (
QuantizationConfig.activation_quantizer_or_default(
config, quantizers.AbsMaxQuantizer(axis=-1)
)
)

self._kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
Expand Down Expand Up @@ -483,7 +490,7 @@ def _gptq_call(self, inputs, training=False):
y = self.activation(y)
return y

def _int4_build(self, kernel_shape):
def _int4_build(self, kernel_shape, config=None):
"""Build variables for int4 quantization.

`kernel_shape` is the *original* float32 kernel shape
Expand All @@ -492,8 +499,10 @@ def _int4_build(self, kernel_shape):
int8 byte.
"""
# Per-channel int8 quantizer for the last axis (features).
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
axis=-1,
self.inputs_quantizer = (
QuantizationConfig.activation_quantizer_or_default(
config, quantizers.AbsMaxQuantizer(axis=-1)
)
)
input_dim, output_dim = kernel_shape
packed_rows = (input_dim + 1) // 2 # ceil for odd dims
Expand Down Expand Up @@ -582,11 +591,15 @@ def grad_fn(*args, upstream=None):
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
return (inputs_grad, None, None)

inputs, inputs_scale = self.inputs_quantizer(inputs)
output_scale = kernel_scale
if self.inputs_quantizer:
inputs, inputs_scale = self.inputs_quantizer(inputs)
output_scale = ops.multiply(output_scale, inputs_scale)

x = ops.matmul(inputs, kernel)
# De-scale outputs
x = ops.cast(x, self.compute_dtype)
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
x = ops.divide(x, output_scale)
return x, grad_fn

x = matmul_with_inputs_gradient(
Expand Down Expand Up @@ -633,10 +646,15 @@ def grad_fn(*args, upstream=None):
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
return (inputs_grad, None, None)

inputs, inputs_scale = self.inputs_quantizer(inputs)
output_scale = kernel_scale

if self.inputs_quantizer:
inputs, inputs_scale = self.inputs_quantizer(inputs)
output_scale = ops.multiply(output_scale, inputs_scale)

x = ops.matmul(inputs, unpacked_kernel)
x = ops.cast(x, self.compute_dtype)
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
x = ops.divide(x, output_scale)
return x, grad_fn

x = matmul_with_inputs_gradient(
Expand Down Expand Up @@ -753,33 +771,41 @@ def quantize(self, mode, type_check=True, config=None):
if type_check and (type(self) is not Dense):
raise self._not_implemented_error(self.quantize)

config = validate_and_resolve_config(mode, config)
mode = config.mode

kernel_shape = self._kernel.shape
if mode == "int8":
kernel_value, kernel_scale = quantizers.abs_max_quantize(
self._kernel, axis=0, to_numpy=True
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
config, quantizers.AbsMaxQuantizer(axis=0)
)
kernel_value, kernel_scale = weight_quantizer(
self._kernel, to_numpy=True
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
del self._kernel
# Build variables for int8 mode
self.quantized_build(kernel_shape, mode)
self.quantized_build(kernel_shape, mode, config)
self._kernel.assign(kernel_value)
self.kernel_scale.assign(kernel_scale)
elif mode == "int4":
# 1. Quantize to int4 values (still int8 dtype, range [-8,7])
kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
self._kernel,
axis=0,
value_range=(-8, 7),
dtype="int8",
to_numpy=True,
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
config,
quantizers.AbsMaxQuantizer(
axis=0, value_range=(-8, 7), output_dtype="int8"
),
)
kernel_value_int4, kernel_scale = weight_quantizer(
self._kernel, to_numpy=True
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
# 2. Pack two int4 values into a single int8 byte.
packed_kernel_value, _, _ = quantizers.pack_int4(kernel_value_int4)
del self._kernel
# Build variables using the original kernel shape; _int4_build will
# compute the packed shape internally.
self.quantized_build(kernel_shape, mode)
self.quantized_build(kernel_shape, mode, config)
# Assign packed values.
self._kernel.assign(packed_kernel_value)
self.kernel_scale.assign(kernel_scale)
Expand Down
58 changes: 58 additions & 0 deletions keras/src/layers/core/dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,67 @@
from keras.src import testing
from keras.src.backend.common import keras_tensor
from keras.src.quantizers.gptq_config import GPTQConfig
from keras.src.quantizers.quantization_config import Int4QuantizationConfig
from keras.src.quantizers.quantization_config import Int8QuantizationConfig
from keras.src.quantizers.quantizers import AbsMaxQuantizer


class DenseTest(testing.TestCase):
@parameterized.named_parameters(
("int8", "int8", {"axis": 0}, {"axis": -1}),
(
"int4",
"int4",
{"axis": 0, "value_range": (-8, 7), "output_dtype": "int8"},
{"axis": -1},
),
("int8_weight_only", "int8", {"axis": 0}, None),
)
def test_dense_quantize_config(
self, mode, weight_quantizer_args, activation_quantizer_args
):
"""Test Dense quantization with QuantizationConfig."""
layer = layers.Dense(units=32)
layer.build((None, 8))

weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args)
if activation_quantizer_args is not None:
activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args)
else:
activation_quantizer = None

if mode == "int8":
config = Int8QuantizationConfig(
weight_quantizer=weight_quantizer,
activation_quantizer=activation_quantizer,
)
elif mode == "int4":
config = Int4QuantizationConfig(
weight_quantizer=weight_quantizer,
activation_quantizer=activation_quantizer,
)

layer.quantize(mode, config=config)

if activation_quantizer_args is not None:
# Verify inputs_quantizer is set correctly
self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer)
self.assertEqual(layer.inputs_quantizer.axis, (-1,))
else:
# Verify inputs_quantizer is None
self.assertIsNone(layer.inputs_quantizer)

# Verify call works
x = np.random.random((2, 8)).astype("float32")
y = layer(x)
self.assertEqual(y.shape, (2, 32))

if mode == "int4":
# Verify kernel is int8 (packed int4)
self.assertEqual(
backend.standardize_dtype(layer._kernel.dtype), "int8"
)

@pytest.mark.requires_trainable_backend
def test_dense_basics(self):
# 2D case, no bias.
Expand Down
Loading
Loading