Skip to content

Commit d9c50c9

Browse files
committed
Support (un)squeeze in XNN delegate via conversion to view
1 parent 36bdc16 commit d9c50c9

7 files changed

+457
-0
lines changed

backends/xnnpack/_passes/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from executorch.backends.xnnpack._passes.conv1d_unsqueeze_pass import (
1313
Conv1dUnsqueezePass,
1414
)
15+
from executorch.backends.xnnpack._passes.convert_squeeze_to_view_pass import (
16+
ConvertSqueezeToViewPass,
17+
)
1518
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
1619
from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
1720
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
@@ -67,6 +70,7 @@ def __init__(
6770
DecomposeConcatenate,
6871
RemoveGetItemPass,
6972
Conv1dUnsqueezePass,
73+
ConvertSqueezeToViewPass,
7074
PReLUReshapePass,
7175
ChannelsLastTaggedReshapePass,
7276
TagImplicitQDqPass,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
11+
from executorch.backends.xnnpack.utils.utils import check_or_raise
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import PassResult
14+
from torch.fx.experimental.symbolic_shapes import has_free_symbols
15+
16+
17+
class ConvertSqueezeToViewPass(XNNPACKPass):
18+
"""
19+
This pass is used to convert squeeze and unsqueeze nodes into view_copy.
20+
This allows them to be subsequentially lowered as static_reshape ops.
21+
"""
22+
23+
SUPPORTED_OPS = [
24+
exir_ops.edge.aten.squeeze_copy.dim,
25+
exir_ops.edge.aten.squeeze_copy.dims,
26+
exir_ops.edge.aten.unsqueeze_copy.default,
27+
]
28+
29+
def call(self, graph_module: torch.fx.GraphModule):
30+
graph = graph_module.graph
31+
node_list = list(graph.nodes)
32+
for node in node_list:
33+
if node.op == "call_function":
34+
if node.target in self.SUPPORTED_OPS:
35+
out_shape = node.meta["val"].shape
36+
37+
# Replace up to one dynamic dimension with -1 (inferred dim).
38+
new_shape = []
39+
dynamic_dim_count = 0
40+
for d in out_shape:
41+
if has_free_symbols(d):
42+
new_shape.append(-1)
43+
dynamic_dim_count += 1
44+
else:
45+
new_shape.append(d)
46+
47+
# This constraint should be enforced by the partitioner.
48+
check_or_raise(
49+
dynamic_dim_count <= 1,
50+
"XNN supports only one dynamic dimension",
51+
)
52+
53+
with graph_module.graph.inserting_after(node):
54+
view_node = graph_module.graph.create_node(
55+
"call_function",
56+
target=exir_ops.edge.aten.view_copy.default,
57+
args=(node.args[0], new_shape),
58+
kwargs=node.kwargs,
59+
)
60+
61+
node.replace_all_uses_with(view_node)
62+
graph_module.graph.erase_node(node)
63+
64+
graph_module.recompile()
65+
# Since we are overriding "call", we need to call the parent's "call"
66+
# to retrace the graph and regenerate metadata
67+
graph_module = super().call(graph_module).graph_module
68+
69+
return PassResult(graph_module, True)

backends/xnnpack/partition/config/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545
SliceCopyConfig,
4646
SoftmaxConfig,
4747
SquareRootConfig,
48+
SqueezeCopyConfig,
4849
SubConfig,
50+
UnsqueezeCopyConfig,
4951
UpsampleBilinear2dConfig,
5052
ViewCopyConfig,
5153
)
@@ -99,7 +101,9 @@
99101
SliceCopyConfig,
100102
SoftmaxConfig,
101103
SquareRootConfig,
104+
SqueezeCopyConfig,
102105
SubConfig,
106+
UnsqueezeCopyConfig,
103107
UpsampleBilinear2dConfig,
104108
ViewCopyConfig,
105109
# Quant/Dequant Op Configs

backends/xnnpack/partition/config/generic_node_configs.py

+53
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from executorch.exir.backend.utils import is_shape_dynamic, WhyNoPartition
2323
from torch.export import ExportedProgram
24+
from torch.fx.experimental.symbolic_shapes import has_free_symbols
2425

2526
logger = logging.getLogger(__name__)
2627
why = WhyNoPartition(logger=logger)
@@ -314,6 +315,31 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
314315
return torch.ops.aten.max_pool2d.default
315316

316317

318+
class SqueezeCopyConfig(GenericNodePartitionerConfig):
319+
target_name = "squeeze_copy.dims"
320+
321+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
322+
return [ConfigPrecisionType.FP32]
323+
324+
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
325+
return torch.ops.aten.squeeze_copy.default
326+
327+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
328+
"""
329+
XNNPACK's static_reshape only supports 1 dynamic dimension
330+
"""
331+
if not self.check_common_constraints(node, ep):
332+
return False
333+
334+
new_shape = node.meta["val"].shape
335+
dynamic_dim_count = sum(1 for d in new_shape if has_free_symbols(d))
336+
if dynamic_dim_count > 1:
337+
why(node, reason="only a single dynamic dimension is supported")
338+
return False
339+
340+
return True
341+
342+
317343
class UpsampleBilinear2dConfig(GenericNodePartitionerConfig):
318344
target_name = "upsample_bilinear2d.vec"
319345

