Skip to content

Commit ed7847c

Browse files
Fix quantization in Whisper model export (#26353)
### Description Fix quantization in Whisper model export ### Motivation and Context As titled. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 8ab27d9 commit ed7847c

File tree

7 files changed

+135
-55
lines changed

7 files changed

+135
-55
lines changed

onnxruntime/python/tools/transformers/models/whisper/README.md

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,24 @@ $ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --o
7575
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision fp16 --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk
7676
```
7777

78+
Export + Quantize for INT8 CUDA
79+
```
80+
# From source:
81+
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk
82+
83+
# From wheel:
84+
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk
85+
```
86+
87+
Export + Quantize for INT8 CPU
88+
```
89+
# From source:
90+
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk
91+
92+
# From wheel:
93+
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk
94+
```
95+
7896
## Exporting Whisper with Beam Search
7997

8098
There are several ways to export Whisper with beam search.
@@ -143,13 +161,22 @@ $ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --o
143161
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda
144162
```
145163

146-
Export + Quantize for INT8
164+
Export + Quantize for INT8 CUDA
165+
```
166+
# From source:
167+
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --use_gpu --provider cuda
168+
169+
# From wheel:
170+
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --use_gpu --provider cuda
171+
```
172+
173+
Export + Quantize for INT8 CPU
147174
```
148175
# From source:
149-
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_embedding_layer
176+
$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --provider cpu
150177
151178
# From wheel:
152-
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_embedding_layer
179+
$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --provider cpu
153180
```
154181

155182
Note: INT8 CPU is not compatible with `--output_cross_qk`.

onnxruntime/python/tools/transformers/models/whisper/benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def parse_args():
449449
type=str,
450450
required=True,
451451
default="fp32",
452-
choices=["int8", "fp16", "fp32"],
452+
choices=["int4", "int8", "fp16", "fp32"],
453453
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
454454
)
455455

@@ -579,7 +579,7 @@ def main():
579579
config = WhisperConfig.from_pretrained(args.model_name)
580580
processor = WhisperProcessor.from_pretrained(args.model_name)
581581
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
582-
use_fp16 = args.precision == "fp16"
582+
use_fp16 = args.precision == "fp16" or (args.precision in {"int8", "int4"} and args.device != "cpu")
583583

584584
setattr(args, "processor", processor) # noqa: B010
585585
setattr(args, "target_device", target_device) # noqa: B010

onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def get_args():
9797
"--precision",
9898
type=str,
9999
required=True,
100-
choices=["int8", "fp16", "fp32"],
100+
choices=["int4", "int8", "fp16", "fp32"],
101101
help="Precision to run model",
102102
)
103103

onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,18 @@
88
import logging
99
import os
1010

11+
import onnx
1112
import torch
1213
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
1314
from whisper_chain import chain_model
1415
from whisper_encoder import WhisperEncoder
1516
from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper
1617

17-
from onnxruntime import quantization
18+
from onnxruntime.quantization.matmul_nbits_quantizer import (
19+
KQuantWeightOnlyQuantConfig,
20+
MatMulNBitsQuantizer,
21+
QuantFormat,
22+
)
1823

1924
logger = logging.getLogger("")
2025

@@ -94,8 +99,8 @@ def parse_arguments(argv=None):
9499
required=False,
95100
type=Precision,
96101
default=Precision.FLOAT32,
97-
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8],
98-
help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization",
102+
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4],
103+
help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8/int4 for quantization",
99104
)
100105

