Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 104 additions & 7 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -28,6 +29,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include <cmath>
Expand All @@ -48,6 +50,11 @@ namespace mlir::torch {
#include "torch-mlir/Conversion/Passes.h.inc"

namespace {
struct RankTemplate {
int64_t rank;
RankedTensorType type;
Value shape;
};

// Runs an in-place inclusive prefix sum along the middle dimension (K) of
// `running` using a binary lifting scheme. The input must have shape [N, K, C].
Expand Down Expand Up @@ -2634,14 +2641,109 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewriteImpl(
auto input = adaptor.getInput();
auto weight = adaptor.getWeight();

auto inputTy = cast<RankedTensorType>(input.getType());
auto weightTy = cast<RankedTensorType>(weight.getType());
auto outputTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto weightTy = dyn_cast<RankedTensorType>(weight.getType());
if (!inputTy || !weightTy || !outputTy)
return rewriter.notifyMatchFailure(
op, "Input, weight and output to Convolution must be ranked tensors");

int64_t outputRank = outputTy.getRank();
if (outputRank != 4 && outputRank != 5)
return rewriter.notifyMatchFailure(
op, "Unimplemented: only 2D or 3D convolutions supported");

auto funcOp = op->getParentOfType<func::FuncOp>();
llvm::DenseMap<unsigned, SmallVector<RankTemplate>> argToTemplates;
bool templatesBuilt = false;
DominanceInfo domInfo(funcOp);

auto buildTemplates = [&]() {
if (templatesBuilt)
return;
templatesBuilt = true;
funcOp.walk([&](tosa::ReshapeOp reshapeOp) {
Value source = reshapeOp.getInput1();
auto blockArg = dyn_cast<BlockArgument>(source);
if (!blockArg)
return;

auto dstType =
dyn_cast<RankedTensorType>(reshapeOp.getResult().getType());
if (!dstType || (dstType.getRank() != 4 && dstType.getRank() != 5))
return;

unsigned argNumber = blockArg.getArgNumber();
auto &templates = argToTemplates[argNumber];
for (const auto &tmpl : templates) {
if (tmpl.rank == dstType.getRank() && tmpl.type == dstType)
return;
}
templates.push_back(
RankTemplate{dstType.getRank(), dstType, reshapeOp.getShape()});
});
};

auto normalizeOperandRank = [&](Value operand,
int64_t requiredRank) -> FailureOr<Value> {
auto rankedType = dyn_cast<RankedTensorType>(operand.getType());
if (!rankedType)
return failure();
if (rankedType.getRank() == requiredRank)
return operand;

auto blockArg = dyn_cast<BlockArgument>(operand);
if (!blockArg)
return failure();

buildTemplates();
auto tmplIt = argToTemplates.find(blockArg.getArgNumber());
if (tmplIt == argToTemplates.end())
return failure();

const RankTemplate *match = nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Template matching is rank-only and first-hit. If multiple reshape templates exist for the same argument/rank, the selected template can be wrong, causing incorrect reshapes and semantic drift.

for (const auto &tmpl : tmplIt->second) {
if (tmpl.rank == requiredRank) {
match = &tmpl;
break;
}
}
if (!match)
return failure();

Value shapeVal = match->shape;
if (auto shapeOp = shapeVal.getDefiningOp<tosa::ConstShapeOp>()) {
OpBuilder builder(op);
shapeVal = tosa::ConstShapeOp::create(
builder, op->getLoc(), shapeOp.getType(), shapeOp.getValues());
} else if (!domInfo.properlyDominates(shapeVal, op)) {
return failure();
}

auto reshape = tosa::ReshapeOp::create(rewriter, op->getLoc(), match->type,
operand, shapeVal);
return reshape.getResult();
};

if (inputTy.getRank() != outputRank) {
auto normalized = normalizeOperandRank(input, outputRank);
if (failed(normalized))
return rewriter.notifyMatchFailure(
op, "Input rank mismatch without normalization template");
input = *normalized;
inputTy = cast<RankedTensorType>(input.getType());
}

if (weightTy.getRank() != outputRank) {
auto normalized = normalizeOperandRank(weight, outputRank);
if (failed(normalized))
return rewriter.notifyMatchFailure(
op, "Weight rank mismatch without normalization template");
weight = *normalized;
weightTy = cast<RankedTensorType>(weight.getType());
}

auto inputElemTy = inputTy.getElementType();
auto weightElemTy = weightTy.getElementType();
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
Expand All @@ -2650,16 +2752,11 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewriteImpl(

int64_t inputRank = inputTy.getRank();
int64_t weightRank = weightTy.getRank();
int64_t outputRank = outputTy.getRank();

if (inputRank != weightRank || outputRank != inputRank)
return rewriter.notifyMatchFailure(
op, "Input, weight and output ranks must match for convolution");

if (inputRank != 4 && inputRank != 5)
return rewriter.notifyMatchFailure(
op, "Unimplemented: only 2D or 3D convolutions supported");

bool is3D = inputRank == 5;
int64_t spatialRank = inputRank - 2;

Expand Down
37 changes: 37 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,43 @@ def Convolution2DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))


# ==============================================================================


class Convolution2DReshapeInputsModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([9], torch.float32, True),
([9], torch.float32, True),
([1], torch.float32, True),
]
)
def forward(self, inputVec, weight, bias):
input4d = torch.reshape(inputVec, (1, 1, 3, 3))
weight4d = torch.reshape(weight, (1, 1, 3, 3))
return torch.ops.aten.convolution(
input4d,
weight4d,
bias=bias,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
)