@@ -336,6 +362,33 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
336362
return torch.ops.aten.upsample_bilinear2d.vec
337363

338364

365+
class UnsqueezeCopyConfig(GenericNodePartitionerConfig):
366+
target_name = "unsqueeze_copy.default"
367+
368+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
369+
return [ConfigPrecisionType.FP32]
370+
371+
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
372+
return torch.ops.aten.unsqueeze_copy.default
373+
374+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
375+
"""
376+
XNNPACK's static_reshape only supports 1 dynamic dimension
377+
"""
378+
if not self.check_common_constraints(node, ep):
379+
return False
380+
381+
new_shape = node.meta["val"].shape
382+
dynamic_dim_count = sum(
383+
1 for d in new_shape if not isinstance(d, int) and has_free_symbols(d)
384+
)
385+
if dynamic_dim_count > 1:
386+
why(node, reason="only a single dynamic dimension is supported")
387+
return False
388+
389+
return True
390+
391+
339392
class ViewCopyConfig(GenericNodePartitionerConfig):
340393
target_name = "view_copy.default"
341394

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Export, Tester
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from torch.export import Dim
13+
14+
15+
class TestSqueeze(unittest.TestCase):
16+
class Squeeze(torch.nn.Module):
17+
def __init__(self, dims):
18+
super().__init__()
19+
self.dims = dims
20+
21+
def forward(self, x):
22+
return torch.squeeze(x, self.dims)
23+
24+
def test_fp32_squeeze(self):
25+
inputs = (torch.randn(1, 2, 1, 4, 1),)
26+
squeeze_dims = (0, 2, 4)
27+
28+
for dims in squeeze_dims:
29+
(
30+
Tester(self.Squeeze(dims), inputs)
31+
.export()
32+
.check_node_count(
33+
{
34+
torch.ops.aten.squeeze.dim: 1,
35+
}
36+
)
37+
.to_edge_transform_and_lower()
38+
.check_node_count(
39+
{
40+
exir_ops.edge.aten.squeeze_copy.dim: 0,
41+
exir_ops.edge.aten.view_copy.default: 0,
42+
torch.ops.higher_order.executorch_call_delegate: 1,
43+
}
44+
)
45+
.run_method_and_compare_outputs()
46+
)
47+
48+
def test_fp16_squeeze(self):
49+
inputs = (torch.randn(1, 2, 1, 4, 1).to(torch.float16),)
50+
squeeze_dims = (0, 2, 4)
51+
52+
for dims in squeeze_dims:
53+
(
54+
Tester(self.Squeeze(dims), inputs)
55+
.export()
56+
.check_node_count(
57+
{
58+
torch.ops.aten.squeeze.dim: 1,
59+
}
60+
)
61+
.to_edge_transform_and_lower()
62+
.check_node_count(
63+
{
64+
exir_ops.edge.aten.squeeze_copy.dim: 0,
65+
exir_ops.edge.aten.view_copy.default: 0,
66+
torch.ops.higher_order.executorch_call_delegate: 1,
67+
}
68+
)
69+
.run_method_and_compare_outputs()
70+
)
71+
72+
def test_fp32_squeeze_dynamic(self):
73+
inputs = (torch.randn(1, 2, 1, 4, 1),)
74+
squeeze_dims = (0, 2, 4)
75+
dynamic_shapes = {"x": {1: Dim("x_1", min=1, max=10)}}
76+
77+
for dims in squeeze_dims:
78+
(
79+
Tester(self.Squeeze(dims), inputs)
80+
.export(Export(dynamic_shapes=dynamic_shapes))
81+
.check_node_count(
82+
{
83+
torch.ops.aten.squeeze.dim: 1,
84+
}
85+
)
86+
.to_edge_transform_and_lower()
87+
.check_node_count(
88+
{
89+
exir_ops.edge.aten.squeeze_copy.dim: 0,
90+
exir_ops.edge.aten.view_copy.default: 0,
91+
torch.ops.higher_order.executorch_call_delegate: 1,
92+
}
93+
)
94+
.run_method_and_compare_outputs()
95+
)
96+
97+
def test_fp32_squeeze_unsupported_dynamism(self):
98+
inputs = (torch.randn(1, 2, 1, 4, 1),)
99+
squeeze_dims = (0, 2, 4)
100+
# Only one dynamic dimension is supported.
101+
dynamic_shapes = {
102+
"x": {
103+
1: Dim("x_1", min=1, max=10),
104+
3: Dim("x_3", min=1, max=10),
105+
}
106+
}
107+
108+
for dims in squeeze_dims:
109+
(
110+
Tester(self.Squeeze(dims), inputs)
111+
.export(Export(dynamic_shapes=dynamic_shapes))
112+
.check_node_count(
113+
{
114+
torch.ops.aten.squeeze.dim: 1,
115+
}
116+
)
117+
.to_edge_transform_and_lower()
118+
.check_node_count(
119+
{
120+
exir_ops.edge.aten.squeeze_copy.dims: 1,
121+
torch.ops.higher_order.executorch_call_delegate: 0,
122+
}
123+
)
124+
.run_method_and_compare_outputs()
125+
)

0 commit comments

Comments
 (0)