1212#include " mlir/Dialect/Arith/IR/Arith.h"
1313#include " mlir/Dialect/Tensor/IR/Tensor.h"
1414#include " mlir/Dialect/Tosa/IR/TosaOps.h"
15+ #include " mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1516#include " mlir/IR/Matchers.h"
1617#include " mlir/Transforms/DialectConversion.h"
1718#include " torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
@@ -2280,9 +2281,9 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
22802281 Type &resultType,
22812282 const llvm::ArrayRef<int64_t > weightShape,
22822283 Value &input, Value &weights, Value &bias,
2283- const int64_t groups, DenseI64ArrayAttr & pads,
2284- DenseI64ArrayAttr & strides,
2285- DenseI64ArrayAttr &dilations ) {
2284+ const int64_t groups, DenseI64ArrayAttr pads,
2285+ DenseI64ArrayAttr strides, DenseI64ArrayAttr dilations ,
2286+ TypeAttr accType ) {
22862287 // Set up constants outside of loop
22872288 const int64_t sizeOfSliceInput = weightShape[1 ];
22882289 const int64_t sizeOfSliceKernel = weightShape[0 ] / groups;
@@ -2312,7 +2313,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
23122313 // Create conv
23132314 Value tempConv2D = tosa::CreateOpAndInfer<mlir::tosa::Conv2DOp>(
23142315 rewriter, input.getLoc (), outputType, sliceInput, sliceWeight,
2315- sliceBias, pads, strides, dilations);
2316+ sliceBias, pads, strides, dilations, accType );
23162317 // Add value to vector
23172318 sliceValues.push_back (tempConv2D);
23182319 }
@@ -2420,6 +2421,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24202421 return rewriter.notifyMatchFailure (op,
24212422 " non-const dilation list unsupported" );
24222423
2424+ TypeAttr accType;
2425+ if (failed (tosa::getConvOpsAccType (rewriter, inputTy, weightTy, outputTy,
2426+ accType)))
2427+ return rewriter.notifyMatchFailure (
2428+ op, " failed to get accumulator type for convolution ops" );
2429+
24232430 // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
24242431 // Perform the necessary transformations.
24252432 std::optional<Value> nchwToNhwcTransposeConst =
@@ -2523,22 +2530,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25232530 // quantized input is i32, which gets rescaled down to quantized output range.
25242531 SmallVector<int64_t > outputShape = {transposedInputShape[0 ], outputHDim,
25252532 outputWDim, outputCDim};
2526-
2527- DenseI64ArrayAttr paddingAttr = rewriter.getDenseI64ArrayAttr (padding);
2528- DenseI64ArrayAttr strideAttr = rewriter.getDenseI64ArrayAttr (stride);
2529- DenseI64ArrayAttr dilationAttr = rewriter.getDenseI64ArrayAttr (dilation);
2530-
25312533 Value convOpResult;
25322534 if (groups == 1 ) {
25332535 // full convolution
25342536 auto convOpTy =
25352537 RankedTensorType::get (makeShapeLLVMCompatible (outputShape), biasElemTy);
25362538 convOpResult =
25372539 rewriter
2538- .create <tosa::Conv2DOp>(op->getLoc (),
2539- getTypeConverter ()->convertType (convOpTy),
2540- transposedInput, transformedWeight, bias,
2541- paddingAttr, strideAttr, dilationAttr)
2540+ .create <tosa::Conv2DOp>(
2541+ op->getLoc (), getTypeConverter ()->convertType (convOpTy),
2542+ transposedInput, transformedWeight, bias,
2543+ rewriter.getDenseI64ArrayAttr (padding),
2544+ rewriter.getDenseI64ArrayAttr (stride),
2545+ rewriter.getDenseI64ArrayAttr (dilation), accType)
25422546 .getResult ();
25432547 } else if (weightShape[1 ] == 1 ) {
25442548 // depthwise convolution
@@ -2548,14 +2552,18 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25482552 rewriter
25492553 .create <tosa::DepthwiseConv2DOp>(
25502554 op->getLoc (), getTypeConverter ()->convertType (convOpTy),
2551- transposedInput, transformedWeight, bias, paddingAttr,
2552- strideAttr, dilationAttr)
2555+ transposedInput, transformedWeight, bias,
2556+ rewriter.getDenseI64ArrayAttr (padding),
2557+ rewriter.getDenseI64ArrayAttr (stride),
2558+ rewriter.getDenseI64ArrayAttr (dilation), accType)
25532559 .getResult ();
25542560 } else {
25552561 // general group convolution
25562562 convOpResult = createConvInGroups (
25572563 rewriter, op, outputTy, weightShape, transposedInput, transformedWeight,
2558- bias, groups, paddingAttr, strideAttr, dilationAttr);
2564+ bias, groups, rewriter.getDenseI64ArrayAttr (padding),
2565+ rewriter.getDenseI64ArrayAttr (stride),
2566+ rewriter.getDenseI64ArrayAttr (dilation), accType);
25592567 }
25602568
25612569 std::optional<Value> nhwcToNchwTransposeConst =
@@ -4103,9 +4111,11 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
41034111 }
41044112 }
41054113
4106- auto result = rewriter.create <tosa::TileOp>(
4107- op->getLoc (), resultType, reshapedInput,
4108- rewriter.getDenseI64ArrayAttr (tileOpShape));
4114+ auto tileOpMultiples =
4115+ tosa::getTosaConstShape (rewriter, op->getLoc (), tileOpShape);
4116+
4117+ auto result = rewriter.create <tosa::TileOp>(op->getLoc (), resultType,
4118+ reshapedInput, tileOpMultiples);
41094119
41104120 rewriter.replaceOp (op, {result.getResult ()});
41114121 }
@@ -4298,9 +4308,11 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
42984308 RankedTensorType::get (makeShapeLLVMCompatible (expandedIndicesShape),
42994309 rewriter.getIntegerType (32 ));
43004310
4311+ auto tileOpMultiples =
4312+ tosa::getTosaConstShape (rewriter, op->getLoc (), tileShape);
4313+
43014314 auto expandedIndices = rewriter.create <tosa::TileOp>(
4302- op->getLoc (), tileType, reshapedIndices.getResult (),
4303- rewriter.getDenseI64ArrayAttr (tileShape));
4315+ op->getLoc (), tileType, reshapedIndices.getResult (), tileOpMultiples);
43044316
43054317 // convert torch style index and dim into tf style indices
43064318 // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
@@ -4639,17 +4651,23 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
46394651 if (needsTiling) {
46404652 auto idxType =
46414653 dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType ());
4654+
46424655 // indicesTfConcatTensors has a trailing [1] dim for the final concat.
46434656 auto maxRankMaxDimShapeTf (maxRankMaxDimShape);
46444657 maxRankMaxDimShapeTf.push_back (1 );
4658+
46454659 auto tileOpShapeTf (tileOpShape);
46464660 tileOpShapeTf.push_back (1 );
4661+
46474662 auto tileOutputTy = RankedTensorType::get (maxRankMaxDimShapeTf,
46484663 idxType.getElementType ());
46494664 auto reshapedIdxTensor = indicesTfConcatTensors[i];
4665+
4666+ auto tileOpMultiples =
4667+ tosa::getTosaConstShape (rewriter, op->getLoc (), tileOpShapeTf);
4668+
46504669 indicesTfConcatTensors[i] = rewriter.create <tosa::TileOp>(
4651- op->getLoc (), tileOutputTy, reshapedIdxTensor,
4652- rewriter.getDenseI64ArrayAttr (tileOpShapeTf));
4670+ op->getLoc (), tileOutputTy, reshapedIdxTensor, tileOpMultiples);
46534671 }
46544672
46554673 // Every index tensor now has the same rank and shape
@@ -6220,12 +6238,14 @@ class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
62206238 op->getLoc (), fillValueMatchedInputRankType, fillValue,
62216239 rewriter.getDenseI64ArrayAttr (fillValueMatchedInputRankShape));
62226240
6241+ auto tileOpMultiples =
6242+ tosa::getTosaConstShape (rewriter, op->getLoc (), outType.getShape ());
6243+
62236244 fillValueTargetTensor = rewriter.create <tosa::TileOp>(
62246245 op->getLoc (),
62256246 RankedTensorType::get (makeShapeTorchCompatible (outType.getShape ()),
62266247 fillValueElemTy),
6227- fillValueMatchedInputRankTensor.getResult (),
6228- makeShapeTorchCompatible (outType.getShape ()));
6248+ fillValueMatchedInputRankTensor.getResult (), tileOpMultiples);
62296249 } else {
62306250 if (failed (torchScalarToTosaTensor (
62316251 rewriter, op, op.getValue (), fillValueTargetTensor, outElemTy,
@@ -6376,7 +6396,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
63766396 }
63776397
63786398 DenseElementsAttr paddingAttr = DenseIntElementsAttr::get (
6379- RankedTensorType::get ({rank, 2 }, rewriter.getI64Type ()),
6399+ RankedTensorType::get ({2 * rank }, rewriter.getI64Type ()),
63806400 translatePadsList);
63816401
63826402 Value padsList1 = rewriter.create <mlir::tosa::ConstOp>(
@@ -8033,9 +8053,11 @@ LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewrite(
80338053 resultType.getElementType ()),
80348054 self, rewriter.getDenseI64ArrayAttr (resultShapeIndex1Replaced));
80358055
8056+ auto selfTileOpMultiples = tosa::getTosaConstShape (rewriter, op->getLoc (),
8057+ resultShapeIndex0Replaced);
8058+
80368059 auto selfTiled = rewriter.create <tosa::TileOp>(
8037- op->getLoc (), resultType, selfReshaped.getResult (),
8038- rewriter.getDenseI64ArrayAttr (resultShapeIndex0Replaced));
8060+ op->getLoc (), resultType, selfReshaped.getResult (), selfTileOpMultiples);
80398061
80408062 // Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]}
80418063 auto vec2Reshaped = rewriter.create <tosa::ReshapeOp>(
@@ -8044,9 +8066,11 @@ LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewrite(
80448066 resultType.getElementType ()),
80458067 vec2, rewriter.getDenseI64ArrayAttr (resultShapeIndex0Replaced));
80468068
8069+ auto vec2TileOpMultiples = tosa::getTosaConstShape (rewriter, op->getLoc (),
8070+ resultShapeIndex1Replaced);
8071+
80478072 auto vec2Tiled = rewriter.create <tosa::TileOp>(
8048- op->getLoc (), resultType, vec2Reshaped.getResult (),
8049- rewriter.getDenseI64ArrayAttr (resultShapeIndex1Replaced));
8073+ op->getLoc (), resultType, vec2Reshaped.getResult (), vec2TileOpMultiples);
80508074
80518075 auto result =
80528076 tosa::createMulOpAndCast (rewriter, op, resultType, selfTiled.getResult (),
0 commit comments