99from keras .src import quantizers
1010from keras .src import regularizers
1111from keras .src .api_export import keras_export
12+ from keras .src .dtype_policies import QuantizedFloat8DTypePolicy
1213from keras .src .layers .input_spec import InputSpec
1314from keras .src .layers .layer import Layer
15+ from keras .src .quantizers .quantization_config import QuantizationConfig
16+ from keras .src .quantizers .quantization_config import validate_and_resolve_config
1417from keras .src .quantizers .quantizers import dequantize_with_sz_map
1518
1619
@@ -370,9 +373,9 @@ def variable_serialization_spec(self):
370373
371374 def quantized_build (self , kernel_shape , mode , config = None ):
372375 if mode == "int8" :
373- self ._int8_build (kernel_shape )
376+ self ._int8_build (kernel_shape , config )
374377 elif mode == "int4" :
375- self ._int4_build (kernel_shape )
378+ self ._int4_build (kernel_shape , config )
376379 elif mode == "float8" :
377380 self ._float8_build ()
378381 elif mode == "gptq" :
@@ -381,8 +384,14 @@ def quantized_build(self, kernel_shape, mode, config=None):
381384 raise self ._quantization_mode_error (mode )
382385 self ._is_quantized = True
383386
384- def _int8_build (self , kernel_shape ):
385- self .inputs_quantizer = quantizers .AbsMaxQuantizer (axis = - 1 )
387+ def _int8_build (self , kernel_shape , config = None ):
388+ # Per-channel int8 quantizer for the last axis (features).
389+ self .inputs_quantizer = (
390+ QuantizationConfig .activation_quantizer_or_default (
391+ config , quantizers .AbsMaxQuantizer (axis = - 1 )
392+ )
393+ )
394+
386395 self ._kernel = self .add_weight (
387396 name = "kernel" ,
388397 shape = kernel_shape ,
@@ -481,7 +490,7 @@ def _gptq_call(self, inputs, training=False):
481490 y = self .activation (y )
482491 return y
483492
484- def _int4_build (self , kernel_shape ):
493+ def _int4_build (self , kernel_shape , config = None ):
485494 """Build variables for int4 quantization.
486495
487496 `kernel_shape` is the *original* float32 kernel shape
@@ -490,8 +499,10 @@ def _int4_build(self, kernel_shape):
490499 int8 byte.
491500 """
492501 # Per-channel int8 quantizer for the last axis (features).
493- self .inputs_quantizer = quantizers .AbsMaxQuantizer (
494- axis = - 1 ,
502+ self .inputs_quantizer = (
503+ QuantizationConfig .activation_quantizer_or_default (
504+ config , quantizers .AbsMaxQuantizer (axis = - 1 )
505+ )
495506 )
496507 input_dim , output_dim = kernel_shape
497508 packed_rows = (input_dim + 1 ) // 2 # ceil for odd dims
@@ -515,8 +526,6 @@ def _int4_build(self, kernel_shape):
515526 self ._orig_input_dim = input_dim
516527
517528 def _float8_build (self ):
518- from keras .src .dtype_policies import QuantizedFloat8DTypePolicy
519-
520529 # If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set
521530 # `amax_history_length` to its default value.
522531 amax_history_length = getattr (
@@ -580,7 +589,15 @@ def grad_fn(*args, upstream=None):
580589 inputs_grad = ops .matmul (upstream , ops .transpose (float_kernel ))
581590 return (inputs_grad , None , None )
582591
583- inputs , inputs_scale = self .inputs_quantizer (inputs )
592+ if self .inputs_quantizer :
593+ inputs , inputs_scale = self .inputs_quantizer (inputs )
594+ else :
595+ # Weight-only quantization: inputs are not quantized
596+ # We still need inputs_scale for the formula:
597+ # x = x / (inputs_scale * kernel_scale)
598+ # If inputs are not quantized, inputs_scale should be 1.
599+ inputs_scale = ops .ones ((1 ,), dtype = self .compute_dtype )
600+
584601 x = ops .matmul (inputs , kernel )
585602 # De-scale outputs
586603 x = ops .cast (x , self .compute_dtype )
@@ -631,7 +648,10 @@ def grad_fn(*args, upstream=None):
631648 inputs_grad = ops .matmul (upstream , ops .transpose (float_kernel ))
632649 return (inputs_grad , None , None )
633650
634- inputs , inputs_scale = self .inputs_quantizer (inputs )
651+ if self .inputs_quantizer :
652+ inputs , inputs_scale = self .inputs_quantizer (inputs )
653+ else :
654+ inputs_scale = ops .ones ((1 ,), dtype = self .compute_dtype )
635655 x = ops .matmul (inputs , unpacked_kernel )
636656 x = ops .cast (x , self .compute_dtype )
637657 x = ops .divide (x , ops .multiply (inputs_scale , kernel_scale ))
@@ -746,38 +766,53 @@ def grad(*args, upstream=None, variables=None):
746766 x = self .activation (x )
747767 return x
748768
749- def quantize (self , mode , type_check = True , config = None ):
769+ def quantize (self , mode = None , type_check = True , config = None ):
750770 # Prevent quantization of the subclasses
751771 if type_check and (type (self ) is not Dense ):
752772 raise self ._not_implemented_error (self .quantize )
753773
774+ config = validate_and_resolve_config (mode , config )
775+ mode = config .mode
776+
754777 kernel_shape = self ._kernel .shape
755778 if mode == "int8" :
756- kernel_value , kernel_scale = quantizers .abs_max_quantize (
757- self ._kernel , axis = 0 , to_numpy = True
779+ # Handle weight quantization
780+ # Quantize `self._kernel` to int8 and compute corresponding scale
781+ weight_quantizer = QuantizationConfig .weight_quantizer_or_default (
782+ config , quantizers .AbsMaxQuantizer (axis = 0 )
758783 )
759- kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
784+ kernel_value , kernel_scale = weight_quantizer (
785+ self ._kernel , to_numpy = True
786+ )
787+
788+ if len (kernel_scale .shape ) > 0 :
789+ kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
790+
760791 del self ._kernel
761792 # Build variables for int8 mode
762- self .quantized_build (kernel_shape , mode )
793+ self .quantized_build (kernel_shape , mode , config )
763794 self ._kernel .assign (kernel_value )
764795 self .kernel_scale .assign (kernel_scale )
765796 elif mode == "int4" :
766797 # 1. Quantize to int4 values (still int8 dtype, range [-8,7])
767- kernel_value_int4 , kernel_scale = quantizers .abs_max_quantize (
768- self ._kernel ,
769- axis = 0 ,
770- value_range = (- 8 , 7 ),
771- dtype = "int8" ,
772- to_numpy = True ,
798+ weight_quantizer = QuantizationConfig .weight_quantizer_or_default (
799+ config ,
800+ quantizers .AbsMaxQuantizer (
801+ axis = 0 , value_range = (- 8 , 7 ), output_dtype = "int8"
802+ ),
773803 )
774- kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
804+ kernel_value_int4 , kernel_scale = weight_quantizer (
805+ self ._kernel , to_numpy = True
806+ )
807+
808+ if len (kernel_scale .shape ) > 0 :
809+ kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
775810 # 2. Pack two int4 values into a single int8 byte.
776811 packed_kernel_value , _ , _ = quantizers .pack_int4 (kernel_value_int4 )
777812 del self ._kernel
778813 # Build variables using the original kernel shape; _int4_build will
779814 # compute the packed shape internally.
780- self .quantized_build (kernel_shape , mode )
815+ self .quantized_build (kernel_shape , mode , config )
781816 # Assign packed values.
782817 self ._kernel .assign (packed_kernel_value )
783818 self .kernel_scale .assign (kernel_scale )
0 commit comments