101106
conversion_args.add_argument(
@@ -289,28 +294,20 @@ def parse_arguments(argv=None):
289294
###################################
290295

291296
quant_args.add_argument(
292-
"--quantize_embedding_layer",
293-
required=False,
294-
action="store_true",
295-
help="Quantize MatMul, GEMM, and Gather.",
296-
)
297-
quant_args.set_defaults(quantize_embedding_layer=False)
298-
299-
quant_args.add_argument(
300-
"--quantize_per_channel",
297+
"--accuracy_level",
298+
default=0,
301299
required=False,
302-
action="store_true",
303-
help="Quantize weights per each channel.",
300+
type=int,
301+
help="Accuracy level of the 4-bit quantized MatMul computation.",
304302
)
305-
quant_args.set_defaults(quantize_per_channel=False)
306303

307304
quant_args.add_argument(
308-
"--quantize_reduce_range",
305+
"--quantize_symmetric",
309306
required=False,
310307
action="store_true",
311-
help="Quantize weights with 7 bits.",
308+
help="Quantize weights symmetrically",
312309
)
313-
quant_args.set_defaults(quantize_reduce_range=False)
310+
quant_args.set_defaults(quantize_symmetric=False)
314311

315312
args = parser.parse_args(argv)
316313

@@ -323,6 +320,22 @@ def parse_arguments(argv=None):
323320
return args
324321

325322

323+
# quant_method is reserved for mixed precision in future
324+
def make_quant_algo_config(precision, quant_method: str, matmul_nodes=None):
325+
customized_weight_config = {}
326+
quant_algo_config = None
327+
328+
# need to use k_quant for int8
329+
if precision == Precision.INT8:
330+
for node_name in matmul_nodes:
331+
customized_weight_config[node_name] = {"bits": 8}
332+
quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
333+
else:
334+
quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
335+
336+
return quant_algo_config
337+
338+
326339
def export_onnx_models(
327340
model_name_or_path,
328341
model_impl,
@@ -340,19 +353,21 @@ def export_onnx_models(
340353
output_qk: bool = False,
341354
overwrite: bool = False,
342355
use_int32_inputs: bool = True,
343-
quantize_embedding_layer: bool = False,
344-
quantize_per_channel: bool = False,
345-
quantize_reduce_range: bool = False,
356+
accuracy_level: int = 0,
357+
quantize_symmetric: bool = False,
346358
provider: str = "cpu",
347359
):
348360
device = torch.device("cuda" if use_gpu else "cpu")
361+
if not use_gpu:
362+
accuracy_level = 4 # change to 4 for CPU EP
363+
use_fp16_inputs = precision == Precision.FLOAT16 or (precision in (Precision.INT8, Precision.INT4) and use_gpu)
349364

350365
models = WhisperHelper.load_model(
351366
model_name_or_path,
352367
model_impl,
353368
cache_dir,
354369
device,
355-
torch.float16 if precision == Precision.FLOAT16 else torch.float32,
370+
torch.float16 if use_fp16_inputs else torch.float32,
356371
merge_encoder_and_decoder_init,
357372
no_beam_search_op,
358373
output_qk,
@@ -384,7 +399,7 @@ def export_onnx_models(
384399
PROVIDERS[provider],
385400
verbose,
386401
use_external_data_format,
387-
use_fp16_inputs=(precision == Precision.FLOAT16),
402+
use_fp16_inputs=use_fp16_inputs,
388403
use_int32_inputs=use_int32_inputs,
389404
use_encoder_hidden_states=(name == "decoder_init"),
390405
use_kv_cache_inputs=(name == "decoder"),
@@ -430,27 +445,43 @@ def export_onnx_models(
430445
model.verify_onnx(
431446
onnx_path,
432447
PROVIDERS[provider],
433-
use_fp16_inputs=(precision == Precision.FLOAT16),
448+
use_fp16_inputs=use_fp16_inputs,
434449
)
435450
else:
436451
model.verify_onnx(
437452
onnx_path,
438453
PROVIDERS[provider],
439-
use_fp16_inputs=(precision == Precision.FLOAT16),
454+
use_fp16_inputs=use_fp16_inputs,
440455
use_int32_inputs=use_int32_inputs,
441456
)
442457

443-
if precision == Precision.INT8:
444-
quantization.quantize_dynamic(
445-
onnx_path,
458+
if precision in (Precision.INT8, Precision.INT4):
459+
onnx_model = onnx.load(onnx_path, load_external_data=True)
460+
matmul_nodes = [node.name for node in onnx_model.graph.node if node.op_type == "MatMul"]
461+
quant_algo_config = make_quant_algo_config(precision, "k_quant", matmul_nodes)
462+
463+
quant = MatMulNBitsQuantizer(
464+
model=onnx_model,
465+
block_size=32,
466+
is_symmetric=quantize_symmetric,
467+
accuracy_level=accuracy_level,
468+
quant_format=QuantFormat.QOperator,
469+
op_types_to_quantize=("MatMul",),
470+
algo_config=quant_algo_config,
471+
)
472+
quant.process()
473+
if os.path.exists(output_path):
474+
os.remove(output_path)
475+
if os.path.exists(output_path + ".data"):
476+
os.remove(output_path + ".data")
477+
onnx.save_model(
478+
quant.model.model,
446479
output_path,
447-
op_types_to_quantize=(
448-
["MatMul", "Gemm", "Gather"] if quantize_embedding_layer else ["MatMul", "Gemm"]
449-
),
450-
use_external_data_format=use_external_data_format,
451-
per_channel=quantize_per_channel,
452-
reduce_range=quantize_reduce_range,
453-
extra_options={"MatMulConstBOnly": True},
480+
save_as_external_data=True,
481+
all_tensors_to_one_file=True,
482+
location=os.path.basename(output_path) + ".data",
483+
size_threshold=0,
484+
convert_attribute=False,
454485
)
455486
else:
456487
logger.info(f"Skip optimizing: existing ONNX model {onnx_path}")
@@ -493,9 +524,8 @@ def main(argv=None):
493524
args.output_cross_qk,
494525
args.overwrite,
495526
not args.use_int64_inputs,
496-
args.quantize_embedding_layer,
497-
args.quantize_per_channel,
498-
args.quantize_reduce_range,
527+
args.accuracy_level,
528+
args.quantize_symmetric,
499529
args.provider,
500530
)
501531

onnxruntime/python/tools/transformers/models/whisper/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch>=2.7.0
1+
torch==2.7.0
22
transformers==4.52.3
33
openai-whisper==20240927
44
ffmpeg-python

onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,19 @@ def chain_model(args):
5454
config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
5555
tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
5656

57+
use_fp16_inputs = args.precision == Precision.FLOAT16 or (
58+
args.precision in (Precision.INT8, Precision.INT4) and args.use_gpu
59+
)
5760
# Create inputs/outputs for WhisperBeamSearch op
58-
temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature"
61+
temperature_name = "temperature_fp16" if use_fp16_inputs else "temperature"
5962
beam_inputs = [
60-
"input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features",
63+
"input_features_fp16" if use_fp16_inputs else "input_features",
6164
"max_length",
6265
"min_length",
6366
"num_beams",
6467
"num_return_sequences",
65-
"length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty",
66-
"repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty",
68+
"length_penalty_fp16" if use_fp16_inputs else "length_penalty",
69+
"repetition_penalty_fp16" if use_fp16_inputs else "repetition_penalty",
6770
"vocab_mask" if args.use_vocab_mask else "",
6871
"prefix_vocab_mask" if args.use_prefix_vocab_mask else "",
6972
"", # attention mask
@@ -74,8 +77,8 @@ def chain_model(args):
7477
temperature_name if args.use_temperature else "",
7578
]
7679

77-
sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores"
78-
scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores"
80+
sequence_scores_name = "sequence_scores_fp16" if use_fp16_inputs else "sequence_scores"
81+
scores_name = "scores_fp16" if use_fp16_inputs else "scores"
7982
beam_outputs = [
8083
"sequences",
8184
sequence_scores_name if args.output_sequence_scores else "",
@@ -85,7 +88,7 @@ def chain_model(args):
8588
]
8689

8790
graph_nodes = []
88-
if args.precision == Precision.FLOAT16:
91+
if use_fp16_inputs:
8992
input_features_cast_node = helper.make_node(
9093
"Cast",
9194
inputs=["input_features"],

0 commit comments

Comments
 (0)