From 8e20f7614b71d05996cae0c45fe7edbf6935efd2 Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Tue, 3 Mar 2026 15:43:04 +0100 Subject: [PATCH 1/4] [TorchToTosa] add conv reshape in core lowering - Insert rank-4/5 reshapes for conv inputs/weights during TorchToTosa lowering Signed-off-by: Cathal Corbett Change-Id: Ica1b5cc265822ecd054f832908ec31bc2325c661 --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 111 +++++++++++++++++++-- test/Conversion/TorchToTosa/basic.mlir | 85 ++++++++++++++-- 2 files changed, 182 insertions(+), 14 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b34e0bf5ca61..d82e5cdcd451 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/DialectResourceBlobManager.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -28,6 +29,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include @@ -48,6 +50,11 @@ namespace mlir::torch { #include "torch-mlir/Conversion/Passes.h.inc" namespace { +struct RankTemplate { + int64_t rank; + RankedTensorType type; + Value shape; +}; // Runs an in-place inclusive prefix sum along the middle dimension (K) of // `running` using a binary lifting scheme. The input must have shape [N, K, C]. @@ -2634,14 +2641,109 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( auto input = adaptor.getInput(); auto weight = adaptor.getWeight(); - auto inputTy = cast(input.getType()); - auto weightTy = cast(weight.getType()); auto outputTy = cast(getTypeConverter()->convertType(op.getType())); + auto inputTy = dyn_cast(input.getType()); + auto weightTy = dyn_cast(weight.getType()); if (!inputTy || !weightTy || !outputTy) return rewriter.notifyMatchFailure( op, "Input, weight and output to Convolution must be ranked tensors"); + int64_t outputRank = outputTy.getRank(); + if (outputRank != 4 && outputRank != 5) + return rewriter.notifyMatchFailure( + op, "Unimplemented: only 2D or 3D convolutions supported"); + + auto funcOp = op->getParentOfType(); + llvm::DenseMap> argToTemplates; + bool templatesBuilt = false; + DominanceInfo domInfo(funcOp); + + auto buildTemplates = [&]() { + if (templatesBuilt) + return; + templatesBuilt = true; + funcOp.walk([&](tosa::ReshapeOp reshapeOp) { + Value source = reshapeOp.getInput1(); + auto blockArg = dyn_cast(source); + if (!blockArg) + return; + + auto dstType = + dyn_cast(reshapeOp.getResult().getType()); + if (!dstType || (dstType.getRank() != 4 && dstType.getRank() != 5)) + return; + + unsigned argNumber = blockArg.getArgNumber(); + auto &templates = argToTemplates[argNumber]; + for (const auto &tmpl : templates) { + if (tmpl.rank == dstType.getRank() && tmpl.type == dstType) + return; + } + templates.push_back( + RankTemplate{dstType.getRank(), dstType, reshapeOp.getShape()}); + }); + }; + + auto normalizeOperandRank = [&](Value operand, + int64_t requiredRank) -> FailureOr { + auto rankedType = dyn_cast(operand.getType()); + if (!rankedType) + return failure(); + if (rankedType.getRank() == requiredRank) + return operand; + + auto blockArg = dyn_cast(operand); + if (!blockArg) + return failure(); + + buildTemplates(); + auto tmplIt = argToTemplates.find(blockArg.getArgNumber()); + if (tmplIt == argToTemplates.end()) + return failure(); + + const RankTemplate *match = nullptr; + for (const auto &tmpl : tmplIt->second) { + if (tmpl.rank == requiredRank) { + match = &tmpl; + break; + } + } + if (!match) + return failure(); + + Value shapeVal = match->shape; + if (auto shapeOp = shapeVal.getDefiningOp()) { + OpBuilder builder(op); + shapeVal = tosa::ConstShapeOp::create( + builder, op->getLoc(), shapeOp.getType(), shapeOp.getValues()); + } else if (!domInfo.properlyDominates(shapeVal, op)) { + return failure(); + } + + auto reshape = tosa::ReshapeOp::create(rewriter, op->getLoc(), match->type, + operand, shapeVal); + return reshape.getResult(); + }; + + if (inputTy.getRank() != outputRank) { + auto normalized = normalizeOperandRank(input, outputRank); + if (failed(normalized)) + return rewriter.notifyMatchFailure( + op, "Input rank mismatch without normalization template"); + input = *normalized; + inputTy = cast(input.getType()); + } + + if (weightTy.getRank() != outputRank) { + auto normalized = normalizeOperandRank(weight, outputRank); + if (failed(normalized)) + return rewriter.notifyMatchFailure( + op, "Weight rank mismatch without normalization template"); + weight = *normalized; + weightTy = cast(weight.getType()); + } + auto inputElemTy = inputTy.getElementType(); auto weightElemTy = weightTy.getElementType(); auto inputShape = makeShapeTorchCompatible(inputTy.getShape()); @@ -2650,16 +2752,11 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( int64_t inputRank = inputTy.getRank(); int64_t weightRank = weightTy.getRank(); - int64_t outputRank = outputTy.getRank(); if (inputRank != weightRank || outputRank != inputRank) return rewriter.notifyMatchFailure( op, "Input, weight and output ranks must match for convolution"); - if (inputRank != 4 && inputRank != 5) - return rewriter.notifyMatchFailure( - op, "Unimplemented: only 2D or 3D convolutions supported"); - bool is3D = inputRank == 5; int64_t spatialRank = inputRank - 2; diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index f95347563fae..9d6fa9af7f1e 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt %s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CHECK // CHECK-LABEL: func.func @torch.aten.tanh$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -13,6 +13,80 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // ----- +// CHECK-LABEL: func.func @conv2d_io_insert_reshape( +// CHECK: %[[SHAPE:.*]] = tosa.const_shape +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const" +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const" +// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]] +// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[SHAPE]] +// CHECK: %[[CONV:.*]] = tosa.conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] +func.func @conv2d_io_insert_reshape(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x1x16xf32> { + %shape = "tosa.const_shape"() {values = dense<[1, 1, 16, 16]> : tensor<4xindex>} : () -> !tosa.shape<4> + %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %r0 = "tosa.reshape"(%arg0, %shape) : (tensor<256xf32>, !tosa.shape<4>) -> tensor<1x1x16x16xf32> + %r1 = "tosa.reshape"(%arg1, %shape) : (tensor<256xf32>, !tosa.shape<4>) -> tensor<1x1x16x16xf32> + %conv = "tosa.conv2d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {pad = array, stride = array, dilation = array, acc_type = f32} : (tensor<1x1x16x16xf32>, tensor<1x1x16x16xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x16xf32> + return %conv : tensor<1x1x1x16xf32> +} + +// CHECK-LABEL: func.func @depthwise_conv2d_io_insert_reshape( +// CHECK: %[[SHAPE:.*]] = tosa.const_shape +// CHECK: %[[WSHAPE:.*]] = tosa.const_shape +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const" +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const" +// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]] +// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]] +// CHECK: %[[CONV:.*]] = tosa.depthwise_conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] +func.func @depthwise_conv2d_io_insert_reshape(%arg0: tensor<9xf32>, %arg1: tensor<9xf32>, %arg2: tensor<1xf32>) -> tensor<1x1x1x1xf32> { + %shape = "tosa.const_shape"() {values = dense<[1, 3, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %wshape = "tosa.const_shape"() {values = dense<[3, 3, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %r0 = "tosa.reshape"(%arg0, %shape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x3x3x1xf32> + %r1 = "tosa.reshape"(%arg1, %wshape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<3x3x1x1xf32> + %conv = "tosa.depthwise_conv2d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {pad = array, stride = array, dilation = array, acc_type = f32} : (tensor<1x3x3x1xf32>, tensor<3x3x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32> + return %conv : tensor<1x1x1x1xf32> +} + +// CHECK-LABEL: func.func @transpose_conv2d_io_insert_reshape( +// CHECK: %[[SHAPE:.*]] = tosa.const_shape +// CHECK: %[[WSHAPE:.*]] = tosa.const_shape +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const" +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const" +// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]] +// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]] +// CHECK: %[[CONV:.*]] = tosa.transpose_conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] +func.func @transpose_conv2d_io_insert_reshape(%arg0: tensor<9xf32>, %arg1: tensor<9xf32>, %arg2: tensor<1xf32>) -> tensor<1x5x5x1xf32> { + %shape = "tosa.const_shape"() {values = dense<[1, 3, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %wshape = "tosa.const_shape"() {values = dense<[1, 3, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %r0 = "tosa.reshape"(%arg0, %shape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x3x3x1xf32> + %r1 = "tosa.reshape"(%arg1, %wshape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x3x3x1xf32> + %conv = "tosa.transpose_conv2d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {out_pad = array, stride = array, acc_type = f32, dilation = array, pad = array} : (tensor<1x3x3x1xf32>, tensor<1x3x3x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x1xf32> + return %conv : tensor<1x5x5x1xf32> +} + +// CHECK-LABEL: func.func @conv3d_io_insert_reshape( +// CHECK: %[[SHAPE:.*]] = tosa.const_shape +// CHECK: %[[WSHAPE:.*]] = tosa.const_shape +// CHECK: %[[INPUT_ZP:.*]] = "tosa.const" +// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const" +// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]] +// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]] +// CHECK: %[[CONV:.*]] = tosa.conv3d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] +func.func @conv3d_io_insert_reshape(%arg0: tensor<64xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x1x4x4x4xf32> { + %shape = "tosa.const_shape"() {values = dense<[1, 1, 4, 4, 4]> : tensor<5xindex>} : () -> !tosa.shape<5> + %wshape = "tosa.const_shape"() {values = dense<[1, 1, 1, 1, 1]> : tensor<5xindex>} : () -> !tosa.shape<5> + %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> + %r0 = "tosa.reshape"(%arg0, %shape) : (tensor<64xf32>, !tosa.shape<5>) -> tensor<1x1x4x4x4xf32> + %r1 = "tosa.reshape"(%arg1, %wshape) : (tensor<1xf32>, !tosa.shape<5>) -> tensor<1x1x1x1x1xf32> + %conv = "tosa.conv3d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {pad = array, stride = array, dilation = array, acc_type = f32} : (tensor<1x1x4x4x4xf32>, tensor<1x1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x4x4x4xf32> + return %conv : tensor<1x1x4x4x4xf32> +} + // CHECK-LABEL: func.func @torch.aten.sigmoid$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor @@ -2417,8 +2491,7 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}} - %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32> + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32> // expected-error {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}} return %3 : !torch.vtensor<[1,192,35,35],f32> } @@ -2664,8 +2737,7 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6 func.func @torch.aten.index.Tensor_hacked_twin.dynamic_size(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[?,1],si64>, %arg2: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[?,4],f32> attributes {torch.assume_strict_symbolic_shapes} { %0 = torch.prim.ListConstruct %arg1, %arg2 : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[1,4],si64>) -> !torch.list - // expected-error @+1 {{failed to legalize operation 'torch.aten.index.Tensor_hacked_twin' that was explicitly marked illegal}} - %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[?,4],f32>, !torch.list -> !torch.vtensor<[?,4],f32> + %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[?,4],f32>, !torch.list -> !torch.vtensor<[?,4],f32> // expected-error {{failed to legalize operation 'torch.aten.index.Tensor_hacked_twin' that was explicitly marked illegal}} return %1 : !torch.vtensor<[?,4],f32> } @@ -4552,8 +4624,7 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{ %none = torch.constant.none %cpu = torch.constant.device "cpu" %false = torch.constant.bool false - // expected-error @below {{failed to legalize operation 'torch.aten.empty.memory_format' that was explicitly marked illegal}} - %out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32> + %out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32> // expected-error {{failed to legalize operation 'torch.aten.empty.memory_format' that was explicitly marked illegal}} return %out : !torch.vtensor<[1,0,256],f32> } From e2b77f6e799d5b3f6036f499492bf5e64394446b Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Wed, 11 Mar 2026 08:50:41 +0100 Subject: [PATCH 2/4] Add conv2d e2e python test Change-Id: I5c0c1a5ae2d90cee500dc76247f5952b99bb48f9 Signed-off-by: Cathal Corbett --- .../torch_mlir_e2e_test/test_suite/conv.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index de9b2d4531c8..8a1bd004ed86 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -273,6 +273,43 @@ def Convolution2DModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) +# ============================================================================== + + +class Convolution2DReshapeInputsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([9], torch.float32, True), + ([9], torch.float32, True), + ([1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + input4d = torch.reshape(inputVec, (1, 1, 3, 3)) + weight4d = torch.reshape(weight, (1, 1, 3, 3)) + return torch.ops.aten.convolution( + input4d, + weight4d, + bias=bias, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Convolution2DReshapeInputsModule()) +def Convolution2DReshapeInputsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(9), tu.rand(9), tu.rand(1)) + + class Convolution2DStaticModule(torch.nn.Module): def __init__(self): super().__init__() From b648054c5f21ef25793f6448abb5b096bdeac981 Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Thu, 12 Mar 2026 14:13:45 +0100 Subject: [PATCH 3/4] Fix mlir testing Change-Id: I55752f920f8ad170b3d35c0c8bf5f8b94c4d9de0 Signed-off-by: Cathal Corbett --- test/Conversion/TorchToTosa/basic.mlir | 155 ++++++++++++++----------- 1 file changed, 90 insertions(+), 65 deletions(-) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 9d6fa9af7f1e..f9e09029ad78 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -14,77 +14,99 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // ----- // CHECK-LABEL: func.func @conv2d_io_insert_reshape( -// CHECK: %[[SHAPE:.*]] = tosa.const_shape -// CHECK: %[[INPUT_ZP:.*]] = "tosa.const" -// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const" -// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]] -// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[SHAPE]] -// CHECK: %[[CONV:.*]] = tosa.conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] -func.func @conv2d_io_insert_reshape(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x1x16xf32> { - %shape = "tosa.const_shape"() {values = dense<[1, 1, 16, 16]> : tensor<4xindex>} : () -> !tosa.shape<4> - %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> - %weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> - %r0 = "tosa.reshape"(%arg0, %shape) : (tensor<256xf32>, !tosa.shape<4>) -> tensor<1x1x16x16xf32> - %r1 = "tosa.reshape"(%arg1, %shape) : (tensor<256xf32>, !tosa.shape<4>) -> tensor<1x1x16x16xf32> - %conv = "tosa.conv2d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {pad = array, stride = array, dilation = array, acc_type = f32} : (tensor<1x1x16x16xf32>, tensor<1x1x16x16xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x16xf32> - return %conv : tensor<1x1x1x16xf32> +// CHECK-DAG: torch_c.to_builtin_tensor %arg0 +// CHECK-DAG: torch_c.to_builtin_tensor %arg1 +// CHECK: tosa.reshape +// CHECK: tosa.reshape +// CHECK: tosa.conv2d +// CHECK-NOT: torch.aten.convolution +func.func @conv2d_io_insert_reshape(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,1,1],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int16 = torch.constant.int 16 + %false = torch.constant.bool false + %shape = torch.prim.ListConstruct %int1, %int1, %int16, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %wshape = torch.prim.ListConstruct %int1, %int1, %int16, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[256],f32>, !torch.list -> !torch.vtensor<[1,1,16,16],f32> + %weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[256],f32>, !torch.list -> !torch.vtensor<[1,1,16,16],f32> + %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %output_padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int1 : !torch.vtensor<[1,1,16,16],f32>, !torch.vtensor<[1,1,16,16],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,1,1],f32> + return %conv : !torch.vtensor<[1,1,1,1],f32> } // CHECK-LABEL: func.func @depthwise_conv2d_io_insert_reshape( -// CHECK: %[[SHAPE:.*]] = tosa.const_shape -// CHECK: %[[WSHAPE:.*]] = tosa.const_shape -// CHECK: %[[INPUT_ZP:.*]] = "tosa.const" -// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const" -// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]] -// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]] -// CHECK: %[[CONV:.*]] = tosa.depthwise_conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] -func.func @depthwise_conv2d_io_insert_reshape(%arg0: tensor<9xf32>, %arg1: tensor<9xf32>, %arg2: tensor<1xf32>) -> tensor<1x1x1x1xf32> { - %shape = "tosa.const_shape"() {values = dense<[1, 3, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> - %wshape = "tosa.const_shape"() {values = dense<[3, 3, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> - %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> - %weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> - %r0 = "tosa.reshape"(%arg0, %shape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x3x3x1xf32> - %r1 = "tosa.reshape"(%arg1, %wshape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<3x3x1x1xf32> - %conv = "tosa.depthwise_conv2d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {pad = array, stride = array, dilation = array, acc_type = f32} : (tensor<1x3x3x1xf32>, tensor<3x3x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32> - return %conv : tensor<1x1x1x1xf32> +// CHECK-DAG: torch_c.to_builtin_tensor %arg0 +// CHECK-DAG: torch_c.to_builtin_tensor %arg1 +// CHECK: tosa.reshape +// CHECK: tosa.reshape +// CHECK: tosa.depthwise_conv2d +// CHECK-NOT: torch.aten.convolution +func.func @depthwise_conv2d_io_insert_reshape(%arg0: !torch.vtensor<[9],f32>, %arg1: !torch.vtensor<[9],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[1,3,1,1],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false = torch.constant.bool false + %shape = torch.prim.ListConstruct %int1, %int3, %int3, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %wshape = torch.prim.ListConstruct %int3, %int1, %int3, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[1,3,3,1],f32> + %weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[3,1,3,1],f32> + %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %output_padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int3 : !torch.vtensor<[1,3,3,1],f32>, !torch.vtensor<[3,1,3,1],f32>, !torch.vtensor<[3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,1,1],f32> + return %conv : !torch.vtensor<[1,3,1,1],f32> } // CHECK-LABEL: func.func @transpose_conv2d_io_insert_reshape( -// CHECK: %[[SHAPE:.*]] = tosa.const_shape -// CHECK: %[[WSHAPE:.*]] = tosa.const_shape -// CHECK: %[[INPUT_ZP:.*]] = "tosa.const" -// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const" -// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]] -// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]] -// CHECK: %[[CONV:.*]] = tosa.transpose_conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] -func.func @transpose_conv2d_io_insert_reshape(%arg0: tensor<9xf32>, %arg1: tensor<9xf32>, %arg2: tensor<1xf32>) -> tensor<1x5x5x1xf32> { - %shape = "tosa.const_shape"() {values = dense<[1, 3, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> - %wshape = "tosa.const_shape"() {values = dense<[1, 3, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> - %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> - %weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> - %r0 = "tosa.reshape"(%arg0, %shape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x3x3x1xf32> - %r1 = "tosa.reshape"(%arg1, %wshape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x3x3x1xf32> - %conv = "tosa.transpose_conv2d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {out_pad = array, stride = array, acc_type = f32, dilation = array, pad = array} : (tensor<1x3x3x1xf32>, tensor<1x3x3x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x1xf32> - return %conv : tensor<1x5x5x1xf32> +// CHECK-DAG: torch_c.to_builtin_tensor %arg0 +// CHECK-DAG: torch_c.to_builtin_tensor %arg1 +// CHECK: tosa.reshape +// CHECK: tosa.reshape +// CHECK: tosa.transpose_conv2d +// CHECK-NOT: torch.aten.convolution +func.func @transpose_conv2d_io_insert_reshape(%arg0: !torch.vtensor<[9],f32>, %arg1: !torch.vtensor<[9],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,5,5],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %true = torch.constant.bool true + %shape = torch.prim.ListConstruct %int1, %int1, %int3, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %wshape = torch.prim.ListConstruct %int1, %int1, %int3, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[1,1,3,3],f32> + %weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[1,1,3,3],f32> + %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %output_padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %true, %output_padding, %int1 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,5,5],f32> + return %conv : !torch.vtensor<[1,1,5,5],f32> } // CHECK-LABEL: func.func @conv3d_io_insert_reshape( -// CHECK: %[[SHAPE:.*]] = tosa.const_shape -// CHECK: %[[WSHAPE:.*]] = tosa.const_shape -// CHECK: %[[INPUT_ZP:.*]] = "tosa.const" -// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const" -// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]] -// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]] -// CHECK: %[[CONV:.*]] = tosa.conv3d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] -func.func @conv3d_io_insert_reshape(%arg0: tensor<64xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x1x4x4x4xf32> { - %shape = "tosa.const_shape"() {values = dense<[1, 1, 4, 4, 4]> : tensor<5xindex>} : () -> !tosa.shape<5> - %wshape = "tosa.const_shape"() {values = dense<[1, 1, 1, 1, 1]> : tensor<5xindex>} : () -> !tosa.shape<5> - %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> - %weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> - %r0 = "tosa.reshape"(%arg0, %shape) : (tensor<64xf32>, !tosa.shape<5>) -> tensor<1x1x4x4x4xf32> - %r1 = "tosa.reshape"(%arg1, %wshape) : (tensor<1xf32>, !tosa.shape<5>) -> tensor<1x1x1x1x1xf32> - %conv = "tosa.conv3d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {pad = array, stride = array, dilation = array, acc_type = f32} : (tensor<1x1x4x4x4xf32>, tensor<1x1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x4x4x4xf32> - return %conv : tensor<1x1x4x4x4xf32> +// CHECK-DAG: torch_c.to_builtin_tensor %arg0 +// CHECK-DAG: torch_c.to_builtin_tensor %arg1 +// CHECK: tosa.reshape +// CHECK: tosa.reshape +// CHECK: tosa.conv3d +// CHECK-NOT: torch.aten.convolution +func.func @conv3d_io_insert_reshape(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,4,4,4],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %false = torch.constant.bool false + %shape = torch.prim.ListConstruct %int1, %int1, %int4, %int4, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %wshape = torch.prim.ListConstruct %int1, %int1, %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[64],f32>, !torch.list -> !torch.vtensor<[1,1,4,4,4],f32> + %weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[1,1,1,1,1],f32> + %stride = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %output_padding = torch.prim.ListConstruct %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int1 : !torch.vtensor<[1,1,4,4,4],f32>, !torch.vtensor<[1,1,1,1,1],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,4,4],f32> + return %conv : !torch.vtensor<[1,1,4,4,4],f32> } // CHECK-LABEL: func.func @torch.aten.sigmoid$basic( @@ -2491,7 +2513,8 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32> // expected-error {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}} + // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}} + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32> return %3 : !torch.vtensor<[1,192,35,35],f32> } @@ -2737,7 +2760,8 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6 func.func @torch.aten.index.Tensor_hacked_twin.dynamic_size(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[?,1],si64>, %arg2: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[?,4],f32> attributes {torch.assume_strict_symbolic_shapes} { %0 = torch.prim.ListConstruct %arg1, %arg2 : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[1,4],si64>) -> !torch.list - %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[?,4],f32>, !torch.list -> !torch.vtensor<[?,4],f32> // expected-error {{failed to legalize operation 'torch.aten.index.Tensor_hacked_twin' that was explicitly marked illegal}} + // expected-error @+1 {{failed to legalize operation 'torch.aten.index.Tensor_hacked_twin' that was explicitly marked illegal}} + %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[?,4],f32>, !torch.list -> !torch.vtensor<[?,4],f32> return %1 : !torch.vtensor<[?,4],f32> } @@ -4624,7 +4648,8 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{ %none = torch.constant.none %cpu = torch.constant.device "cpu" %false = torch.constant.bool false - %out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32> // expected-error {{failed to legalize operation 'torch.aten.empty.memory_format' that was explicitly marked illegal}} + // expected-error @below {{failed to legalize operation 'torch.aten.empty.memory_format' that was explicitly marked illegal}} + %out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32> return %out : !torch.vtensor<[1,0,256],f32> } From 28854a45c2e06f6b75fd81de66f4a56422a172dc Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Thu, 12 Mar 2026 17:32:25 +0100 Subject: [PATCH 4/4] localize conv reshape inference and update tests Change-Id: I9c67bbbb030511a72ec8f4c59b3498ba4dfdb0fc Signed-off-by: Cathal Corbett --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 184 ++++++++++++++++----- test/Conversion/TorchToTosa/basic.mlir | 75 +++++---- 2 files changed, 184 insertions(+), 75 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index d82e5cdcd451..ef52150db1f9 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" @@ -27,9 +28,11 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include @@ -54,6 +57,7 @@ struct RankTemplate { int64_t rank; RankedTensorType type; Value shape; + std::optional> shapeValues; }; // Runs an in-place inclusive prefix sum along the middle dimension (K) of @@ -2655,37 +2659,113 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( op, "Unimplemented: only 2D or 3D convolutions supported"); auto funcOp = op->getParentOfType(); - llvm::DenseMap> argToTemplates; - bool templatesBuilt = false; DominanceInfo domInfo(funcOp); - auto buildTemplates = [&]() { - if (templatesBuilt) - return; - templatesBuilt = true; - funcOp.walk([&](tosa::ReshapeOp reshapeOp) { - Value source = reshapeOp.getInput1(); - auto blockArg = dyn_cast(source); - if (!blockArg) - return; + auto peelTrivialDefs = [](Value source) -> Value { + while (true) { + if (auto unrealized = + source.getDefiningOp()) { + if (unrealized->getNumOperands() == 1) { + source = unrealized.getOperand(0); + continue; + } + } + if (auto castOp = source.getDefiningOp()) { + source = castOp.getSource(); + continue; + } + if (auto toBuiltin = + source.getDefiningOp()) { + source = toBuiltin.getOperand(); + continue; + } + break; + } + return source; + }; - auto dstType = - dyn_cast(reshapeOp.getResult().getType()); - if (!dstType || (dstType.getRank() != 4 && dstType.getRank() != 5)) + auto addTemplate = [&](SmallVectorImpl &templates, int64_t rank, + RankedTensorType type, Value shape, + std::optional> shapeValues) { + for (const auto &tmpl : templates) { + if (tmpl.rank == rank && tmpl.type == type) return; + } + templates.push_back(RankTemplate{rank, type, shape, shapeValues}); + }; - unsigned argNumber = blockArg.getArgNumber(); - auto &templates = argToTemplates[argNumber]; - for (const auto &tmpl : templates) { - if (tmpl.rank == dstType.getRank() && tmpl.type == dstType) - return; + auto collectTemplatesFromSource = + [&](Value source) -> SmallVector { + SmallVector templates; + if (!source) + return templates; + source = peelTrivialDefs(source); + + SmallVector worklist; + llvm::SmallDenseSet visited; + worklist.push_back(source); + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current).second) + continue; + + for (OpOperand &use : current.getUses()) { + Operation *user = use.getOwner(); + + if (auto reshapeOp = dyn_cast(user)) { + if (!domInfo.properlyDominates(reshapeOp.getOperation(), op)) + continue; + auto dstType = + dyn_cast(reshapeOp.getResult().getType()); + if (!dstType || (dstType.getRank() != 4 && dstType.getRank() != 5)) + continue; + addTemplate(templates, dstType.getRank(), dstType, + reshapeOp.getShape(), std::nullopt); + continue; + } + + if (auto reshapeOp = dyn_cast(user)) { + if (!domInfo.properlyDominates(reshapeOp.getOperation(), op)) + continue; + auto torchTy = + dyn_cast(reshapeOp.getResult().getType()); + if (!torchTy || !torchTy.hasSizes() || !torchTy.hasDtype()) + continue; + auto dstType = dyn_cast(torchTy.toBuiltinTensor()); + if (!dstType || (dstType.getRank() != 4 && dstType.getRank() != 5)) + continue; + SmallVector shapeValues; + for (int64_t dim : dstType.getShape()) + shapeValues.push_back(dim); + addTemplate(templates, dstType.getRank(), dstType, Value(), + shapeValues); + continue; + } + + if (auto toBuiltin = + dyn_cast(user)) { + worklist.push_back(toBuiltin.getResult()); + continue; + } + + if (auto unrealized = dyn_cast(user)) { + if (unrealized->getNumResults() == 1) + worklist.push_back(unrealized.getResult(0)); + continue; + } + + if (auto castOp = dyn_cast(user)) { + worklist.push_back(castOp.getResult()); + continue; + } } - templates.push_back( - RankTemplate{dstType.getRank(), dstType, reshapeOp.getShape()}); - }); + } + + return templates; }; - auto normalizeOperandRank = [&](Value operand, + auto normalizeOperandRank = [&](Value operand, Value torchOperand, int64_t requiredRank) -> FailureOr { auto rankedType = dyn_cast(operand.getType()); if (!rankedType) @@ -2693,30 +2773,50 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( if (rankedType.getRank() == requiredRank) return operand; - auto blockArg = dyn_cast(operand); - if (!blockArg) + SmallVector templates = collectTemplatesFromSource(operand); + if (torchOperand) + templates.append(collectTemplatesFromSource(torchOperand)); + if (!templates.empty()) { + SmallVector deduped; + for (const auto &tmpl : templates) + addTemplate(deduped, tmpl.rank, tmpl.type, tmpl.shape, + tmpl.shapeValues); + templates.swap(deduped); + } + if (templates.empty()) return failure(); - buildTemplates(); - auto tmplIt = argToTemplates.find(blockArg.getArgNumber()); - if (tmplIt == argToTemplates.end()) - return failure(); - - const RankTemplate *match = nullptr; - for (const auto &tmpl : tmplIt->second) { - if (tmpl.rank == requiredRank) { - match = &tmpl; - break; - } + auto operandTy = dyn_cast(operand.getType()); + auto operandElemTy = operandTy.getElementType(); + std::optional operandNumElements; + if (operandTy.hasStaticShape()) + operandNumElements = operandTy.getNumElements(); + SmallVector candidates; + for (const auto &tmpl : templates) { + if (tmpl.rank != requiredRank) + continue; + if (tmpl.type.getElementType() != operandElemTy) + continue; + if (operandNumElements && tmpl.type.hasStaticShape() && + tmpl.type.getNumElements() != *operandNumElements) + continue; + candidates.push_back(&tmpl); } - if (!match) + if (candidates.empty()) return failure(); + if (candidates.size() != 1) + return failure(); + const RankTemplate *match = candidates.front(); Value shapeVal = match->shape; - if (auto shapeOp = shapeVal.getDefiningOp()) { - OpBuilder builder(op); + if (!shapeVal) { + if (!match->shapeValues) + return failure(); + shapeVal = + tosa::getTosaConstShape(rewriter, op->getLoc(), *match->shapeValues); + } else if (auto shapeOp = shapeVal.getDefiningOp()) { shapeVal = tosa::ConstShapeOp::create( - builder, op->getLoc(), shapeOp.getType(), shapeOp.getValues()); + rewriter, op->getLoc(), shapeOp.getType(), shapeOp.getValues()); } else if (!domInfo.properlyDominates(shapeVal, op)) { return failure(); } @@ -2727,7 +2827,7 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( }; if (inputTy.getRank() != outputRank) { - auto normalized = normalizeOperandRank(input, outputRank); + auto normalized = normalizeOperandRank(input, op.getInput(), outputRank); if (failed(normalized)) return rewriter.notifyMatchFailure( op, "Input rank mismatch without normalization template"); @@ -2736,7 +2836,7 @@ LogicalResult ConvertAtenOp::matchAndRewriteImpl( } if (weightTy.getRank() != outputRank) { - auto normalized = normalizeOperandRank(weight, outputRank); + auto normalized = normalizeOperandRank(weight, op.getWeight(), outputRank); if (failed(normalized)) return rewriter.notifyMatchFailure( op, "Weight rank mismatch without normalization template"); diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index f9e09029ad78..39a4c4f049f4 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CHECK +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.tanh$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -14,11 +14,13 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // ----- // CHECK-LABEL: func.func @conv2d_io_insert_reshape( -// CHECK-DAG: torch_c.to_builtin_tensor %arg0 -// CHECK-DAG: torch_c.to_builtin_tensor %arg1 -// CHECK: tosa.reshape -// CHECK: tosa.reshape -// CHECK: tosa.conv2d +// CHECK-DAG: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK-DAG: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK-DAG: %[[NORM_INPUT_NHWC:.*]] = tosa.transpose %[[NORM_INPUT:.*]] {perms = array} : (tensor<1x1x16x16xf32>) -> tensor<1x16x16x1xf32> +// CHECK-DAG: %[[NORM_WEIGHT_NHWC:.*]] = tosa.transpose %[[NORM_WEIGHT:.*]] {perms = array} : (tensor<1x1x16x16xf32>) -> tensor<1x16x16x1xf32> +// CHECK-DAG: %[[NORM_INPUT]] = tosa.reshape %[[ARG0_BUILTIN]], %{{.*}} : (tensor<256xf32>, !tosa.shape<4>) -> tensor<1x1x16x16xf32> +// CHECK-DAG: %[[NORM_WEIGHT]] = tosa.reshape %[[ARG1_BUILTIN]], %{{.*}} : (tensor<256xf32>, !tosa.shape<4>) -> tensor<1x1x16x16xf32> +// CHECK: tosa.conv2d %[[NORM_INPUT_NHWC]], %[[NORM_WEIGHT_NHWC]] // CHECK-NOT: torch.aten.convolution func.func @conv2d_io_insert_reshape(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,1,1],f32> { %int0 = torch.constant.int 0 @@ -27,22 +29,25 @@ func.func @conv2d_io_insert_reshape(%arg0: !torch.vtensor<[256],f32>, %arg1: !to %false = torch.constant.bool false %shape = torch.prim.ListConstruct %int1, %int1, %int16, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %wshape = torch.prim.ListConstruct %int1, %int1, %int16, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[256],f32>, !torch.list -> !torch.vtensor<[1,1,16,16],f32> - %weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[256],f32>, !torch.list -> !torch.vtensor<[1,1,16,16],f32> + %_input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[256],f32>, !torch.list -> !torch.vtensor<[1,1,16,16],f32> + %_weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[256],f32>, !torch.list -> !torch.vtensor<[1,1,16,16],f32> %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list %dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %output_padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list - %conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int1 : !torch.vtensor<[1,1,16,16],f32>, !torch.vtensor<[1,1,16,16],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,1,1],f32> + %conv = torch.aten.convolution %arg0, %arg1, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,1,1],f32> return %conv : !torch.vtensor<[1,1,1,1],f32> } // CHECK-LABEL: func.func @depthwise_conv2d_io_insert_reshape( -// CHECK-DAG: torch_c.to_builtin_tensor %arg0 -// CHECK-DAG: torch_c.to_builtin_tensor %arg1 -// CHECK: tosa.reshape -// CHECK: tosa.reshape -// CHECK: tosa.depthwise_conv2d +// CHECK-DAG: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[9],f32> -> tensor<9xf32> +// CHECK-DAG: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[9],f32> -> tensor<9xf32> +// CHECK-DAG: %[[NORM_INPUT_NHWC:.*]] = tosa.transpose %[[NORM_INPUT:.*]] {perms = array} : (tensor<1x3x3x1xf32>) -> tensor<1x3x1x3xf32> +// CHECK-DAG: %[[NORM_WEIGHT_NHWC:.*]] = tosa.transpose %[[NORM_WEIGHT:.*]] {perms = array} : (tensor<3x1x3x1xf32>) -> tensor<3x1x3x1xf32> +// CHECK-DAG: %[[NORM_WEIGHT_FOR_CONV:.*]] = tosa.reshape %[[NORM_WEIGHT_NHWC]], %{{.*}} : (tensor<3x1x3x1xf32>, !tosa.shape<4>) -> tensor<3x1x3x1xf32> +// CHECK-DAG: %[[NORM_INPUT]] = tosa.reshape %[[ARG0_BUILTIN]], %{{.*}} : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x3x3x1xf32> +// CHECK-DAG: %[[NORM_WEIGHT]] = tosa.reshape %[[ARG1_BUILTIN]], %{{.*}} : (tensor<9xf32>, !tosa.shape<4>) -> tensor<3x1x3x1xf32> +// CHECK: tosa.depthwise_conv2d %[[NORM_INPUT_NHWC]], %[[NORM_WEIGHT_FOR_CONV]] // CHECK-NOT: torch.aten.convolution func.func @depthwise_conv2d_io_insert_reshape(%arg0: !torch.vtensor<[9],f32>, %arg1: !torch.vtensor<[9],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[1,3,1,1],f32> { %int0 = torch.constant.int 0 @@ -51,22 +56,24 @@ func.func @depthwise_conv2d_io_insert_reshape(%arg0: !torch.vtensor<[9],f32>, %a %false = torch.constant.bool false %shape = torch.prim.ListConstruct %int1, %int3, %int3, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %wshape = torch.prim.ListConstruct %int3, %int1, %int3, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[1,3,3,1],f32> - %weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[3,1,3,1],f32> + %_input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[1,3,3,1],f32> + %_weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[3,1,3,1],f32> %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list %dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %output_padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list - %conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int3 : !torch.vtensor<[1,3,3,1],f32>, !torch.vtensor<[3,1,3,1],f32>, !torch.vtensor<[3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,1,1],f32> + %conv = torch.aten.convolution %arg0, %arg1, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int3 : !torch.vtensor<[9],f32>, !torch.vtensor<[9],f32>, !torch.vtensor<[3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,1,1],f32> return %conv : !torch.vtensor<[1,3,1,1],f32> } // CHECK-LABEL: func.func @transpose_conv2d_io_insert_reshape( -// CHECK-DAG: torch_c.to_builtin_tensor %arg0 -// CHECK-DAG: torch_c.to_builtin_tensor %arg1 -// CHECK: tosa.reshape -// CHECK: tosa.reshape -// CHECK: tosa.transpose_conv2d +// CHECK-DAG: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[9],f32> -> tensor<9xf32> +// CHECK-DAG: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[9],f32> -> tensor<9xf32> +// CHECK-DAG: %[[NORM_INPUT_NHWC:.*]] = tosa.transpose %[[NORM_INPUT:.*]] {perms = array} : (tensor<1x1x3x3xf32>) -> tensor<1x3x3x1xf32> +// CHECK-DAG: %[[NORM_WEIGHT_NHWC:.*]] = tosa.transpose %[[NORM_WEIGHT:.*]] {perms = array} : (tensor<1x1x3x3xf32>) -> tensor<1x3x3x1xf32> +// CHECK-DAG: %[[NORM_INPUT]] = tosa.reshape %[[ARG0_BUILTIN]], %{{.*}} : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x1x3x3xf32> +// CHECK-DAG: %[[NORM_WEIGHT]] = tosa.reshape %[[ARG1_BUILTIN]], %{{.*}} : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x1x3x3xf32> +// CHECK: tosa.transpose_conv2d %[[NORM_INPUT_NHWC]], %[[NORM_WEIGHT_NHWC]] // CHECK-NOT: torch.aten.convolution func.func @transpose_conv2d_io_insert_reshape(%arg0: !torch.vtensor<[9],f32>, %arg1: !torch.vtensor<[9],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,5,5],f32> { %int0 = torch.constant.int 0 @@ -75,22 +82,24 @@ func.func @transpose_conv2d_io_insert_reshape(%arg0: !torch.vtensor<[9],f32>, %a %true = torch.constant.bool true %shape = torch.prim.ListConstruct %int1, %int1, %int3, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %wshape = torch.prim.ListConstruct %int1, %int1, %int3, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[1,1,3,3],f32> - %weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[1,1,3,3],f32> + %_input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[1,1,3,3],f32> + %_weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[9],f32>, !torch.list -> !torch.vtensor<[1,1,3,3],f32> %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list %dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %output_padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list - %conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %true, %output_padding, %int1 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,5,5],f32> + %conv = torch.aten.convolution %arg0, %arg1, %arg2, %stride, %padding, %dilation, %true, %output_padding, %int1 : !torch.vtensor<[9],f32>, !torch.vtensor<[9],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,5,5],f32> return %conv : !torch.vtensor<[1,1,5,5],f32> } // CHECK-LABEL: func.func @conv3d_io_insert_reshape( -// CHECK-DAG: torch_c.to_builtin_tensor %arg0 -// CHECK-DAG: torch_c.to_builtin_tensor %arg1 -// CHECK: tosa.reshape -// CHECK: tosa.reshape -// CHECK: tosa.conv3d +// CHECK-DAG: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK-DAG: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1],f32> -> tensor<1xf32> +// CHECK-DAG: %[[NORM_INPUT_NHWC:.*]] = tosa.transpose %[[NORM_INPUT:.*]] {perms = array} : (tensor<1x1x4x4x4xf32>) -> tensor<1x4x4x4x1xf32> +// CHECK-DAG: %[[NORM_WEIGHT_NHWC:.*]] = tosa.transpose %[[NORM_WEIGHT:.*]] {perms = array} : (tensor<1x1x1x1x1xf32>) -> tensor<1x1x1x1x1xf32> +// CHECK-DAG: %[[NORM_INPUT]] = tosa.reshape %[[ARG0_BUILTIN]], %{{.*}} : (tensor<64xf32>, !tosa.shape<5>) -> tensor<1x1x4x4x4xf32> +// CHECK-DAG: %[[NORM_WEIGHT]] = tosa.reshape %[[ARG1_BUILTIN]], %{{.*}} : (tensor<1xf32>, !tosa.shape<5>) -> tensor<1x1x1x1x1xf32> +// CHECK: tosa.conv3d %[[NORM_INPUT_NHWC]], %[[NORM_WEIGHT_NHWC]] // CHECK-NOT: torch.aten.convolution func.func @conv3d_io_insert_reshape(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,4,4,4],f32> { %int0 = torch.constant.int 0 @@ -99,13 +108,13 @@ func.func @conv3d_io_insert_reshape(%arg0: !torch.vtensor<[64],f32>, %arg1: !tor %false = torch.constant.bool false %shape = torch.prim.ListConstruct %int1, %int1, %int4, %int4, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %wshape = torch.prim.ListConstruct %int1, %int1, %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[64],f32>, !torch.list -> !torch.vtensor<[1,1,4,4,4],f32> - %weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[1,1,1,1,1],f32> + %_input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[64],f32>, !torch.list -> !torch.vtensor<[1,1,4,4,4],f32> + %_weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[1,1,1,1,1],f32> %stride = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list %dilation = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list %output_padding = torch.prim.ListConstruct %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int1 : !torch.vtensor<[1,1,4,4,4],f32>, !torch.vtensor<[1,1,1,1,1],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,4,4],f32> + %conv = torch.aten.convolution %arg0, %arg1, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int1 : !torch.vtensor<[64],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,4,4],f32> return %conv : !torch.vtensor<[1,1,4,4,4],f32> }