Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f64181
docs (graph): typo fix in residual handler
nickfraser May 8, 2025
49af8e8
Fix (graph): update method call for updating node kwargs
nickfraser May 8, 2025
6df6a0a
Feat (ex/finn): Added basic mobilenet v2 PTQ example
nickfraser May 8, 2025
958dbf0
Feat (ex/finn): Added conda file for environment setup
nickfraser May 8, 2025
08b06fc
fix (ex/finn): Convert adaptive avg pool to average pool
nickfraser May 8, 2025
6aee06c
Removed Conv1d map & shared bias
nickfraser May 8, 2025
fe4d8a3
Used calib_loader to get example output
nickfraser May 8, 2025
42b2723
Refactor to make more readable
nickfraser May 8, 2025
f11f215
Cleanup code
nickfraser May 8, 2025
44ee63d
Added a few different quantization options
nickfraser May 8, 2025
3685605
Added GPTQ, removed BN calibration
nickfraser May 8, 2025
6a5806e
Reset to default config
nickfraser May 8, 2025
8233f7a
docs (ex/finn): Added basic README
nickfraser May 8, 2025
eb45dee
feat (ex/finn): Added GPFQ
nickfraser May 8, 2025
51afa4b
docs (ex/finn): Added note about GPFQ
nickfraser May 8, 2025
1eff6f0
fix (ex/finn): Adjust GPTQ settings
nickfraser May 9, 2025
15faf96
ex (finn): switched weight quantizer to MSE
nickfraser May 13, 2025
9758589
feat (ex/finn): Switched to QuantReLU6
nickfraser May 13, 2025
805c3ab
fix (ex/finn): Update GPTQ defaults
nickfraser May 13, 2025
46e2dd8
fix (ex/finn): Set act_eq_alpha correctly
nickfraser May 13, 2025
79dfb43
Fix (ex/finn): Don't run final (redundant) validation when verbose=True
nickfraser May 13, 2025
3c6fbe4
fix (ex/finn): changed default value for act_eq_alpha=0.7
nickfraser May 13, 2025
d469db7
Fix (ex/finn): harden weight scales before applying GPTQ, GPFQ
nickfraser May 13, 2025
32a74be
Fis (ex/finn): Update default PTQ settings
nickfraser May 14, 2025
d6c265f
Docs (ex/finn): Updated instructions in PTQ section
nickfraser May 14, 2025
e191ea6
[ex/finn/mnv2] Updated suggested setting for PTQ
nickfraser May 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def quantize(
graph_model = act_handler(graph_model, layer_map=quant_act_map)
graph_model = add_output_quant_handler(
graph_model, quant_identity_map, quant_act_map, unsigned_act_tuple)
# The call to esidual_handler has to be performed before layer_handler
# The call to residual_handler has to be performed before layer_handler
# so that all requantization steps are correctly inserted and aligned.
graph_model = residual_handler(
graph_model, quant_identity_map, quant_act_map, unsigned_act_tuple, align_input_quant)
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def maybe_replace_node(n: Node) -> Node:
new_kwargs = map_arg(use_node.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple)
assert isinstance(new_kwargs, dict)
use_node._Node__update_args_kwargs(new_args, new_kwargs)
use_node._update_args_kwargs(new_args, new_kwargs)
return to_process


Expand Down
119 changes: 119 additions & 0 deletions src/brevitas_examples/finn_mobilenetv2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Programmatic Quantization of MobileNetv2 for FINN

This tutorial / demo shows how to do programmatic quantization of a MobileNetv2 model for [FINN](https://github.com/xilinx/finn).
It includes applying several post-training quantization (PTQ) algoritms to the model, in order to maintain / recover model accuracy.
In future, this may be extended to include some quantization-aware training (QAT) if it is requested.

## Demo Overview

The demo shows 3 main aspects of Brevitas's quantization flow:
- insertion of quantization nodes for inference-only compute;
- maintaining / recovering accuracy with PTQ; and
- export to [QONNX](https://github.com/fastmachinelearning/qonnx) for further processing with FINN.

## Environmental Setup

If [miniforge](https://github.com/conda-forge/miniforge) is installed, the environment can be set up as follows:

```bash
mamba env -n brevitas_finn_demo -f conda/brevitas_finn.yml
conda activate brevitas_finn_demo
pip install --no-deps /path/to/brevitas
```

## Running the Demo

By default, the demo quantizes the model weights and activation to 8 bits and can be run as following:

```bash
python finn_mobilenetv2.py
```

which should achieve approximately ~71.61% accuracy (from baseline of 72.01%).
Running the script produces a QONNX file `quant_mobilenet_v2.onnx`, which can be viewed with [Netron](https://github.com/lutzroeder/netron).
You should notice the following about the output model:
- it contains Quant nodes for the weights for every Conv2d / Linear layer;
- the activation before any Conv2d / Linear layer also has a Quant node;
- Quant nodes before eltwise additions (or concatenations) have the same scale factors; and
- Batchnorm layers are left in the network _without_ merging them into surrounding layers.

The above properties allow many models to be consumed the FINN's frontend.

## Modifying the Demo

There are several ways to modify the demo including:
- changing bit-widths of weights / activations; and
- recovering accuracy with PTQ algorithms.

### Changing the Bitwidth of Conv2d Layers

In Brevitas, there are multiple ways to achieve this.
You'll find a few suggested options commented in `finn_mobilenetv2.py`,
before the call to `quantize_finn()` a few different options.
For each modification, we strongly recommend viewing the resultant QONNX model in Netron to see how the models compute graph has changed.

We reproduce and explain the code snippets below:

#### Snippet 1

```python
finn_quant_maps = default_quantize_maps_finn()
finn_quant_maps["compute_layer_map"][nn.Conv2d][1]['weight_bit_width'] = 4 # 1. Override Conv2d weights to have 4-bits
model = quantize_finn(model, **finn_quant_maps)
```

Adding / overriding the `nn.Conv2d` entry in the `compute_layer_map` argument to `quantize_finn` to set `weight_bit_width=4`,
as expected, sets the bit-width of the weights of all `Conv2d` layers to 4.

#### Snippet 2

```python
finn_quant_maps = default_quantize_maps_finn()
finn_quant_maps["compute_layer_map"][nn.Conv2d][1]['weight_bit_width'] = lambda module: 8 if module.groups != 1 else 4 # 2. Groupwise Conv2ds @ 8-bits, the rest @ 4-bits
model = quantize_finn(model, **finn_quant_maps)
```

Alternatively, `weight_bit_width` and other parameters can be _lambda functions_ which can be a function of the module instance itself.
In this case, groupwise convolutions will remain at 8-bits, while all other convolutions will be at 4-bits.

#### Snippet 3

```python
finn_quant_maps = default_quantize_maps_finn()
finn_quant_maps["compute_layer_map"][nn.Conv2d][1]['weight_bit_width'] = lambda module, name: 8 if module.groups != 1 or name == "features.0.0" else 4 # 3. Keep first conv in 8-bits otherwise same as above
model = quantize_finn(model, **finn_quant_maps)
```

If the signature of the lambda function has 2 arguments, its name will be passed along with the module itself.
This means that any function of the module or its name can be used to determine the quantization parameters.
In this case, the very first convolution (named `"features.0.0"` in PyTorch) also remains at 8-bits,
along with the groupwise convolutions, while the rest remain at 8-bits.

### Retaining and Recovering the Accuracy

You'll notice that reducing the weight bit-width of several layers in MobileNetv2 significantly reduces its accuracy.
However, most of this accuracy can be recovered by applying 1 or more PTQ algorithms to the model.
The demo has the following parameters to control which PTQ algorithms are applied to the model:

```python
act_eq = False # Apply act equalization
act_eq_alpha = 0.5 # [0.0 -> 1.0] Intuition: higher makes weights easier to quantize, lower makes the activations easier to quantize
act_eq_add_mul_node = False # Add extra elementwise mul nodes before activation quantization. If True, lower `alpha` seems to work better (`alpha=0.175`)
bias_corr = False # Apply bias correction
gptq = False # Apply GPTQ
gpfq = False # Apply GPFQ
```

These flags enable the application of:
- [activation equalization](https://arxiv.org/abs/2211.10438);
- [bias correction](https://arxiv.org/abs/1906.04721);
- [GPTQ](https://arxiv.org/abs/2210.17323); and
- [GPFQ](https://arxiv.org/abs/2201.11113).

We leave the explanation of these techniques to their respective papers,
but a good starting point is to set `act_eq=True`, `gpfq=True`.
Afterwhich, finding the combination of PTQ flags / settings if left to the user to maximise the accuracy.
If `act_eq_add_mul_node=True`, the compute graph will be augmented to include a channelwise multiplication before many activation quantization functions,
which may help to increase accuracy at the cost of passing that complexity to downstream tools (i.e., FINN).
GPTQ & GPFQ cannot likely be applied at the same time.
Brevitas has many more PTQ algorithms not included here, please see the [imagenet](../imagenet_classification) and [LLM](../llm) examples to see how they are applied.
10 changes: 10 additions & 0 deletions src/brevitas_examples/finn_mobilenetv2/conda/brevitas_finn.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: brevitas_finn
channels:
- conda-forge
dependencies:
- python=3.9
- setuptools<70.0
- pip:
- brevitas[finn-integration] @ git+https://github.com/Xilinx/brevitas.git@dev
- torchvision
- tqdm
113 changes: 113 additions & 0 deletions src/brevitas_examples/finn_mobilenetv2/finn_mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@

import torch
import torch.nn as nn

from torchvision.models import mobilenet_v2, MobileNet_V2_Weights

import brevitas.onnx as bo

from graph.target.finn import default_quantize_maps_finn
from graph.target.finn import preprocess_for_finn_quantize
from graph.target.finn import quantize_finn

from utils import get_dataloader
from utils import test
from quant_utils import apply_act_equalization
from quant_utils import apply_bias_correction
from quant_utils import apply_gpfq
from quant_utils import apply_gptq
from quant_utils import calibrate

# Global settings
batch_size=200
subset_size=None # Use 'None' if you want to use the entire dataset
device="cuda:0"
verbose=False # Validate model after every step

act_eq = False # Apply act equalization - not very useful for MNv2
act_eq_alpha = 0.5 # [0.0 -> 1.0] Intuition: higher makes weights easier to quantize, lower makes the activations easier to quantize
act_eq_add_mul_node = False # Add extra elementwise mul nodes before activation quantization. If True, lower `alpha` seems to work better (`alpha=0.175`)
bias_corr = False # Apply bias correction
gptq = False # Apply GPTQ
gpfq = False # Apply GPFQ

# Configure datasets
imagenet_datadir = "imagenet_symlink"
calib_loader = get_dataloader(f"{imagenet_datadir}", "calibration", batch_size=batch_size, subset_size=subset_size)
valid_loader = get_dataloader(f"{imagenet_datadir}", "val", batch_size=batch_size, subset_size=subset_size)

# Load model
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
model.to(device=device)
model.eval()
# Test float model
print("Float Validation Accuracy")
results = test(model, valid_loader)

# Modify network to assist the quantization process
x = next(iter(calib_loader))[0].to(device=device) # Resolve the shape of the AveragePool
model = preprocess_for_finn_quantize(model, x=x)
# Test preprocessed model
if verbose:
print("Preprocessed Validation Accuracy")
results = test(model, valid_loader)

# Pre-quantization transformations
if act_eq:
print(f"Applying Activation Equalization (alpha={act_eq_alpha},add_mul_node={act_eq_add_mul_node}):")
apply_act_equalization(calib_loader, model, alpha=act_eq_alpha, add_mul_node=act_eq_add_mul_node)
if verbose:
print("Equalized Model Validation Accuracy")
results = test(model, valid_loader)

# Quantize Model
finn_quant_maps = default_quantize_maps_finn()
#finn_quant_maps["compute_layer_map"][nn.Conv2d][1]['weight_bit_width'] = 4 # 1. Override Conv2d weights to have 4-bits
#finn_quant_maps["compute_layer_map"][nn.Conv2d][1]['weight_bit_width'] = lambda module: 8 if module.groups != 1 else 4 # 2. Groupwise Conv2ds @ 8-bits, the rest @ 4-bits
#finn_quant_maps["compute_layer_map"][nn.Conv2d][1]['weight_bit_width'] = lambda module, name: 8 if module.groups != 1 or name == "features.0.0" else 4 # 3. Keep first conv in 8-bits otherwise same as above
model = quantize_finn(model, **finn_quant_maps)
model.to(device=device) # TODO: fix this

# Post-quantization transformations
print("Applying activation calibration:")
calibrate(calib_loader, model)
if verbose:
print("Quantized Model Validation Accuracy")
results = test(model, valid_loader)

if gptq:
print("Applying GPTQ:")
apply_gptq(calib_loader, model)
if verbose:
print("Quantized Model Validation Accuracy")
results = test(model, valid_loader)

if gpfq:
print("Applying GPFQ:")
apply_gpfq(calib_loader, model)
if verbose:
print("Quantized Model Validation Accuracy")
results = test(model, valid_loader)

if bias_corr:
print("Applying Bias Correction:")
apply_bias_correction(calib_loader, model)
if verbose:
print("Quantized Model Validation Accuracy")
results = test(model, valid_loader)

# Test Quantized model
if not verbose: # Not require, since model accuracy is already measured after each step
print("Quantized Model Validation Accuracy")
results = test(model, valid_loader)

# Export model to QONNX
with torch.no_grad():
bo.export_qonnx(
model,
(x),
"quant_mobilenet_v2.onnx",
do_constant_folding=True,
input_names=['x'],
opset_version=17,
)
107 changes: 107 additions & 0 deletions src/brevitas_examples/finn_mobilenetv2/graph/target/finn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@

from torch.fx import symbolic_trace
import torch.nn as nn
import torch.nn.functional as F

from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool
from brevitas.graph.quantize import preprocess_for_quantize
from brevitas.graph.quantize import quantize
from brevitas.graph.quantize import UNSIGNED_ACT_TUPLE
from brevitas.graph.standardize import MeanMethodToAdaptiveAvgPool2d
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.quant import Int8WeightPerChannelFloatMSE, Int8ActPerTensorFloat, Uint8ActPerTensorFloat, Uint8ActPerTensorFloatMaxInit, Int32Bias
import brevitas.nn as qnn
from brevitas_examples.finn_mobilenetv2.nn.target.finn import QuantReLU6

SHARED_WEIGHT_QUANT = Int8WeightPerChannelFloatMSE
SHARED_BIAS_QUANT = Int32Bias
SHARED_UNSIGNED_ACT_QUANT = Uint8ActPerTensorFloat
SHARED_SIGNED_ACT_QUANT = Int8ActPerTensorFloat
SHARED_RELU6_QUANT = Uint8ActPerTensorFloatMaxInit

FINN_COMPUTE_LAYER_MAP = {
nn.Conv2d: (
qnn.QuantConv2d,
{
'weight_quant': SHARED_WEIGHT_QUANT,
'return_quant_tensor': True}),
nn.Linear: (
qnn.QuantLinear,
{
'weight_quant': SHARED_WEIGHT_QUANT,
'bias_quant': SHARED_BIAS_QUANT,
'return_quant_tensor': True}),}

FINN_QUANT_ACT_MAP = {
nn.ReLU:
(qnn.QuantReLU, {
'act_quant': SHARED_UNSIGNED_ACT_QUANT, 'return_quant_tensor': True}),
nn.ReLU6: (
QuantReLU6, {
'act_quant': SHARED_UNSIGNED_ACT_QUANT,
'return_quant_tensor': True}),}

FINN_QUANT_IDENTITY_MAP = {
'signed':
(qnn.QuantIdentity, {
'act_quant': SHARED_SIGNED_ACT_QUANT, 'return_quant_tensor': True}),
'unsigned': (
qnn.QuantIdentity, {
'act_quant': SHARED_UNSIGNED_ACT_QUANT, 'return_quant_tensor': True}),}


def default_quantize_maps_finn():
return {
"quant_identity_map": FINN_QUANT_IDENTITY_MAP,
"compute_layer_map": FINN_COMPUTE_LAYER_MAP,
"quant_act_map": FINN_QUANT_ACT_MAP,
"unsigned_act_tuple": UNSIGNED_ACT_TUPLE,
}


def preprocess_for_finn_quantize(
model,
*model_args,
trace_model=True,
relu6_to_relu=False,
equalize_iters=0,
equalize_merge_bias=False,
merge_bn=False,
equalize_bias_shrinkage='vaiq',
equalize_scale_computation='maxabs',
**model_kwargs):
training_state = model.training
model.eval()

if trace_model:
model = symbolic_trace(model)
model = MeanMethodToAdaptiveAvgPool2d().apply(model)
model = TorchFunctionalToModule(fn_to_module_map=((F.adaptive_avg_pool2d, nn.AdaptiveAvgPool2d),)).apply(model)
model = AdaptiveAvgPoolToAvgPool().apply(model, *model_args, **model_kwargs)
model = preprocess_for_quantize(
model,
False,
relu6_to_relu,
equalize_iters,
equalize_merge_bias,
merge_bn,
equalize_bias_shrinkage,
equalize_scale_computation)
model.train(training_state)
return model


def quantize_finn(
graph_model,
quant_identity_map=FINN_QUANT_IDENTITY_MAP,
compute_layer_map=FINN_COMPUTE_LAYER_MAP,
quant_act_map=FINN_QUANT_ACT_MAP,
unsigned_act_tuple=UNSIGNED_ACT_TUPLE,
requantize_layer_handler_output=True):
return quantize(
graph_model,
quant_identity_map=quant_identity_map,
compute_layer_map=compute_layer_map,
quant_act_map=quant_act_map,
unsigned_act_tuple=unsigned_act_tuple,
requantize_layer_handler_output=requantize_layer_handler_output)
Loading