Skip to content

Commit d1dc557

Browse files
committed
[Transformation] Handle zero point when extracting bias, add test
1 parent 070b6cd commit d1dc557

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

src/qonnx/transformation/extract_conv_bias.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def apply(self, model):
7777
quant_scale = model.get_initializer(producer.input[1])
7878
if quant_scale.shape != (1,):
7979
model.set_initializer(producer.input[1], quant_scale.reshape(add_shape))
80+
quant_zpt = model.get_initializer(producer.input[2])
81+
if quant_zpt.shape != (1,):
82+
model.set_initializer(producer.input[2], quant_zpt.reshape(add_shape))
8083
model.set_tensor_shape(producer.output[0], add_shape)
8184

8285
act_add_tensor = helper.make_tensor_value_info(

tests/transformation/test_extract_conv_bias.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,33 @@
4141
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model
4242

4343

44-
# depthwise or channelwise
45-
@pytest.mark.parametrize("dw", [True, False])
46-
# conv bias
47-
@pytest.mark.parametrize(
48-
"bias", ["float", "int_quant_per_tensor", "int_quant_per_channel", "bp_quant_per_tensor", "bp_quant_per_channel", None]
49-
)
50-
def test_extract_conv_bias(dw, bias):
44+
# Helper function to generate valid parameter combinations, with an option to include 'dw'
45+
def generate_params(include_dw=True):
46+
params = []
47+
biases = ["float", None, "int_quant", "bp_quant"]
48+
scales = ["per_tensor", "per_channel"]
49+
zero_points = ["per_tensor", "per_channel"]
50+
51+
dw_options = [True, False] if include_dw else [None]
52+
53+
for dw in dw_options:
54+
for bias in biases:
55+
if bias in ["float", None]:
56+
# Ignore scale and zero_point for this bias
57+
params.append((dw, bias, None, None) if include_dw else (bias, None, None))
58+
else:
59+
# Include all combinations of scale and zero_point for other biases
60+
for scale in scales:
61+
for zero_point in zero_points:
62+
if include_dw:
63+
params.append((dw, bias, scale, zero_point))
64+
else:
65+
params.append((bias, scale, zero_point))
66+
return params
67+
68+
69+
@pytest.mark.parametrize("dw, bias, scale, zero_point", generate_params(include_dw=True))
70+
def test_extract_conv_bias(dw, bias, scale, zero_point):
5171
ishape = (1, 32, 111, 111)
5272
if dw is True:
5373
group = ishape[1]
@@ -75,10 +95,14 @@ def test_extract_conv_bias(dw, bias):
7595

7696
if bias is not None:
7797
bias_shape = (out_channels,)
78-
if "quant_per_channel" in bias:
98+
if scale is not None and "per_channel" in scale:
7999
scale_shape = (out_channels,)
80-
elif "quant_per_tensor" in bias:
100+
elif scale is not None and "per_tensor" in scale:
81101
scale_shape = (1,)
102+
if scale is not None and "per_channel" in zero_point:
103+
zpt_shape = (out_channels,)
104+
elif scale is not None and "per_tensor" in zero_point:
105+
zpt_shape = (1,)
82106
B = oh.make_tensor_value_info("B", TensorProto.FLOAT, bias_shape)
83107

84108
cnv_node = oh.make_node(
@@ -94,15 +118,15 @@ def test_extract_conv_bias(dw, bias):
94118
value_info = [W] if not bias else [W, B]
95119
# if the bias isn't quantized, we can directly wire up the Conv layer
96120
# otherwise an additional Quant node needs to be inserted
97-
if bias is not None and "quant" in bias:
121+
if bias not in ["float", None]:
98122
if "bp" in bias:
99123
optype = "BipolarQuant"
100124
elif "int" in bias:
101125
optype = "IntQuant"
102126
# inputs to Quant node
103127
param0 = oh.make_tensor_value_info("param0", TensorProto.FLOAT, bias_shape)
104128
param1 = oh.make_tensor_value_info("param1", TensorProto.FLOAT, scale_shape)
105-
param2 = oh.make_tensor_value_info("param2", TensorProto.FLOAT, [1])
129+
param2 = oh.make_tensor_value_info("param2", TensorProto.FLOAT, zpt_shape)
106130
value_info.append(param0)
107131
value_info.append(param1)
108132
value_info.append(param2)
@@ -138,11 +162,12 @@ def test_extract_conv_bias(dw, bias):
138162
if bias is not None:
139163
b_tensor = gen_finn_dt_tensor(DataType["FLOAT32"], bias_shape)
140164
# set B tensor directly or set first input of quant node
141-
if "quant" in bias:
165+
if bias != "float":
142166
model.set_initializer("param0", b_tensor)
143-
scale = gen_finn_dt_tensor(DataType["FLOAT32"], bias_shape)
167+
scale = gen_finn_dt_tensor(DataType["FLOAT32"], scale_shape)
144168
model.set_initializer("param1", scale)
145-
model.set_initializer("param2", np.zeros(1))
169+
zpt = gen_finn_dt_tensor(DataType["FLOAT32"], zpt_shape)
170+
model.set_initializer("param2", zpt)
146171
if "int" in bias:
147172
model.set_initializer("param3", 8 * np.ones(1))
148173
else:
@@ -167,11 +192,8 @@ def test_extract_conv_bias(dw, bias):
167192
assert np.isclose(produced, expected, atol=1e-3).all()
168193

169194

170-
# conv transpose bias
171-
@pytest.mark.parametrize(
172-
"bias", ["float", "int_quant_per_tensor", "int_quant_per_channel", "bp_quant_per_tensor", "bp_quant_per_channel", None]
173-
)
174-
def test_extract_conv_transpose_bias(bias):
195+
@pytest.mark.parametrize("bias, scale, zero_point", generate_params(include_dw=False))
196+
def test_extract_conv_transpose_bias(bias, scale, zero_point):
175197
ishape = (1, 32, 111, 111)
176198
group = 1
177199
out_channels = 64
@@ -191,10 +213,15 @@ def test_extract_conv_transpose_bias(bias):
191213

192214
if bias is not None:
193215
bias_shape = (out_channels,)
194-
if "quant_per_channel" in bias:
216+
if scale is not None and "per_channel" in scale:
195217
scale_shape = (out_channels,)
196-
elif "quant_per_tensor" in bias:
218+
elif scale is not None and "per_tensor" in scale:
197219
scale_shape = (1,)
220+
if scale is not None and "per_channel" in zero_point:
221+
zpt_shape = (out_channels,)
222+
elif scale is not None and "per_tensor" in zero_point:
223+
zpt_shape = (1,)
224+
198225
B = oh.make_tensor_value_info("B", TensorProto.FLOAT, bias_shape)
199226

200227
cnv_node = oh.make_node(
@@ -211,15 +238,15 @@ def test_extract_conv_transpose_bias(bias):
211238

212239
# If the bias isn't quantized, we can directly wire up the ConvTranspose layer
213240
# Otherwise, an additional Quant node needs to be inserted
214-
if bias is not None and "quant" in bias:
241+
if bias not in ["float", None]:
215242
if "bp" in bias:
216243
optype = "BipolarQuant"
217244
elif "int" in bias:
218245
optype = "IntQuant"
219246
# Inputs to Quant node
220247
param0 = oh.make_tensor_value_info("param0", TensorProto.FLOAT, bias_shape)
221248
param1 = oh.make_tensor_value_info("param1", TensorProto.FLOAT, scale_shape)
222-
param2 = oh.make_tensor_value_info("param2", TensorProto.FLOAT, [1])
249+
param2 = oh.make_tensor_value_info("param2", TensorProto.FLOAT, zpt_shape)
223250
value_info.append(param0)
224251
value_info.append(param1)
225252
value_info.append(param2)
@@ -256,11 +283,12 @@ def test_extract_conv_transpose_bias(bias):
256283
if bias is not None:
257284
b_tensor = gen_finn_dt_tensor(DataType["FLOAT32"], bias_shape)
258285
# Set B tensor directly or set first input of quant node
259-
if "quant" in bias:
286+
if bias != "float":
260287
model.set_initializer("param0", b_tensor)
261-
scale = gen_finn_dt_tensor(DataType["FLOAT32"], bias_shape)
288+
scale = gen_finn_dt_tensor(DataType["FLOAT32"], scale_shape)
262289
model.set_initializer("param1", scale)
263-
model.set_initializer("param2", np.zeros(1))
290+
zpt = gen_finn_dt_tensor(DataType["FLOAT32"], zpt_shape)
291+
model.set_initializer("param2", zpt)
264292
if "int" in bias:
265293
model.set_initializer("param3", 8 * np.ones(1))
266294
else:

0 commit comments

Comments
 (0)