Skip to content

Commit 0d26928

Browse files
authored
Rename matmul_4bits_quantizer.py to matmul_nbits_quantizer.py (#24472)
### Description * Rename filename and class name since it supports 4 and 8 bits. * Update HQQWeightOnlyQuantizer to support 8 bits. * Update some comments. ### Motivation and Context #24384 added 8 bits support for the default weight only quantizer.
1 parent cd9c02f commit 0d26928

File tree

5 files changed

+57
-32
lines changed

5 files changed

+57
-32
lines changed

Diff for: onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py renamed to onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py

+30-16
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def __init__(
5050
quant_axes (dict[str, int], optional):
5151
op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
5252
customized_weight_config:
53-
customized weight config for nodes if needed.
54-
If both customized_weight_config and nodes_to_exclude are set, nodes_to_exclude overwrites customized_weight_config.
53+
customized weight config for nodes if needed. It is dictionary with node name as key,
54+
and the value is a dict of customized config.
5555
"""
5656
self.algorithm = algorithm
5757
self.quant_format = quant_format
@@ -81,6 +81,9 @@ def __init__(
8181
Defaults to QuantFormat.QOperator.
8282
op_types_to_quantize (optional):
8383
set of operator types to quantize.
84+
customized_weight_config:
85+
customized weight config for nodes if needed. It is dictionary with node name as key,
86+
and the value is a dict of customized config.
8487
"""
8588
assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format"
8689

@@ -220,6 +223,8 @@ def __init__(
220223
set of operator types to quantize.
221224
quant_axes (dict[str, int], optional):
222225
op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
226+
bits (int, optional):
227+
number of bits per element after quantization. Default 4.
223228
"""
224229
super().__init__(
225230
algorithm="DEFAULT",
@@ -654,32 +659,36 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
654659
b_array_torch = torch.from_numpy(b_array)
655660
if torch.cuda.is_available():
656661
b_array_torch = b_array_torch.cuda()
662+
663+
bits = self.config.bits
657664
quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal(
658-
b_array_torch.T, bits=self.config.bits, group_size=self.config.block_size
665+
b_array_torch.T, bits=bits, group_size=self.config.block_size
659666
)
660667
quant_weight_torch = quant_weight_torch.contiguous()
661668
scales_torch = scales_torch.contiguous()
662669
zero_points_torch = zero_points_torch.contiguous()
663670

671+
packed_size = 8 // bits # number of elements packed into one byte
672+
664673
packed_torch = torch.zeros(
665-
(quant_weight_torch.shape[0], quant_weight_torch.shape[1] // 2),
674+
(quant_weight_torch.shape[0], quant_weight_torch.shape[1] // packed_size),
666675
dtype=torch.uint8,
667676
device=quant_weight_torch.device,
668677
)
669-
self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, self.config.bits)
678+
self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, bits)
670679
scales = scales_torch.cpu().numpy()
671680
zero_points = zero_points_torch.cpu().numpy()
672681
# reshape to the predefined shape in MatmulNbits
673682
scales = scales.reshape(-1)
674683
zero_points = zero_points.reshape(-1)
675684
rows, cols = b_array_torch.shape
676685
block_size = self.config.block_size
677-
blob_size = block_size // 2
686+
blob_size = block_size // packed_size
678687
k_blocks = (rows + block_size - 1) // block_size
679688
packed_torch = packed_torch.reshape(cols, k_blocks, blob_size)
680689

681690
b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy())
682-
b_quant.name = b_pb.name + "_Q4"
691+
b_quant.name = b_pb.name + "_Q" + str(bits)
683692
for input in bs_graph.input:
684693
if input.name == input_b:
685694
bs_graph.input.remove(input)
@@ -699,21 +708,21 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
699708
rows, cols = b_array.shape
700709
kwargs["K"] = rows
701710
kwargs["N"] = cols
702-
kwargs["bits"] = self.config.bits
711+
kwargs["bits"] = bits
703712
kwargs["block_size"] = self.config.block_size
704713

705-
matmul_q4_node = onnx.helper.make_node(
714+
matmul_q_node = onnx.helper.make_node(
706715
"MatMulNBits",
707716
inputs=input_names,
708717
outputs=[node.output[0]],
709-
name=node.name + "_Q4" if node.name else "",
718+
name=node.name + "_Q" + str(bits) if node.name else "",
710719
domain="com.microsoft",
711720
**kwargs,
712721
)
713722

714723
logger.info(f"complete quantization of {node.name} ...")
715724

716-
return [matmul_q4_node]
725+
return [matmul_q_node]
717726

718727

719728
def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
@@ -761,7 +770,7 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n
761770
packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
762771
)
763772
else:
764-
# QDQ format only support 4 bits quantization
773+
assert qbits == 4, "QDQ format only support 4 bits quantization"
765774
packed = np.zeros((rows * cols + 1) // 2, dtype="uint8")
766775
zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8")
767776
scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype)
@@ -1095,14 +1104,13 @@ def quantize_awq(self, model: ModelProto | str) -> ModelProto:
10951104
return quantized_model
10961105

10971106

1098-
# TODO(fajin): change class name
1099-
class MatMul4BitsQuantizer:
1107+
class MatMulNBitsQuantizer:
11001108
"""
11011109
Target node: QOperator node: QDQ nodes:
11021110
MatMul MatMulNBits DeQuantizeLinear -> MatMul
11031111
Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
11041112
1105-
Perform 4b quantization of constant weights for target nodes.
1113+
Perform 4/8 bits quantization of constant weights for target nodes.
11061114
If algo_config.quant_format is QOperator:
11071115
- nodes are replaced by the corresponding QOperator nodes.
11081116
- quantized weights are stored in the contrib ops.
@@ -1114,6 +1122,7 @@ class MatMul4BitsQuantizer:
11141122
Note:
11151123
- for quantized gather, the memory usage of "DequantizeLinear + Gather" is the same as the original Gather
11161124
during runtime. Therefor it is not recommended.
1125+
- when a node is in nodes_to_exclude, and the node configuration in algo_config.customized_weight_config will be ignored.
11171126
"""
11181127

11191128
def __init__(
@@ -1148,8 +1157,13 @@ def __init__(
11481157
quant_format=quant_format,
11491158
op_types_to_quantize=op_types_to_quantize,
11501159
quant_axes=quant_axes,
1160+
bits=4, # default to 4 bits
11511161
)
1162+
11521163
self.algo_config = algo_config
1164+
if hasattr(self.algo_config, "bits"):
1165+
assert self.algo_config.bits in [4, 8], "Only support 4 or 8 bits quantization"
1166+
11531167
if algo_config.algorithm == "HQQ":
11541168
self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config)
11551169
elif algo_config.algorithm == "DEFAULT":
@@ -1511,7 +1525,7 @@ def parse_args():
15111525
else:
15121526
raise ValueError(f"Unsupported quantization method: {args.quant_method}")
15131527

1514-
quant = MatMul4BitsQuantizer(
1528+
quant = MatMulNBitsQuantizer(
15151529
model=model,
15161530
accuracy_level=args.accuracy_level,
15171531
nodes_to_exclude=args.nodes_to_exclude,

Diff for: onnxruntime/python/tools/quantization/quantize.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -929,11 +929,11 @@ def quantize(
929929
)
930930
else:
931931
# training package doesn't has quantize_matmul_4bits, avoid global import
932-
from .matmul_4bits_quantizer import MatMul4BitsQuantizer, WeightOnlyQuantConfig
932+
from .matmul_nbits_quantizer import MatMulNBitsQuantizer, WeightOnlyQuantConfig
933933

934934
if isinstance(quant_config, WeightOnlyQuantConfig):
935935
model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load(model_input)
936-
quant = MatMul4BitsQuantizer(model, algo_config=quant_config)
936+
quant = MatMulNBitsQuantizer(model, algo_config=quant_config)
937937
quant.process()
938938
quant.model.save_model_to_file(model_output, True)
939939
else:

Diff for: onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,13 @@
3131
from packaging import version
3232
from transformers import AutoConfig, AutoModelForCausalLM
3333

34+
from onnxruntime import __version__ as ort_version
3435
from onnxruntime import quantization as ort_quantization
35-
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
36+
37+
if version.parse(ort_version) < version.parse("1.22.0"):
38+
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer as MatMulNBitsQuantizer
39+
else:
40+
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
3641

3742
torch_export_onnx_opset_version = 14
3843
logger = logging.getLogger("")
@@ -714,7 +719,7 @@ def get_args():
714719
required=False,
715720
default=32,
716721
type=int,
717-
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
722+
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py for details.",
718723
)
719724

720725
blockwise_group.add_argument(
@@ -1025,7 +1030,7 @@ def main():
10251030
for fp_path, int4_path in zip(old_paths, new_paths, strict=False):
10261031
if os.path.exists(fp_path):
10271032
model = onnx.load_model(fp_path, load_external_data=True)
1028-
quant = MatMul4BitsQuantizer(
1033+
quant = MatMulNBitsQuantizer(
10291034
model=model,
10301035
block_size=args.block_size,
10311036
is_symmetric=True,

Diff for: onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414
from benchmark_helper import Precision
1515
from fusion_options import AttentionOpType
1616
from onnx_model import OnnxModel
17+
from packaging import version
1718
from transformers import AutoConfig, AutoModelForCausalLM
1819

19-
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
20+
from onnxruntime import __version__ as ort_version
21+
22+
if version.parse(ort_version) < version.parse("1.22.0"):
23+
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer as MatMulNBitsQuantizer
24+
else:
25+
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
2026

2127

2228
class ConvertPhi2ToONNX:
@@ -160,7 +166,7 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
160166
return
161167
else:
162168
assert self.precision == Precision.INT4
163-
quant = MatMul4BitsQuantizer(
169+
quant = MatMulNBitsQuantizer(
164170
model=optimizer.model,
165171
block_size=self.block_size,
166172
is_symmetric=True,
@@ -351,7 +357,7 @@ def parse_arguments():
351357
required=False,
352358
default=16,
353359
type=int,
354-
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
360+
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py for details.",
355361
)
356362

357363
parser.add_argument(

Diff for: onnxruntime/test/python/quantization/test_op_matmul_4bits.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,17 @@ def quant_test(
195195
)
196196

197197
# Quantize fp32 model to int4 model
198-
from onnxruntime.quantization import matmul_4bits_quantizer
198+
from onnxruntime.quantization import matmul_nbits_quantizer
199199

200200
model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
201-
quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig(
201+
quant_config = matmul_nbits_quantizer.DefaultWeightOnlyQuantConfig(
202202
block_size=block_size,
203203
is_symmetric=is_symmetric,
204204
quant_format=quant_format,
205205
op_types_to_quantize=op_types_to_quantize,
206206
quant_axes=quant_axes,
207207
)
208-
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config)
208+
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(model, algo_config=quant_config)
209209
quant.process()
210210
quant.model.save_model_to_file(model_int4_path, False)
211211

@@ -260,21 +260,21 @@ def quant_test_with_algo(
260260
)
261261

262262
# Quantize fp32 model to int4 model
263-
from onnxruntime.quantization import matmul_4bits_quantizer
263+
from onnxruntime.quantization import matmul_nbits_quantizer
264264

265265
algo_config = None
266266
if algorithm == "RTN":
267267
# test RTN algorithm
268-
algo_config = matmul_4bits_quantizer.RTNWeightOnlyQuantConfig()
268+
algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig()
269269
elif algorithm == "GPTQ":
270270
# test GPTQ algorithm
271-
algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader)
271+
algo_config = matmul_nbits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader)
272272
elif algorithm == "HQQ":
273273
# test HQQ algorithm
274-
algo_config = matmul_4bits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size)
274+
algo_config = matmul_nbits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size)
275275

276276
model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
277-
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config)
277+
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config)
278278
quant.process()
279279
quant.model.save_model_to_file(model_int4_path, False)
280280

0 commit comments

Comments
 (0)