Skip to content

Commit 19d0bbe

Browse files
PannenetsFfanyunqian
and
fanyunqian
authored
[Fix] onnx_qnn backend's output needs a dequant. (#126)
* [Fix] onnx_qnn backend's output needs a dequant. * [Test] add unit test for onnx_qnn Co-authored-by: fanyunqian <[email protected]>
1 parent 4e9f2dc commit 19d0bbe

File tree

3 files changed

+165
-3
lines changed

3 files changed

+165
-3
lines changed

mqbench/deploy/deploy_onnx_qnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,15 @@ def search_and_replace_input(next_node, name, new_name):
210210
if prev_node != 'INPUT_TOKEN' and prev_node.op_type in self.qlinear_op_type and \
211211
next_node != 'OUTPUT_TOKEN' and next_node.op_type in self.qlinear_op_type:
212212
search_and_replace_input(next_node, node.output[0], node.input[0])
213-
elif prev_node != 'INPUT_TOKEN' and prev_node.op_type in self.qlinear_op_type:
213+
elif prev_node != 'INPUT_TOKEN' and prev_node.op_type in self.qlinear_op_type and \
214+
next_node == 'OUTPUT_TOKEN':
214215
if dequantize_node is None:
215216
output_value_info = [f'{node.output[0]}_DequantizeLinear']
216217
dequantize_node = onnx.helper.make_node("DequantizeLinear",
217218
node.input[0:3],
218219
output_value_info,
219220
('input' if prev_node == 'INPUT_TOKEN' else prev_node.name) + '_dequantized')
220221
self.onnx_model.insert_node_purely(dequantize_node)
221-
search_and_replace_input(next_node, node.output[0], dequantize_node.output[0])
222222
else:
223223
if quantize_node is None:
224224
output_value_info = [f'{node.output[0]}_QuantizeLinear']

test/backend/test_backend.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mqbench.convert_deploy import convert_deploy
66
from mqbench.utils.state import enable_calibration, enable_quantization
77

8+
from .test_model.unet import UNet
89
from ..version import GITHUB_RES
910

1011

@@ -117,7 +118,7 @@ def test_quantize_vitis(self):
117118
else:
118119
pass
119120

120-
def test_quantize_onnxqnn(self):
121+
def test_quantize_onnxqnn_1(self):
121122
model_to_quantize = torch.hub.load(GITHUB_RES, 'resnet18', pretrained=False)
122123
dummy_input = torch.randn(2, 3, 224, 224, device='cpu')
123124
model_to_quantize.train()
@@ -130,6 +131,19 @@ def test_quantize_onnxqnn(self):
130131
model_prepared.eval()
131132
convert_deploy(model_prepared, BackendType.ONNX_QNN, {'x': [1, 3, 224, 224]}, model_name='resnet18_onnx_qnn.onnx')
132133

134+
def test_quantize_onnxqnn_2(self):
135+
model_to_quantize = UNet(3, 2)
136+
dummy_input = torch.randn(2, 3, 224, 224, device='cpu')
137+
model_to_quantize.train()
138+
model_prepared = prepare_by_platform(model_to_quantize, BackendType.ONNX_QNN)
139+
enable_calibration(model_prepared)
140+
model_prepared(dummy_input)
141+
enable_quantization(model_prepared)
142+
loss = model_prepared(dummy_input).sum()
143+
loss.backward()
144+
model_prepared.eval()
145+
convert_deploy(model_prepared, BackendType.ONNX_QNN, {'x': [1, 3, 224, 224]}, model_name='resnet18_onnx_qnn.onnx')
146+
133147
def test_quantize_ppl_cuda(self):
134148
import numpy as np
135149
model_to_quantize = torch.hub.load(GITHUB_RES, 'resnet18', pretrained=False)

