Skip to content

Commit 36bdc16

Browse files
committed
Add view_copy/static_reshape support to XNNPACK delegate
1 parent e78ed83 commit 36bdc16

11 files changed

+588
-12
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11-
from executorch.backends.xnnpack.utils.utils import is_param_node
11+
from executorch.backends.xnnpack.utils.utils import get_input_node, is_param_node
1212
from executorch.exir.dialects._ops import ops as exir_ops
1313
from executorch.exir.pass_base import PassResult
1414

@@ -77,6 +77,21 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
7777
# is done
7878
PARTNER_NODE = "XNN_CHANNELS_LAST_TAGGED_RESHAPE_PARTNER_NODE"
7979

80+
def is_view_dim_order_invariant(self, node: torch.fx.Node) -> bool:
81+
# View must be done in NCHW dim order if channel or batch is changed,
82+
# or if rank is not 4.
83+
in_shape = get_input_node(node, 0).meta["val"].shape
84+
out_shape = node.meta["val"].shape
85+
86+
if len(in_shape) != 4 or len(out_shape) != 4:
87+
return False
88+
89+
# Are batch and channel modified? If so, return false.
90+
if in_shape[0] != out_shape[0] or in_shape[1] != out_shape[1]:
91+
return False
92+
93+
return True
94+
8095
def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
8196
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = True
8297

@@ -93,6 +108,13 @@ def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
93108
return node.target in self.memory_sensitive_ops_nhwc
94109

95110
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
111+
# Views depend on whether batch or channel are modified.
112+
if (
113+
node.target == exir_ops.edge.aten.view_copy.default
114+
and not self.is_view_dim_order_invariant(node)
115+
):
116+
return True
117+
96118
return node.target in self.memory_sensitive_ops_nchw
97119

98120
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:

backends/xnnpack/operators/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,5 @@
4949
op_static_resize_bilinear_2d,
5050
op_sub,
5151
op_to_copy,
52+
op_view_copy,
5253
)

backends/xnnpack/operators/op_skip_ops.py

-10
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,6 @@ class OpTCopyDefault(OpSkipOps):
7777
target = "aten.t_copy.default"
7878

7979

80-
@register_node_visitor
81-
class OpViewCopyDefault(OpSkipOps):
82-
"""
83-
currently, do nothing if node is view_copy.default
84-
need to handle this later on, currently view it as one of skip ops
85-
"""
86-
87-
target = "aten.view_copy.default"
88-
89-
9080
@register_node_visitor
9181
class OpSymSizeInt(OpSkipOps):
9282
"""
+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Dict
10+
11+
import torch
12+
from executorch.backends.xnnpack.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
17+
XNNGraph,
18+
XNNStaticReshape,
19+
XNode,
20+
)
21+
from executorch.backends.xnnpack.utils.utils import (
22+
check_or_raise,
23+
get_input_node,
24+
PERM_NCHW_TO_NHWC,
25+
)
26+
27+
28+
@register_node_visitor
29+
class ViewCopyVisitor(NodeVisitor):
30+
target = "aten.view_copy.default"
31+
32+
def __init__(self, *args) -> None:
33+
super().__init__(*args)
34+
35+
def define_node(
36+
self,
37+
node: torch.fx.Node,
38+
xnn_graph: XNNGraph,
39+
vals_to_ids: Dict[torch.fx.Node, int],
40+
debug_handle: int,
41+
) -> None:
42+
self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
43+
44+
input_node = get_input_node(node, 0)
45+
46+
# input
47+
input_id = vals_to_ids[input_node]
48+
49+
# output
50+
output_id = vals_to_ids[node]
51+
52+
# input shape
53+
check_or_raise(
54+
"val" in input_node.meta,
55+
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
56+
)
57+
58+
# output shape
59+
check_or_raise(
60+
"val" in node.meta,
61+
"Missing val in tensor metadata for input when serializing XNNStaticReshape",
62+
)
63+
64+
new_shape = node.args[1]
65+
check_or_raise(
66+
all(isinstance(d, int) for d in new_shape),
67+
"Symbolic reshape parameter is not supported in XNNStaticReshape",
68+
)
69+
70+
# PyTorch uses -1 for inferred dims, whereas XNNPACK expects 0.
71+
new_shape = tuple(d if d != -1 else 0 for d in new_shape)
72+
73+
# Handle NCHW dim order - if this op is in NCHW order, we need to permute the
74+
# view shape correspondingly.
75+
if "XNN_NHWC_NODE" in node.meta:
76+
check_or_raise(len(new_shape) == 4, "Invalid NCHW shape")
77+
new_shape = [new_shape[PERM_NCHW_TO_NHWC[n]] for n in range(4)]
78+
79+
num_dynamic_dims = sum(1 for d in new_shape if d == 0)
80+
81+
check_or_raise(
82+
num_dynamic_dims <= 1,
83+
"XNNPACK reshape only supports 1 dynamic dimension.",
84+
)
85+
86+
ser_node = XNode(
87+
xnode_union=XNNStaticReshape(
88+
num_dims=len(new_shape),
89+
new_shape=new_shape,
90+
input_id=input_id,
91+
output_id=output_id,
92+
flags=0,
93+
),
94+
debug_handle=debug_handle,
95+
)
96+
xnn_graph.xnodes.append(ser_node)

backends/xnnpack/partition/config/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
SquareRootConfig,
4848
SubConfig,
4949
UpsampleBilinear2dConfig,
50+
ViewCopyConfig,
5051
)
5152
from executorch.backends.xnnpack.partition.config.node_configs import (
5253
BatchNormConfig,
@@ -100,6 +101,7 @@
100101
SquareRootConfig,
101102
SubConfig,
102103
UpsampleBilinear2dConfig,
104+
ViewCopyConfig,
103105
# Quant/Dequant Op Configs
104106
QuantizedPerTensorConfig,
105107
DeQuantizedPerTensorConfig,

backends/xnnpack/partition/config/generic_node_configs.py

+26
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,32 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
336336
return torch.ops.aten.upsample_bilinear2d.vec
337337

338338

339+
class ViewCopyConfig(GenericNodePartitionerConfig):
340+
target_name = "view_copy.default"
341+
342+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
343+
return [ConfigPrecisionType.FP32]
344+
345+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
346+
"""
347+
XNNPACK's static_reshape only supports 1 dynamic dimension
348+
"""
349+
if not self.check_common_constraints(node, ep):
350+
return False
351+
352+
new_shape = node.args[1]
353+
if not all(isinstance(n, int) for n in new_shape):
354+
why(node, reason="symbolic reshape is not supported")
355+
return False
356+
357+
dynamic_dim_count = sum(1 for d in new_shape if d == -1)
358+
if dynamic_dim_count > 1:
359+
why(node, reason="only a single dynamic dimension is supported")
360+
return False
361+
362+
return True
363+
364+
339365
class FloorConfig(GenericNodePartitionerConfig):
340366
target_name = "floor.default"
341367

0 commit comments

Comments
 (0)