Skip to content

Commit 535d750

Browse files
authored
Support Add, Relu in InsertQuantizeOpOnDtypeMismatch (#166)
* Support Add, Relu in InsertQuantizeOpOnDtypeMismatch This adds Add, Relu in InsertQuantizeOpOnDtypeMismatch. TICO-DCO-1.0-Signed-off-by: Hyukjin Jeong <hj1.jeong@samsung.com>
1 parent bede9c4 commit 535d750

2 files changed

Lines changed: 87 additions & 1 deletion

File tree

test/quantization/pass/test_insert_quantize_on_dtype_mismatch.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
2222
from tico.serialize.quant_param import QPARAM_KEY, QuantParam
2323

24+
from test.modules.op.add import SimpleAdd
2425
from test.modules.op.bmm import SimpleBatchMatMul
2526
from test.modules.op.linear import SimpleLinear
2627
from test.modules.op.mul import SimpleMulWithTensor
2728
from test.modules.op.permute import SimplePermute
29+
from test.modules.op.relu import SimpleRelu
2830
from test.modules.op.reshape import ReshapeTorchAPI
2931

3032

@@ -212,3 +214,25 @@ def test_i16o8(self):
212214
desired_dtype="uint8",
213215
)
214216
self.run_test()
217+
218+
219+
class ReluTest(InsertQuantizeOnDtypeMismatchTest):
220+
def test_i16o8(self):
221+
self.setup(
222+
SimpleRelu(),
223+
torch.ops.aten.relu.default,
224+
input_dtype="int16",
225+
desired_dtype="int16",
226+
)
227+
self.run_test()
228+
229+
230+
class AddTest(InsertQuantizeOnDtypeMismatchTest):
231+
def test_i16o8(self):
232+
self.setup(
233+
SimpleAdd(),
234+
torch.ops.aten.add.Tensor,
235+
input_dtype="int16",
236+
desired_dtype="int16",
237+
)
238+
self.run_test()

tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
from tico.utils.trace_decorators import trace_graph_diff_on_pass
3030
from tico.utils.utils import quant_min_max, set_new_meta_val
3131
from tico.utils.validate_args_kwargs import (
32+
AddTensorArgs,
3233
BmmArgs,
3334
LinearArgs,
3435
MulTensorArgs,
3536
PermuteArgs,
37+
ReluArgs,
3638
ReshapeArgs,
3739
)
3840

@@ -77,7 +79,7 @@ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
7779
max_ = u8_scale * (255 - u8_zerop)
7880
min_ = u8_scale * (-u8_zerop)
7981

80-
abs_max = max([max_, min_], key=abs)
82+
abs_max = abs(max([max_, min_], key=abs))
8183
s16_scale = abs_max / 32767
8284
s16_zerop = 0
8385

@@ -210,6 +212,42 @@ def _insert_quantize_op_after(node):
210212
logger.debug(
211213
f"quantize_per_tensor.default is inserted after {node.name}."
212214
)
215+
else:
216+
raise NotYetSupportedError(
217+
f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
218+
)
219+
220+
elif node.target == torch.ops.aten.add.Tensor:
221+
add_args = AddTensorArgs(*node.args, **node.kwargs)
222+
x = add_args.input
223+
y = add_args.other
224+
225+
if not isinstance(x, torch.fx.Node):
226+
continue
227+
if not isinstance(y, torch.fx.Node):
228+
continue
229+
230+
if QPARAM_KEY not in x.meta:
231+
continue
232+
if QPARAM_KEY not in y.meta:
233+
continue
234+
if QPARAM_KEY not in node.meta:
235+
continue
236+
237+
if qparam_dtype(x) == qparam_dtype(node):
238+
continue
239+
240+
if qparam_dtype(x) != qparam_dtype(y):
241+
continue
242+
243+
if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
244+
quantize = _insert_quantize_op_after(node)
245+
246+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
247+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
248+
logger.debug(
249+
f"quantize_per_tensor.default is inserted after {node.name}."
250+
)
213251
else:
214252
raise NotYetSupportedError("Unsupported dtype")
215253

@@ -335,6 +373,30 @@ def _insert_quantize_op_after(node):
335373
else:
336374
raise NotYetSupportedError("Unsupported dtype")
337375

376+
elif node.target == torch.ops.aten.relu.default:
377+
relu_args = ReluArgs(*node.args, **node.kwargs)
378+
inp = relu_args.input
379+
380+
if QPARAM_KEY not in inp.meta:
381+
continue
382+
383+
if QPARAM_KEY not in node.meta:
384+
continue
385+
386+
if qparam_dtype(inp) == qparam_dtype(node):
387+
continue
388+
389+
if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
390+
quantize = _insert_quantize_op_after(node)
391+
392+
quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
393+
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
394+
logger.debug(
395+
f"quantize_per_tensor.default is inserted after {node.name}."
396+
)
397+
else:
398+
raise NotYetSupportedError("Unsupported dtype")
399+
338400
# TODO Support more ops.
339401

340402
graph.eliminate_dead_code()

0 commit comments

Comments
 (0)