Skip to content

Commit 14d3486

Browse files
lucylqZonglin Peng
authored and
Zonglin Peng
committed
type checking for xnnpack
Differential Revision: D68350539 Pull Request resolved: pytorch#7741
1 parent daf57c2 commit 14d3486

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

backends/xnnpack/_passes/tag_implicit_q_dq_pass.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def is_supported_quant_op(self, node: torch.fx.Node) -> bool:
9090

9191
# Weight and Input should both be quantized
9292
if op_name == exir_ops.edge.aten.convolution.default.name():
93-
return is_dequant(node.args[1])
93+
if isinstance(node.args[1], torch.fx.Node):
94+
# pyre-ignore Incompatible parameter type [6]: is_dequant expects Node
95+
return is_dequant(node.args[1])
9496

9597
return op_name in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET
9698

backends/xnnpack/operators/node_visitor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -560,11 +560,11 @@ def get_serialized_buffer_index(
560560
# which should be used for depthwise/transpose convolution weights for XNNPACK
561561
shape = const_val.shape
562562
const_val = const_val.reshape(
563-
(groups, const_val.shape[0] // groups) + const_val.shape[1:]
563+
(groups, const_val.shape[0] // groups) + tuple(const_val.shape[1:])
564564
)
565565
const_val = const_val.permute((0, 2, 1) + tuple(range(3, const_val.dim())))
566566
const_val = const_val.reshape(
567-
(shape[1] * groups, shape[0] // groups) + shape[2:]
567+
(shape[1] * groups, shape[0] // groups) + tuple(shape[2:])
568568
).contiguous()
569569

570570
if convert_to_nhwc:

0 commit comments

Comments
 (0)