Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions docs/source/brevitas/usage_guide.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,10 @@ Brevitas models can be exported to ONNX using Optimum:

```python
import torch
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode
from optimum.amd.brevitas.export import onnx_export_from_quantized_model

# Export to ONNX through optimum.exporters.
with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=StdQCDQONNXManager):
onnx_export_from_model(
model, "llm_quantized_onnx", task="text-generation-with-past", do_validation=False, no_post_process=True
)
onnx_export_from_quantized_model(model, "llm_quantized_onnx")
```

## Complete example
Expand Down
17 changes: 2 additions & 15 deletions examples/quantization/brevitas/quantize_llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from argparse import ArgumentParser

import torch
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode

from optimum.amd import BrevitasQuantizationConfig, BrevitasQuantizer
from optimum.amd.brevitas.accelerate_utils import calc_cpu_device_map, calc_gpu_device_map, offload_model, remove_hooks
from optimum.amd.brevitas.data_utils import compute_perplexity, get_dataset_for_model
from optimum.exporters.onnx import onnx_export_from_model
from optimum.amd.brevitas.export import onnx_export_from_quantized_model
from transformers import AutoTokenizer


Expand Down Expand Up @@ -80,16 +76,7 @@ def main(args):
quantized_model = quantized_model.to("cpu")

# Export to ONNX through optimum.exporters.
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(export_weight_q_node=True)
with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=export_manager):
onnx_export_from_model(
quantized_model,
args.onnx_output_path,
task="text-generation-with-past",
do_validation=False,
no_post_process=True,
)
onnx_export_from_quantized_model(quantized_model, args.onnx_output_path)
return return_val


Expand Down
51 changes: 51 additions & 0 deletions optimum/amd/brevitas/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode

from optimum.exporters.onnx import onnx_export_from_model
from optimum.exporters.onnx.base import OnnxConfig
from transformers.modeling_utils import PreTrainedModel


def onnx_export_from_quantized_model(
quantized_model: Union["PreTrainedModel"],
output: Union[str, Path],
opset: Optional[int] = None,
optimize: Optional[str] = None,
monolith: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
fn_get_submodels: Optional[Callable] = None,
_variant: str = "default",
preprocessors: List = None,
device: str = "cpu",
no_dynamic_axes: bool = False,
task: str = "text-generation-with-past",
use_subprocess: bool = False,
do_constant_folding: bool = True,
**kwargs_shapes,
):
with torch.no_grad(), brevitas_proxy_export_mode(quantized_model, export_manager=StdQCDQONNXManager):
onnx_export_from_model(
quantized_model,
output,
opset=opset,
monolith=monolith,
optimize=optimize,
model_kwargs=model_kwargs,
custom_onnx_configs=custom_onnx_configs,
fn_get_submodels=fn_get_submodels,
_variant=_variant,
preprocessors=preprocessors,
device=device,
no_dynamic_axes=no_dynamic_axes,
use_subprocess=use_subprocess,
do_constant_folding=do_constant_folding,
task=task,
do_validation=False,
no_post_process=True,
**kwargs_shapes,
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
EXTRAS_REQUIRE = {
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
"brevitas": ["brevitas", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"],
"brevitas": ["brevitas", "torch>=2.2", "datasets>=2.17", "onnx", "onnxruntime", "accelerate"],
}

setup(
Expand Down