|
6 | 6 | import numpy as np |
7 | 7 |
|
8 | 8 | from keras.src import activations |
| 9 | +from keras.src import backend |
9 | 10 | from keras.src import constraints |
10 | 11 | from keras.src import dtype_policies |
11 | 12 | from keras.src import initializers |
@@ -741,12 +742,27 @@ def grad_fn(*args, upstream=None): |
741 | 742 | inputs_scale = self._adjust_scale_for_quant( |
742 | 743 | inputs_scale, "input" |
743 | 744 | ) |
| 745 | + x = ops.einsum(self.equation, inputs, kernel) |
| 746 | + # De-scale outputs |
| 747 | + x = ops.cast(x, self.compute_dtype) |
| 748 | + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) |
744 | 749 | else: |
745 | | - inputs_scale = ops.ones((1,), dtype=self.compute_dtype) |
746 | | - x = ops.einsum(self.equation, inputs, kernel) |
747 | | - # De-scale outputs |
748 | | - x = ops.cast(x, self.compute_dtype) |
749 | | - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) |
| 750 | + # Weight-only quantization: dequantize kernel and use float |
| 751 | + # einsum. This is a workaround for PyTorch's einsum which |
| 752 | + # doesn't support mixed-precision inputs (float input, |
| 753 | + # int8 kernel). |
| 754 | + if backend.backend() == "torch": |
| 755 | + kernel_scale = self._adjust_scale_for_dequant(kernel_scale) |
| 756 | + float_kernel = ops.divide( |
| 757 | + ops.cast(kernel, dtype=self.compute_dtype), |
| 758 | + kernel_scale, |
| 759 | + ) |
| 760 | + x = ops.einsum(self.equation, inputs, float_kernel) |
| 761 | + else: |
| 762 | + x = ops.einsum(self.equation, inputs, kernel) |
| 763 | + # De-scale outputs |
| 764 | + x = ops.cast(x, self.compute_dtype) |
| 765 | + x = ops.divide(x, kernel_scale) |
750 | 766 | return x, grad_fn |
751 | 767 |
|
752 | 768 | x = einsum_with_inputs_gradient( |
@@ -823,16 +839,29 @@ def grad_fn(*args, upstream=None): |
823 | 839 | inputs_scale = self._adjust_scale_for_quant( |
824 | 840 | inputs_scale, "input" |
825 | 841 | ) |
| 842 | + x = ops.einsum(self.equation, inputs_q, unpacked_kernel) |
| 843 | + # De-scale outputs |
| 844 | + x = ops.cast(x, self.compute_dtype) |
| 845 | + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) |
826 | 846 | else: |
827 | | - inputs_q = inputs |
828 | | - inputs_scale = ops.ones((1,), dtype=self.compute_dtype) |
829 | | - |
830 | | - # Compute einsum on quantized inputs and unpacked int4 kernel. |
831 | | - x = ops.einsum(self.equation, inputs_q, unpacked_kernel) |
832 | | - |
833 | | - # De-scale outputs. |
834 | | - x = ops.cast(x, self.compute_dtype) |
835 | | - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) |
| 847 | + # Weight-only quantization: dequantize kernel and use float |
| 848 | + # einsum. This is a workaround for PyTorch's einsum which |
| 849 | + # doesn't support mixed-precision inputs (float input, |
| 850 | + # int4 kernel). |
| 851 | + if backend.backend() == "torch": |
| 852 | + # Align `kernel_scale` to the same layout as |
| 853 | + # `unpacked_kernel`. |
| 854 | + kernel_scale = self._adjust_scale_for_dequant(kernel_scale) |
| 855 | + float_kernel = ops.divide( |
| 856 | + ops.cast(unpacked_kernel, dtype=self.compute_dtype), |
| 857 | + kernel_scale, |
| 858 | + ) |
| 859 | + x = ops.einsum(self.equation, inputs, float_kernel) |
| 860 | + else: |
| 861 | + x = ops.einsum(self.equation, inputs, unpacked_kernel) |
| 862 | + # De-scale outputs |
| 863 | + x = ops.cast(x, self.compute_dtype) |
| 864 | + x = ops.divide(x, kernel_scale) |
836 | 865 | return x, grad_fn |
837 | 866 |
|
838 | 867 | x = einsum_with_inputs_gradient( |
|
0 commit comments