@register_test_case(module_factory=lambda: Convolution2DReshapeInputsModule())
def Convolution2DReshapeInputsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(9), tu.rand(9), tu.rand(1))


class Convolution2DStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
98 changes: 97 additions & 1 deletion test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s
// RUN: torch-mlir-opt %s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CHECK

// CHECK-LABEL: func.func @torch.aten.tanh$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
Expand All @@ -13,6 +13,102 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte

// -----

// CHECK-LABEL: func.func @conv2d_io_insert_reshape(
// CHECK-DAG: torch_c.to_builtin_tensor %arg0
// CHECK-DAG: torch_c.to_builtin_tensor %arg1
// CHECK: tosa.reshape
// CHECK: tosa.reshape
// CHECK: tosa.conv2d
// CHECK-NOT: torch.aten.convolution
func.func @conv2d_io_insert_reshape(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,1,1],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int16 = torch.constant.int 16
%false = torch.constant.bool false
%shape = torch.prim.ListConstruct %int1, %int1, %int16, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%wshape = torch.prim.ListConstruct %int1, %int1, %int16, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[256],f32>, !torch.list<int> -> !torch.vtensor<[1,1,16,16],f32>
%weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[256],f32>, !torch.list<int> -> !torch.vtensor<[1,1,16,16],f32>
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%output_padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int1 : !torch.vtensor<[1,1,16,16],f32>, !torch.vtensor<[1,1,16,16],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1,1],f32>
return %conv : !torch.vtensor<[1,1,1,1],f32>
}

// CHECK-LABEL: func.func @depthwise_conv2d_io_insert_reshape(
// CHECK-DAG: torch_c.to_builtin_tensor %arg0
// CHECK-DAG: torch_c.to_builtin_tensor %arg1
// CHECK: tosa.reshape
// CHECK: tosa.reshape
// CHECK: tosa.depthwise_conv2d
// CHECK-NOT: torch.aten.convolution
func.func @depthwise_conv2d_io_insert_reshape(%arg0: !torch.vtensor<[9],f32>, %arg1: !torch.vtensor<[9],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[1,3,1,1],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%shape = torch.prim.ListConstruct %int1, %int3, %int3, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%wshape = torch.prim.ListConstruct %int3, %int1, %int3, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[9],f32>, !torch.list<int> -> !torch.vtensor<[1,3,3,1],f32>
%weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[9],f32>, !torch.list<int> -> !torch.vtensor<[3,1,3,1],f32>
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%output_padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int3 : !torch.vtensor<[1,3,3,1],f32>, !torch.vtensor<[3,1,3,1],f32>, !torch.vtensor<[3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,1,1],f32>
return %conv : !torch.vtensor<[1,3,1,1],f32>
}

// CHECK-LABEL: func.func @transpose_conv2d_io_insert_reshape(
// CHECK-DAG: torch_c.to_builtin_tensor %arg0
// CHECK-DAG: torch_c.to_builtin_tensor %arg1
// CHECK: tosa.reshape
// CHECK: tosa.reshape
// CHECK: tosa.transpose_conv2d
// CHECK-NOT: torch.aten.convolution
func.func @transpose_conv2d_io_insert_reshape(%arg0: !torch.vtensor<[9],f32>, %arg1: !torch.vtensor<[9],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,5,5],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%true = torch.constant.bool true
%shape = torch.prim.ListConstruct %int1, %int1, %int3, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%wshape = torch.prim.ListConstruct %int1, %int1, %int3, %int3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[9],f32>, !torch.list<int> -> !torch.vtensor<[1,1,3,3],f32>
%weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[9],f32>, !torch.list<int> -> !torch.vtensor<[1,1,3,3],f32>
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%output_padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %true, %output_padding, %int1 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,5,5],f32>
return %conv : !torch.vtensor<[1,1,5,5],f32>
}

// CHECK-LABEL: func.func @conv3d_io_insert_reshape(
// CHECK-DAG: torch_c.to_builtin_tensor %arg0
// CHECK-DAG: torch_c.to_builtin_tensor %arg1
// CHECK: tosa.reshape
// CHECK: tosa.reshape
// CHECK: tosa.conv3d
// CHECK-NOT: torch.aten.convolution
func.func @conv3d_io_insert_reshape(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,4,4,4],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%false = torch.constant.bool false
%shape = torch.prim.ListConstruct %int1, %int1, %int4, %int4, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%wshape = torch.prim.ListConstruct %int1, %int1, %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%input_tmpl = torch.aten.reshape %arg0, %shape : !torch.vtensor<[64],f32>, !torch.list<int> -> !torch.vtensor<[1,1,4,4,4],f32>
%weight_tmpl = torch.aten.reshape %arg1, %wshape : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[1,1,1,1,1],f32>
%stride = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%dilation = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%output_padding = torch.prim.ListConstruct %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%conv = torch.aten.convolution %input_tmpl, %weight_tmpl, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int1 : !torch.vtensor<[1,1,4,4,4],f32>, !torch.vtensor<[1,1,1,1,1],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,4,4,4],f32>
return %conv : !torch.vtensor<[1,1,4,4,4],f32>
}

// CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
Expand Down
Loading