4141from qonnx .util .basic import gen_finn_dt_tensor , qonnx_make_model
4242
4343
44- # depthwise or channelwise
45- @pytest .mark .parametrize ("dw" , [True , False ])
46- # conv bias
47- @pytest .mark .parametrize (
48- "bias" , ["float" , "int_quant_per_tensor" , "int_quant_per_channel" , "bp_quant_per_tensor" , "bp_quant_per_channel" , None ]
49- )
50- def test_extract_conv_bias (dw , bias ):
44+ # Helper function to generate valid parameter combinations, with an option to include 'dw'
45+ def generate_params (include_dw = True ):
46+ params = []
47+ biases = ["float" , None , "int_quant" , "bp_quant" ]
48+ scales = ["per_tensor" , "per_channel" ]
49+ zero_points = ["per_tensor" , "per_channel" ]
50+
51+ dw_options = [True , False ] if include_dw else [None ]
52+
53+ for dw in dw_options :
54+ for bias in biases :
55+ if bias in ["float" , None ]:
56+ # Ignore scale and zero_point for this bias
57+ params .append ((dw , bias , None , None ) if include_dw else (bias , None , None ))
58+ else :
59+ # Include all combinations of scale and zero_point for other biases
60+ for scale in scales :
61+ for zero_point in zero_points :
62+ if include_dw :
63+ params .append ((dw , bias , scale , zero_point ))
64+ else :
65+ params .append ((bias , scale , zero_point ))
66+ return params
67+
68+
69+ @pytest .mark .parametrize ("dw, bias, scale, zero_point" , generate_params (include_dw = True ))
70+ def test_extract_conv_bias (dw , bias , scale , zero_point ):
5171 ishape = (1 , 32 , 111 , 111 )
5272 if dw is True :
5373 group = ishape [1 ]
@@ -75,10 +95,14 @@ def test_extract_conv_bias(dw, bias):
7595
7696 if bias is not None :
7797 bias_shape = (out_channels ,)
78- if "quant_per_channel " in bias :
98+ if scale is not None and "per_channel " in scale :
7999 scale_shape = (out_channels ,)
80- elif "quant_per_tensor " in bias :
100+ elif scale is not None and "per_tensor " in scale :
81101 scale_shape = (1 ,)
102+ if scale is not None and "per_channel" in zero_point :
103+ zpt_shape = (out_channels ,)
104+ elif scale is not None and "per_tensor" in zero_point :
105+ zpt_shape = (1 ,)
82106 B = oh .make_tensor_value_info ("B" , TensorProto .FLOAT , bias_shape )
83107
84108 cnv_node = oh .make_node (
@@ -94,15 +118,15 @@ def test_extract_conv_bias(dw, bias):
94118 value_info = [W ] if not bias else [W , B ]
95119 # if the bias isn't quantized, we can directly wire up the Conv layer
96120 # otherwise an additional Quant node needs to be inserted
97- if bias is not None and "quant" in bias :
121+ if bias not in [ "float" , None ] :
98122 if "bp" in bias :
99123 optype = "BipolarQuant"
100124 elif "int" in bias :
101125 optype = "IntQuant"
102126 # inputs to Quant node
103127 param0 = oh .make_tensor_value_info ("param0" , TensorProto .FLOAT , bias_shape )
104128 param1 = oh .make_tensor_value_info ("param1" , TensorProto .FLOAT , scale_shape )
105- param2 = oh .make_tensor_value_info ("param2" , TensorProto .FLOAT , [ 1 ] )
129+ param2 = oh .make_tensor_value_info ("param2" , TensorProto .FLOAT , zpt_shape )
106130 value_info .append (param0 )
107131 value_info .append (param1 )
108132 value_info .append (param2 )
@@ -138,11 +162,12 @@ def test_extract_conv_bias(dw, bias):
138162 if bias is not None :
139163 b_tensor = gen_finn_dt_tensor (DataType ["FLOAT32" ], bias_shape )
140164 # set B tensor directly or set first input of quant node
141- if "quant" in bias :
165+ if bias != "float" :
142166 model .set_initializer ("param0" , b_tensor )
143- scale = gen_finn_dt_tensor (DataType ["FLOAT32" ], bias_shape )
167+ scale = gen_finn_dt_tensor (DataType ["FLOAT32" ], scale_shape )
144168 model .set_initializer ("param1" , scale )
145- model .set_initializer ("param2" , np .zeros (1 ))
169+ zpt = gen_finn_dt_tensor (DataType ["FLOAT32" ], zpt_shape )
170+ model .set_initializer ("param2" , zpt )
146171 if "int" in bias :
147172 model .set_initializer ("param3" , 8 * np .ones (1 ))
148173 else :
@@ -167,11 +192,8 @@ def test_extract_conv_bias(dw, bias):
167192 assert np .isclose (produced , expected , atol = 1e-3 ).all ()
168193
169194
170- # conv transpose bias
171- @pytest .mark .parametrize (
172- "bias" , ["float" , "int_quant_per_tensor" , "int_quant_per_channel" , "bp_quant_per_tensor" , "bp_quant_per_channel" , None ]
173- )
174- def test_extract_conv_transpose_bias (bias ):
195+ @pytest .mark .parametrize ("bias, scale, zero_point" , generate_params (include_dw = False ))
196+ def test_extract_conv_transpose_bias (bias , scale , zero_point ):
175197 ishape = (1 , 32 , 111 , 111 )
176198 group = 1
177199 out_channels = 64
@@ -191,10 +213,15 @@ def test_extract_conv_transpose_bias(bias):
191213
192214 if bias is not None :
193215 bias_shape = (out_channels ,)
194- if "quant_per_channel " in bias :
216+ if scale is not None and "per_channel " in scale :
195217 scale_shape = (out_channels ,)
196- elif "quant_per_tensor " in bias :
218+ elif scale is not None and "per_tensor " in scale :
197219 scale_shape = (1 ,)
220+ if scale is not None and "per_channel" in zero_point :
221+ zpt_shape = (out_channels ,)
222+ elif scale is not None and "per_tensor" in zero_point :
223+ zpt_shape = (1 ,)
224+
198225 B = oh .make_tensor_value_info ("B" , TensorProto .FLOAT , bias_shape )
199226
200227 cnv_node = oh .make_node (
@@ -211,15 +238,15 @@ def test_extract_conv_transpose_bias(bias):
211238
212239 # If the bias isn't quantized, we can directly wire up the ConvTranspose layer
213240 # Otherwise, an additional Quant node needs to be inserted
214- if bias is not None and "quant" in bias :
241+ if bias not in [ "float" , None ] :
215242 if "bp" in bias :
216243 optype = "BipolarQuant"
217244 elif "int" in bias :
218245 optype = "IntQuant"
219246 # Inputs to Quant node
220247 param0 = oh .make_tensor_value_info ("param0" , TensorProto .FLOAT , bias_shape )
221248 param1 = oh .make_tensor_value_info ("param1" , TensorProto .FLOAT , scale_shape )
222- param2 = oh .make_tensor_value_info ("param2" , TensorProto .FLOAT , [ 1 ] )
249+ param2 = oh .make_tensor_value_info ("param2" , TensorProto .FLOAT , zpt_shape )
223250 value_info .append (param0 )
224251 value_info .append (param1 )
225252 value_info .append (param2 )
@@ -256,11 +283,12 @@ def test_extract_conv_transpose_bias(bias):
256283 if bias is not None :
257284 b_tensor = gen_finn_dt_tensor (DataType ["FLOAT32" ], bias_shape )
258285 # Set B tensor directly or set first input of quant node
259- if "quant" in bias :
286+ if bias != "float" :
260287 model .set_initializer ("param0" , b_tensor )
261- scale = gen_finn_dt_tensor (DataType ["FLOAT32" ], bias_shape )
288+ scale = gen_finn_dt_tensor (DataType ["FLOAT32" ], scale_shape )
262289 model .set_initializer ("param1" , scale )
263- model .set_initializer ("param2" , np .zeros (1 ))
290+ zpt = gen_finn_dt_tensor (DataType ["FLOAT32" ], zpt_shape )
291+ model .set_initializer ("param2" , zpt )
264292 if "int" in bias :
265293 model .set_initializer ("param3" , 8 * np .ones (1 ))
266294 else :
0 commit comments