test/backend/test_model/unet.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) 2022 Carl Zeiss AG – All Rights Reserved.
2+
# ZEISS, ZEISS.com are registered trademarks of Carl Zeiss AG
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
8+
__all__ = ['UNet']
9+
10+
class UNet(nn.Module):
11+
def __init__(
12+
self, num_channels, num_classes, depth=4, initial_filter_count=64, bilinear=True
13+
):
14+
super(UNet, self).__init__()
15+
16+
self.num_channels = num_channels
17+
self.num_classes = num_classes
18+
self.depth = depth
19+
self.initial_filter_count = initial_filter_count
20+
self.bilinear = bilinear
21+
22+
factor = 2 if bilinear else 1
23+
24+
filter_count = initial_filter_count
25+
26+
encoder_blocks = []
27+
encoder_blocks.append(DoubleConv(num_channels, filter_count))
28+
for d in range(depth):
29+
if d < depth - 1:
30+
encoder_blocks.append(Down(filter_count, 2 * filter_count))
31+
else:
32+
encoder_blocks.append(Down(filter_count, (2 * filter_count) // factor))
33+
filter_count *= 2
34+
self.encoder_blocks = nn.Sequential(*encoder_blocks)
35+
36+
decoder_blocks = []
37+
for d in range(depth):
38+
if d < depth - 1:
39+
decoder_blocks.append(
40+
Up(filter_count, filter_count // 2 // factor, bilinear)
41+
)
42+
else:
43+
decoder_blocks.append(Up(filter_count, filter_count // 2, bilinear))
44+
filter_count //= 2
45+
self.decoder_blocks = nn.Sequential(*decoder_blocks)
46+
47+
self.outc = OutputConvolution(filter_count, num_classes)
48+
49+
def forward(self, x):
50+
xs = []
51+
for encoder_block in self.encoder_blocks:
52+
x = encoder_block(x)
53+
xs.append(x)
54+
55+
xs.reverse()
56+
xs = xs[1:]
57+
58+
for decoder_block, x_skip in zip(self.decoder_blocks, xs):
59+
x = decoder_block(x, x_skip)
60+
61+
logits = self.outc(x)
62+
63+
return logits
64+
65+
66+
class DoubleConv(nn.Module):
67+
"""Module combining Conv -> BN -> ReLU -> Conv -> BN -> ReLU."""
68+
69+
def __init__(
70+
self, num_input_channels, num_output_channels, num_middle_channels=None
71+
):
72+
super().__init__()
73+
74+
if not num_middle_channels:
75+
num_middle_channels = num_output_channels
76+
77+
self.double_conv = nn.Sequential(
78+
nn.Conv2d(
79+
num_input_channels,
80+
num_middle_channels,
81+
kernel_size=3,
82+
padding=1,
83+
bias=False,
84+
),
85+
nn.BatchNorm2d(num_middle_channels),
86+
nn.ReLU(inplace=True),
87+
nn.Conv2d(
88+
num_middle_channels,
89+
num_output_channels,
90+
kernel_size=3,
91+
padding=1,
92+
bias=False,
93+
),
94+
nn.BatchNorm2d(num_output_channels),
95+
nn.ReLU(inplace=True),
96+
)
97+
98+
def forward(self, x):
99+
return self.double_conv(x)
100+
101+
102+
class Down(nn.Module):
103+
"""Module combining downscaling and DoubleConvolution."""
104+
105+
def __init__(self, num_input_channels, num_output_channels):
106+
super().__init__()
107+
108+
self.maxpool_conv = nn.Sequential(
109+
nn.MaxPool2d(2), DoubleConv(num_input_channels, num_output_channels)
110+
)
111+
112+
def forward(self, x):
113+
return self.maxpool_conv(x)
114+
115+
116+
class Up(nn.Module):
117+
"""Module combining upscaling and DoubleConvolution."""
118+
119+
def __init__(self, num_input_channels, num_output_channels, bilinear=True):
120+
super().__init__()
121+
122+
# if bilinear, use the normal convolutions to reduce the number of channels
123+
if bilinear:
124+
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
125+
self.conv = DoubleConv(
126+
num_input_channels, num_output_channels, num_input_channels // 2
127+
)
128+
else:
129+
self.up = nn.ConvTranspose2d(
130+
num_input_channels, num_input_channels // 2, kernel_size=2, stride=2
131+
)
132+
self.conv = DoubleConv(num_input_channels, num_output_channels)
133+
134+
def forward(self, x1, x2):
135+
x1 = self.up(x1)
136+
x = torch.cat([x2, x1], dim=1)
137+
138+
return self.conv(x)
139+
140+
141+
class OutputConvolution(nn.Module):
142+
def __init__(self, num_input_channels, num_output_channels):
143+
super(OutputConvolution, self).__init__()
144+
145+
self.conv = nn.Conv2d(num_input_channels, num_output_channels, kernel_size=1)
146+
147+
def forward(self, x):
148+
return self.conv(x)

0 commit comments

Comments
 (0)