diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 36a7833dca0..09776e6164a 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -12,6 +12,9 @@ from executorch.backends.xnnpack._passes.conv1d_unsqueeze_pass import ( Conv1dUnsqueezePass, ) +from executorch.backends.xnnpack._passes.convert_squeeze_to_view_pass import ( + ConvertSqueezeToViewPass, +) from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass from executorch.backends.xnnpack._passes.convert_to_sdpa import ConvertToSDPAPass from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import ( @@ -67,6 +70,7 @@ def __init__( DecomposeConcatenate, RemoveGetItemPass, Conv1dUnsqueezePass, + ConvertSqueezeToViewPass, PReLUReshapePass, ChannelsLastTaggedReshapePass, TagImplicitQDqPass, diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 89a44f303df..5ab2ee2f547 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -8,7 +8,7 @@ import torch from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass -from executorch.backends.xnnpack.utils.utils import is_param_node +from executorch.backends.xnnpack.utils.utils import get_input_node, is_param_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult @@ -77,6 +77,21 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass): # is done PARTNER_NODE = "XNN_CHANNELS_LAST_TAGGED_RESHAPE_PARTNER_NODE" + def is_view_dim_order_invariant(self, node: torch.fx.Node) -> bool: + # View must be done in NCHW dim order if channel or batch is changed, + # or if rank is not 4. + in_shape = get_input_node(node, 0).meta["val"].shape + out_shape = node.meta["val"].shape + + if len(in_shape) != 4 or len(out_shape) != 4: + return False + + # Are batch and channel modified? If so, return false. + if in_shape[0] != out_shape[0] or in_shape[1] != out_shape[1]: + return False + + return True + def mark_as_nhwc_node(self, node: torch.fx.Node) -> None: node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = True @@ -93,6 +108,13 @@ def requires_nhwc_input(self, node: torch.fx.Node) -> bool: return node.target in self.memory_sensitive_ops_nhwc def requires_nchw_inputs(self, node: torch.fx.Node) -> bool: + # Views depend on whether batch or channel are modified. + if ( + node.target == exir_ops.edge.aten.view_copy.default + and not self.is_view_dim_order_invariant(node) + ): + return True + return node.target in self.memory_sensitive_ops_nchw def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool: diff --git a/backends/xnnpack/_passes/convert_squeeze_to_view_pass.py b/backends/xnnpack/_passes/convert_squeeze_to_view_pass.py new file mode 100644 index 00000000000..362b30d572a --- /dev/null +++ b/backends/xnnpack/_passes/convert_squeeze_to_view_pass.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.backends.xnnpack.utils.utils import check_or_raise +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult +from torch.fx.experimental.symbolic_shapes import has_free_symbols + + +class ConvertSqueezeToViewPass(XNNPACKPass): + """ + This pass is used to convert squeeze and unsqueeze nodes into view_copy. + This allows them to be subsequentially lowered as static_reshape ops. + """ + + SUPPORTED_OPS = [ + exir_ops.edge.aten.squeeze_copy.dim, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + ] + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + node_list = list(graph.nodes) + for node in node_list: + if node.op == "call_function": + if node.target in self.SUPPORTED_OPS: + out_shape = node.meta["val"].shape + + # Replace up to one dynamic dimension with -1 (inferred dim). + new_shape = [] + dynamic_dim_count = 0 + for d in out_shape: + if has_free_symbols(d): + new_shape.append(-1) + dynamic_dim_count += 1 + else: + new_shape.append(d) + + # This constraint should be enforced by the partitioner. + check_or_raise( + dynamic_dim_count <= 1, + "XNN supports only one dynamic dimension", + ) + + with graph_module.graph.inserting_after(node): + view_node = graph_module.graph.create_node( + "call_function", + target=exir_ops.edge.aten.view_copy.default, + args=(node.args[0], new_shape), + kwargs=node.kwargs, + ) + + node.replace_all_uses_with(view_node) + graph_module.graph.erase_node(node) + + graph_module.recompile() + # Since we are overriding "call", we need to call the parent's "call" + # to retrace the graph and regenerate metadata + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index b2653a5fdc7..0382a6fc059 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -49,4 +49,5 @@ op_static_resize_bilinear_2d, op_sub, op_to_copy, + op_view_copy, ) diff --git a/backends/xnnpack/operators/op_skip_ops.py b/backends/xnnpack/operators/op_skip_ops.py index 6597c0568e3..d649df50ecf 100644 --- a/backends/xnnpack/operators/op_skip_ops.py +++ b/backends/xnnpack/operators/op_skip_ops.py @@ -77,16 +77,6 @@ class OpTCopyDefault(OpSkipOps): target = "aten.t_copy.default" -@register_node_visitor -class OpViewCopyDefault(OpSkipOps): - """ - currently, do nothing if node is view_copy.default - need to handle this later on, currently view it as one of skip ops - """ - - target = "aten.view_copy.default" - - @register_node_visitor class OpSymSizeInt(OpSkipOps): """ diff --git a/backends/xnnpack/operators/op_view_copy.py b/backends/xnnpack/operators/op_view_copy.py new file mode 100644 index 00000000000..5a8bf342eab --- /dev/null +++ b/backends/xnnpack/operators/op_view_copy.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNStaticReshape, + XNode, +) +from executorch.backends.xnnpack.utils.utils import ( + check_or_raise, + get_input_node, + PERM_NCHW_TO_NHWC, +) + + +@register_node_visitor +class ViewCopyVisitor(NodeVisitor): + target = "aten.view_copy.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + input_node = get_input_node(node, 0) + + # input + input_id = vals_to_ids[input_node] + + # output + output_id = vals_to_ids[node] + + # input shape + check_or_raise( + "val" in input_node.meta, + "Missing val in tensor metadata for input when serializing XNNStaticReshape", + ) + + # output shape + check_or_raise( + "val" in node.meta, + "Missing val in tensor metadata for input when serializing XNNStaticReshape", + ) + + new_shape = node.args[1] + check_or_raise( + all(isinstance(d, int) for d in new_shape), + "Symbolic reshape parameter is not supported in XNNStaticReshape", + ) + + # PyTorch uses -1 for inferred dims, whereas XNNPACK expects 0. + new_shape = tuple(d if d != -1 else 0 for d in new_shape) + + # Handle NCHW dim order - if this op is in NCHW order, we need to permute the + # view shape correspondingly. + if "XNN_NHWC_NODE" in node.meta: + check_or_raise(len(new_shape) == 4, "Invalid NCHW shape") + new_shape = [new_shape[PERM_NCHW_TO_NHWC[n]] for n in range(4)] + + num_dynamic_dims = sum(1 for d in new_shape if d == 0) + + check_or_raise( + num_dynamic_dims <= 1, + "XNNPACK reshape only supports 1 dynamic dimension.", + ) + + ser_node = XNode( + xnode_union=XNNStaticReshape( + num_dims=len(new_shape), + new_shape=new_shape, + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index ed105dc1f53..e453631dfde 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -45,8 +45,11 @@ SliceCopyConfig, SoftmaxConfig, SquareRootConfig, + SqueezeCopyConfig, SubConfig, + UnsqueezeCopyConfig, UpsampleBilinear2dConfig, + ViewCopyConfig, ) from executorch.backends.xnnpack.partition.config.node_configs import ( BatchNormConfig, @@ -98,8 +101,11 @@ SliceCopyConfig, SoftmaxConfig, SquareRootConfig, + SqueezeCopyConfig, SubConfig, + UnsqueezeCopyConfig, UpsampleBilinear2dConfig, + ViewCopyConfig, # Quant/Dequant Op Configs QuantizedPerTensorConfig, DeQuantizedPerTensorConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index dbcb5c92035..0ccc4717696 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -21,6 +21,7 @@ ) from executorch.exir.backend.utils import is_shape_dynamic, WhyNoPartition from torch.export import ExportedProgram +from torch.fx.experimental.symbolic_shapes import has_free_symbols logger = logging.getLogger(__name__) why = WhyNoPartition(logger=logger) @@ -314,6 +315,31 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]: return torch.ops.aten.max_pool2d.default +class SqueezeCopyConfig(GenericNodePartitionerConfig): + target_name = "squeeze_copy.dims" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + def get_original_aten(self) -> Optional[torch._ops.OpOverload]: + return torch.ops.aten.squeeze_copy.default + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + XNNPACK's static_reshape only supports 1 dynamic dimension + """ + if not self.check_common_constraints(node, ep): + return False + + new_shape = node.meta["val"].shape + dynamic_dim_count = sum(1 for d in new_shape if has_free_symbols(d)) + if dynamic_dim_count > 1: + why(node, reason="only a single dynamic dimension is supported") + return False + + return True + + class UpsampleBilinear2dConfig(GenericNodePartitionerConfig): target_name = "upsample_bilinear2d.vec" @@ -336,6 +362,59 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]: return torch.ops.aten.upsample_bilinear2d.vec +class UnsqueezeCopyConfig(GenericNodePartitionerConfig): + target_name = "unsqueeze_copy.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + def get_original_aten(self) -> Optional[torch._ops.OpOverload]: + return torch.ops.aten.unsqueeze_copy.default + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + XNNPACK's static_reshape only supports 1 dynamic dimension + """ + if not self.check_common_constraints(node, ep): + return False + + new_shape = node.meta["val"].shape + dynamic_dim_count = sum( + 1 for d in new_shape if not isinstance(d, int) and has_free_symbols(d) + ) + if dynamic_dim_count > 1: + why(node, reason="only a single dynamic dimension is supported") + return False + + return True + + +class ViewCopyConfig(GenericNodePartitionerConfig): + target_name = "view_copy.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + XNNPACK's static_reshape only supports 1 dynamic dimension + """ + if not self.check_common_constraints(node, ep): + return False + + new_shape = node.args[1] + if not all(isinstance(n, int) for n in new_shape): + why(node, reason="symbolic reshape is not supported") + return False + + dynamic_dim_count = sum(1 for d in new_shape if d == -1) + if dynamic_dim_count > 1: + why(node, reason="only a single dynamic dimension is supported") + return False + + return True + + class FloorConfig(GenericNodePartitionerConfig): target_name = "floor.default" diff --git a/backends/xnnpack/test/ops/test_squeeze.py b/backends/xnnpack/test/ops/test_squeeze.py new file mode 100644 index 00000000000..e611683ff88 --- /dev/null +++ b/backends/xnnpack/test/ops/test_squeeze.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Export, Tester +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import Dim + + +class TestSqueeze(unittest.TestCase): + class Squeeze(torch.nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.squeeze(x, self.dims) + + def test_fp32_squeeze(self): + inputs = (torch.randn(1, 2, 1, 4, 1),) + squeeze_dims = (0, 2, 4) + + for dims in squeeze_dims: + ( + Tester(self.Squeeze(dims), inputs) + .export() + .check_node_count( + { + torch.ops.aten.squeeze.dim: 1, + } + ) + .to_edge_transform_and_lower() + .check_node_count( + { + exir_ops.edge.aten.squeeze_copy.dim: 0, + exir_ops.edge.aten.view_copy.default: 0, + torch.ops.higher_order.executorch_call_delegate: 1, + } + ) + .run_method_and_compare_outputs() + ) + + def test_fp16_squeeze(self): + inputs = (torch.randn(1, 2, 1, 4, 1).to(torch.float16),) + squeeze_dims = (0, 2, 4) + + for dims in squeeze_dims: + ( + Tester(self.Squeeze(dims), inputs) + .export() + .check_node_count( + { + torch.ops.aten.squeeze.dim: 1, + } + ) + .to_edge_transform_and_lower() + .check_node_count( + { + exir_ops.edge.aten.squeeze_copy.dim: 0, + exir_ops.edge.aten.view_copy.default: 0, + torch.ops.higher_order.executorch_call_delegate: 1, + } + ) + .run_method_and_compare_outputs() + ) + + def test_fp32_squeeze_dynamic(self): + inputs = (torch.randn(1, 2, 1, 4, 1),) + squeeze_dims = (0, 2, 4) + dynamic_shapes = {"x": {1: Dim("x_1", min=1, max=10)}} + + for dims in squeeze_dims: + ( + Tester(self.Squeeze(dims), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count( + { + torch.ops.aten.squeeze.dim: 1, + } + ) + .to_edge_transform_and_lower() + .check_node_count( + { + exir_ops.edge.aten.squeeze_copy.dim: 0, + exir_ops.edge.aten.view_copy.default: 0, + torch.ops.higher_order.executorch_call_delegate: 1, + } + ) + .run_method_and_compare_outputs() + ) + + def test_fp32_squeeze_unsupported_dynamism(self): + inputs = (torch.randn(1, 2, 1, 4, 1),) + squeeze_dims = (0, 2, 4) + # Only one dynamic dimension is supported. + dynamic_shapes = { + "x": { + 1: Dim("x_1", min=1, max=10), + 3: Dim("x_3", min=1, max=10), + } + } + + for dims in squeeze_dims: + ( + Tester(self.Squeeze(dims), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count( + { + torch.ops.aten.squeeze.dim: 1, + } + ) + .to_edge_transform_and_lower() + .check_node_count( + { + exir_ops.edge.aten.squeeze_copy.dims: 1, + torch.ops.higher_order.executorch_call_delegate: 0, + } + ) + .run_method_and_compare_outputs() + ) diff --git a/backends/xnnpack/test/ops/test_unsqueeze.py b/backends/xnnpack/test/ops/test_unsqueeze.py new file mode 100644 index 00000000000..befe1c902dc --- /dev/null +++ b/backends/xnnpack/test/ops/test_unsqueeze.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Export, Tester +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import Dim + + +class TestUnsqueeze(unittest.TestCase): + class Unsqueeze(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.unsqueeze(x, self.dim) + + def test_fp32_unsqueeze(self): + inputs = (torch.randn(1, 2, 4),) + for dim in range(len(inputs[0].shape)): + ( + Tester(self.Unsqueeze(dim), inputs) + .export() + .check_node_count( + { + torch.ops.aten.unsqueeze.default: 1, + } + ) + .to_edge_transform_and_lower() + .check_node_count( + { + exir_ops.edge.aten.unsqueeze_copy.default: 0, + exir_ops.edge.aten.view_copy.default: 0, + torch.ops.higher_order.executorch_call_delegate: 1, + } + ) + .run_method_and_compare_outputs() + ) + + def test_fp16_unsqueeze(self): + inputs = (torch.randn(1, 2, 4).to(torch.float16),) + for dim in range(len(inputs[0].shape)): + ( + Tester(self.Unsqueeze(dim), inputs) + .export() + .check_node_count( + { + torch.ops.aten.unsqueeze.default: 1, + } + ) + .to_edge_transform_and_lower() + .check_node_count( + { + exir_ops.edge.aten.unsqueeze_copy.default: 0, + exir_ops.edge.aten.view_copy.default: 0, + torch.ops.higher_order.executorch_call_delegate: 1, + } + ) + .run_method_and_compare_outputs() + ) + + def test_fp32_unsqueeze_dynamic(self): + inputs = (torch.randn(1, 2, 4),) + dynamic_shapes = {"x": {1: Dim("x_1", min=1, max=10)}} + + for dim in range(len(inputs[0].shape)): + ( + Tester(self.Unsqueeze(dim), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count( + { + torch.ops.aten.unsqueeze.default: 1, + } + ) + .to_edge_transform_and_lower() + .check_node_count( + { + exir_ops.edge.aten.unsqueeze_copy.default: 0, + exir_ops.edge.aten.view_copy.default: 0, + torch.ops.higher_order.executorch_call_delegate: 1, + } + ) + .run_method_and_compare_outputs() + ) + + def test_fp32_unsqueeze_unsupported_dynamism(self): + inputs = (torch.randn(1, 2, 4),) + # Only one dynamic dimension is supported. + dynamic_shapes = { + "x": { + 1: Dim("x_1", min=1, max=10), + 2: Dim("x_2", min=1, max=10), + } + } + + for dim in range(len(inputs[0].shape)): + ( + Tester(self.Unsqueeze(dim), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count( + { + torch.ops.aten.unsqueeze.default: 1, + } + ) + .to_edge_transform_and_lower() + .check_node_count( + { + exir_ops.edge.aten.unsqueeze_copy.default: 1, + torch.ops.higher_order.executorch_call_delegate: 0, + } + ) + .run_method_and_compare_outputs() + ) diff --git a/backends/xnnpack/test/ops/test_view_copy.py b/backends/xnnpack/test/ops/test_view_copy.py new file mode 100644 index 00000000000..9e952996639 --- /dev/null +++ b/backends/xnnpack/test/ops/test_view_copy.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Export, Tester +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import Dim + + +class TestViewCopy(unittest.TestCase): + class View(torch.nn.Module): + def __init__(self, new_shape): + super().__init__() + self.new_shape = new_shape + + def forward(self, x): + z = x.view(self.new_shape) + return z + + def test_fp16_view_copy(self): + inputs = (torch.randn(4, 4).to(torch.float16),) + ( + Tester(self.View((2, 8)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((2, 8)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_inferred_dim(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((-1, 8)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_dynamic_shape(self): + inputs = (torch.randn(4, 4, 6),) + for dynamic_dim_index in range(len(inputs[0].shape)): + dynamic_shapes = { + "x": {dynamic_dim_index: Dim("x", min=1, max=10) * 2}, + } + + # Test as min and max bounds. + test_inputs = [ + (inputs[0].clone(),), + (inputs[0].clone(),), + ] + test_inputs[0][0][dynamic_dim_index] = 2 + test_inputs[1][0][dynamic_dim_index] = 20 + + # Non-dynamic dimensions are halved in the view. + view_size = [n // 2 for n in inputs[0].shape] + view_size[dynamic_dim_index] = -1 + + tester = ( + Tester(self.View(view_size), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + for test_input in test_inputs: + tester.run_method_and_compare_outputs(inputs=test_input) + + def test_fp32_view_copy_unsupported_dynamism(self): + class SymbolicView(torch.nn.Module): + def forward(self, x): + return x.view(x.shape[0] // 2, x.shape[1] * 2) + + inputs = (torch.randn(4, 4),) + dynamic_shapes = { + "x": {1: Dim("x", min=1, max=10) * 2}, + } + ( + Tester(SymbolicView(), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { # Expect no delegation. + torch.ops.higher_order.executorch_call_delegate: 0, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_static_symbolic_arg(self): + class SymbolicView(torch.nn.Module): + def forward(self, x): + return x.view(x.shape[0] // 2, x.shape[1] * 2) + + inputs = (torch.randn(4, 4),) + ( + Tester(SymbolicView(), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_increase_rank(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((1, 2, 4, 2)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_increase_rank_dynamic(self): + test_inputs = ( + (torch.randn(2, 4),), + (torch.randn(10, 4),), + ) + dynamic_shapes = { + "x": {0: Dim("x", min=1, max=10) * 2}, + } + inputs = (torch.randn(4, 4),) + tester = ( + Tester(self.View((1, 2, 4, -1)), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + for test_input in test_inputs: + tester.run_method_and_compare_outputs(inputs=test_input) + + def test_fp32_view_copy_decrease_rank(self): + inputs = (torch.randn(4, 4),) + ( + Tester(self.View((-1)), inputs) + .export() + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_view_copy_decrease_rank_dynamic(self): + test_inputs = ( + (torch.randn(2, 2, 4),), + (torch.randn(2, 10, 4),), + ) + dynamic_shapes = { + "x": {1: Dim("x", min=1, max=10) * 2}, + } + inputs = (torch.randn(2, 4, 4),) + tester = ( + Tester(self.View((-1)), inputs) + .export(Export(dynamic_shapes=dynamic_shapes)) + .check_node_count({torch.ops.aten.view.default: 1}) + .to_edge_transform_and_lower() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + for test_input in test_inputs: + tester.run_method_and_compare_outputs(inputs=test_input) + + def test_fp32_view_copy_nhwc(self): + class ViewNHWC(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view(1, 3, 3, -1) + y = self.conv2(y) + return y.view(1, 3, 2, -1) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(ViewNHWC(), inputs) + .export() + .dump_artifact() + .check_node_count({torch.ops.aten.view.default: 2}) + .to_edge_transform_and_lower() + .dump_artifact() + .check_node_count( + { + torch.ops.higher_order.executorch_call_delegate: 1, + exir_ops.edge.aten.view_copy.default: 0, + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index c1438b29213..5c104bb0df2 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -14,6 +14,7 @@ OpSequencesAddConv2d, ) from executorch.backends.xnnpack.test.tester import RunPasses, Tester +from executorch.exir.dialects._ops import ops as exir_ops class TestChannelsLastTaggedReshapePass(unittest.TestCase): @@ -176,3 +177,151 @@ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self): ) .run_method_and_compare_outputs() ) + + def test_fp32_channels_last_tagged_reshape_pass_nhwc_view(self): + # View can run in NHWC because channel and batch are unchanged. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((1, 3, 3, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 2, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_nchw_view_channel_modified(self): + # View cannot run in NHWC because channel and/or batch are modified. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(6, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((1, 6, 6, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_nchw_view_batch_modified(self): + # View cannot run in NHWC because channel and/or batch are modified. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + y = self.conv1(x) + y = y.view((2, 3, 6, -1)) + return self.conv2(y) + + inputs = (torch.randn(1, 3, 8, 8),) + ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 2, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 4, + } + ) + ) + + def test_fp32_channels_last_tagged_reshape_pass_flatten_view(self): + # View cannot run in NHWC because tensor rank changes. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.linear1 = torch.nn.Linear(36 * 3, 1) + + def forward(self, x): + y = self.conv1(x) + y = y.view((x.shape[0], -1)) + return self.linear1(y) + + inputs = (torch.randn(1, 3, 8, 8),) + tester = ( + Tester(Model(), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 1, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .run_method_and_compare_outputs() + .check_node_count( + { + exir_ops.edge.aten.convolution.default: 1, + exir_ops.edge.aten.view_copy.default: 1, + exir_ops.edge.aten._to_copy.default: 2, + } + ) + ) + + # Verify view is not tagged. + graph = tester.get_artifact().exported_program().module().graph + view_nodes = [ + n for n in graph.nodes if n.target == exir_ops.edge.aten.view_copy.default + ] + self.assertEqual(1, len(view_nodes)) + self.assertTrue(ChannelsLastTaggedReshapePass(None).is_nchw_node(view_nodes[0])) diff --git a/backends/xnnpack/test/passes/test_convert_squeeze_to_view_pass.py b/backends/xnnpack/test/passes/test_convert_squeeze_to_view_pass.py new file mode 100644 index 00000000000..aee81dbdde4 --- /dev/null +++ b/backends/xnnpack/test/passes/test_convert_squeeze_to_view_pass.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +import torch +from executorch.backends.xnnpack._passes.convert_squeeze_to_view_pass import ( + ConvertSqueezeToViewPass, +) +from executorch.backends.xnnpack.test.tester import RunPasses, Tester +from executorch.exir.dialects._ops import ops as exir_ops + + +class TestConvertSqueezeToView(unittest.TestCase): + PassStage = RunPasses([ConvertSqueezeToViewPass]) + + class SqueezeModel(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.squeeze(x, self.dim) + + class UnsqueezeModel(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.unsqueeze(x, self.dim) + + def test_fp32_convert_squeeze_to_view(self): + inputs = (torch.randn(1, 2, 1, 4, 1),) + squeeze_dims = (0, 2, 4) + + for dims in squeeze_dims: + ( + Tester(self.SqueezeModel(dims), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.squeeze_copy.dims: 1, + } + ) + .run_passes(self.PassStage) + .check_node_count( + { + exir_ops.edge.aten.squeeze_copy.dims: 0, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_method_and_compare_outputs() + ) + + def test_fp32_convert_unsqueeze_to_view(self): + inputs = (torch.randn(1, 2, 4),) + + for dim in range(len(inputs[0].shape)): + ( + Tester(self.UnsqueezeModel(dim), inputs) + .export() + .to_edge() + .check_node_count( + { + exir_ops.edge.aten.unsqueeze_copy.default: 1, + } + ) + .run_passes(self.PassStage) + .check_node_count( + { + exir_ops.edge.aten.unsqueeze_copy.default: 0, + exir_ops.edge.aten.view_copy.default: 1, + } + ) + .run_method_and_compare_outputs() + ) diff --git a/backends/xnnpack/test/tester/__init__.py b/backends/xnnpack/test/tester/__init__.py index f92088c72e8..de3c0f55cb7 100644 --- a/backends/xnnpack/test/tester/__init__.py +++ b/backends/xnnpack/test/tester/__init__.py @@ -13,6 +13,7 @@ Serialize, Tester, ToEdge, + ToEdgeTransformAndLower, ToExecutorch, ) @@ -22,6 +23,7 @@ Quantize, Export, ToEdge, + ToEdgeTransformAndLower, RunPasses, ToExecutorch, Serialize, diff --git a/exir/tests/test_memory_format_ops_pass_utils.py b/exir/tests/test_memory_format_ops_pass_utils.py index 8bf810e847e..45182ab1970 100644 --- a/exir/tests/test_memory_format_ops_pass_utils.py +++ b/exir/tests/test_memory_format_ops_pass_utils.py @@ -108,6 +108,8 @@ def memory_format_test_runner( test_set.module, test_set.sample_input, strict=True ).run_decompositions({}) + print(before) + if test_set.use_xnnpack: epm = to_edge_transform_and_lower( before, diff --git a/runtime/core/portable_type/tensor_impl.cpp b/runtime/core/portable_type/tensor_impl.cpp index f0297778744..307210a465e 100644 --- a/runtime/core/portable_type/tensor_impl.cpp +++ b/runtime/core/portable_type/tensor_impl.cpp @@ -95,7 +95,7 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef new_sizes) { case TensorShapeDynamism::STATIC: if (!std::equal(sizes_, sizes_ + dim_, new_sizes.begin())) { #ifdef ET_LOG_ENABLED - std::array old_sizes_str, new_sizes_str; + std::array old_sizes_str, new_sizes_str; executorch::runtime::sizes_to_string( old_sizes_str.data(),