Skip to content

Commit 4b48bb7

Browse files
authored
[tosa] Add support for torch.aten.upsample_bilinear2d legalization (#4492)
1 parent 7f1d4b2 commit 4b48bb7

File tree

6 files changed

+289
-87
lines changed

6 files changed

+289
-87
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ std::optional<Value> createRoundHalfToEven(ConversionPatternRewriter &rewriter,
136136
Operation *op, Value input,
137137
RankedTensorType resultTy);
138138

139+
Value convertResizeOp(ConversionPatternRewriter &rewriter, Operation *op,
140+
const TypeConverter *typeConverter, Value input,
141+
RankedTensorType inputTy, RankedTensorType resultTy,
142+
int64_t outputHeight, int64_t outputWidth,
143+
bool alignCorners, tosa::ResizeMode mode);
144+
139145
} // namespace tosa
140146
} // namespace mlir
141147

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ FailureOr<Value> getZeroPointValue(PatternRewriter &rewriter, Operation *op,
119119
// Check if a shaped type has any dimension with size 0.
120120
bool typeHasZeroDim(ShapedType type);
121121

122+
// Compute scale/offset/border parameters for TOSA resize on one dimension.
123+
void computeResizeParams(int inputSize, int outputSize, bool alignCorners,
124+
tosa::ResizeMode mode, int &scaleN, int &scaleD,
125+
int &offset, int &border);
126+
122127
} // namespace tosa
123128
} // namespace mlir
124129

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 109 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -7758,22 +7758,6 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewriteImpl(
77587758
"TOSA resize() takes rank==4 tensors.");
77597759

