Skip to content

Commit 791d444

Browse files
Tracinzhangqi3fanyunqian
authored
[Update] Conv3d && Fix bugs. (#146)
* [Update] Conv3d && Fix bugs. * [REQ] Update torch version to 1.10.0. * [REQ] Update torchvision version to 0.11.1. * [Update] Update observer. * [Test] Update test. Co-authored-by: zhangqi3 <[email protected]> Co-authored-by: fanyunqian <[email protected]>
1 parent 55c304b commit 791d444

11 files changed

+64
-59
lines changed

.github/workflows/python-package-conda.yml renamed to .github/workflows/lint-and-test.yml

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
name: Lint and test.
1+
name: Lint and test
22

33
on: [push]
44

55
jobs:
6-
build-linux:
6+
Lint-and-test:
77
runs-on: ubuntu-latest
88
strategy:
99
max-parallel: 5
@@ -22,6 +22,9 @@ jobs:
2222
run: |
2323
conda install flake8
2424
flake8 .
25+
- name: Install onnxruntime and onnxsim
26+
run:
27+
pip install onnxruntime onnx-simplifier
2528
- name: Install Protobuf
2629
run:
2730
conda install protobuf=3.20.1

mqbench/custom_quantizer/model_quantizer.py

-4
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,6 @@ def _convert(self, module, mapping=None, inplace=False, scope=''):
263263
if not isinstance(mod, _FusedModule):
264264
self._convert(mod, mapping, True, new_scope)
265265
reassign[name] = swap_module(mod, mapping, {})
266-
if isinstance(mod, torch.nn.ConvTranspose2d):
267-
if hasattr(reassign[name], "weight_fake_quant") and reassign[name].weight_fake_quant.ch_axis != -1:
268-
reassign[name].weight_fake_quant.ch_axis = 1
269-
reassign[name].weight_fake_quant.activation_post_process.ch_axis = 1
270266
for key, value in reassign.items():
271267
module._modules[key] = value
272268

mqbench/custom_quantizer/onnx_qnn_quantizer.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from torch.quantization.utils import get_combined_dict
88

99

10-
import mqbench.nn as qnn
11-
import mqbench.nn.intrinsic as qnni
10+
import mqbench.nn as qnn
11+
import mqbench.nn.intrinsic as qnni
1212
from mqbench.utils.registry import register_model_quantizer
1313
from mqbench.prepare_by_platform import BackendType
1414
from mqbench.custom_quantizer import ModelQuantizer
@@ -57,7 +57,6 @@ def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Di
5757
get_default_qat_module_mappings(), additional_qat_module_mapping)
5858
# There is no QLinearFC in ONNX for now.
5959
del(all_mappings[torch.nn.modules.linear.Linear])
60-
del(all_mappings[torch.nn.modules.linear._LinearWithBias])
6160
del(all_mappings[torch.nn.intrinsic.modules.fused.LinearReLU])
6261
del(all_mappings[qnni.modules.fused.LinearBn1d])
6362
root = self._convert(root, all_mappings, inplace=True)
@@ -97,4 +96,4 @@ def implicit_merge_patterns(self) -> list:
9796
# In reversed order!
9897
return [
9998
(torch.nn.ReLU, operator.add)
100-
]
99+
]

mqbench/custom_quantizer/tensorrt_quantizer.py

-3
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,6 @@ def _find_act_quants(self, model: GraphModule) -> List:
128128
if node.op == "call_function" and node.target == operator.add and \
129129
self._is_skiped_add(node, modules, input_node_list):
130130
continue
131-
if node.op == "call_function" and node.target == operator.add:
132-
import pdb
133-
pdb.set_trace()
134131
for _node in input_node_list:
135132
if self._is_implicit_merge(modules, (node, _node)):
136133
logger.info("Implicit merge: {} + {}".format(_node.name, node.name))

mqbench/deploy/deploy_tengine.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
import os
22
from collections import OrderedDict
33

4-
import onnx
5-
from onnx import numpy_helper
6-
from onnxsim import simplify
7-
84
from ..utils.logger import logger
95
from .deploy_linear import (
106
LinearQuantizer_process,
@@ -20,6 +16,14 @@
2016
get_constant_inputs
2117
)
2218

19+
import onnx
20+
from onnx import numpy_helper
21+
try:
22+
from onnxsim import simplify
23+
except ModuleNotFoundError:
24+
logger.warn('onnxsim not found, if you want to use deploy_tengine, please install it.')
25+
26+
2327

2428
class Tengine_process(LinearQuantizer_process):
2529

mqbench/fuser_method_mappings.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(self, quantizer: QuantizerCls, node: Node):
3333
self.conv_node = node
3434
self.conv = quantizer.modules[self.conv_node.target]
3535

36+
3637
def fuse_linear_bn(linear, bn):
3738
r"""Given the linear and bn modules, fuses them and returns the fused module
3839
@@ -83,7 +84,6 @@ def fuse_deconv_bn_relu(deconv, bn, relu):
8384
return qnni.ConvTransposeReLU2d(fuse_deconv_bn_eval(deconv, bn), relu)
8485

8586

86-
8787
def fuse_conv_freezebn(conv, bn):
8888
assert(bn.training is False), "Freezebn must be eval."
8989

@@ -100,6 +100,7 @@ def fuse_conv_freezebn(conv, bn):
100100
else:
101101
return nn.utils.fuse_conv_bn_eval(conv, bn)
102102

103+
103104
def fuse_conv_freezebn_relu(conv, bn, relu):
104105
assert(conv.training == relu.training and bn.training is False), "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
105106
fused_module : Optional[Type[nn.Sequential]] = None

mqbench/fusion_method.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,37 @@ def convert_qnniqat_linearbn(model, fused_node):
4242

4343
@register_convert_function(qnniqat.ConvFreezebn2d)
4444
@register_convert_function(nniqat.ConvBn2d)
45+
@register_convert_function(nniqat.ConvBn3d)
4546
def convert_nniqat_convbn(model, fused_node):
47+
"""nniqat.ConvBn2d ----> nn.Conv2d ----> nniqat.Conv2d
48+
"""
49+
fused_module_class_map = {
50+
qnniqat.ConvFreezebn2d: torch.nn.Conv2d,
51+
qnniqat.ConvFreezebnReLU2d: torch.nn.Conv2d,
52+
nniqat.ConvBn2d: torch.nn.Conv2d,
53+
nniqat.ConvBnReLU2d: torch.nn.Conv2d,
54+
nniqat.ConvBn3d: torch.nn.Conv3d,
55+
nniqat.ConvBnReLU3d: torch.nn.Conv3d,
56+
}
57+
fused_qat_module_class_map = {
58+
torch.nn.Conv2d: torch.nn.qat.Conv2d,
59+
torch.nn.Conv3d: torch.nn.qat.Conv3d,
60+
}
4661
modules = dict(model.named_modules())
4762
fused_module = modules[fused_node.target]
4863
# Create a Conv2d from FusedModule.
49-
conv = torch.nn.Conv2d(fused_module.in_channels, fused_module.out_channels, fused_module.kernel_size,
50-
fused_module.stride, fused_module.padding, fused_module.dilation,
51-
fused_module.groups, fused_module.bias is not None, fused_module.padding_mode)
64+
conv = fused_module_class_map[type(fused_module)](fused_module.in_channels, fused_module.out_channels,
65+
fused_module.kernel_size, fused_module.stride,
66+
fused_module.padding, fused_module.dilation,
67+
fused_module.groups, fused_module.bias is not None,
68+
fused_module.padding_mode)
5269
conv.weight = fused_module.weight
5370
if fused_module.bias is not None:
5471
conv.bias = fused_module.bias
5572
fused_conv = fuse_conv_bn_eval(conv.eval(), fused_module.bn)
5673
# We need nn.qat.conv here to export weight quantize node.
5774
fused_conv.qconfig = fused_module.qconfig
58-
fused_conv = torch.nn.qat.Conv2d.from_float(fused_conv)
75+
fused_conv = fused_qat_module_class_map[type(conv)].from_float(fused_conv)
5976
# Attach weight fake quantize params.
6077
fused_conv.weight_fake_quant = fused_module.weight_fake_quant
6178
conv_parent_name, conv_name = _parent_name(fused_node.target)
@@ -64,7 +81,8 @@ def convert_nniqat_convbn(model, fused_node):
6481

6582
@register_convert_function(qnniqat.ConvFreezebnReLU2d)
6683
@register_convert_function(nniqat.ConvBnReLU2d)
67-
def convert_nniqat_convbnrelu(model, fused_node):
84+
@register_convert_function(nniqat.ConvBnReLU3d)
85+
def convert_nniqat_convbnrelu(model, fused_node):
6886
convert_nniqat_convbn(model, fused_node)
6987
modules = dict(model.named_modules())
7088
fused_module = modules[fused_node.target]
@@ -196,6 +214,9 @@ def convert_qnniqat_deconvbnrelu(model, fused_node):
196214

197215
@register_convert_function(qnniqat.ConvBn2d)
198216
def convert_qnniqat_convbn(model, fused_node):
217+
"""mqbench.nn.intrinsic.qat module add bias quant.
218+
That is the difference between torch.nn.intrinsic.qat module.
219+
"""
199220
modules = dict(model.named_modules())
200221
fused_module = modules[fused_node.target]
201222
# Create a Conv2d from FusedModule.
@@ -222,6 +243,9 @@ def convert_qnniqat_convbn(model, fused_node):
222243

223244
@register_convert_function(qnniqat.ConvBnReLU2d)
224245
def convert_qnniqat_convbnrelu(model, fused_node):
246+
"""mqbench.nn.intrinsic.qat module add bias quant.
247+
That is the difference between torch.nn.intrinsic.qat module.
248+
"""
225249
convert_qnniqat_convbn(model, fused_node)
226250
modules = dict(model.named_modules())
227251
fused_module = modules[fused_node.target]

mqbench/observer.py

+9-28
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
from functools import partial
33
from typing import Tuple
4-
from copy import deepcopy
4+
55
import torch
66
from torch.quantization.observer import _ObserverBase
77

@@ -28,12 +28,14 @@ class ObserverBase(_ObserverBase):
2828
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
2929
reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False,
3030
factory_kwargs=None):
31-
factory_kwargs = deepcopy(factory_kwargs)
32-
self.not_calc_quant_min_max = factory_kwargs.pop('not_calc_quant_min_max', False) if isinstance(factory_kwargs, dict) else False
31+
# Since torch 1.10, function calculate_qmin_qmax is not a member function of observer,
32+
# but import from utils. It is hard to control. We use try...except here.
33+
stored_min, sotred_max = quant_min, quant_max
34+
if quant_max is not None and quant_min is not None and (quant_max - quant_min + 1 > 256):
35+
quant_min, quant_max = -128, 127
3336
super(ObserverBase, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max)
34-
# for compatibility with 1.10, prevent the value of self.quant_min,self.quant_max being modified
35-
self.quant_min = quant_min
36-
self.quant_max = quant_max
37+
self.quant_min = stored_min
38+
self.quant_max = sotred_max
3739
self.quant_min, self.quant_max = self._calculate_qmin_qmax()
3840
self.ch_axis = ch_axis
3941
self.pot_scale = pot_scale
@@ -79,28 +81,7 @@ def _calculate_qmin_qmax(self) -> Tuple[int, int]:
7981
observer datatype and if range is reduced.
8082
"""
8183
if self.has_customized_qrange:
82-
# This initialization here is to be resolve TorchScript compilation issues and allow
83-
# using of refinement to decouple initial_qmin and initial_qmax from quantization range.
84-
# The actual values of initial_qmin and initial_qmax will be reset below.
85-
initial_quant_min, initial_quant_max = 0, 255
86-
# The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
87-
# attribute from Optional valid integers for use, based on TorchScript's requirements.
88-
custom_quant_min, custom_quant_max = self.quant_min, self.quant_max
89-
if custom_quant_min is not None and custom_quant_max is not None:
90-
initial_quant_min, initial_quant_max = (
91-
custom_quant_min,
92-
custom_quant_max,
93-
)
94-
95-
qrange_len = initial_quant_max - initial_quant_min + 1
96-
if is_symmetric_quant(self.qscheme):
97-
quant_min, quant_max = -qrange_len // 2, qrange_len // 2 - 1
98-
else:
99-
quant_min, quant_max = 0, qrange_len - 1
100-
if self.reduce_range:
101-
quant_min, quant_max = quant_min // 2, quant_max // 2
102-
if self.not_calc_quant_min_max:
103-
quant_min, quant_max = self.quant_min, self.quant_max
84+
quant_min, quant_max = self.quant_min, self.quant_max
10485
else:
10586
# Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
10687
if self.dtype == torch.qint8:

