-
Notifications
You must be signed in to change notification settings - Fork 29
Support Add, Relu in InsertQuantizeOpOnDtypeMismatch #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,10 +29,12 @@ | |
| from tico.utils.trace_decorators import trace_graph_diff_on_pass | ||
| from tico.utils.utils import quant_min_max, set_new_meta_val | ||
| from tico.utils.validate_args_kwargs import ( | ||
| AddTensorArgs, | ||
| BmmArgs, | ||
| LinearArgs, | ||
| MulTensorArgs, | ||
| PermuteArgs, | ||
| ReluArgs, | ||
| ReshapeArgs, | ||
| ) | ||
|
|
||
|
|
@@ -77,7 +79,7 @@ def _u8_to_i16(qparam: QuantParam) -> QuantParam: | |
| max_ = u8_scale * (255 - u8_zerop) | ||
| min_ = u8_scale * (-u8_zerop) | ||
|
|
||
| abs_max = max([max_, min_], key=abs) | ||
| abs_max = abs(max([max_, min_], key=abs)) | ||
| s16_scale = abs_max / 32767 | ||
| s16_zerop = 0 | ||
|
|
||
|
|
@@ -210,6 +212,42 @@ def _insert_quantize_op_after(node): | |
| logger.debug( | ||
| f"quantize_per_tensor.default is inserted after {node.name}." | ||
| ) | ||
| else: | ||
| raise NotYetSupportedError( | ||
| f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}" | ||
| ) | ||
|
|
||
| elif node.target == torch.ops.aten.add.Tensor: | ||
| add_args = AddTensorArgs(*node.args, **node.kwargs) | ||
| x = add_args.input | ||
| y = add_args.other | ||
|
|
||
| if not isinstance(x, torch.fx.Node): | ||
| continue | ||
| if not isinstance(y, torch.fx.Node): | ||
| continue | ||
|
|
||
| if QPARAM_KEY not in x.meta: | ||
| continue | ||
| if QPARAM_KEY not in y.meta: | ||
| continue | ||
| if QPARAM_KEY not in node.meta: | ||
| continue | ||
|
|
||
| if qparam_dtype(x) == qparam_dtype(node): | ||
| continue | ||
|
|
||
| if qparam_dtype(x) != qparam_dtype(y): | ||
| continue | ||
|
|
||
| if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8": | ||
| quantize = _insert_quantize_op_after(node) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it okay without checking qparam_dtype(y)? What if one of x and y get folded in another pass and dtype gets differ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds reasonable. I will add a check for that case. |
||
| quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY]) | ||
| node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY]) | ||
| logger.debug( | ||
| f"quantize_per_tensor.default is inserted after {node.name}." | ||
| ) | ||
| else: | ||
| raise NotYetSupportedError("Unsupported dtype") | ||
|
|
||
|
|
@@ -335,6 +373,30 @@ def _insert_quantize_op_after(node): | |
| else: | ||
| raise NotYetSupportedError("Unsupported dtype") | ||
|
|
||
| elif node.target == torch.ops.aten.relu.default: | ||
| relu_args = ReluArgs(*node.args, **node.kwargs) | ||
| inp = relu_args.input | ||
|
|
||
| if QPARAM_KEY not in inp.meta: | ||
| continue | ||
|
|
||
| if QPARAM_KEY not in node.meta: | ||
| continue | ||
|
|
||
| if qparam_dtype(inp) == qparam_dtype(node): | ||
| continue | ||
|
|
||
| if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8": | ||
| quantize = _insert_quantize_op_after(node) | ||
|
|
||
| quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY]) | ||
| node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY]) | ||
| logger.debug( | ||
| f"quantize_per_tensor.default is inserted after {node.name}." | ||
| ) | ||
| else: | ||
| raise NotYetSupportedError("Unsupported dtype") | ||
|
|
||
| # TODO Support more ops. | ||
|
|
||
| graph.eliminate_dead_code() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line fixes a bug. Without this change, scale can be a negative value.