From 595d0e7bfd5cecf2c5b184b0a44e342f6b9756aa Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Fri, 17 Jan 2025 13:41:58 -0800 Subject: [PATCH] type checking for xnnpack Summary: Fix pyre errors in difftrain D68003735 Reviewed By: Gasoonjia Differential Revision: D68350539 --- backends/xnnpack/_passes/tag_implicit_q_dq_pass.py | 4 +++- backends/xnnpack/operators/node_visitor.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py b/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py index 3c6345e28a2..dc488081025 100644 --- a/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py +++ b/backends/xnnpack/_passes/tag_implicit_q_dq_pass.py @@ -90,7 +90,9 @@ def is_supported_quant_op(self, node: torch.fx.Node) -> bool: # Weight and Input should both be quantized if op_name == exir_ops.edge.aten.convolution.default.name(): - return is_dequant(node.args[1]) + if isinstance(node.args[1], torch.fx.Node): + # pyre-ignore Incompatible parameter type [6]: is_dequant expects Node + return is_dequant(node.args[1]) return op_name in SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index e0871089ec8..0a825a94bef 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -560,11 +560,11 @@ def get_serialized_buffer_index( # which should be used for depthwise/transpose convolution weights for XNNPACK shape = const_val.shape const_val = const_val.reshape( - (groups, const_val.shape[0] // groups) + const_val.shape[1:] + (groups, const_val.shape[0] // groups) + tuple(const_val.shape[1:]) ) const_val = const_val.permute((0, 2, 1) + tuple(range(3, const_val.dim()))) const_val = const_val.reshape( - (shape[1] * groups, shape[0] // groups) + shape[2:] + (shape[1] * groups, shape[0] // groups) + tuple(shape[2:]) ).contiguous() if convert_to_nhwc: