Skip to content

Commit bbc17ca

Browse files
committed
Support (un)squeeze in XNN delegate via conversion to view
1 parent 36bdc16 commit bbc17ca

7 files changed

+417
-0
lines changed

backends/xnnpack/_passes/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from executorch.backends.xnnpack._passes.conv1d_unsqueeze_pass import (
1313
Conv1dUnsqueezePass,
1414
)
15+
from executorch.backends.xnnpack._passes.convert_squeeze_to_view_pass import (
16+
ConvertSqueezeToViewPass,
17+
)
1518
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
1619
from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass
1720
from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import (
@@ -67,6 +70,7 @@ def __init__(
6770
DecomposeConcatenate,
6871
RemoveGetItemPass,
6972
Conv1dUnsqueezePass,
73+
ConvertSqueezeToViewPass,
7074
PReLUReshapePass,
7175
ChannelsLastTaggedReshapePass,
7276
TagImplicitQDqPass,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Optional
10+
11+
import torch
12+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
13+
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
14+
from executorch.backends.xnnpack.utils.utils import check_or_raise, get_param_tensor, is_param_node
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
from executorch.exir.pass_base import PassResult
17+
from torch._ops import OpOverload
18+
from torch.fx.experimental.symbolic_shapes import has_free_symbols
19+
20+
21+
class ConvertSqueezeToViewPass(XNNPACKPass):
22+
"""
23+
This pass is used to convert squeeze and unsqueeze nodes into view_copy.
24+
This allows them to be subsequentially lowered as static_reshape ops.
25+
"""
26+
27+
SUPPORTED_OPS = [
28+
exir_ops.edge.aten.squeeze_copy.dim,
29+
exir_ops.edge.aten.squeeze_copy.dims,
30+
exir_ops.edge.aten.unsqueeze_copy.default,
31+
]
32+
33+
def call(self, graph_module: torch.fx.GraphModule):
34+
graph = graph_module.graph
35+
node_list = list(graph.nodes)
36+
for node in node_list:
37+
if node.op == "call_function":
38+
if node.target in self.SUPPORTED_OPS:
39+
out_shape = node.meta["val"].shape
40+
41+
# Replace up to one dynamic dimension with -1 (inferred dim).
42+
new_shape = []
43+
dynamic_dim_count = 0
44+
for d in out_shape:
45+
if has_free_symbols(d):
46+
new_shape.append(-1)
47+
dynamic_dim_count += 1
48+
else:
49+
new_shape.append(d)
50+
51+
# This constraint should be enforced by the partitioner.
52+
check_or_raise(
53+
dynamic_dim_count <= 1,
54+
"XNN supports only one dynamic dimension"
55+
)
56+
57+
with graph_module.graph.inserting_after(node):
58+
view_node = graph_module.graph.create_node(
59+
"call_function",
60+
target=exir_ops.edge.aten.view_copy.default,
61+
args=(node.args[0], new_shape),
62+
kwargs=node.kwargs
63+
)
64+
65+
node.replace_all_uses_with(view_node)
66+
graph_module.graph.erase_node(node)
67+
68+
69+
graph_module.recompile()
70+
# Since we are overriding "call", we need to call the parent's "call"
71+
# to retrace the graph and regenerate metadata
72+
graph_module = super().call(graph_module).graph_module
73+
74+
return PassResult(graph_module, True)
75+

backends/xnnpack/partition/config/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545
SliceCopyConfig,
4646
SoftmaxConfig,
4747
SquareRootConfig,
48+
SqueezeCopyConfig,
4849
SubConfig,
50+
UnsqueezeCopyConfig,
4951
UpsampleBilinear2dConfig,
5052
ViewCopyConfig,
5153
)
@@ -99,7 +101,9 @@
99101
SliceCopyConfig,
100102
SoftmaxConfig,
101103
SquareRootConfig,
104+
SqueezeCopyConfig,
102105
SubConfig,
106+
UnsqueezeCopyConfig,
103107
UpsampleBilinear2dConfig,
104108
ViewCopyConfig,
105109
# Quant/Dequant Op Configs

backends/xnnpack/partition/config/generic_node_configs.py

+53
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from executorch.exir.backend.utils import is_shape_dynamic, WhyNoPartition
2323
from torch.export import ExportedProgram
24+
from torch.fx.experimental.symbolic_shapes import has_free_symbols
2425

2526
logger = logging.getLogger(__name__)
2627
why = WhyNoPartition(logger=logger)
@@ -314,6 +315,31 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
314315
return torch.ops.aten.max_pool2d.default
315316

316317

318+
class SqueezeCopyConfig(GenericNodePartitionerConfig):
319+
target_name = "squeeze_copy.dims"
320+
321+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
322+
return [ConfigPrecisionType.FP32]
323+
324+
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
325+
return torch.ops.aten.squeeze_copy.default
326+
327+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
328+
"""
329+
XNNPACK's static_reshape only supports 1 dynamic dimension
330+
"""
331+
if not self.check_common_constraints(node, ep):
332+
return False
333+
334+
new_shape = node.meta["val"].shape
335+
dynamic_dim_count = sum(1 for d in new_shape if has_free_symbols(d))
336+
if dynamic_dim_count > 1:
337+
why(node, reason="only a single dynamic dimension is supported")
338+
return False
339+
340+
return True
341+
342+
317343
class UpsampleBilinear2dConfig(GenericNodePartitionerConfig):
318344
target_name = "upsample_bilinear2d.vec"
319345

@@ -336,6 +362,33 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
336362
return torch.ops.aten.upsample_bilinear2d.vec
337363

338364

365+
class UnsqueezeCopyConfig(GenericNodePartitionerConfig):
366+
target_name = "unsqueeze_copy.default"
367+
368+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
369+
return [ConfigPrecisionType.FP32]
370+
371+
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
372+
return torch.ops.aten.unsqueeze_copy.default
373+
374+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
375+
"""
376+
XNNPACK's static_reshape only supports 1 dynamic dimension
377+
"""
378+
if not self.check_common_constraints(node, ep):
379+
return False
380+
381+
new_shape = node.meta["val"].shape
382+
dynamic_dim_count = sum(
383+
1 for d in new_shape if not isinstance(d, int) and has_free_symbols(d)
384+
)
385+
if dynamic_dim_count > 1:
386+
why(node, reason="only a single dynamic dimension is supported")
387+
return False
388+
389+
return True
390+
391+
339392
class ViewCopyConfig(GenericNodePartitionerConfig):
340393
target_name = "view_copy.default"
341394

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Export, Tester
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from torch.export import Dim
13+
14+
15+
class TestSqueeze(unittest.TestCase):
16+
class Squeeze(torch.nn.Module):
17+
def __init__(self, dims):
18+
super().__init__()
19+
self.dims = dims
20+
21+
def forward(self, x):
22+
return torch.squeeze(x, self.dims)
23+
24+
def test_fp32_squeeze(self):
25+
inputs = (torch.randn(1,2,1,4,1),)
26+
squeeze_dims = (0, 2, 4)
27+
28+
for dims in squeeze_dims:
29+
(
30+
Tester(self.Squeeze(dims), inputs)
31+
.export()
32+
.check_node_count({
33+
torch.ops.aten.squeeze.dim: 1,
34+
})
35+
.to_edge_transform_and_lower()
36+
.check_node_count({
37+
exir_ops.edge.aten.squeeze_copy.dim: 0,
38+
exir_ops.edge.aten.view_copy.default: 0,
39+
torch.ops.higher_order.executorch_call_delegate: 1,
40+
})
41+
.run_method_and_compare_outputs()
42+
)
43+
44+
def test_fp16_squeeze(self):
45+
inputs = (torch.randn(1,2,1,4,1).to(torch.float16),)
46+
squeeze_dims = (0, 2, 4)
47+
48+
for dims in squeeze_dims:
49+
(
50+
Tester(self.Squeeze(dims), inputs)
51+
.export()
52+
.check_node_count({
53+
torch.ops.aten.squeeze.dim: 1,
54+
})
55+
.to_edge_transform_and_lower()
56+
.check_node_count({
57+
exir_ops.edge.aten.squeeze_copy.dim: 0,
58+
exir_ops.edge.aten.view_copy.default: 0,
59+
torch.ops.higher_order.executorch_call_delegate: 1,
60+
})
61+
.run_method_and_compare_outputs()
62+
)
63+
64+
def test_fp32_squeeze_dynamic(self):
65+
inputs = (torch.randn(1,2,1,4,1),)
66+
squeeze_dims = (0, 2, 4)
67+
dynamic_shapes = { "x": { 1: Dim("x_1", min=1, max=10) } }
68+
69+
for dims in squeeze_dims:
70+
(
71+
Tester(self.Squeeze(dims), inputs)
72+
.export(Export(dynamic_shapes=dynamic_shapes))
73+
.check_node_count({
74+
torch.ops.aten.squeeze.dim: 1,
75+
})
76+
.to_edge_transform_and_lower()
77+
.check_node_count({
78+
exir_ops.edge.aten.squeeze_copy.dim: 0,
79+
exir_ops.edge.aten.view_copy.default: 0,
80+
torch.ops.higher_order.executorch_call_delegate: 1,
81+
})
82+
.run_method_and_compare_outputs()
83+
)
84+
85+
def test_fp32_squeeze_unsupported_dynamism(self):
86+
inputs = (torch.randn(1,2,1,4,1),)
87+
squeeze_dims = (0, 2, 4)
88+
# Only one dynamic dimension is supported.
89+
dynamic_shapes = { "x": {
90+
1: Dim("x_1", min=1, max=10),
91+
3: Dim("x_3", min=1, max=10),
92+
} }
93+
94+
for dims in squeeze_dims:
95+
(
96+
Tester(self.Squeeze(dims), inputs)
97+
.export(Export(dynamic_shapes=dynamic_shapes))
98+
.check_node_count({
99+
torch.ops.aten.squeeze.dim: 1,
100+
})
101+
.to_edge_transform_and_lower()
102+
.check_node_count({
103+
exir_ops.edge.aten.squeeze_copy.dims: 1,
104+
torch.ops.higher_order.executorch_call_delegate: 0,
105+
})
106+
.run_method_and_compare_outputs()
107+
)
+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.test.tester import Export, Tester
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from torch.export import Dim
13+
14+
15+
class TestUnsqueeze(unittest.TestCase):
16+
class Unsqueeze(torch.nn.Module):
17+
def __init__(self, dim):
18+
super().__init__()
19+
self.dim = dim
20+
21+
def forward(self, x):
22+
return torch.unsqueeze(x, self.dim)
23+
24+
def test_fp32_unsqueeze(self):
25+
inputs = (torch.randn(1,2,4),)
26+
for dim in range(len(inputs[0].shape)):
27+
(
28+
Tester(self.Unsqueeze(dim), inputs)
29+
.export()
30+
.check_node_count({
31+
torch.ops.aten.unsqueeze.default: 1,
32+
})
33+
.to_edge_transform_and_lower()
34+
.check_node_count({
35+
exir_ops.edge.aten.unsqueeze_copy.default: 0,
36+
exir_ops.edge.aten.view_copy.default: 0,
37+
torch.ops.higher_order.executorch_call_delegate: 1,
38+
})
39+
.run_method_and_compare_outputs()
40+
)
41+
42+
def test_fp16_unsqueeze(self):
43+
inputs = (torch.randn(1,2,4).to(torch.float16),)
44+
for dim in range(len(inputs[0].shape)):
45+
(
46+
Tester(self.Unsqueeze(dim), inputs)
47+
.export()
48+
.check_node_count({
49+
torch.ops.aten.unsqueeze.default: 1,
50+
})
51+
.to_edge_transform_and_lower()
52+
.check_node_count({
53+
exir_ops.edge.aten.unsqueeze_copy.default: 0,
54+
exir_ops.edge.aten.view_copy.default: 0,
55+
torch.ops.higher_order.executorch_call_delegate: 1,
56+
})
57+
.run_method_and_compare_outputs()
58+
)
59+
60+
def test_fp32_unsqueeze_dynamic(self):
61+
inputs = (torch.randn(1,2,4),)
62+
dynamic_shapes = { "x": { 1: Dim("x_1", min=1, max=10) } }
63+
64+
for dim in range(len(inputs[0].shape)):
65+
(
66+
Tester(self.Unsqueeze(dim), inputs)
67+
.export(Export(dynamic_shapes=dynamic_shapes))
68+
.check_node_count({
69+
torch.ops.aten.unsqueeze.default: 1,
70+
})
71+
.to_edge_transform_and_lower()
72+
.check_node_count({
73+
exir_ops.edge.aten.unsqueeze_copy.default: 0,
74+
exir_ops.edge.aten.view_copy.default: 0,
75+
torch.ops.higher_order.executorch_call_delegate: 1,
76+
})
77+
.run_method_and_compare_outputs()
78+
)
79+
80+
def test_fp32_unsqueeze_unsupported_dynamism(self):
81+
inputs = (torch.randn(1,2,4),)
82+
# Only one dynamic dimension is supported.
83+
dynamic_shapes = { "x": {
84+
1: Dim("x_1", min=1, max=10),
85+
2: Dim("x_2", min=1, max=10),
86+
} }
87+
88+
for dim in range(len(inputs[0].shape)):
89+
(
90+
Tester(self.Unsqueeze(dim), inputs)
91+
.export(Export(dynamic_shapes=dynamic_shapes))
92+
.check_node_count({
93+
torch.ops.aten.unsqueeze.default: 1,
94+
})
95+
.to_edge_transform_and_lower()
96+
.check_node_count({
97+
exir_ops.edge.aten.unsqueeze_copy.default: 1,
98+
torch.ops.higher_order.executorch_call_delegate: 0,
99+
})
100+
.run_method_and_compare_outputs()
101+
)

0 commit comments

Comments
 (0)