requirements.txt

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
torch==1.8.1
2-
torchvision==0.9.1
3-
onnx-simplifier
1+
torch==1.10.0
2+
torchvision==0.11.1
43
onnx

test/model/test_model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111

1212
class TestQuantizeModel(unittest.TestCase):
1313
def test_model_ppl(self):
14-
exclude_list = ['googlenet', 'deeplabv3_mobilenet_v3_large', 'inception_v3', 'lraspp_mobilenet_v3_large',
15-
'mobilenet_v3_large', 'mobilenet_v3_small']
14+
test_model_list = ['alexnet', 'deeplabv3_resnet50', 'densenet121', 'fcn_resnet50', 'mnasnet0_5',
15+
'mobilenet_v2', 'resnet18', 'resnext50_32x4d', 'shufflenet_v2_x0_5', 'squeezenet1_0',
16+
'vgg11', 'vgg11_bn', 'wide_resnet50_2', 'regnet_x_400mf']
1617
entrypoints = torch.hub.list(GITHUB_RES, force_reload=False)
1718
for entrypoint in entrypoints:
18-
if entrypoint in exclude_list:
19+
if entrypoint not in test_model_list:
1920
continue
2021
logger.info(f'testing {entrypoint}')
2122
if 'deeplab' in entrypoint or 'fcn' in entrypoint:

test/version.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
TORCHVISION_VERSION = 'v0.9.0'
2-
GITHUB_RES = 'pytorch/vision:{}'.format(TORCHVISION_VERSION)
1+
TORCHVISION_VERSION = 'v0.11.1'
2+
GITHUB_RES = 'pytorch/vision:{}'.format(TORCHVISION_VERSION)

0 commit comments

Comments
 (0)