88import logging
99import os
1010
11+ import onnx
1112import torch
1213from benchmark_helper import Precision , create_onnxruntime_session , prepare_environment , setup_logger
1314from whisper_chain import chain_model
1415from whisper_encoder import WhisperEncoder
1516from whisper_helper import PRETRAINED_WHISPER_MODELS , WhisperHelper
1617
17- from onnxruntime import quantization
18+ from onnxruntime .quantization .matmul_nbits_quantizer import (
19+ KQuantWeightOnlyQuantConfig ,
20+ MatMulNBitsQuantizer ,
21+ QuantFormat ,
22+ )
1823
1924logger = logging .getLogger ("" )
2025
@@ -94,8 +99,8 @@ def parse_arguments(argv=None):
9499 required = False ,
95100 type = Precision ,
96101 default = Precision .FLOAT32 ,
97- choices = [Precision .FLOAT32 , Precision .FLOAT16 , Precision .INT8 ],
98- help = "Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization" ,
102+ choices = [Precision .FLOAT32 , Precision .FLOAT16 , Precision .INT8 , Precision . INT4 ],
103+ help = "Precision of model to run. fp32 for full precision, fp16 for half precision, int8/int4 for quantization" ,
99104 )
100105
101106 conversion_args .add_argument (
@@ -289,28 +294,20 @@ def parse_arguments(argv=None):
289294 ###################################
290295
291296 quant_args .add_argument (
292- "--quantize_embedding_layer" ,
293- required = False ,
294- action = "store_true" ,
295- help = "Quantize MatMul, GEMM, and Gather." ,
296- )
297- quant_args .set_defaults (quantize_embedding_layer = False )
298-
299- quant_args .add_argument (
300- "--quantize_per_channel" ,
297+ "--accuracy_level" ,
298+ default = 0 ,
301299 required = False ,
302- action = "store_true" ,
303- help = "Quantize weights per each channel ." ,
300+ type = int ,
301+ help = "Accuracy level of the 4-bit quantized MatMul computation ." ,
304302 )
305- quant_args .set_defaults (quantize_per_channel = False )
306303
307304 quant_args .add_argument (
308- "--quantize_reduce_range " ,
305+ "--quantize_symmetric " ,
309306 required = False ,
310307 action = "store_true" ,
311- help = "Quantize weights with 7 bits. " ,
308+ help = "Quantize weights symmetrically " ,
312309 )
313- quant_args .set_defaults (quantize_reduce_range = False )
310+ quant_args .set_defaults (quantize_symmetric = False )
314311
315312 args = parser .parse_args (argv )
316313
@@ -323,6 +320,22 @@ def parse_arguments(argv=None):
323320 return args
324321
325322
323+ # quant_method is reserved for mixed precision in future
324+ def make_quant_algo_config (precision , quant_method : str , matmul_nodes = None ):
325+ customized_weight_config = {}
326+ quant_algo_config = None
327+
328+ # need to use k_quant for int8
329+ if precision == Precision .INT8 :
330+ for node_name in matmul_nodes :
331+ customized_weight_config [node_name ] = {"bits" : 8 }
332+ quant_algo_config = KQuantWeightOnlyQuantConfig (customized_weight_config = customized_weight_config )
333+ else :
334+ quant_algo_config = KQuantWeightOnlyQuantConfig (customized_weight_config = customized_weight_config )
335+
336+ return quant_algo_config
337+
338+
326339def export_onnx_models (
327340 model_name_or_path ,
328341 model_impl ,
@@ -340,19 +353,21 @@ def export_onnx_models(
340353 output_qk : bool = False ,
341354 overwrite : bool = False ,
342355 use_int32_inputs : bool = True ,
343- quantize_embedding_layer : bool = False ,
344- quantize_per_channel : bool = False ,
345- quantize_reduce_range : bool = False ,
356+ accuracy_level : int = 0 ,
357+ quantize_symmetric : bool = False ,
346358 provider : str = "cpu" ,
347359):
348360 device = torch .device ("cuda" if use_gpu else "cpu" )
361+ if not use_gpu :
362+ accuracy_level = 4 # change to 4 for CPU EP
363+ use_fp16_inputs = precision == Precision .FLOAT16 or (precision in (Precision .INT8 , Precision .INT4 ) and use_gpu )
349364
350365 models = WhisperHelper .load_model (
351366 model_name_or_path ,
352367 model_impl ,
353368 cache_dir ,
354369 device ,
355- torch .float16 if precision == Precision . FLOAT16 else torch .float32 ,
370+ torch .float16 if use_fp16_inputs else torch .float32 ,
356371 merge_encoder_and_decoder_init ,
357372 no_beam_search_op ,
358373 output_qk ,
@@ -384,7 +399,7 @@ def export_onnx_models(
384399 PROVIDERS [provider ],
385400 verbose ,
386401 use_external_data_format ,
387- use_fp16_inputs = ( precision == Precision . FLOAT16 ) ,
402+ use_fp16_inputs = use_fp16_inputs ,
388403 use_int32_inputs = use_int32_inputs ,
389404 use_encoder_hidden_states = (name == "decoder_init" ),
390405 use_kv_cache_inputs = (name == "decoder" ),
@@ -430,27 +445,43 @@ def export_onnx_models(
430445 model .verify_onnx (
431446 onnx_path ,
432447 PROVIDERS [provider ],
433- use_fp16_inputs = ( precision == Precision . FLOAT16 ) ,
448+ use_fp16_inputs = use_fp16_inputs ,
434449 )
435450 else :
436451 model .verify_onnx (
437452 onnx_path ,
438453 PROVIDERS [provider ],
439- use_fp16_inputs = ( precision == Precision . FLOAT16 ) ,
454+ use_fp16_inputs = use_fp16_inputs ,
440455 use_int32_inputs = use_int32_inputs ,
441456 )
442457
443- if precision == Precision .INT8 :
444- quantization .quantize_dynamic (
445- onnx_path ,
458+ if precision in (Precision .INT8 , Precision .INT4 ):
459+ onnx_model = onnx .load (onnx_path , load_external_data = True )
460+ matmul_nodes = [node .name for node in onnx_model .graph .node if node .op_type == "MatMul" ]
461+ quant_algo_config = make_quant_algo_config (precision , "k_quant" , matmul_nodes )
462+
463+ quant = MatMulNBitsQuantizer (
464+ model = onnx_model ,
465+ block_size = 32 ,
466+ is_symmetric = quantize_symmetric ,
467+ accuracy_level = accuracy_level ,
468+ quant_format = QuantFormat .QOperator ,
469+ op_types_to_quantize = ("MatMul" ,),
470+ algo_config = quant_algo_config ,
471+ )
472+ quant .process ()
473+ if os .path .exists (output_path ):
474+ os .remove (output_path )
475+ if os .path .exists (output_path + ".data" ):
476+ os .remove (output_path + ".data" )
477+ onnx .save_model (
478+ quant .model .model ,
446479 output_path ,
447- op_types_to_quantize = (
448- ["MatMul" , "Gemm" , "Gather" ] if quantize_embedding_layer else ["MatMul" , "Gemm" ]
449- ),
450- use_external_data_format = use_external_data_format ,
451- per_channel = quantize_per_channel ,
452- reduce_range = quantize_reduce_range ,
453- extra_options = {"MatMulConstBOnly" : True },
480+ save_as_external_data = True ,
481+ all_tensors_to_one_file = True ,
482+ location = os .path .basename (output_path ) + ".data" ,
483+ size_threshold = 0 ,
484+ convert_attribute = False ,
454485 )
455486 else :
456487 logger .info (f"Skip optimizing: existing ONNX model { onnx_path } " )
@@ -493,9 +524,8 @@ def main(argv=None):
493524 args .output_cross_qk ,
494525 args .overwrite ,
495526 not args .use_int64_inputs ,
496- args .quantize_embedding_layer ,
497- args .quantize_per_channel ,
498- args .quantize_reduce_range ,
527+ args .accuracy_level ,
528+ args .quantize_symmetric ,
499529 args .provider ,
500530 )
501531
0 commit comments