Skip to content

Commit 8e20f76

Browse files
committed
[TorchToTosa] add conv reshape in core lowering
- Insert rank-4/5 reshapes for conv inputs/weights during TorchToTosa lowering Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: Ica1b5cc265822ecd054f832908ec31bc2325c661
1 parent 4b48bb7 commit 8e20f76

File tree

2 files changed

+182
-14
lines changed

2 files changed

+182
-14
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 104 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1616
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1717
#include "mlir/IR/DialectResourceBlobManager.h"
18+
#include "mlir/IR/Dominance.h"
1819
#include "mlir/IR/Matchers.h"
1920
#include "mlir/Pass/Pass.h"
2021
#include "mlir/Transforms/DialectConversion.h"
@@ -28,6 +29,7 @@
2829
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
2930
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
3031
#include "llvm/ADT/APInt.h"
32+
#include "llvm/ADT/DenseMap.h"
3133
#include "llvm/ADT/STLExtras.h"
3234
#include "llvm/ADT/TypeSwitch.h"
3335
#include <cmath>
@@ -48,6 +50,11 @@ namespace mlir::torch {
4850
#include "torch-mlir/Conversion/Passes.h.inc"
4951

5052
namespace {
53+
struct RankTemplate {
54+
int64_t rank;
55+
RankedTensorType type;
56+
Value shape;
57+
};
5158

5259
// Runs an in-place inclusive prefix sum along the middle dimension (K) of
5360
// `running` using a binary lifting scheme. The input must have shape [N, K, C].
@@ -2634,14 +2641,109 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewriteImpl(
26342641
auto input = adaptor.getInput();
26352642
auto weight = adaptor.getWeight();
26362643

2637-
auto inputTy = cast<RankedTensorType>(input.getType());
2638-
auto weightTy = cast<RankedTensorType>(weight.getType());
26392644
auto outputTy =
26402645
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
2646+
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
2647+
auto weightTy = dyn_cast<RankedTensorType>(weight.getType());
26412648
if (!inputTy || !weightTy || !outputTy)
26422649
return rewriter.notifyMatchFailure(
26432650
op, "Input, weight and output to Convolution must be ranked tensors");
26442651

2652+
int64_t outputRank = outputTy.getRank();
2653+
if (outputRank != 4 && outputRank != 5)
2654+
return rewriter.notifyMatchFailure(
2655+
op, "Unimplemented: only 2D or 3D convolutions supported");
2656+
2657+
auto funcOp = op->getParentOfType<func::FuncOp>();
2658+
llvm::DenseMap<unsigned, SmallVector<RankTemplate>> argToTemplates;
2659+
bool templatesBuilt = false;
2660+
DominanceInfo domInfo(funcOp);
2661+
2662+
auto buildTemplates = [&]() {
2663+
if (templatesBuilt)
2664+
return;
2665+
templatesBuilt = true;
2666+
funcOp.walk([&](tosa::ReshapeOp reshapeOp) {
2667+
Value source = reshapeOp.getInput1();
2668+
auto blockArg = dyn_cast<BlockArgument>(source);
2669+
if (!blockArg)
2670+
return;
2671+
2672+
auto dstType =
2673+
dyn_cast<RankedTensorType>(reshapeOp.getResult().getType());
2674+
if (!dstType || (dstType.getRank() != 4 && dstType.getRank() != 5))
2675+
return;
2676+
2677+
unsigned argNumber = blockArg.getArgNumber();
2678+
auto &templates = argToTemplates[argNumber];
2679+
for (const auto &tmpl : templates) {
2680+
if (tmpl.rank == dstType.getRank() && tmpl.type == dstType)
2681+
return;
2682+
}
2683+
templates.push_back(
2684+
RankTemplate{dstType.getRank(), dstType, reshapeOp.getShape()});
2685+
});
2686+
};
2687+
2688+
auto normalizeOperandRank = [&](Value operand,
2689+
int64_t requiredRank) -> FailureOr<Value> {
2690+
auto rankedType = dyn_cast<RankedTensorType>(operand.getType());
2691+
if (!rankedType)
2692+
return failure();
2693+
if (rankedType.getRank() == requiredRank)
2694+
return operand;
2695+
2696+
auto blockArg = dyn_cast<BlockArgument>(operand);
2697+
if (!blockArg)
2698+
return failure();
2699+
2700+
buildTemplates();
2701+
auto tmplIt = argToTemplates.find(blockArg.getArgNumber());
2702+
if (tmplIt == argToTemplates.end())
2703+
return failure();
2704+
2705+
const RankTemplate *match = nullptr;
2706+
for (const auto &tmpl : tmplIt->second) {
2707+
if (tmpl.rank == requiredRank) {
2708+
match = &tmpl;
2709+
break;
2710+
}
2711+
}
2712+
if (!match)
2713+
return failure();
2714+
2715+
Value shapeVal = match->shape;
2716+
if (auto shapeOp = shapeVal.getDefiningOp<tosa::ConstShapeOp>()) {
2717+
OpBuilder builder(op);
2718+
shapeVal = tosa::ConstShapeOp::create(
2719+
builder, op->getLoc(), shapeOp.getType(), shapeOp.getValues());
2720+
} else if (!domInfo.properlyDominates(shapeVal, op)) {
2721+
return failure();
2722+
}
2723+
2724+
auto reshape = tosa::ReshapeOp::create(rewriter, op->getLoc(), match->type,
2725+
operand, shapeVal);
2726+
return reshape.getResult();
2727+
};
2728+
2729+
if (inputTy.getRank() != outputRank) {
2730+
auto normalized = normalizeOperandRank(input, outputRank);
2731+
if (failed(normalized))
2732+
return rewriter.notifyMatchFailure(
2733+
op, "Input rank mismatch without normalization template");
2734+
input = *normalized;
2735+
inputTy = cast<RankedTensorType>(input.getType());
2736+
}
2737+
2738+
if (weightTy.getRank() != outputRank) {
2739+
auto normalized = normalizeOperandRank(weight, outputRank);
2740+
if (failed(normalized))
2741+
return rewriter.notifyMatchFailure(
2742+
op, "Weight rank mismatch without normalization template");
2743+
weight = *normalized;
2744+
weightTy = cast<RankedTensorType>(weight.getType());
2745+
}
2746+
26452747
auto inputElemTy = inputTy.getElementType();
26462748
auto weightElemTy = weightTy.getElementType();
26472749
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
@@ -2650,16 +2752,11 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewriteImpl(
26502752

26512753
int64_t inputRank = inputTy.getRank();
26522754
int64_t weightRank = weightTy.getRank();
2653-
int64_t outputRank = outputTy.getRank();
26542755

26552756
if (inputRank != weightRank || outputRank != inputRank)
26562757
return rewriter.notifyMatchFailure(
26572758
op, "Input, weight and output ranks must match for convolution");
26582759

2659-
if (inputRank != 4 && inputRank != 5)
2660-
return rewriter.notifyMatchFailure(
2661-
op, "Unimplemented: only 2D or 3D convolutions supported");
2662-
26632760
bool is3D = inputRank == 5;
26642761
int64_t spatialRank = inputRank - 2;
26652762

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s
1+
// RUN: torch-mlir-opt %s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CHECK
22

33
// CHECK-LABEL: func.func @torch.aten.tanh$basic(
44
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@@ -13,6 +13,80 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
1313

1414
// -----
1515

16+
// CHECK-LABEL: func.func @conv2d_io_insert_reshape(
17+
// CHECK: %[[SHAPE:.*]] = tosa.const_shape
18+
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"
19+
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"
20+
// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]]
21+
// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[SHAPE]]
22+
// CHECK: %[[CONV:.*]] = tosa.conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
23+
func.func @conv2d_io_insert_reshape(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x1x16xf32> {
24+
%shape = "tosa.const_shape"() {values = dense<[1, 1, 16, 16]> : tensor<4xindex>} : () -> !tosa.shape<4>
25+
%input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
26+
%weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
27+
%r0 = "tosa.reshape"(%arg0, %shape) : (tensor<256xf32>, !tosa.shape<4>) -> tensor<1x1x16x16xf32>
28+
%r1 = "tosa.reshape"(%arg1, %shape) : (tensor<256xf32>, !tosa.shape<4>) -> tensor<1x1x16x16xf32>
29+
%conv = "tosa.conv2d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, acc_type = f32} : (tensor<1x1x16x16xf32>, tensor<1x1x16x16xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x16xf32>
30+
return %conv : tensor<1x1x1x16xf32>
31+
}
32+
33+
// CHECK-LABEL: func.func @depthwise_conv2d_io_insert_reshape(
34+
// CHECK: %[[SHAPE:.*]] = tosa.const_shape
35+
// CHECK: %[[WSHAPE:.*]] = tosa.const_shape
36+
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"
37+
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"
38+
// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]]
39+
// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]]
40+
// CHECK: %[[CONV:.*]] = tosa.depthwise_conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
41+
func.func @depthwise_conv2d_io_insert_reshape(%arg0: tensor<9xf32>, %arg1: tensor<9xf32>, %arg2: tensor<1xf32>) -> tensor<1x1x1x1xf32> {
42+
%shape = "tosa.const_shape"() {values = dense<[1, 3, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
43+
%wshape = "tosa.const_shape"() {values = dense<[3, 3, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
44+
%input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
45+
%weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
46+
%r0 = "tosa.reshape"(%arg0, %shape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x3x3x1xf32>
47+
%r1 = "tosa.reshape"(%arg1, %wshape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<3x3x1x1xf32>
48+
%conv = "tosa.depthwise_conv2d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, acc_type = f32} : (tensor<1x3x3x1xf32>, tensor<3x3x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32>
49+
return %conv : tensor<1x1x1x1xf32>
50+
}
51+
52+
// CHECK-LABEL: func.func @transpose_conv2d_io_insert_reshape(
53+
// CHECK: %[[SHAPE:.*]] = tosa.const_shape
54+
// CHECK: %[[WSHAPE:.*]] = tosa.const_shape
55+
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"
56+
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"
57+
// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]]
58+
// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]]
59+
// CHECK: %[[CONV:.*]] = tosa.transpose_conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
60+
func.func @transpose_conv2d_io_insert_reshape(%arg0: tensor<9xf32>, %arg1: tensor<9xf32>, %arg2: tensor<1xf32>) -> tensor<1x5x5x1xf32> {
61+
%shape = "tosa.const_shape"() {values = dense<[1, 3, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
62+
%wshape = "tosa.const_shape"() {values = dense<[1, 3, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
63+
%input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
64+
%weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
65+
%r0 = "tosa.reshape"(%arg0, %shape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x3x3x1xf32>
66+
%r1 = "tosa.reshape"(%arg1, %wshape) : (tensor<9xf32>, !tosa.shape<4>) -> tensor<1x3x3x1xf32>
67+
%conv = "tosa.transpose_conv2d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>} : (tensor<1x3x3x1xf32>, tensor<1x3x3x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x1xf32>
68+
return %conv : tensor<1x5x5x1xf32>
69+
}
70+
71+
// CHECK-LABEL: func.func @conv3d_io_insert_reshape(
72+
// CHECK: %[[SHAPE:.*]] = tosa.const_shape
73+
// CHECK: %[[WSHAPE:.*]] = tosa.const_shape
74+
// CHECK: %[[INPUT_ZP:.*]] = "tosa.const"
75+
// CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"
76+
// CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]]
77+
// CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]]
78+
// CHECK: %[[CONV:.*]] = tosa.conv3d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
79+
func.func @conv3d_io_insert_reshape(%arg0: tensor<64xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x1x4x4x4xf32> {
80+
%shape = "tosa.const_shape"() {values = dense<[1, 1, 4, 4, 4]> : tensor<5xindex>} : () -> !tosa.shape<5>
81+
%wshape = "tosa.const_shape"() {values = dense<[1, 1, 1, 1, 1]> : tensor<5xindex>} : () -> !tosa.shape<5>
82+
%input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
83+
%weight_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
84+
%r0 = "tosa.reshape"(%arg0, %shape) : (tensor<64xf32>, !tosa.shape<5>) -> tensor<1x1x4x4x4xf32>
85+
%r1 = "tosa.reshape"(%arg1, %wshape) : (tensor<1xf32>, !tosa.shape<5>) -> tensor<1x1x1x1x1xf32>
86+
%conv = "tosa.conv3d"(%r0, %r1, %arg2, %input_zp, %weight_zp) {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>, acc_type = f32} : (tensor<1x1x4x4x4xf32>, tensor<1x1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x4x4x4xf32>
87+
return %conv : tensor<1x1x4x4x4xf32>
88+
}
89+
1690
// CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
1791
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
1892
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@@ -2417,8 +2491,7 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc
24172491
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
24182492
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
24192493
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2420-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2421-
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32>
2494+
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32> // expected-error {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
24222495
return %3 : !torch.vtensor<[1,192,35,35],f32>
24232496
}
24242497

@@ -2664,8 +2737,7 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6
26642737

26652738
func.func @torch.aten.index.Tensor_hacked_twin.dynamic_size(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[?,1],si64>, %arg2: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[?,4],f32> attributes {torch.assume_strict_symbolic_shapes} {
26662739
%0 = torch.prim.ListConstruct %arg1, %arg2 : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[1,4],si64>) -> !torch.list<vtensor>
2667-
// expected-error @+1 {{failed to legalize operation 'torch.aten.index.Tensor_hacked_twin' that was explicitly marked illegal}}
2668-
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[?,4],f32>, !torch.list<vtensor> -> !torch.vtensor<[?,4],f32>
2740+
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[?,4],f32>, !torch.list<vtensor> -> !torch.vtensor<[?,4],f32> // expected-error {{failed to legalize operation 'torch.aten.index.Tensor_hacked_twin' that was explicitly marked illegal}}
26692741
return %1 : !torch.vtensor<[?,4],f32>
26702742
}
26712743

@@ -4552,8 +4624,7 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
45524624
%none = torch.constant.none
45534625
%cpu = torch.constant.device "cpu"
45544626
%false = torch.constant.bool false
4555-
// expected-error @below {{failed to legalize operation 'torch.aten.empty.memory_format' that was explicitly marked illegal}}
4556-
%out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32>
4627+
%out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32> // expected-error {{failed to legalize operation 'torch.aten.empty.memory_format' that was explicitly marked illegal}}
45574628
return %out : !torch.vtensor<[1,0,256],f32>
45584629
}
45594630

0 commit comments

Comments
 (0)