77607760
auto inputShape = inputTy.getShape();
7761-
auto inputElemTy = inputTy.getElementType();
7762-
// TOSA works in NHWC. Perform the necessary transformations.
7763-
SmallVector<int32_t> nchwToNhwcDims({0, 2, 3, 1});
7764-
SmallVector<int64_t> transposedInputShape(
7765-
{inputShape[0], inputShape[2], inputShape[3], inputShape[1]});
7766-
auto transposedInputTy = RankedTensorType::get(
7767-
makeShapeLLVMCompatible(transposedInputShape), inputElemTy);
7768-
auto transposedInput =
7769-
tosa::TransposeOp::create(
7770-
rewriter, op->getLoc(),
7771-
getTypeConverter()->convertType(transposedInputTy), input,
7772-
rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
7773-
.getResult();
7774-
7775-
auto inputHeight = transposedInputShape[1];
7776-
auto inputWidth = transposedInputShape[2];
77777761

77787762
int outputHeight, outputWidth;
77797763
if (!isa<Torch::NoneType>(op.getScaleFactor().getType())) {
@@ -7783,8 +7767,8 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewriteImpl(
77837767
return rewriter.notifyMatchFailure(
77847768
op, "non-const scale_factor parameter unsupported");
77857769

7786-
outputHeight = inputHeight * scaleFactor[0];
7787-
outputWidth = inputWidth * scaleFactor[1];
7770+
outputHeight = inputShape[2] * scaleFactor[0];
7771+
outputWidth = inputShape[3] * scaleFactor[1];
77887772

77897773
} else {
77907774
if (!isa<Torch::NoneType>(op.getSize().getType()))
@@ -7841,78 +7825,13 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewriteImpl(
78417825
return rewriter.notifyMatchFailure(
78427826
op, "Application of antialias not yet supported");
78437827

7844-
SmallVector<int64_t> transposedResizedOpShape(
7845-
{inputShape[0], outputHeight, outputWidth, inputShape[1]});
7846-
auto transposedResizedOpTy = RankedTensorType::get(
7847-
makeShapeLLVMCompatible(transposedResizedOpShape), inputElemTy);
7848-
7849-
// Formatting snake_case to match TOSA spec names for readability
7850-
int scale_y_n, scale_y_d, offset_y, border_y;
7851-
int scale_x_n, scale_x_d, offset_x, border_x;
7852-
7853-
// Align corners sets the scaling ratio to (OH - 1)/(IH - 1)
7854-
// rather than OH / IH. Similarly for width.
7855-
auto normalize = [&](int input, int output, int &n, int &d, int &offset,
7856-
int &border) {
7857-
// Dimension is length 1, we are just sampling from one value.
7858-
if (input == 1) {
7859-
n = output;
7860-
d = 1;
7861-
offset = 0;
7862-
border = output - 1;
7863-
return;
7864-
}
7865-
7866-
// Apply if aligned and capable to be aligned.
7867-
bool apply_aligned = alignCorners && (output > 1);
7868-
n = apply_aligned ? (output - 1) : output;
7869-
d = apply_aligned ? (input - 1) : input;
7870-
7871-
// Simplify the scalers, make sure they are even values.
7872-
int gcd = std::gcd(n, d);
7873-
n = 2 * n / gcd;
7874-
d = 2 * d / gcd;
7875-
7876-
offset = 0;
7877-
7878-
// If nearest neighbours we need to guarantee we round up.
7879-
if (mode == tosa::ResizeMode::NEAREST_NEIGHBOR && alignCorners) {
7880-
offset += n / 2;
7881-
}
7882-
7883-
// TBD: impact of antialias parameter here ?
7884-
7885-
// We can compute this directly based on previous values.
7886-
border = d * (output - 1) - n * (input - 1) + offset;
7887-
};
7888-
7889-
normalize(inputHeight, outputHeight, scale_y_n, scale_y_d, offset_y,
7890-
border_y);
7891-
normalize(inputWidth, outputWidth, scale_x_n, scale_x_d, offset_x, border_x);
7892-
7893-
auto scale = tosa::getTosaConstShape(
7894-
rewriter, op->getLoc(), {scale_y_n, scale_y_d, scale_x_n, scale_x_d});
7895-
auto offset =
7896-
tosa::getTosaConstShape(rewriter, op->getLoc(), {offset_y, offset_x});
7897-
auto border =
7898-
tosa::getTosaConstShape(rewriter, op->getLoc(), {border_y, border_x});
7899-
7900-
auto modeAttr = tosa::ResizeModeAttr::get(rewriter.getContext(), mode);
7901-
7902-
auto resizeOpResult =
7903-
tosa::ResizeOp::create(rewriter, op->getLoc(), transposedResizedOpTy,
7904-
transposedInput, scale, offset, border, modeAttr)
7905-
.getResult();
7906-
79077828
auto resultType =
79087829
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
79097830

7910-
SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
7911-
rewriter
7912-
.replaceOpWithNewOp<tosa::TransposeOp>(
7913-
op, getTypeConverter()->convertType(resultType), resizeOpResult,
7914-
rewriter.getDenseI32ArrayAttr(nhwcToNchwDims))
7915-
.getResult();
7831+
Value resizeOp = convertResizeOp(rewriter, op, this->getTypeConverter(),
7832+
input, inputTy, resultType, outputHeight,
7833+
outputWidth, alignCorners, mode);
7834+
rewriter.replaceOp(op, {resizeOp});
79167835

79177836
return success();
79187837
}
@@ -9283,6 +9202,101 @@ LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewriteImpl(
92839202
return success();
92849203
}
92859204

9205+
// Legalization for aten.upsample_bilinear2d
9206+
template <typename AtenOpT>
9207+
class ConvertUpsampleBilinear2dForward : public OpConversionPattern<AtenOpT> {
9208+
public:
9209+
using OpConversionPattern<AtenOpT>::OpConversionPattern;
9210+
using OpAdaptor = typename AtenOpT::Adaptor;
9211+
LogicalResult
9212+
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
9213+
ConversionPatternRewriter &rewriter) const override {
9214+
Value input;
9215+
if constexpr (std::is_same<AtenOpT, AtenUpsampleBilinear2dOp>()) {
9216+
input = adaptor.getSelf();
9217+
} else if constexpr (std::is_same<AtenOpT, AtenUpsampleBilinear2dVecOp>()) {
9218+
input = adaptor.getInput();
9219+
} else {
9220+
return rewriter.notifyMatchFailure(
9221+
op, "Expected either AtenUpsampleBilinear2dOp or "
9222+
"AtenUpsampleBilinear2dVecOp");
9223+
}
9224+
9225+
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
9226+
if (!inputTy) {
9227+
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
9228+
}
9229+
if (inputTy.getRank() != 4) {
9230+
return rewriter.notifyMatchFailure(op, "TOSA resize() requires rank 4");
9231+
}
9232+
9233+
auto inputShape = inputTy.getShape();
9234+
9235+
int64_t outputHeight;
9236+
int64_t outputWidth;
9237+
9238+
if constexpr (std::is_same<AtenOpT, AtenUpsampleBilinear2dOp>()) {
9239+
SmallVector<int64_t> outputSize;
9240+
if (!matchPattern(op.getOutputSize(),
9241+
m_TorchListOfConstantInts(outputSize))) {
9242+
return rewriter.notifyMatchFailure(
9243+
op, "Non-constant output size not supported");
9244+
}
9245+
9246+
outputHeight = outputSize[0];
9247+
outputWidth = outputSize[1];
9248+
} else if constexpr (std::is_same<AtenOpT, AtenUpsampleBilinear2dVecOp>()) {
9249+
if (!isa<Torch::NoneType>(op.getOutputSize().getType())) {
9250+
SmallVector<int64_t> outputSize;
9251+
if (!matchPattern(op.getOutputSize(),
9252+
m_TorchListOfConstantInts(outputSize))) {
9253+
return rewriter.notifyMatchFailure(
9254+
op, "Non-constant output size not supported");
9255+
}
9256+
9257+
outputHeight = outputSize[0];
9258+
outputWidth = outputSize[1];
9259+
} else {
9260+
if (isa<Torch::NoneType>(op.getScaleFactors().getType())) {
9261+
return rewriter.notifyMatchFailure(
9262+
op, "Missing output size and scale factors");
9263+
}
9264+
9265+
SmallVector<double, 2> scaleFactors;
9266+
if (!matchPattern(op.getScaleFactors(),
9267+
m_TorchListOfConstantFloats(scaleFactors))) {
9268+
return rewriter.notifyMatchFailure(
9269+
op, "Non-constant scale_factors not supported");
9270+
}
9271+
9272+
// PyTorch uses floor after the scale multiplication
9273+
// https://docs.pytorch.org/docs/stable/generated/torch.nn.UpsamplingBilinear2d.html
9274+
outputHeight =
9275+
static_cast<int64_t>(std::floor(inputShape[2] * scaleFactors[0]));
9276+
outputWidth =
9277+
static_cast<int64_t>(std::floor(inputShape[3] * scaleFactors[1]));
9278+
}
9279+
}
9280+
9281+
bool alignCorners;
9282+
if (!matchPattern(op.getAlignCorners(),
9283+
m_TorchConstantBool(&alignCorners))) {
9284+
return rewriter.notifyMatchFailure(
9285+
op, "Non-constant align_corners parameter unsupported");
9286+
}
9287+
9288+
auto resultTy = cast<RankedTensorType>(
9289+
this->getTypeConverter()->convertType(op.getType()));
9290+
9291+
Value resizeOp = convertResizeOp(
9292+
rewriter, op, this->getTypeConverter(), input, inputTy, resultTy,
9293+
outputHeight, outputWidth, alignCorners, tosa::ResizeMode::BILINEAR);
9294+
rewriter.replaceOp(op, {resizeOp});
9295+
9296+
return success();
9297+
}
9298+
};
9299+
92869300
// Legalization for aten.upsample_nearest2d
92879301
template <typename AtenOpT>
92889302
class ConvertUpsampleNearest2dForward
@@ -10617,6 +10631,14 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
1061710631
INSERT_POW_OP_PATTERN(AtenPowScalarOp);
1061810632
#undef INSERT_POW_OP_PATTERN
1061910633

10634+
#define INSERT_UPSAMPLE_BILINEAR_2D_FORWARD_OP_PATTERN(AtenOp) \
10635+
illegalOps.insert(AtenOp::getOperationName()); \
10636+
patterns.add<ConvertUpsampleBilinear2dForward<AtenOp>>(typeConverter, \
10637+
context);
10638+
INSERT_UPSAMPLE_BILINEAR_2D_FORWARD_OP_PATTERN(AtenUpsampleBilinear2dOp);
10639+
INSERT_UPSAMPLE_BILINEAR_2D_FORWARD_OP_PATTERN(AtenUpsampleBilinear2dVecOp);
10640+
#undef INSERT_UPSAMPLE_BILINEAR_2D_FORWARD_OP_PATTERN
10641+
1062010642
#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \
1062110643
illegalOps.insert(AtenOp::getOperationName()); \
1062210644
patterns.addWithLabel<ConvertUpsampleNearest2dForward<AtenOp>>( \

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
1111
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
1212
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
13+
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
1314
#include "torch-mlir/Conversion/Utils/Utils.h"
1415
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
16+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
1517

1618
#include <cstdint>
1719
#include <iterator>
@@ -1275,5 +1277,64 @@ std::optional<Value> createRoundHalfToEven(ConversionPatternRewriter &rewriter,
12751277
return selectOp.getResult();
12761278
}
12771279

1280+
Value convertResizeOp(ConversionPatternRewriter &rewriter, Operation *op,
1281+
const TypeConverter *typeConverter, Value input,
1282+
RankedTensorType inputTy, RankedTensorType resultTy,
1283+
int64_t outputHeight, int64_t outputWidth,
1284+
bool alignCorners, tosa::ResizeMode mode) {
1285+
auto inputShape = inputTy.getShape();
1286+
auto inputElemTy = inputTy.getElementType();
1287+
1288+
// TOSA works in NHWC. Perform the necessary transformations.
1289+
SmallVector<int32_t> nchwToNhwcDims({0, 2, 3, 1});
1290+
SmallVector<int64_t> transposedInputShape(
1291+
{inputShape[0], inputShape[2], inputShape[3], inputShape[1]});
1292+
auto transposedInputTy = RankedTensorType::get(
1293+
makeShapeLLVMCompatible(transposedInputShape), inputElemTy);
1294+
auto transposedInput =
1295+
tosa::TransposeOp::create(
1296+
rewriter, op->getLoc(), typeConverter->convertType(transposedInputTy),
1297+
input, rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
1298+
.getResult();
1299+
1300+
int inputHeight = transposedInputShape[1];
1301+
int inputWidth = transposedInputShape[2];
1302+
1303+
SmallVector<int64_t> transposedResizedOpShape(
1304+
{inputShape[0], outputHeight, outputWidth, inputShape[1]});
1305+
auto transposedResizedOpTy = RankedTensorType::get(
1306+
makeShapeLLVMCompatible(transposedResizedOpShape), inputElemTy);
1307+
1308+
// Formatting snake_case to match TOSA spec names for readability
1309+
int scale_y_n, scale_y_d, offset_y, border_y;
1310+
int scale_x_n, scale_x_d, offset_x, border_x;
1311+
1312+
computeResizeParams(inputHeight, outputHeight, alignCorners, mode, scale_y_n,
1313+
scale_y_d, offset_y, border_y);
1314+
computeResizeParams(inputWidth, outputWidth, alignCorners, mode, scale_x_n,
1315+
scale_x_d, offset_x, border_x);
1316+
1317+
auto scale = tosa::getTosaConstShape(
1318+
rewriter, op->getLoc(), {scale_y_n, scale_y_d, scale_x_n, scale_x_d});
1319+
auto offset =
1320+
tosa::getTosaConstShape(rewriter, op->getLoc(), {offset_y, offset_x});
1321+
auto border =
1322+
tosa::getTosaConstShape(rewriter, op->getLoc(), {border_y, border_x});
1323+
1324+
auto modeAttr = tosa::ResizeModeAttr::get(rewriter.getContext(), mode);
1325+
1326+
auto resizeOpResult =
1327+
tosa::ResizeOp::create(rewriter, op->getLoc(), transposedResizedOpTy,
1328+
transposedInput, scale, offset, border, modeAttr)
1329+
.getResult();
1330+
1331+
SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
1332+
auto transposedResizedOp = tosa::TransposeOp::create(
1333+
rewriter, op->getLoc(), typeConverter->convertType(resultTy),
1334+
resizeOpResult, rewriter.getDenseI32ArrayAttr(nhwcToNchwDims));
1335+
1336+
return transposedResizedOp.getResult();
1337+
}
1338+
12781339
} // namespace tosa
12791340
} // namespace mlir

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "torch-mlir/Conversion/Utils/Utils.h"
1515
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1616
#include "llvm/ADT/ArrayRef.h"
17+
#include <numeric>
1718

1819
namespace mlir {
1920
namespace tosa {
@@ -585,5 +586,37 @@ bool typeHasZeroDim(ShapedType type) {
585586
return llvm::any_of(outShape, [](int64_t dim) { return dim == 0; });
586587
}
587588

589+
void computeResizeParams(int inputSize, int outputSize, bool alignCorners,
590+
tosa::ResizeMode mode, int &scaleN, int &scaleD,
591+
int &offset, int &border) {
592+
// Dimension is length 1, we are just sampling from one value.
593+
if (inputSize == 1) {
594+
scaleN = outputSize;
595+
scaleD = 1;
596+
offset = 0;
597+
border = outputSize - 1;
598+
return;
599+
}
600+
601+
// Apply if aligned and capable to be aligned.
602+
bool applyAligned = alignCorners && (outputSize > 1);
603+
scaleN = applyAligned ? (outputSize - 1) : outputSize;
604+
scaleD = applyAligned ? (inputSize - 1) : inputSize;
605+
606+
// Simplify the scalers, make sure they are even values.
607+
int gcd = std::gcd(scaleN, scaleD);
608+
scaleN = 2 * scaleN / gcd;
609+
scaleD = 2 * scaleD / gcd;
610+
611+
// If nearest neighbors we need to guarantee we round up.
612+
offset = 0;
613+
if (mode == tosa::ResizeMode::NEAREST_NEIGHBOR && alignCorners) {
614+
offset += scaleN / 2;
615+
}
616+
617+
// We can compute this directly based on previous values.
618+
border = scaleD * (outputSize - 1) - scaleN * (inputSize - 1) + offset;
619+
}
620+
588621
} // namespace tosa
589622
} // namespace mlir

0 commit comments

Comments
 (0)