Skip to content

Commit 3a31239

Browse files
mixed precision einsum fix for torch + fixed tf/jax tests
1 parent a3668d5 commit 3a31239

File tree

4 files changed

+58
-18
lines changed

4 files changed

+58
-18
lines changed

keras/src/layers/core/einsum_dense.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77

88
from keras.src import activations
9+
from keras.src import backend
910
from keras.src import constraints
1011
from keras.src import dtype_policies
1112
from keras.src import initializers
@@ -741,12 +742,27 @@ def grad_fn(*args, upstream=None):
741742
inputs_scale = self._adjust_scale_for_quant(
742743
inputs_scale, "input"
743744
)
745+
x = ops.einsum(self.equation, inputs, kernel)
746+
# De-scale outputs
747+
x = ops.cast(x, self.compute_dtype)
748+
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
744749
else:
745-
inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
746-
x = ops.einsum(self.equation, inputs, kernel)
747-
# De-scale outputs
748-
x = ops.cast(x, self.compute_dtype)
749-
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
750+
# Weight-only quantization: dequantize kernel and use float
751+
# einsum. This is a workaround for PyTorch's einsum which
752+
# doesn't support mixed-precision inputs (float input,
753+
# int8 kernel).
754+
if backend.backend() == "torch":
755+
kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
756+
float_kernel = ops.divide(
757+
ops.cast(kernel, dtype=self.compute_dtype),
758+
kernel_scale,
759+
)
760+
x = ops.einsum(self.equation, inputs, float_kernel)
761+
else:
762+
x = ops.einsum(self.equation, inputs, kernel)
763+
# De-scale outputs
764+
x = ops.cast(x, self.compute_dtype)
765+
x = ops.divide(x, kernel_scale)
750766
return x, grad_fn
751767

752768
x = einsum_with_inputs_gradient(
@@ -823,16 +839,29 @@ def grad_fn(*args, upstream=None):
823839
inputs_scale = self._adjust_scale_for_quant(
824840
inputs_scale, "input"
825841
)
842+
x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
843+
# De-scale outputs
844+
x = ops.cast(x, self.compute_dtype)
845+
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
826846
else:
827-
inputs_q = inputs
828-
inputs_scale = ops.ones((1,), dtype=self.compute_dtype)
829-
830-
# Compute einsum on quantized inputs and unpacked int4 kernel.
831-
x = ops.einsum(self.equation, inputs_q, unpacked_kernel)
832-
833-
# De-scale outputs.
834-
x = ops.cast(x, self.compute_dtype)
835-
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
847+
# Weight-only quantization: dequantize kernel and use float
848+
# einsum. This is a workaround for PyTorch's einsum which
849+
# doesn't support mixed-precision inputs (float input,
850+
# int4 kernel).
851+
if backend.backend() == "torch":
852+
# Align `kernel_scale` to the same layout as
853+
# `unpacked_kernel`.
854+
kernel_scale = self._adjust_scale_for_dequant(kernel_scale)
855+
float_kernel = ops.divide(
856+
ops.cast(unpacked_kernel, dtype=self.compute_dtype),
857+
kernel_scale,
858+
)
859+
x = ops.einsum(self.equation, inputs, float_kernel)
860+
else:
861+
x = ops.einsum(self.equation, inputs, unpacked_kernel)
862+
# De-scale outputs
863+
x = ops.cast(x, self.compute_dtype)
864+
x = ops.divide(x, kernel_scale)
836865
return x, grad_fn
837866

838867
x = einsum_with_inputs_gradient(

keras/src/layers/core/einsum_dense_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ class EinsumDenseTest(testing.TestCase):
3131
{"axis": -1},
3232
),
3333
("int8_weight_only", "int8", {"axis": 0}, None),
34+
(
35+
"int4_weight_only",
36+
"int4",
37+
{"axis": 0, "value_range": (-8, 7), "output_dtype": "int8"},
38+
None,
39+
),
3440
)
3541
def test_einsum_dense_quantize(
3642
self, mode, weight_quantizer_args, activation_quantizer_args

keras/src/quantizers/gptq_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -625,14 +625,16 @@ def test_quantize_gptq_combinations(self, dataset, config):
625625
"mode": "gptq",
626626
"config": None,
627627
"expected_exception": ValueError,
628-
"error_msg": "Mode 'gptq' requires a valid `config`",
628+
"error_msg": "For GPTQ, you must pass a GPTQConfig "
629+
"object explicitly.",
629630
},
630631
{
631632
"testcase_name": "gptq_with_base_quantization_config",
632633
"mode": "gptq",
633634
"config": QuantizationConfig(),
634-
"expected_exception": ValueError,
635-
"error_msg": "Mode 'gptq' requires a valid `config`",
635+
"expected_exception": NotImplementedError,
636+
"error_msg": "Do not instantiate "
637+
"QuantizationConfig directly.",
636638
},
637639
)
638640
def test_quantize_scenarios(

keras/src/quantizers/quantization_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ def __init__(self, weight_quantizer=None, activation_quantizer=None):
1111

1212
@property
1313
def mode(self):
14-
raise NotImplementedError
14+
raise NotImplementedError(
15+
"Subclasses must implement this property. Do not instantiate "
16+
"QuantizationConfig directly."
17+
)
1518

1619
def get_config(self):
1720
return {

0 commit comments

Comments
 (0)