@@ -199,6 +199,7 @@ def __init__(
199199 op_types_to_quantize : tuple [str , ...] | None = None ,
200200 quant_axes : tuple [tuple [str , int ], ...] | None = None ,
201201 bits : int = 4 ,
202+ channel_wised_quantize : bool = False ,
202203 ):
203204 """
204205 This is a class for weight only affine quantization configuration.
@@ -231,6 +232,9 @@ def __init__(
231232 self .is_symmetric = is_symmetric
232233 self .bits = bits
233234 self .accuracy_level = accuracy_level
235+ self .channel_wised_quantize = channel_wised_quantize
236+ if channel_wised_quantize and quant_format == QuantFormat .QOperator :
237+ raise NotImplementedError ("QuantFormat.QOperator is not supported channel_wised_quantize yet" )
234238
235239
236240class NVAWQWeightOnlyQuantConfig (WeightOnlyQuantConfig ):
@@ -725,6 +729,26 @@ def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, Gr
725729 return None , None
726730
727731
732+ # transpose int4 matrix (packed as uint8)
733+ def transpose_packed_int4_matrix (packed , rows , cols ):
734+ # unpack to int4 matrix
735+ total = rows * cols
736+ high = (packed >> 4 ) & 0x0F
737+ low = packed & 0x0F
738+ int4_vals = np .empty (total , dtype = np .uint8 )
739+ int4_vals [0 ::2 ] = low
740+ int4_vals [1 ::2 ] = high
741+ int4_matrix = int4_vals .reshape ((rows , cols ))
742+
743+ # transpose int4 matrix
744+ int4_matrix_transposed = int4_matrix .T
745+
746+ # pack to uint8
747+ flat = int4_matrix_transposed .reshape (- 1 )
748+ packed = ((flat [1 ::2 ] << 4 ) & 0xF0 ) | (flat [0 ::2 ] & 0x0F )
749+ return packed .astype (np .uint8 )
750+
751+
728752class DefaultWeightOnlyQuantizer :
729753 def __init__ (self , config : DefaultWeightOnlyQuantConfig ):
730754 self .config = config
@@ -761,6 +785,10 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n
761785 packed , fp32weight , scales , zero_point , block_size , cols , rows , self .config .is_symmetric
762786 )
763787 else :
788+ # block size equal to rows (K) if channel wised quantize enabled
789+ block_size = rows if self .config .channel_wised_quantize else self .config .block_size
790+ k_blocks = (rows + block_size - 1 ) // block_size
791+
764792 # QDQ format only support 4 bits quantization
765793 packed = np .zeros ((rows * cols + 1 ) // 2 , dtype = "uint8" )
766794 zero_point = np .zeros ((cols * k_blocks + 1 ) // 2 , dtype = "uint8" )
@@ -803,6 +831,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
803831 )
804832 scales_tensor = onnx .numpy_helper .from_array (scales , b_tensor .name + "_DQ_scales" )
805833
834+ # if QDQ, CW and SYM enabled, optimize for Intel NPU, tranpose the weight to NHWC format will increase performance
835+ qdq_opt_for_intel_npu_enabled = self .config .quant_format == QuantFormat .QDQ \
836+ and self .config .channel_wised_quantize and self .config .is_symmetric
837+ if qdq_opt_for_intel_npu_enabled :
838+ rows , cols = b_ndarray .shape
839+ packed = transpose_packed_int4_matrix (packed , rows , cols )
840+ scales = scales .reshape ((cols , 1 )) # (cols, 1)
841+ b_quant = onnx .helper .make_tensor (b_tensor .name + f"_DQ_Q{ bits } " , qtype , [cols , rows ], packed .tobytes (), True )
842+ scales_tensor = onnx .numpy_helper .from_array (scales , b_tensor .name + "_DQ_scales" )
843+
806844 for input in b_graph .input :
807845 if input .name == input_b :
808846 b_graph .input .remove (input )
@@ -840,15 +878,21 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
840878 else :
841879 dq_input_names = [b_quant .name , scales_tensor .name ]
842880 dq_output_names = [b_quant .name + "_output" ]
843- matmul_input_names = [node .input [0 ], dq_output_names [0 ]]
881+ tp_input_names = [dq_output_names [0 ]]
882+ tp_output_names = [dq_output_names [0 ] + "_transposed" ]
883+ matmul_input_names = [node .input [0 ], tp_output_names [0 ] if qdq_opt_for_intel_npu_enabled else dq_output_names [0 ]]
844884 matmul_output_names = [node .output [0 ]]
845885 if not self .config .is_symmetric :
846886 zp_tensor = onnx .helper .make_tensor (
847887 b_tensor .name + "_DQ_zero_points" , qtype , scales .shape , zero_points .tobytes (), True
848888 )
849889 dq_input_names .append (zp_tensor .name )
850890 b_graph .initializer .extend ([zp_tensor ])
851- dq_kwargs = {"axis" : 0 , "block_size" : self .config .block_size }
891+ rows , cols = b_ndarray .shape
892+ dq_kwargs = {
893+ "axis" : 1 if qdq_opt_for_intel_npu_enabled else 0 ,
894+ "block_size" : rows if self .config .channel_wised_quantize else self .config .block_size
895+ }
852896 dq_node = onnx .helper .make_node (
853897 "DequantizeLinear" ,
854898 inputs = dq_input_names ,
@@ -862,7 +906,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
862906 outputs = matmul_output_names ,
863907 name = node .name + f"_matmul_Q{ bits } " if node .name else "" ,
864908 )
865- output_nodes .extend ([dq_node , matmul_node ])
909+ if qdq_opt_for_intel_npu_enabled :
910+ tp_node = onnx .helper .make_node (
911+ "Transpose" ,
912+ inputs = tp_input_names ,
913+ outputs = tp_output_names ,
914+ perm = [1 ,0 ],
915+ )
916+ output_nodes .extend ([dq_node , tp_node , matmul_node ])
917+ else :
918+ output_nodes .extend ([dq_node , matmul_node ])
866919
867920 return output_nodes
868921
0 commit comments