Skip to content

Commit bff9ecc

Browse files
fanyunqianTracin
fanyunqian
authored andcommitted
[Fix] update vitis quantizer
1 parent 3c7bfcf commit bff9ecc

File tree

1 file changed

+3
-63
lines changed

1 file changed

+3
-63
lines changed

mqbench/custom_quantizer/vitis_quantizer.py

+3-63
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import mqbench.nn.intrinsic as qnni
1010
import mqbench.nn.intrinsic.qat as qnniqat
11+
from mqbench.utils import getitem2node
1112
from mqbench.utils.logger import logger
1213
from mqbench.utils.registry import register_model_quantizer
1314
from mqbench.prepare_by_platform import BackendType
@@ -57,7 +58,6 @@ def module_type_to_quant_output(self) -> tuple:
5758
# Linear
5859
torch.nn.qat.modules.linear.Linear,
5960
# Pooling
60-
torch.nn.modules.pooling.MaxPool2d,
6161
torch.nn.modules.pooling.AvgPool2d,
6262
torch.nn.modules.pooling.AdaptiveAvgPool2d,
6363
# BN
@@ -77,20 +77,12 @@ def function_type_to_quant_output(self) -> List:
7777
torch.cat,
7878
torch.nn.functional.adaptive_avg_pool2d,
7979
torch.nn.functional.avg_pool2d,
80-
torch.nn.functional.max_pool2d,
8180
torch.nn.functional.relu,
8281
torch.nn.functional.conv2d,
8382
torch.nn.functional.linear,
8483
torch.nn.functional.interpolate,
8584
]
8685

87-
@property
88-
def function_type_not_to_quant_alone(self) -> List:
89-
return [
90-
operator.getitem,
91-
]
92-
93-
9486
def prepare(self, model: GraphModule, qconfig):
9587
model = _fuse_fx(model, self.extra_fuse_dict)
9688
model = self._weight_quant(model, qconfig)
@@ -106,7 +98,7 @@ def _find_act_quants(self, model: GraphModule) -> List:
10698
if hasattr(self, 'node_need_to_quantize_output'):
10799
return self.node_need_to_quantize_output
108100
self.node_need_to_quantize_output = []
109-
getitem2node = self._get_items_for_the_graph(model)
101+
g2node = getitem2node(model)
110102
for node in nodes:
111103
if (node.op == "call_module" and node.target in self.exclude_module_name) or \
112104
((node.op == 'call_function' or node.op == 'call_method') and
@@ -116,61 +108,10 @@ def _find_act_quants(self, model: GraphModule) -> List:
116108
if (node.op == "call_module" and isinstance(modules[node.target], self.module_type_to_quant_output)) or \
117109
((node.op == 'call_function' or node.op == 'call_method') and
118110
node.target in self.function_type_to_quant_output):
119-
input_node_list = self._flatten_args(node.args)
120-
for _node in input_node_list:
121-
if isinstance(_node, torch.fx.node.Node):
122-
if self._is_implicit_merge(modules, (node, _node)):
123-
logger.info("Implicit merge: {} + {}".format(_node.name, node.name))
124-
continue
125-
if (_node.op == 'placeholder') or \
126-
((_node.op == 'call_function' or _node.op == 'call_method') and
127-
_node.target in self.function_type_not_to_quant_alone):
128-
if _node not in getitem2node:
129-
self.node_need_to_quantize_output.append(_node)
130-
logger.info(f'Add {_node.name}/{_node.target}/{_node.op} to input quantize')
131111
self.node_need_to_quantize_output.append(node)
132112
logger.info(f'Add {node.name} to output quantize')
133113
return self.node_need_to_quantize_output
134114

135-
def _get_items_for_the_graph(self, model: GraphModule) -> dict:
136-
def _update_getitem_path(getitem_args_dict):
137-
for node in getitem_args_dict:
138-
args_list = getitem_args_dict[node]
139-
while args_list[0] in getitem_args_dict:
140-
args_list = getitem_args_dict[args_list[0]] + args_list[1:]
141-
getitem_args_dict[node] = args_list
142-
return getitem_args_dict
143-
144-
def _getitem_from_args(args, original_args_dict):
145-
ret = original_args_dict
146-
for a in args:
147-
try:
148-
ret = ret[a]
149-
except (IndexError, KeyError):
150-
return {}
151-
return ret
152-
nodes = list(model.graph.nodes)
153-
# the getitem's call graph
154-
getitem_args_dict = {}
155-
# the dict used in the model
156-
original_key_dict = {}
157-
getitem2node = {}
158-
for node in nodes:
159-
# update the getitems
160-
if node.target in [operator.getitem]:
161-
getitem_args_dict[node] = list(node.args)
162-
getitem_args_dict = _update_getitem_path(getitem_args_dict)
163-
elif node.target == 'update':
164-
if node.args[0] not in original_key_dict:
165-
original_key_dict[node.args[0]] = {}
166-
original_key_dict[node.args[0]].update(node.args[1])
167-
for node in getitem_args_dict:
168-
val = _getitem_from_args(getitem_args_dict[node], original_key_dict)
169-
if isinstance(val, torch.fx.node.Node):
170-
getitem2node[node] = val
171-
return getitem2node
172-
173-
174115
def _find_input_quants(self, model) -> List:
175116
node_need_to_quantize_weight = []
176117
nodes = list(model.graph.nodes)
@@ -179,7 +120,6 @@ def _find_input_quants(self, model) -> List:
179120
node_need_to_quantize_weight.append(list(node.users)[0])
180121
return node_need_to_quantize_weight
181122

182-
183123
def _find_weight_quants(self, model) -> List:
184124
node_need_to_quantize_weight = []
185125
nodes = list(model.graph.nodes)
@@ -219,4 +159,4 @@ def _set_quant_type(self, model: GraphModule) -> NoReturn:
219159
next_op = module_dict[node.target]
220160
if isinstance(next_op, TqtFakeQuantize):
221161
next_op.set_quant_type('input')
222-
logger.info(f'{node.target} has been set to quant type <input>')
162+
logger.info(f'{node.target} has been set to quant type <input>')

0 commit comments

Comments
 (0)