From 4d85c0fdcdf4d948de5e5ae0e891d4c046d9ecc1 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Sat, 18 Jan 2025 21:47:43 -0800 Subject: [PATCH] Update XNN delegate to handle non-decomposed upsample ops (#7770) Summary: I'm preparing to disable decomposition by default for core upsample ops in pytorch/pytorch#141791. This PR makes changes to ExecuTorch to handle non-decomposed upsample_bilinear2d. There are tests that expect the decomposition when not delegated, as well as a recomposition pass in the XNNPACK delegate which need updates to function. The changes in this PR are intended to work with both the decomposed and non-decomposed upsample ops. After the decomposition change lands in PyTorch, I will clean up the remaining ExecuTorch usages. Differential Revision: D68374352 --- .../_passes/convert_to_upsample_bilinear2d.py | 6 ++++++ backends/xnnpack/test/ops/test_bilinear2d.py | 16 ++++++++++++---- exir/emit/test/test_emit.py | 2 +- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/backends/xnnpack/_passes/convert_to_upsample_bilinear2d.py b/backends/xnnpack/_passes/convert_to_upsample_bilinear2d.py index 47bff3b99eb..a3cf2d1ac4d 100644 --- a/backends/xnnpack/_passes/convert_to_upsample_bilinear2d.py +++ b/backends/xnnpack/_passes/convert_to_upsample_bilinear2d.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import torch from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass from executorch.backends.xnnpack.partition.graphs import bilinear_2d @@ -23,6 +25,10 @@ def create_upsample_bilinear_2d( align_corners: bool, ): output = internal_match.returning_nodes[0] + if output.target == exir_ops.edge.aten.upsample_bilinear2d.vec: + # Op was not decomposed, do nothing + return + output_shape = output.meta["val"].shape output_h = output_shape[-2] output_w = output_shape[-1] diff --git a/backends/xnnpack/test/ops/test_bilinear2d.py b/backends/xnnpack/test/ops/test_bilinear2d.py index 6a194763657..24c990d6bb1 100644 --- a/backends/xnnpack/test/ops/test_bilinear2d.py +++ b/backends/xnnpack/test/ops/test_bilinear2d.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import unittest import torch @@ -131,11 +133,17 @@ def test_fp32_bilinear2d_dynamic_bilinear2d_not_partitioned(self): 3: torch.export.Dim("w", min=1, max=12), } } - ( + artifact_str = str( Tester(self.StaticResizeBilinear2dModule(), example_inputs) .export(Export(dynamic_shapes)) .to_edge_transform_and_lower() - # NOTE The decomposition is partially delegated. This will need to be replaced - # with the aten upsample op once decomp is removed. - .check("executorch_exir_dialects_edge__ops_aten_index_Tensor") + .get_artifact() + .exported_program() + ) + # NOTE The decomposition can be partially delegated. This will need to be replaced + # with the aten upsample op once decomp is removed. + self.assertTrue( + "executorch_exir_dialects_edge__ops_aten_index_Tensor" in artifact_str + or "executorch_exir_dialects_edge__ops_aten_upsample_bilinear2d_vec" + in artifact_str ) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 3fca3958feb..349a68cbd7c 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -642,7 +642,7 @@ class M(torch.nn.Module): def forward(self, x): return torch.nn.functional.interpolate(x, scale_factor=2) - x = (torch.randn(1, 1, 2, 2),) + x = (torch.randn(1, 1, 2, 2, 2),) program = ( to_edge(export(M(), x, strict=True)).to_executorch().executorch_program )