Skip to content

Commit 3d87295

Browse files
authored
Merge pull request #224 from fastmachinelearning/feature/extract_conv_quant_bias
Extract quantized biases for Conv/ConvTranspose
2 parents 752850f + d1dc557 commit 3d87295

File tree

2 files changed

+334
-2
lines changed

2 files changed

+334
-2
lines changed

src/qonnx/transformation/extract_conv_bias.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,16 @@ def apply(self, model):
4949
# Extract bias
5050
bias = model.get_initializer(n.input[2])
5151
if bias is None:
52-
warnings.warn(f"Could not extract bias from node {n}")
53-
continue
52+
# check if bias is quantized
53+
# then initializer would be empty but coming from a Quant node
54+
producer = model.find_producer(n.input[2])
55+
# only if producer is Quant node and has no predecessors continue with extraction
56+
if not (
57+
producer.op_type in ["Quant", "IntQuant", "BipolarQuant"]
58+
and not model.find_direct_predecessors(producer)
59+
):
60+
warnings.warn(f"Could not extract bias from node {n}")
61+
continue
5462

5563
# Insert bias as Add node behind the Conv node
5664
out_shape = model.get_tensor_shape(n.output[0])
@@ -62,6 +70,17 @@ def apply(self, model):
6270
add_shape[1] = bias_shape[0]
6371
if bias is not None:
6472
model.set_initializer(n.input[2], bias.reshape(add_shape))
73+
else:
74+
# if connected to a Quant node, we need to reshape the parameters
75+
quant_param = model.get_initializer(producer.input[0])
76+
model.set_initializer(producer.input[0], quant_param.reshape(add_shape))
77+
quant_scale = model.get_initializer(producer.input[1])
78+
if quant_scale.shape != (1,):
79+
model.set_initializer(producer.input[1], quant_scale.reshape(add_shape))
80+
quant_zpt = model.get_initializer(producer.input[2])
81+
if quant_zpt.shape != (1,):
82+
model.set_initializer(producer.input[2], quant_zpt.reshape(add_shape))
83+
model.set_tensor_shape(producer.output[0], add_shape)
6584

