@@ -1105,7 +1105,7 @@ enum xnn_status xnn_define_convolution_2d(
1105
1105
return status ;
1106
1106
}
1107
1107
1108
- const struct xnn_value * input_value = & subgraph -> values [input_id ];
1108
+ struct xnn_value * input_value = & subgraph -> values [input_id ];
1109
1109
status = xnn_subgraph_check_input_type_dense (xnn_node_type_convolution_2d , input_id , input_value );
1110
1110
if (status != xnn_status_success ) {
1111
1111
return status ;
@@ -1299,15 +1299,19 @@ enum xnn_status xnn_define_convolution_2d(
1299
1299
}
1300
1300
}
1301
1301
}
1302
- if (input_value -> datatype == output_value -> datatype ) {
1303
- const bool unit_subsampling = (subsampling_width | subsampling_height ) == 1 ;
1304
- const size_t kernel_size = kernel_height * kernel_width ;
1305
- if (groups == 1 && kernel_size == 1 && unit_subsampling && !any_padding ) {
1306
- // Check if the convolution can take the vmulcaddc path.
1307
- if (group_input_channels + group_output_channels > 2 ) {
1308
- return xnn_define_fully_connected (subgraph , output_min , output_max ,
1309
- input_id , filter_id , bias_id , output_id , /*flags=*/ 0 );
1302
+ const bool unit_subsampling = (subsampling_width | subsampling_height ) == 1 ;
1303
+ const size_t kernel_size = kernel_height * kernel_width ;
1304
+ if (groups == 1 && kernel_size == 1 && unit_subsampling && !any_padding ) {
1305
+ // Check if the convolution can take the vmulcaddc path.
1306
+ if (group_input_channels + group_output_channels > 2 ) {
1307
+ if (input_value -> datatype == xnn_datatype_qdint8 ) {
1308
+ // Dynammically quantized tensors for fully connected ops are quantized
1309
+ // per-channel, not per-batch.
1310
+ input_value -> quantization .num_nonbatch_dims = 1 ;
1311
+ input_value -> quantization .dynamic_params_size = xnn_tensor_get_dynamic_quant_param_size (input_value );
1310
1312
}
1313
+ return xnn_define_fully_connected (subgraph , output_min , output_max ,
1314
+ input_id , filter_id , bias_id , output_id , /*flags=*/ 0 );
1311
1315
}
1312
1316
}
1313
1317
struct xnn_node * node = xnn_subgraph_new_node (subgraph );
0 commit comments