@@ -42,20 +42,37 @@ def convert_qnniqat_linearbn(model, fused_node):
42
42
43
43
@register_convert_function (qnniqat .ConvFreezebn2d )
44
44
@register_convert_function (nniqat .ConvBn2d )
45
+ @register_convert_function (nniqat .ConvBn3d )
45
46
def convert_nniqat_convbn (model , fused_node ):
47
+ """nniqat.ConvBn2d ----> nn.Conv2d ----> nniqat.Conv2d
48
+ """
49
+ fused_module_class_map = {
50
+ qnniqat .ConvFreezebn2d : torch .nn .Conv2d ,
51
+ qnniqat .ConvFreezebnReLU2d : torch .nn .Conv2d ,
52
+ nniqat .ConvBn2d : torch .nn .Conv2d ,
53
+ nniqat .ConvBnReLU2d : torch .nn .Conv2d ,
54
+ nniqat .ConvBn3d : torch .nn .Conv3d ,
55
+ nniqat .ConvBnReLU3d : torch .nn .Conv3d ,
56
+ }
57
+ fused_qat_module_class_map = {
58
+ torch .nn .Conv2d : torch .nn .qat .Conv2d ,
59
+ torch .nn .Conv3d : torch .nn .qat .Conv3d ,
60
+ }
46
61
modules = dict (model .named_modules ())
47
62
fused_module = modules [fused_node .target ]
48
63
# Create a Conv2d from FusedModule.
49
- conv = torch .nn .Conv2d (fused_module .in_channels , fused_module .out_channels , fused_module .kernel_size ,
50
- fused_module .stride , fused_module .padding , fused_module .dilation ,
51
- fused_module .groups , fused_module .bias is not None , fused_module .padding_mode )
64
+ conv = fused_module_class_map [type (fused_module )](fused_module .in_channels , fused_module .out_channels ,
65
+ fused_module .kernel_size , fused_module .stride ,
66
+ fused_module .padding , fused_module .dilation ,
67
+ fused_module .groups , fused_module .bias is not None ,
68
+ fused_module .padding_mode )
52
69
conv .weight = fused_module .weight
53
70
if fused_module .bias is not None :
54
71
conv .bias = fused_module .bias
55
72
fused_conv = fuse_conv_bn_eval (conv .eval (), fused_module .bn )
56
73
# We need nn.qat.conv here to export weight quantize node.
57
74
fused_conv .qconfig = fused_module .qconfig
58
- fused_conv = torch . nn . qat . Conv2d .from_float (fused_conv )
75
+ fused_conv = fused_qat_module_class_map [ type ( conv )] .from_float (fused_conv )
59
76
# Attach weight fake quantize params.
60
77
fused_conv .weight_fake_quant = fused_module .weight_fake_quant
61
78
conv_parent_name , conv_name = _parent_name (fused_node .target )
@@ -64,7 +81,8 @@ def convert_nniqat_convbn(model, fused_node):
64
81
65
82
@register_convert_function (qnniqat .ConvFreezebnReLU2d )
66
83
@register_convert_function (nniqat .ConvBnReLU2d )
67
- def convert_nniqat_convbnrelu (model , fused_node ):
84
+ @register_convert_function (nniqat .ConvBnReLU3d )
85
+ def convert_nniqat_convbnrelu (model , fused_node ):
68
86
convert_nniqat_convbn (model , fused_node )
69
87
modules = dict (model .named_modules ())
70
88
fused_module = modules [fused_node .target ]
@@ -196,6 +214,9 @@ def convert_qnniqat_deconvbnrelu(model, fused_node):
196
214
197
215
@register_convert_function (qnniqat .ConvBn2d )
198
216
def convert_qnniqat_convbn (model , fused_node ):
217
+ """mqbench.nn.intrinsic.qat module add bias quant.
218
+ That is the difference between torch.nn.intrinsic.qat module.
219
+ """
199
220
modules = dict (model .named_modules ())
200
221
fused_module = modules [fused_node .target ]
201
222
# Create a Conv2d from FusedModule.
@@ -222,6 +243,9 @@ def convert_qnniqat_convbn(model, fused_node):
222
243
223
244
@register_convert_function (qnniqat .ConvBnReLU2d )
224
245
def convert_qnniqat_convbnrelu (model , fused_node ):
246
+ """mqbench.nn.intrinsic.qat module add bias quant.
247
+ That is the difference between torch.nn.intrinsic.qat module.
248
+ """
225
249
convert_qnniqat_convbn (model , fused_node )
226
250
modules = dict (model .named_modules ())
227
251
fused_module = modules [fused_node .target ]
0 commit comments