6685
act_add_tensor = helper.make_tensor_value_info(
6786
model.make_new_valueinfo_name(),
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# Copyright (c) 2025 Advanced Micro Devices, Inc.
2+
# All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# * Neither the name of qonnx nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
30+
import pytest
31+
32+
import numpy as np
33+
import onnx.helper as oh
34+
from onnx import TensorProto
35+
36+
import qonnx.core.onnx_exec as oxe
37+
from qonnx.core.datatype import DataType
38+
from qonnx.core.modelwrapper import ModelWrapper
39+
from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv
40+
from qonnx.transformation.infer_shapes import InferShapes
41+
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model
42+
43+
44+
# Helper function to generate valid parameter combinations, with an option to include 'dw'
45+
def generate_params(include_dw=True):
46+
params = []
47+
biases = ["float", None, "int_quant", "bp_quant"]
48+
scales = ["per_tensor", "per_channel"]
49+
zero_points = ["per_tensor", "per_channel"]
50+
51+
dw_options = [True, False] if include_dw else [None]
52+
53+
for dw in dw_options:
54+
for bias in biases:
55+
if bias in ["float", None]:
56+
# Ignore scale and zero_point for this bias
57+
params.append((dw, bias, None, None) if include_dw else (bias, None, None))
58+
else:
59+
# Include all combinations of scale and zero_point for other biases
60+
for scale in scales:
61+
for zero_point in zero_points:
62+
if include_dw:
63+
params.append((dw, bias, scale, zero_point))
64+
else:
65+
params.append((bias, scale, zero_point))
66+
return params
67+
68+
69+
@pytest.mark.parametrize("dw, bias, scale, zero_point", generate_params(include_dw=True))
70+
def test_extract_conv_bias(dw, bias, scale, zero_point):
71+
ishape = (1, 32, 111, 111)
72+
if dw is True:
73+
group = ishape[1]
74+
out_channels = ishape[1]
75+
kernel_size = 3
76+
padding = 1
77+
stride = 1
78+
w_shape = (32, 1, 3, 3)
79+
80+
else:
81+
group = 1
82+
out_channels = 64
83+
kernel_size = 1
84+
padding = 0
85+
stride = 1
86+
w_shape = (64, 32, 1, 1)
87+
88+
wdt = idt = odt = DataType["FLOAT32"]
89+
90+
# set up onnx model
91+
inp = oh.make_tensor_value_info("inp", TensorProto.FLOAT, ishape)
92+
outp = oh.make_tensor_value_info("outp", TensorProto.FLOAT, [ishape[0], out_channels, ishape[2], ishape[3]])
93+
94+
W = oh.make_tensor_value_info("W", TensorProto.FLOAT, w_shape)
95+
96+
if bias is not None:
97+
bias_shape = (out_channels,)
98+
if scale is not None and "per_channel" in scale:
99+
scale_shape = (out_channels,)
100+
elif scale is not None and "per_tensor" in scale:
101+
scale_shape = (1,)
102+
if scale is not None and "per_channel" in zero_point:
103+
zpt_shape = (out_channels,)
104+
elif scale is not None and "per_tensor" in zero_point:
105+
zpt_shape = (1,)
106+
B = oh.make_tensor_value_info("B", TensorProto.FLOAT, bias_shape)
107+
108+
cnv_node = oh.make_node(
109+
"Conv",
110+
inputs=["inp", "W"] if not bias else ["inp", "W", "B"],
111+
outputs=["outp"],
112+
kernel_shape=[kernel_size, kernel_size],
113+
pads=[padding, padding, padding, padding],
114+
strides=[stride, stride],
115+
group=group,
116+
)
117+
nodes = [cnv_node]
118+
value_info = [W] if not bias else [W, B]
119+
# if the bias isn't quantized, we can directly wire up the Conv layer
120+
# otherwise an additional Quant node needs to be inserted
121+
if bias not in ["float", None]:
122+
if "bp" in bias:
123+
optype = "BipolarQuant"
124+
elif "int" in bias:
125+
optype = "IntQuant"
126+
# inputs to Quant node
127+
param0 = oh.make_tensor_value_info("param0", TensorProto.FLOAT, bias_shape)
128+
param1 = oh.make_tensor_value_info("param1", TensorProto.FLOAT, scale_shape)
129+
param2 = oh.make_tensor_value_info("param2", TensorProto.FLOAT, zpt_shape)
130+
value_info.append(param0)
131+
value_info.append(param1)
132+
value_info.append(param2)
133+
if "int" in bias:
134+
param3 = oh.make_tensor_value_info("param3", TensorProto.FLOAT, [1])
135+
value_info.append(param3)
136+
quant_node = oh.make_node(
137+
optype,
138+
domain="qonnx.custom_op.general",
139+
inputs=["param0", "param1", "param2", "param3"] if "int" in bias else ["param0", "param1", "param2"],
140+
outputs=["B"],
141+
narrow=0,
142+
rounding_mode="ROUND",
143+
signed=1,
144+
)
145+
nodes.append(quant_node)
146+
graph = oh.make_graph(
147+
nodes=nodes,
148+
name="cnv_graph",
149+
inputs=[inp],
150+
outputs=[outp],
151+
value_info=value_info,
152+
)
153+
154+
model = qonnx_make_model(graph, producer_name="test-cnv-model")
155+
model = ModelWrapper(model)
156+
model.set_tensor_datatype("inp", idt)
157+
model.set_tensor_datatype("outp", odt)
158+
model.set_tensor_datatype("W", wdt)
159+
160+
w_tensor = gen_finn_dt_tensor(wdt, w_shape)
161+
162+
if bias is not None:
163+
b_tensor = gen_finn_dt_tensor(DataType["FLOAT32"], bias_shape)
164+
# set B tensor directly or set first input of quant node
165+
if bias != "float":
166+
model.set_initializer("param0", b_tensor)
167+
scale = gen_finn_dt_tensor(DataType["FLOAT32"], scale_shape)
168+
model.set_initializer("param1", scale)
169+
zpt = gen_finn_dt_tensor(DataType["FLOAT32"], zpt_shape)
170+
model.set_initializer("param2", zpt)
171+
if "int" in bias:
172+
model.set_initializer("param3", 8 * np.ones(1))
173+
else:
174+
model.set_initializer("B", b_tensor)
175+
176+
model.set_initializer("W", w_tensor)
177+
model = model.transform(InferShapes())
178+
179+
input_tensor = gen_finn_dt_tensor(idt, ishape)
180+
output_dict = oxe.execute_onnx(model, {model.graph.input[0].name: input_tensor})
181+
expected = output_dict[model.graph.output[0].name]
182+
183+
model = model.transform(ExtractBiasFromConv())
184+
185+
if bias is not None:
186+
assert len(model.get_nodes_by_op_type("Add")) > 0, "Bias wasn't extracted into add node"
187+
188+
output_dict = oxe.execute_onnx(model, {model.graph.input[0].name: input_tensor})
189+
produced = output_dict[model.graph.output[0].name]
190+
191+
# check if is close (fp calculation)
192+
assert np.isclose(produced, expected, atol=1e-3).all()
193+
194+
195+
@pytest.mark.parametrize("bias, scale, zero_point", generate_params(include_dw=False))
196+
def test_extract_conv_transpose_bias(bias, scale, zero_point):
197+
ishape = (1, 32, 111, 111)
198+
group = 1
199+
out_channels = 64
200+
kernel_size = 1
201+
padding = 0
202+
stride = 1
203+
w_shape = (32, 64, 1, 1)
204+
205+
wdt = idt = odt = DataType["FLOAT32"]
206+
207+
# Set up ONNX model
208+
inp = oh.make_tensor_value_info("inp", TensorProto.FLOAT, ishape)
209+
outp_shape = (ishape[0], out_channels, ishape[2], ishape[3])
210+
outp = oh.make_tensor_value_info("outp", TensorProto.FLOAT, outp_shape)
211+
212+
W = oh.make_tensor_value_info("W", TensorProto.FLOAT, w_shape)
213+
214+
if bias is not None:
215+
bias_shape = (out_channels,)
216+
if scale is not None and "per_channel" in scale:
217+
scale_shape = (out_channels,)
218+
elif scale is not None and "per_tensor" in scale:
219+
scale_shape = (1,)
220+
if scale is not None and "per_channel" in zero_point:
221+
zpt_shape = (out_channels,)
222+
elif scale is not None and "per_tensor" in zero_point:
223+
zpt_shape = (1,)
224+
225+
B = oh.make_tensor_value_info("B", TensorProto.FLOAT, bias_shape)
226+
227+
cnv_node = oh.make_node(
228+
"ConvTranspose",
229+
inputs=["inp", "W"] if not bias else ["inp", "W", "B"],
230+
outputs=["outp"],
231+
kernel_shape=[kernel_size, kernel_size],
232+
pads=[padding, padding, padding, padding],
233+
strides=[stride, stride],
234+
group=group,
235+
)
236+
nodes = [cnv_node]
237+
value_info = [W] if not bias else [W, B]
238+
239+
# If the bias isn't quantized, we can directly wire up the ConvTranspose layer
240+
# Otherwise, an additional Quant node needs to be inserted
241+
if bias not in ["float", None]:
242+
if "bp" in bias:
243+
optype = "BipolarQuant"
244+
elif "int" in bias:
245+
optype = "IntQuant"
246+
# Inputs to Quant node
247+
param0 = oh.make_tensor_value_info("param0", TensorProto.FLOAT, bias_shape)
248+
param1 = oh.make_tensor_value_info("param1", TensorProto.FLOAT, scale_shape)
249+
param2 = oh.make_tensor_value_info("param2", TensorProto.FLOAT, zpt_shape)
250+
value_info.append(param0)
251+
value_info.append(param1)
252+
value_info.append(param2)
253+
if "int" in bias:
254+
param3 = oh.make_tensor_value_info("param3", TensorProto.FLOAT, [1])
255+
value_info.append(param3)
256+
quant_node = oh.make_node(
257+
optype,
258+
domain="qonnx.custom_op.general",
259+
inputs=["param0", "param1", "param2", "param3"] if "int" in bias else ["param0", "param1", "param2"],
260+
outputs=["B"],
261+
narrow=0,
262+
rounding_mode="ROUND",
263+
signed=1,
264+
)
265+
nodes.append(quant_node)
266+
267+
graph = oh.make_graph(
268+
nodes=nodes,
269+
name="cnv_transpose_graph",
270+
inputs=[inp],
271+
outputs=[outp],
272+
value_info=value_info,
273+
)
274+
275+
model = qonnx_make_model(graph, producer_name="test-cnv-transpose-model")
276+
model = ModelWrapper(model)
277+
model.set_tensor_datatype("inp", idt)
278+
model.set_tensor_datatype("outp", odt)
279+
model.set_tensor_datatype("W", wdt)
280+
281+
w_tensor = gen_finn_dt_tensor(wdt, w_shape)
282+
283+
if bias is not None:
284+
b_tensor = gen_finn_dt_tensor(DataType["FLOAT32"], bias_shape)
285+
# Set B tensor directly or set first input of quant node
286+
if bias != "float":
287+
model.set_initializer("param0", b_tensor)
288+
scale = gen_finn_dt_tensor(DataType["FLOAT32"], scale_shape)
289+
model.set_initializer("param1", scale)
290+
zpt = gen_finn_dt_tensor(DataType["FLOAT32"], zpt_shape)
291+
model.set_initializer("param2", zpt)
292+
if "int" in bias:
293+
model.set_initializer("param3", 8 * np.ones(1))
294+
else:
295+
model.set_initializer("B", b_tensor)
296+
297+
model.set_initializer("W", w_tensor)
298+
model = model.transform(InferShapes())
299+
300+
input_tensor = gen_finn_dt_tensor(idt, ishape)
301+
output_dict = oxe.execute_onnx(model, {model.graph.input[0].name: input_tensor})
302+
expected = output_dict[model.graph.output[0].name]
303+
304+
model = model.transform(ExtractBiasFromConv())
305+
306+
if bias is not None:
307+
assert len(model.get_nodes_by_op_type("Add")) > 0, "Bias wasn't extracted into add node"
308+
309+
output_dict = oxe.execute_onnx(model, {model.graph.input[0].name: input_tensor})
310+
produced = output_dict[model.graph.output[0].name]
311+
312+
# Check if is close (fp calculation)
313+
assert np.isclose(produced, expected, atol=1e-3).all()

0 commit comments

Comments
 (0)