8
8
9
9
import mqbench .nn .intrinsic as qnni
10
10
import mqbench .nn .intrinsic .qat as qnniqat
11
+ from mqbench .utils import getitem2node
11
12
from mqbench .utils .logger import logger
12
13
from mqbench .utils .registry import register_model_quantizer
13
14
from mqbench .prepare_by_platform import BackendType
@@ -57,7 +58,6 @@ def module_type_to_quant_output(self) -> tuple:
57
58
# Linear
58
59
torch .nn .qat .modules .linear .Linear ,
59
60
# Pooling
60
- torch .nn .modules .pooling .MaxPool2d ,
61
61
torch .nn .modules .pooling .AvgPool2d ,
62
62
torch .nn .modules .pooling .AdaptiveAvgPool2d ,
63
63
# BN
@@ -77,20 +77,12 @@ def function_type_to_quant_output(self) -> List:
77
77
torch .cat ,
78
78
torch .nn .functional .adaptive_avg_pool2d ,
79
79
torch .nn .functional .avg_pool2d ,
80
- torch .nn .functional .max_pool2d ,
81
80
torch .nn .functional .relu ,
82
81
torch .nn .functional .conv2d ,
83
82
torch .nn .functional .linear ,
84
83
torch .nn .functional .interpolate ,
85
84
]
86
85
87
- @property
88
- def function_type_not_to_quant_alone (self ) -> List :
89
- return [
90
- operator .getitem ,
91
- ]
92
-
93
-
94
86
def prepare (self , model : GraphModule , qconfig ):
95
87
model = _fuse_fx (model , self .extra_fuse_dict )
96
88
model = self ._weight_quant (model , qconfig )
@@ -106,7 +98,7 @@ def _find_act_quants(self, model: GraphModule) -> List:
106
98
if hasattr (self , 'node_need_to_quantize_output' ):
107
99
return self .node_need_to_quantize_output
108
100
self .node_need_to_quantize_output = []
109
- getitem2node = self . _get_items_for_the_graph (model )
101
+ g2node = getitem2node (model )
110
102
for node in nodes :
111
103
if (node .op == "call_module" and node .target in self .exclude_module_name ) or \
112
104
((node .op == 'call_function' or node .op == 'call_method' ) and
@@ -116,61 +108,10 @@ def _find_act_quants(self, model: GraphModule) -> List:
116
108
if (node .op == "call_module" and isinstance (modules [node .target ], self .module_type_to_quant_output )) or \
117
109
((node .op == 'call_function' or node .op == 'call_method' ) and
118
110
node .target in self .function_type_to_quant_output ):
119
- input_node_list = self ._flatten_args (node .args )
120
- for _node in input_node_list :
121
- if isinstance (_node , torch .fx .node .Node ):
122
- if self ._is_implicit_merge (modules , (node , _node )):
123
- logger .info ("Implicit merge: {} + {}" .format (_node .name , node .name ))
124
- continue
125
- if (_node .op == 'placeholder' ) or \
126
- ((_node .op == 'call_function' or _node .op == 'call_method' ) and
127
- _node .target in self .function_type_not_to_quant_alone ):
128
- if _node not in getitem2node :
129
- self .node_need_to_quantize_output .append (_node )
130
- logger .info (f'Add { _node .name } /{ _node .target } /{ _node .op } to input quantize' )
131
111
self .node_need_to_quantize_output .append (node )
132
112
logger .info (f'Add { node .name } to output quantize' )
133
113
return self .node_need_to_quantize_output
134
114
135
- def _get_items_for_the_graph (self , model : GraphModule ) -> dict :
136
- def _update_getitem_path (getitem_args_dict ):
137
- for node in getitem_args_dict :
138
- args_list = getitem_args_dict [node ]
139
- while args_list [0 ] in getitem_args_dict :
140
- args_list = getitem_args_dict [args_list [0 ]] + args_list [1 :]
141
- getitem_args_dict [node ] = args_list
142
- return getitem_args_dict
143
-
144
- def _getitem_from_args (args , original_args_dict ):
145
- ret = original_args_dict
146
- for a in args :
147
- try :
148
- ret = ret [a ]
149
- except (IndexError , KeyError ):
150
- return {}
151
- return ret
152
- nodes = list (model .graph .nodes )
153
- # the getitem's call graph
154
- getitem_args_dict = {}
155
- # the dict used in the model
156
- original_key_dict = {}
157
- getitem2node = {}
158
- for node in nodes :
159
- # update the getitems
160
- if node .target in [operator .getitem ]:
161
- getitem_args_dict [node ] = list (node .args )
162
- getitem_args_dict = _update_getitem_path (getitem_args_dict )
163
- elif node .target == 'update' :
164
- if node .args [0 ] not in original_key_dict :
165
- original_key_dict [node .args [0 ]] = {}
166
- original_key_dict [node .args [0 ]].update (node .args [1 ])
167
- for node in getitem_args_dict :
168
- val = _getitem_from_args (getitem_args_dict [node ], original_key_dict )
169
- if isinstance (val , torch .fx .node .Node ):
170
- getitem2node [node ] = val
171
- return getitem2node
172
-
173
-
174
115
def _find_input_quants (self , model ) -> List :
175
116
node_need_to_quantize_weight = []
176
117
nodes = list (model .graph .nodes )
@@ -179,7 +120,6 @@ def _find_input_quants(self, model) -> List:
179
120
node_need_to_quantize_weight .append (list (node .users )[0 ])
180
121
return node_need_to_quantize_weight
181
122
182
-
183
123
def _find_weight_quants (self , model ) -> List :
184
124
node_need_to_quantize_weight = []
185
125
nodes = list (model .graph .nodes )
@@ -219,4 +159,4 @@ def _set_quant_type(self, model: GraphModule) -> NoReturn:
219
159
next_op = module_dict [node .target ]
220
160
if isinstance (next_op , TqtFakeQuantize ):
221
161
next_op .set_quant_type ('input' )
222
- logger .info (f'{ node .target } has been set to quant type <input>' )
162
+ logger .info (f'{ node .target } has been set to quant type <input>' )
0 commit comments