@@ -50,8 +50,8 @@ def __init__(
50
50
quant_axes (dict[str, int], optional):
51
51
op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
52
52
customized_weight_config:
53
- customized weight config for nodes if needed.
54
- If both customized_weight_config and nodes_to_exclude are set, nodes_to_exclude overwrites customized_weight_config .
53
+ customized weight config for nodes if needed. It is dictionary with node name as key,
54
+ and the value is a dict of customized config .
55
55
"""
56
56
self .algorithm = algorithm
57
57
self .quant_format = quant_format
@@ -81,6 +81,9 @@ def __init__(
81
81
Defaults to QuantFormat.QOperator.
82
82
op_types_to_quantize (optional):
83
83
set of operator types to quantize.
84
+ customized_weight_config:
85
+ customized weight config for nodes if needed. It is dictionary with node name as key,
86
+ and the value is a dict of customized config.
84
87
"""
85
88
assert quant_format == QuantFormat .QOperator , "RTN only supports QOperator format"
86
89
@@ -220,6 +223,8 @@ def __init__(
220
223
set of operator types to quantize.
221
224
quant_axes (dict[str, int], optional):
222
225
op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
226
+ bits (int, optional):
227
+ number of bits per element after quantization. Default 4.
223
228
"""
224
229
super ().__init__ (
225
230
algorithm = "DEFAULT" ,
@@ -654,32 +659,36 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
654
659
b_array_torch = torch .from_numpy (b_array )
655
660
if torch .cuda .is_available ():
656
661
b_array_torch = b_array_torch .cuda ()
662
+
663
+ bits = self .config .bits
657
664
quant_weight_torch , scales_torch , zero_points_torch = self .quantize_internal (
658
- b_array_torch .T , bits = self . config . bits , group_size = self .config .block_size
665
+ b_array_torch .T , bits = bits , group_size = self .config .block_size
659
666
)
660
667
quant_weight_torch = quant_weight_torch .contiguous ()
661
668
scales_torch = scales_torch .contiguous ()
662
669
zero_points_torch = zero_points_torch .contiguous ()
663
670
671
+ packed_size = 8 // bits # number of elements packed into one byte
672
+
664
673
packed_torch = torch .zeros (
665
- (quant_weight_torch .shape [0 ], quant_weight_torch .shape [1 ] // 2 ),
674
+ (quant_weight_torch .shape [0 ], quant_weight_torch .shape [1 ] // packed_size ),
666
675
dtype = torch .uint8 ,
667
676
device = quant_weight_torch .device ,
668
677
)
669
- self .pack_on_row_fast_248bit (packed_torch , quant_weight_torch , self . config . bits )
678
+ self .pack_on_row_fast_248bit (packed_torch , quant_weight_torch , bits )
670
679
scales = scales_torch .cpu ().numpy ()
671
680
zero_points = zero_points_torch .cpu ().numpy ()
672
681
# reshape to the predefined shape in MatmulNbits
673
682
scales = scales .reshape (- 1 )
674
683
zero_points = zero_points .reshape (- 1 )
675
684
rows , cols = b_array_torch .shape
676
685
block_size = self .config .block_size
677
- blob_size = block_size // 2
686
+ blob_size = block_size // packed_size
678
687
k_blocks = (rows + block_size - 1 ) // block_size
679
688
packed_torch = packed_torch .reshape (cols , k_blocks , blob_size )
680
689
681
690
b_quant = onnx .numpy_helper .from_array (packed_torch .cpu ().numpy ())
682
- b_quant .name = b_pb .name + "_Q4"
691
+ b_quant .name = b_pb .name + "_Q" + str ( bits )
683
692
for input in bs_graph .input :
684
693
if input .name == input_b :
685
694
bs_graph .input .remove (input )
@@ -699,21 +708,21 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
699
708
rows , cols = b_array .shape
700
709
kwargs ["K" ] = rows
701
710
kwargs ["N" ] = cols
702
- kwargs ["bits" ] = self . config . bits
711
+ kwargs ["bits" ] = bits
703
712
kwargs ["block_size" ] = self .config .block_size
704
713
705
- matmul_q4_node = onnx .helper .make_node (
714
+ matmul_q_node = onnx .helper .make_node (
706
715
"MatMulNBits" ,
707
716
inputs = input_names ,
708
717
outputs = [node .output [0 ]],
709
- name = node .name + "_Q4" if node .name else "" ,
718
+ name = node .name + "_Q" + str ( bits ) if node .name else "" ,
710
719
domain = "com.microsoft" ,
711
720
** kwargs ,
712
721
)
713
722
714
723
logger .info (f"complete quantization of { node .name } ..." )
715
724
716
- return [matmul_q4_node ]
725
+ return [matmul_q_node ]
717
726
718
727
719
728
def get_initializer (name , graph_path : list [GraphProto ]) -> tuple [TensorProto , GraphProto ]:
@@ -761,7 +770,7 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n
761
770
packed , fp32weight , scales , zero_point , block_size , cols , rows , self .config .is_symmetric
762
771
)
763
772
else :
764
- # QDQ format only support 4 bits quantization
773
+ assert qbits == 4 , " QDQ format only support 4 bits quantization"
765
774
packed = np .zeros ((rows * cols + 1 ) // 2 , dtype = "uint8" )
766
775
zero_point = np .zeros ((cols * k_blocks + 1 ) // 2 , dtype = "uint8" )
767
776
scales = np .zeros ((k_blocks , cols ), dtype = fp32weight .dtype )
@@ -1095,14 +1104,13 @@ def quantize_awq(self, model: ModelProto | str) -> ModelProto:
1095
1104
return quantized_model
1096
1105
1097
1106
1098
- # TODO(fajin): change class name
1099
- class MatMul4BitsQuantizer :
1107
+ class MatMulNBitsQuantizer :
1100
1108
"""
1101
1109
Target node: QOperator node: QDQ nodes:
1102
1110
MatMul MatMulNBits DeQuantizeLinear -> MatMul
1103
1111
Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
1104
1112
1105
- Perform 4b quantization of constant weights for target nodes.
1113
+ Perform 4/8 bits quantization of constant weights for target nodes.
1106
1114
If algo_config.quant_format is QOperator:
1107
1115
- nodes are replaced by the corresponding QOperator nodes.
1108
1116
- quantized weights are stored in the contrib ops.
@@ -1114,6 +1122,7 @@ class MatMul4BitsQuantizer:
1114
1122
Note:
1115
1123
- for quantized gather, the memory usage of "DequantizeLinear + Gather" is the same as the original Gather
1116
1124
during runtime. Therefor it is not recommended.
1125
+ - when a node is in nodes_to_exclude, and the node configuration in algo_config.customized_weight_config will be ignored.
1117
1126
"""
1118
1127
1119
1128
def __init__ (
@@ -1148,8 +1157,13 @@ def __init__(
1148
1157
quant_format = quant_format ,
1149
1158
op_types_to_quantize = op_types_to_quantize ,
1150
1159
quant_axes = quant_axes ,
1160
+ bits = 4 , # default to 4 bits
1151
1161
)
1162
+
1152
1163
self .algo_config = algo_config
1164
+ if hasattr (self .algo_config , "bits" ):
1165
+ assert self .algo_config .bits in [4 , 8 ], "Only support 4 or 8 bits quantization"
1166
+
1153
1167
if algo_config .algorithm == "HQQ" :
1154
1168
self .node_quantizer = HQQWeightOnlyQuantizer (self .algo_config )
1155
1169
elif algo_config .algorithm == "DEFAULT" :
@@ -1511,7 +1525,7 @@ def parse_args():
1511
1525
else :
1512
1526
raise ValueError (f"Unsupported quantization method: { args .quant_method } " )
1513
1527
1514
- quant = MatMul4BitsQuantizer (
1528
+ quant = MatMulNBitsQuantizer (
1515
1529
model = model ,
1516
1530
accuracy_level = args .accuracy_level ,
1517
1531
nodes_to_exclude = args .nodes_to_exclude ,
0 commit comments