Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions examples/quantization/brevitas/quantize_llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from argparse import ArgumentParser

import torch
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode

from optimum.amd import BrevitasQuantizationConfig, BrevitasQuantizer
from optimum.amd.brevitas.accelerate_utils import calc_cpu_device_map, calc_gpu_device_map, offload_model, remove_hooks
from optimum.amd.brevitas.data_utils import compute_perplexity, get_dataset_for_model
from optimum.exporters.onnx import onnx_export_from_model
from optimum.amd.brevitas.export import export_quantized_model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from optimum.amd.brevitas.export import export_quantized_model
from optimum.amd.brevitas.export import export_to_onnx

Can we keep the ONNX word in the loop to make it explicit. Other name suggestions, quantized_model_to_onnx or save_quantized_model_as_onnx

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted to keep as similar as possible to the original name so it became:
onnx_export_from_quantized_model

from transformers import AutoTokenizer


Expand Down Expand Up @@ -80,16 +76,7 @@ def main(args):
quantized_model = quantized_model.to("cpu")

# Export to ONNX through optimum.exporters.
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(export_weight_q_node=True)
with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=export_manager):
onnx_export_from_model(
quantized_model,
args.onnx_output_path,
task="text-generation-with-past",
do_validation=False,
no_post_process=True,
)
export_quantized_model(quantized_model, args.onnx_output_path)
return return_val


Expand Down
10 changes: 10 additions & 0 deletions optimum/amd/brevitas/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode

from optimum.exporters.onnx import onnx_export_from_model


def export_quantized_model(quantized_model, path, task="text-generation-with-past"):
with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager):
onnx_export_from_model(quantized_model, path, task=task, do_validation=False, no_post_process=True)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
EXTRAS_REQUIRE = {
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
"brevitas": ["brevitas", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"],
"brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"],
}

setup(
Expand Down