@@ -7758,22 +7758,6 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewriteImpl(
77587758 " TOSA resize() takes rank==4 tensors." );
77597759
77607760 auto inputShape = inputTy.getShape ();
7761- auto inputElemTy = inputTy.getElementType ();
7762- // TOSA works in NHWC. Perform the necessary transformations.
7763- SmallVector<int32_t > nchwToNhwcDims ({0 , 2 , 3 , 1 });
7764- SmallVector<int64_t > transposedInputShape (
7765- {inputShape[0 ], inputShape[2 ], inputShape[3 ], inputShape[1 ]});
7766- auto transposedInputTy = RankedTensorType::get (
7767- makeShapeLLVMCompatible (transposedInputShape), inputElemTy);
7768- auto transposedInput =
7769- tosa::TransposeOp::create (
7770- rewriter, op->getLoc (),
7771- getTypeConverter ()->convertType (transposedInputTy), input,
7772- rewriter.getDenseI32ArrayAttr (nchwToNhwcDims))
7773- .getResult ();
7774-
7775- auto inputHeight = transposedInputShape[1 ];
7776- auto inputWidth = transposedInputShape[2 ];
77777761
77787762 int outputHeight, outputWidth;
77797763 if (!isa<Torch::NoneType>(op.getScaleFactor ().getType ())) {
@@ -7783,8 +7767,8 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewriteImpl(
77837767 return rewriter.notifyMatchFailure (
77847768 op, " non-const scale_factor parameter unsupported" );
77857769
7786- outputHeight = inputHeight * scaleFactor[0 ];
7787- outputWidth = inputWidth * scaleFactor[1 ];
7770+ outputHeight = inputShape[ 2 ] * scaleFactor[0 ];
7771+ outputWidth = inputShape[ 3 ] * scaleFactor[1 ];
77887772
77897773 } else {
77907774 if (!isa<Torch::NoneType>(op.getSize ().getType ()))
@@ -7841,78 +7825,13 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewriteImpl(
78417825 return rewriter.notifyMatchFailure (
78427826 op, " Application of antialias not yet supported" );
78437827
7844- SmallVector<int64_t > transposedResizedOpShape (
7845- {inputShape[0 ], outputHeight, outputWidth, inputShape[1 ]});
7846- auto transposedResizedOpTy = RankedTensorType::get (
7847- makeShapeLLVMCompatible (transposedResizedOpShape), inputElemTy);
7848-
7849- // Formatting snake_case to match TOSA spec names for readability
7850- int scale_y_n, scale_y_d, offset_y, border_y;
7851- int scale_x_n, scale_x_d, offset_x, border_x;
7852-
7853- // Align corners sets the scaling ratio to (OH - 1)/(IH - 1)
7854- // rather than OH / IH. Similarly for width.
7855- auto normalize = [&](int input, int output, int &n, int &d, int &offset,
7856- int &border) {
7857- // Dimension is length 1, we are just sampling from one value.
7858- if (input == 1 ) {
7859- n = output;
7860- d = 1 ;
7861- offset = 0 ;
7862- border = output - 1 ;
7863- return ;
7864- }
7865-
7866- // Apply if aligned and capable to be aligned.
7867- bool apply_aligned = alignCorners && (output > 1 );
7868- n = apply_aligned ? (output - 1 ) : output;
7869- d = apply_aligned ? (input - 1 ) : input;
7870-
7871- // Simplify the scalers, make sure they are even values.
7872- int gcd = std::gcd (n, d);
7873- n = 2 * n / gcd;
7874- d = 2 * d / gcd;
7875-
7876- offset = 0 ;
7877-
7878- // If nearest neighbours we need to guarantee we round up.
7879- if (mode == tosa::ResizeMode::NEAREST_NEIGHBOR && alignCorners) {
7880- offset += n / 2 ;
7881- }
7882-
7883- // TBD: impact of antialias parameter here ?
7884-
7885- // We can compute this directly based on previous values.
7886- border = d * (output - 1 ) - n * (input - 1 ) + offset;
7887- };
7888-
7889- normalize (inputHeight, outputHeight, scale_y_n, scale_y_d, offset_y,
7890- border_y);
7891- normalize (inputWidth, outputWidth, scale_x_n, scale_x_d, offset_x, border_x);
7892-
7893- auto scale = tosa::getTosaConstShape (
7894- rewriter, op->getLoc (), {scale_y_n, scale_y_d, scale_x_n, scale_x_d});
7895- auto offset =
7896- tosa::getTosaConstShape (rewriter, op->getLoc (), {offset_y, offset_x});
7897- auto border =
7898- tosa::getTosaConstShape (rewriter, op->getLoc (), {border_y, border_x});
7899-
7900- auto modeAttr = tosa::ResizeModeAttr::get (rewriter.getContext (), mode);
7901-
7902- auto resizeOpResult =
7903- tosa::ResizeOp::create (rewriter, op->getLoc (), transposedResizedOpTy,
7904- transposedInput, scale, offset, border, modeAttr)
7905- .getResult ();
7906-
79077828 auto resultType =
79087829 cast<RankedTensorType>(typeConverter->convertType (op.getType ()));
79097830
7910- SmallVector<int32_t > nhwcToNchwDims ({0 , 3 , 1 , 2 });
7911- rewriter
7912- .replaceOpWithNewOp <tosa::TransposeOp>(
7913- op, getTypeConverter ()->convertType (resultType), resizeOpResult,
7914- rewriter.getDenseI32ArrayAttr (nhwcToNchwDims))
7915- .getResult ();
7831+ Value resizeOp = convertResizeOp (rewriter, op, this ->getTypeConverter (),
7832+ input, inputTy, resultType, outputHeight,
7833+ outputWidth, alignCorners, mode);
7834+ rewriter.replaceOp (op, {resizeOp});
79167835
79177836 return success ();
79187837}
@@ -9283,6 +9202,101 @@ LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewriteImpl(
92839202 return success ();
92849203}
92859204
9205+ // Legalization for aten.upsample_bilinear2d
9206+ template <typename AtenOpT>
9207+ class ConvertUpsampleBilinear2dForward : public OpConversionPattern <AtenOpT> {
9208+ public:
9209+ using OpConversionPattern<AtenOpT>::OpConversionPattern;
9210+ using OpAdaptor = typename AtenOpT::Adaptor;
9211+ LogicalResult
9212+ matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
9213+ ConversionPatternRewriter &rewriter) const override {
9214+ Value input;
9215+ if constexpr (std::is_same<AtenOpT, AtenUpsampleBilinear2dOp>()) {
9216+ input = adaptor.getSelf ();
9217+ } else if constexpr (std::is_same<AtenOpT, AtenUpsampleBilinear2dVecOp>()) {
9218+ input = adaptor.getInput ();
9219+ } else {
9220+ return rewriter.notifyMatchFailure (
9221+ op, " Expected either AtenUpsampleBilinear2dOp or "
9222+ " AtenUpsampleBilinear2dVecOp" );
9223+ }
9224+
9225+ auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
9226+ if (!inputTy) {
9227+ return rewriter.notifyMatchFailure (op, " Only tensor types are supported" );
9228+ }
9229+ if (inputTy.getRank () != 4 ) {
9230+ return rewriter.notifyMatchFailure (op, " TOSA resize() requires rank 4" );
9231+ }
9232+
9233+ auto inputShape = inputTy.getShape ();
9234+
9235+ int64_t outputHeight;
9236+ int64_t outputWidth;
9237+
9238+ if constexpr (std::is_same<AtenOpT, AtenUpsampleBilinear2dOp>()) {
9239+ SmallVector<int64_t > outputSize;
9240+ if (!matchPattern (op.getOutputSize (),
9241+ m_TorchListOfConstantInts (outputSize))) {
9242+ return rewriter.notifyMatchFailure (
9243+ op, " Non-constant output size not supported" );
9244+ }
9245+
9246+ outputHeight = outputSize[0 ];
9247+ outputWidth = outputSize[1 ];
9248+ } else if constexpr (std::is_same<AtenOpT, AtenUpsampleBilinear2dVecOp>()) {
9249+ if (!isa<Torch::NoneType>(op.getOutputSize ().getType ())) {
9250+ SmallVector<int64_t > outputSize;
9251+ if (!matchPattern (op.getOutputSize (),
9252+ m_TorchListOfConstantInts (outputSize))) {
9253+ return rewriter.notifyMatchFailure (
9254+ op, " Non-constant output size not supported" );
9255+ }
9256+
9257+ outputHeight = outputSize[0 ];
9258+ outputWidth = outputSize[1 ];
9259+ } else {
9260+ if (isa<Torch::NoneType>(op.getScaleFactors ().getType ())) {
9261+ return rewriter.notifyMatchFailure (
9262+ op, " Missing output size and scale factors" );
9263+ }
9264+
9265+ SmallVector<double , 2 > scaleFactors;
9266+ if (!matchPattern (op.getScaleFactors (),
9267+ m_TorchListOfConstantFloats (scaleFactors))) {
9268+ return rewriter.notifyMatchFailure (
9269+ op, " Non-constant scale_factors not supported" );
9270+ }
9271+
9272+ // PyTorch uses floor after the scale multiplication
9273+ // https://docs.pytorch.org/docs/stable/generated/torch.nn.UpsamplingBilinear2d.html
9274+ outputHeight =
9275+ static_cast <int64_t >(std::floor (inputShape[2 ] * scaleFactors[0 ]));
9276+ outputWidth =
9277+ static_cast <int64_t >(std::floor (inputShape[3 ] * scaleFactors[1 ]));
9278+ }
9279+ }
9280+
9281+ bool alignCorners;
9282+ if (!matchPattern (op.getAlignCorners (),
9283+ m_TorchConstantBool (&alignCorners))) {
9284+ return rewriter.notifyMatchFailure (
9285+ op, " Non-constant align_corners parameter unsupported" );
9286+ }
9287+
9288+ auto resultTy = cast<RankedTensorType>(
9289+ this ->getTypeConverter ()->convertType (op.getType ()));
9290+
9291+ Value resizeOp = convertResizeOp (
9292+ rewriter, op, this ->getTypeConverter (), input, inputTy, resultTy,
9293+ outputHeight, outputWidth, alignCorners, tosa::ResizeMode::BILINEAR);
9294+ rewriter.replaceOp (op, {resizeOp});
9295+
9296+ return success ();
9297+ }
9298+ };
9299+
92869300// Legalization for aten.upsample_nearest2d
92879301template <typename AtenOpT>
92889302class ConvertUpsampleNearest2dForward
@@ -10617,6 +10631,14 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
1061710631 INSERT_POW_OP_PATTERN (AtenPowScalarOp);
1061810632#undef INSERT_POW_OP_PATTERN
1061910633
10634+ #define INSERT_UPSAMPLE_BILINEAR_2D_FORWARD_OP_PATTERN (AtenOp ) \
10635+ illegalOps.insert (AtenOp::getOperationName ()); \
10636+ patterns.add <ConvertUpsampleBilinear2dForward<AtenOp>>(typeConverter, \
10637+ context);
10638+ INSERT_UPSAMPLE_BILINEAR_2D_FORWARD_OP_PATTERN (AtenUpsampleBilinear2dOp);
10639+ INSERT_UPSAMPLE_BILINEAR_2D_FORWARD_OP_PATTERN (AtenUpsampleBilinear2dVecOp);
10640+ #undef INSERT_UPSAMPLE_BILINEAR_2D_FORWARD_OP_PATTERN
10641+
1062010642#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN (AtenOp ) \
1062110643 illegalOps.insert (AtenOp::getOperationName ()); \
1062210644 patterns.addWithLabel <ConvertUpsampleNearest2dForward<AtenOp>>( \
0 commit comments