Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/quantization/pass/test_insert_quantize_on_dtype_mismatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
from tico.serialize.quant_param import QPARAM_KEY, QuantParam

from test.modules.op.add import SimpleAdd
from test.modules.op.bmm import SimpleBatchMatMul
from test.modules.op.linear import SimpleLinear
from test.modules.op.mul import SimpleMulWithTensor
from test.modules.op.permute import SimplePermute
from test.modules.op.relu import SimpleRelu
from test.modules.op.reshape import ReshapeTorchAPI


Expand Down Expand Up @@ -212,3 +214,25 @@ def test_i16o8(self):
desired_dtype="uint8",
)
self.run_test()


class ReluTest(InsertQuantizeOnDtypeMismatchTest):
def test_i16o8(self):
self.setup(
SimpleRelu(),
torch.ops.aten.relu.default,
input_dtype="int16",
desired_dtype="int16",
)
self.run_test()


class AddTest(InsertQuantizeOnDtypeMismatchTest):
def test_i16o8(self):
self.setup(
SimpleAdd(),
torch.ops.aten.add.Tensor,
input_dtype="int16",
desired_dtype="int16",
)
self.run_test()
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
from tico.utils.trace_decorators import trace_graph_diff_on_pass
from tico.utils.utils import quant_min_max, set_new_meta_val
from tico.utils.validate_args_kwargs import (
AddTensorArgs,
BmmArgs,
LinearArgs,
MulTensorArgs,
PermuteArgs,
ReluArgs,
ReshapeArgs,
)

Expand Down Expand Up @@ -77,7 +79,7 @@ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
max_ = u8_scale * (255 - u8_zerop)
min_ = u8_scale * (-u8_zerop)

abs_max = max([max_, min_], key=abs)
abs_max = abs(max([max_, min_], key=abs))

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line fixes a bug. Without this change, scale can be a negative value.

s16_scale = abs_max / 32767
s16_zerop = 0

Expand Down Expand Up @@ -210,6 +212,42 @@ def _insert_quantize_op_after(node):
logger.debug(
f"quantize_per_tensor.default is inserted after {node.name}."
)
else:
raise NotYetSupportedError(
f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
)

elif node.target == torch.ops.aten.add.Tensor:
add_args = AddTensorArgs(*node.args, **node.kwargs)
x = add_args.input
y = add_args.other

if not isinstance(x, torch.fx.Node):
continue
if not isinstance(y, torch.fx.Node):
continue

if QPARAM_KEY not in x.meta:
continue
if QPARAM_KEY not in y.meta:
continue
if QPARAM_KEY not in node.meta:
continue

if qparam_dtype(x) == qparam_dtype(node):
continue

if qparam_dtype(x) != qparam_dtype(y):
continue

if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
quantize = _insert_quantize_op_after(node)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it okay without checking qparam_dtype(y)? What if one of x and y get folded in another pass and dtype gets differ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable. I will add a check for that case.

quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
logger.debug(
f"quantize_per_tensor.default is inserted after {node.name}."
)
else:
raise NotYetSupportedError("Unsupported dtype")

Expand Down Expand Up @@ -335,6 +373,30 @@ def _insert_quantize_op_after(node):
else:
raise NotYetSupportedError("Unsupported dtype")

elif node.target == torch.ops.aten.relu.default:
relu_args = ReluArgs(*node.args, **node.kwargs)
inp = relu_args.input

if QPARAM_KEY not in inp.meta:
continue

if QPARAM_KEY not in node.meta:
continue

if qparam_dtype(inp) == qparam_dtype(node):
continue

if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
quantize = _insert_quantize_op_after(node)

quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
logger.debug(
f"quantize_per_tensor.default is inserted after {node.name}."
)
else:
raise NotYetSupportedError("Unsupported dtype")

# TODO Support more ops.

graph.eliminate_dead_code()
Expand Down