Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/quantization/brevitas/quantize_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def main(args):

# Export to ONNX through optimum.exporters.
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(export_weight_q_node=True)
if args.qdq_weights:
export_manager.change_weight_export(export_weight_q_node=True)

with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=export_manager):
onnx_export_from_model(
quantized_model,
Expand Down Expand Up @@ -154,11 +156,17 @@ def main(args):
default="auto",
help='Device to run the example on (e.q., "cpu", "cuda:0", "auto"). "auto" will automatically select the device using HuggingFace Accelerate (choices: [%(choices)s], default: %(default)s).',
)
parser.add_argument(
"--qdq-weights",
action="store_true",
default=False,
help="In the ONNX export, save quantized weights as float32 and insert an additional QuantizeLinear node, TensorRT style (default: %(default)s).",
)
parser.add_argument(
"--onnx-output-path",
type=str,
default="llm_quantized_onnx",
help="Location to store the output ONNX model (default: %(default)s)",
help="Location to store the output ONNX model (default: %(default)s).",
)

args = parser.parse_args()
Expand Down
49 changes: 37 additions & 12 deletions tests/brevitas/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from pathlib import Path
from typing import Dict

import numpy as np
import onnx
import torch
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode
from onnx import numpy_helper
from parameterized import parameterized
from testing_utils import SUPPORTED_MODELS_TINY, VALIDATE_EXPORT_ON_SHAPES, get_quantized_model

Expand Down Expand Up @@ -49,14 +51,28 @@ def _get_models_to_test(export_models_dict: Dict, library_name: str = "transform
library_name=library_name,
)

models_to_test.append((f"{model_type}_{task}", model_type, model_name, task, onnx_config_constructor))
models_to_test.append(
(f"{model_type}_{task}_DQ", model_type, model_name, task, onnx_config_constructor, False)
)
models_to_test.append(
(f"{model_type}_{task}_QDQ", model_type, model_name, task, onnx_config_constructor, True)
)
return sorted(models_to_test)


def export_and_validate(
model: torch.nn.Module, task: str, export_output_dir: str, onnx_config_class_constructor, shapes_to_validate: Dict
model: torch.nn.Module,
task: str,
export_output_dir: str,
onnx_config_class_constructor,
shapes_to_validate: Dict,
qdq_weights: bool,
):
with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=StdQCDQONNXManager):
export_manager = StdQCDQONNXManager
if qdq_weights:
export_manager.change_weight_export(export_weight_q_node=True)

with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=export_manager):
library_name = TasksManager._infer_library_from_model(model)
framework = "pt"
dtype = get_parameter_dtype(model) if framework == "pt" else model.dtype
Expand Down Expand Up @@ -121,12 +137,7 @@ def export_and_validate(
class TestOnnxExport(unittest.TestCase):
@parameterized.expand(_get_models_to_test(SUPPORTED_MODELS_TINY))
def test_dynamic_quantization(
self,
test_name,
model_type,
model_name,
task,
onnx_config_class_constructor,
self, test_name, model_type, model_name, task, onnx_config_class_constructor, qdq_weights: bool
):
model = get_quantized_model(
model_name,
Expand All @@ -144,10 +155,24 @@ def test_dynamic_quantization(
export_output_dir=tmpdir,
onnx_config_class_constructor=onnx_config_class_constructor,
shapes_to_validate=VALIDATE_EXPORT_ON_SHAPES,
qdq_weights=qdq_weights,
)

onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx"))

for node in onnx_model.graph.node:
# Check that we have MatmulInteger, etc.
pass
if qdq_weights:
for node in onnx_model.graph.node:
if node.op_type == "Constant":
for attrib in node.attribute:
new_array = numpy_helper.to_array(attrib.t)
if len(new_array.shape) >= 2 and new_array.dtype in [np.uint8, np.int8]:
break
else:
self.assertTrue(False, "Did not found an int8/uint8 serialized weight")
else:
for node in onnx_model.graph.node:
if node.op_type == "Constant":
for attrib in node.attribute:
new_array = numpy_helper.to_array(attrib.t)
if len(new_array.shape) >= 2 and new_array.dtype in [np.uint8, np.int8]:
self.assertTrue(False, "Found uint8/int8 serialized weights while we should not")