Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 204 additions & 7 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -26,8 +28,11 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include <cmath>
Expand All @@ -48,6 +53,12 @@ namespace mlir::torch {
#include "torch-mlir/Conversion/Passes.h.inc"

namespace {
struct RankTemplate {
int64_t rank;
RankedTensorType type;
Value shape;
std::optional<SmallVector<int64_t>> shapeValues;
};

// Runs an in-place inclusive prefix sum along the middle dimension (K) of
// `running` using a binary lifting scheme. The input must have shape [N, K, C].
Expand Down Expand Up @@ -2634,14 +2645,205 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewriteImpl(
auto input = adaptor.getInput();
auto weight = adaptor.getWeight();

auto inputTy = cast<RankedTensorType>(input.getType());
auto weightTy = cast<RankedTensorType>(weight.getType());
auto outputTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
auto weightTy = dyn_cast<RankedTensorType>(weight.getType());
if (!inputTy || !weightTy || !outputTy)
return rewriter.notifyMatchFailure(
op, "Input, weight and output to Convolution must be ranked tensors");

int64_t outputRank = outputTy.getRank();
if (outputRank != 4 && outputRank != 5)
return rewriter.notifyMatchFailure(
op, "Unimplemented: only 2D or 3D convolutions supported");

auto funcOp = op->getParentOfType<func::FuncOp>();
DominanceInfo domInfo(funcOp);

auto peelTrivialDefs = [](Value source) -> Value {
while (true) {
if (auto unrealized =
source.getDefiningOp<UnrealizedConversionCastOp>()) {
if (unrealized->getNumOperands() == 1) {
source = unrealized.getOperand(0);
continue;
}
}
if (auto castOp = source.getDefiningOp<tensor::CastOp>()) {
source = castOp.getSource();
continue;
}
if (auto toBuiltin =
source.getDefiningOp<TorchConversion::ToBuiltinTensorOp>()) {
source = toBuiltin.getOperand();
continue;
}
break;
}
return source;
};

auto addTemplate = [&](SmallVectorImpl<RankTemplate> &templates, int64_t rank,
RankedTensorType type, Value shape,
std::optional<SmallVector<int64_t>> shapeValues) {
for (const auto &tmpl : templates) {
if (tmpl.rank == rank && tmpl.type == type)
return;
}
templates.push_back(RankTemplate{rank, type, shape, shapeValues});
};

auto collectTemplatesFromSource =
[&](Value source) -> SmallVector<RankTemplate> {
SmallVector<RankTemplate> templates;
if (!source)
return templates;
source = peelTrivialDefs(source);

SmallVector<Value> worklist;
llvm::SmallDenseSet<Value, 16> visited;
worklist.push_back(source);

while (!worklist.empty()) {
Value current = worklist.pop_back_val();
if (!visited.insert(current).second)
continue;

for (OpOperand &use : current.getUses()) {
Operation *user = use.getOwner();

if (auto reshapeOp = dyn_cast<tosa::ReshapeOp>(user)) {
if (!domInfo.properlyDominates(reshapeOp.getOperation(), op))
continue;
auto dstType =
dyn_cast<RankedTensorType>(reshapeOp.getResult().getType());
if (!dstType || (dstType.getRank() != 4 && dstType.getRank() != 5))
continue;
addTemplate(templates, dstType.getRank(), dstType,
reshapeOp.getShape(), std::nullopt);
continue;
}

if (auto reshapeOp = dyn_cast<Torch::AtenReshapeOp>(user)) {
if (!domInfo.properlyDominates(reshapeOp.getOperation(), op))
continue;
auto torchTy =
dyn_cast<Torch::ValueTensorType>(reshapeOp.getResult().getType());
if (!torchTy || !torchTy.hasSizes() || !torchTy.hasDtype())
continue;
auto dstType = dyn_cast<RankedTensorType>(torchTy.toBuiltinTensor());
if (!dstType || (dstType.getRank() != 4 && dstType.getRank() != 5))
continue;
SmallVector<int64_t> shapeValues;
for (int64_t dim : dstType.getShape())
shapeValues.push_back(dim);
addTemplate(templates, dstType.getRank(), dstType, Value(),
shapeValues);
continue;
}

if (auto toBuiltin =
dyn_cast<TorchConversion::ToBuiltinTensorOp>(user)) {
worklist.push_back(toBuiltin.getResult());
continue;
}

if (auto unrealized = dyn_cast<UnrealizedConversionCastOp>(user)) {
if (unrealized->getNumResults() == 1)
worklist.push_back(unrealized.getResult(0));
continue;
}

if (auto castOp = dyn_cast<tensor::CastOp>(user)) {
worklist.push_back(castOp.getResult());
continue;
}
}
}

return templates;
};

auto normalizeOperandRank = [&](Value operand, Value torchOperand,
int64_t requiredRank) -> FailureOr<Value> {
auto rankedType = dyn_cast<RankedTensorType>(operand.getType());
if (!rankedType)
return failure();
if (rankedType.getRank() == requiredRank)
return operand;

SmallVector<RankTemplate> templates = collectTemplatesFromSource(operand);
if (torchOperand)
templates.append(collectTemplatesFromSource(torchOperand));
if (!templates.empty()) {
SmallVector<RankTemplate> deduped;
for (const auto &tmpl : templates)
addTemplate(deduped, tmpl.rank, tmpl.type, tmpl.shape,
tmpl.shapeValues);
templates.swap(deduped);
}
if (templates.empty())
return failure();

auto operandTy = dyn_cast<RankedTensorType>(operand.getType());
auto operandElemTy = operandTy.getElementType();
std::optional<int64_t> operandNumElements;
if (operandTy.hasStaticShape())
operandNumElements = operandTy.getNumElements();
SmallVector<const RankTemplate *> candidates;
for (const auto &tmpl : templates) {
if (tmpl.rank != requiredRank)
continue;
if (tmpl.type.getElementType() != operandElemTy)
continue;
if (operandNumElements && tmpl.type.hasStaticShape() &&
tmpl.type.getNumElements() != *operandNumElements)
continue;
candidates.push_back(&tmpl);
}
if (candidates.empty())
return failure();
if (candidates.size() != 1)
return failure();
const RankTemplate *match = candidates.front();

Value shapeVal = match->shape;
if (!shapeVal) {
if (!match->shapeValues)
return failure();
shapeVal =
tosa::getTosaConstShape(rewriter, op->getLoc(), *match->shapeValues);
} else if (auto shapeOp = shapeVal.getDefiningOp<tosa::ConstShapeOp>()) {
shapeVal = tosa::ConstShapeOp::create(
rewriter, op->getLoc(), shapeOp.getType(), shapeOp.getValues());
} else if (!domInfo.properlyDominates(shapeVal, op)) {
return failure();
}

auto reshape = tosa::ReshapeOp::create(rewriter, op->getLoc(), match->type,
operand, shapeVal);
return reshape.getResult();
};

if (inputTy.getRank() != outputRank) {
auto normalized = normalizeOperandRank(input, op.getInput(), outputRank);
if (failed(normalized))
return rewriter.notifyMatchFailure(
op, "Input rank mismatch without normalization template");
input = *normalized;
inputTy = cast<RankedTensorType>(input.getType());
}

if (weightTy.getRank() != outputRank) {
auto normalized = normalizeOperandRank(weight, op.getWeight(), outputRank);
if (failed(normalized))
return rewriter.notifyMatchFailure(
op, "Weight rank mismatch without normalization template");
weight = *normalized;
weightTy = cast<RankedTensorType>(weight.getType());
}

auto inputElemTy = inputTy.getElementType();
auto weightElemTy = weightTy.getElementType();
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
Expand All @@ -2650,16 +2852,11 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewriteImpl(

int64_t inputRank = inputTy.getRank();
int64_t weightRank = weightTy.getRank();
int64_t outputRank = outputTy.getRank();

if (inputRank != weightRank || outputRank != inputRank)
return rewriter.notifyMatchFailure(
op, "Input, weight and output ranks must match for convolution");

if (inputRank != 4 && inputRank != 5)
return rewriter.notifyMatchFailure(
op, "Unimplemented: only 2D or 3D convolutions supported");

bool is3D = inputRank == 5;
int64_t spatialRank = inputRank - 2;

Expand Down
37 changes: 37 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,43 @@ def Convolution2DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2))


# ==============================================================================


class Convolution2DReshapeInputsModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([9], torch.float32, True),
([9], torch.float32, True),
([1], torch.float32, True),
]
)
def forward(self, inputVec, weight, bias):
input4d = torch.reshape(inputVec, (1, 1, 3, 3))
weight4d = torch.reshape(weight, (1, 1, 3, 3))
return torch.ops.aten.convolution(
input4d,
weight4d,
bias=bias,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
)


@register_test_case(module_factory=lambda: Convolution2DReshapeInputsModule())
def Convolution2DReshapeInputsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(9), tu.rand(9), tu.rand(1))


class Convolution2DStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading