Skip to content

Commit 56f701f

Browse files
tatwaichongJerry-Ge
authored andcommitted
[mlir][tosa] Add tf.PadV2 legalization and change PadOp padding to tosa.shape
This patch contains legalization and test changes for changing TOSA PadOp's padding input to type !tosa.shape<2 * rank>, (where rank is the rank of the PadOp's input), instead of a <rank x 2> tensor. This patch also contains new legalization for the tf.PadV2 operator Signed-off-by: TatWai Chong <[email protected]> Signed-off-by: Tai Ly <[email protected]> Signed-off-by: Jerry Ge <[email protected]> Change-Id: Ic73df82a8625a0b16793b74228e2974f6767693a
1 parent ffd42b6 commit 56f701f

File tree

7 files changed

+149
-30
lines changed

7 files changed

+149
-30
lines changed

tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> {
646646
// -----
647647

648648
// CHECK-LABEL: test_pad
649-
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1> : tensor<3x2xi32>}>
649+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
650650
// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}>
651651
// CHECK: %[[VAR1:.*]] = tosa.pad %arg0, %[[VAR0]], %[[PVAL]]
652652
func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xf32> {
@@ -657,6 +657,19 @@ func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xf32> {
657657

658658
// -----
659659

660+
// CHECK-LABEL: test_pad_v2
661+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor<f32>}
662+
// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[1, 0, 0, 1, 1, 2]> : tensor<6xindex>} : () -> !tosa.shape<6>
663+
// CHECK: %[[VAL_3:.*]] = tosa.pad %[[VAL_0]], %[[VAL_2]], %[[VAL_1]]
664+
func.func @test_pad_v2(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xf32> {
665+
%1 = "tf.Const"() {value = dense<[[1, 0], [0, 1], [1, 2]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
666+
%2 = "tf.Const"() {value = dense<-3.40282347E+38> : tensor<f32>} : () -> tensor<f32>
667+
%3 = "tf.PadV2"(%arg0, %1, %2) : (tensor<13x21x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<15x23x5xf32>
668+
func.return %3 : tensor<15x23x5xf32>
669+
}
670+
671+
// -----
672+
660673
// CHECK-LABEL: test_expand_dims
661674
// CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 13, 21, 3>}
662675
func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> {
@@ -827,7 +840,7 @@ func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
827840
// -----
828841

829842
// CHECK-LABEL: test_space_to_batch
830-
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [0, 1], [0, 0]]> : tensor<3x2xi32>}>
843+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6>
831844
// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}>
832845
// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}>
833846
// CHECK-DAG: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[PVAL]]

tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,7 +1308,7 @@ func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> {
13081308
// -----
13091309

13101310
// CHECK-LABEL: test_pad
1311-
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 1], {{\[}}2, 2]]> : tensor<2x2xi32>}>
1311+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[1, 1, 2, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
13121312
// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}>
13131313
// CHECK: %[[VAR1:.*]] = tosa.pad %arg0, %[[VAR0]], %[[PVAL]]
13141314
func.func @test_pad(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
@@ -1323,13 +1323,13 @@ func.func @test_pad(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
13231323
// CHECK-LABEL: test_pad_v2
13241324
// CHECK-SAME: -> tensor<1x257x9x28xf32>
13251325
func.func @test_pad_v2(%arg0: tensor<1x256x8x25xf32>) -> (tensor<*xf32>) {
1326-
// CHECK-DAG: %[[PADDING:.+]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [1, 0], [0, 1], [1, 2]]> : tensor<4x2xi32>}>
1326+
// CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 0, 0, 1, 1, 2]> : tensor<8xindex>} : () -> !tosa.shape<8>
13271327
%0 = "tfl.pseudo_const"() {value = dense<[[0, 0], [1, 0], [0, 1], [1, 2]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
13281328

13291329
// CHECK-DAG: %[[VAL:.+]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor<f32>}>
13301330
%1 = "tfl.pseudo_const"() {value = dense<-3.40282347E+38> : tensor<f32>} : () -> tensor<f32>
13311331

1332-
// CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PADDING]], %[[VAL]] : (tensor<1x256x8x25xf32>, tensor<4x2xi32>, tensor<f32>) -> tensor<1x257x9x28xf32>
1332+
// CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PADDING]], %[[VAL]] : (tensor<1x256x8x25xf32>, !tosa.shape<8>, tensor<f32>) -> tensor<1x257x9x28xf32>
13331333
%2 = "tfl.padv2"(%arg0, %0, %1) : (tensor<1x256x8x25xf32>, tensor<4x2xi32>, tensor<f32>) -> tensor<*xf32>
13341334

13351335
// CHECK: return %[[PAD]]
@@ -1675,7 +1675,7 @@ func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> {
16751675
// -----
16761676

16771677
// CHECK-LABEL: test_space_to_batch
1678-
// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{\[}}[0, 0], [0, 1], [0, 0]]> : tensor<3x2xi32>}>
1678+
// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6>
16791679
// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}>
16801680
// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}>
16811681
// CHECK-DAG: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[PVAL]]
@@ -2727,8 +2727,9 @@ func.func @test_rfft2d_crop_input(%arg0: tensor<13x21x3xf32>) -> tensor<13x2x2xc
27272727
// CHECK-LABEL: test_rfft2d_pad_input
27282728
// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32>
27292729
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
2730-
// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 11], [0, 5]]> : tensor<3x2xi32>}> : () -> tensor<3x2xi32>
2731-
// CHECK: %[[VAL_3:.*]] = tosa.pad %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<13x32x8xf32>
2730+
// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 11, 0, 5]> : tensor<6xindex>} : () -> !tosa.shape<6>
2731+
// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[13, 32, 5, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
2732+
// CHECK: %[[VAL_3:.*]] = tosa.pad %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x32x8xf32>
27322733
// CHECK: %[[VAL_4:.*]], %[[VAL_5:.*]] = tosa.rfft2d %[[VAL_3]] : (tensor<13x32x8xf32>) -> (tensor<13x32x5xf32>, tensor<13x32x5xf32>)
27332734
// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array<i64: 13, 32, 5, 1>} : (tensor<13x32x5xf32>) -> tensor<13x32x5x1xf32>
27342735
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 13, 32, 5, 1>} : (tensor<13x32x5xf32>) -> tensor<13x32x5x1xf32>
@@ -2747,8 +2748,9 @@ func.func @test_rfft2d_pad_input(%arg0: tensor<13x21x3xf32>) -> (tensor<13x32x5x
27472748
// CHECK-LABEL: test_rfft2d_crop_height_pad_width
27482749
// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32>
27492750
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
2750-
// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0], [0, 0], [0, 13]]> : tensor<3x2xi32>}> : () -> tensor<3x2xi32>
2751-
// CHECK: %[[VAL_3:.*]] = tosa.pad %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<13x21x16xf32>
2751+
// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 13]> : tensor<6xindex>} : () -> !tosa.shape<6>
2752+
// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[13, 2, 9, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
2753+
// CHECK: %[[VAL_3:.*]] = tosa.pad %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<f32>) -> tensor<13x21x16xf32>
27522754
// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_3]] {size = array<i64: 13, 2, 16>, start = array<i64: 0, 0, 0>} : (tensor<13x21x16xf32>) -> tensor<13x2x16xf32>
27532755
// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]] = tosa.rfft2d %[[VAL_4]] : (tensor<13x2x16xf32>) -> (tensor<13x2x9xf32>, tensor<13x2x9xf32>
27542756
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: 13, 2, 9, 1>} : (tensor<13x2x9xf32>) -> tensor<13x2x9x1xf32>

tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ limitations under the License.
4444
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
4545
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
4646
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
47+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
4748
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
4849
#include "mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project
4950
#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
@@ -947,7 +948,10 @@ std::optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
947948
if (!matchPattern(paddings_value, m_Constant(&paddings_elems)))
948949
return std::nullopt;
949950

950-
SmallVector<int32_t> a0_pad_const(2 * (input_rank));
951+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
952+
auto rank_one_shape_type = tosa::shapeType::get(rewriter.getContext(), 1);
953+
954+
SmallVector<int64_t> a0_pad_const(2 * (input_rank));
951955
SmallVector<int64_t> padded_shape(input_rank);
952956

953957
// 1. Pad based on paddings operand. No padding on the batch dimension.
@@ -989,21 +993,13 @@ std::optional<Value> convertSpaceToBatchNDOp(PatternRewriter& rewriter,
989993
padded_shape[i + block_rank + 1] = input_shape[i + block_rank + 1];
990994
}
991995

992-
RankedTensorType a0_pad_const_attr_type =
993-
tensorflow::GetTypeFromTFTensorShape({(input_rank), 2},
994-
rewriter.getIntegerType(32));
995-
996-
// Create a const op to generate the tensor type for the input padding array
997-
auto a0_pad_const_op = rewriter.create<tosa::ConstOp>(
998-
op->getLoc(), a0_pad_const_attr_type,
999-
DenseElementsAttr::get(a0_pad_const_attr_type,
1000-
llvm::ArrayRef(a0_pad_const)));
996+
Value a0_padding = mlir::tosa::getTosaConstShape(rewriter, op->getLoc(), a0_pad_const);
1001997

1002998
auto a1_pad_input_op = CreateOpAndInfer<tosa::PadOp>(
1003999
rewriter, op->getLoc(),
10041000
tensorflow::GetTypeFromTFTensorShape(padded_shape,
10051001
result_type.getElementType()),
1006-
input_value, a0_pad_const_op.getResult());
1002+
input_value, a0_padding);
10071003

10081004
// 2. Reshape the padded structure of shape padded_shape to
10091005
// [batch + padded_shape[1] / block_shape[0], block_shape[0], ...
@@ -2555,6 +2551,9 @@ std::optional<Value> convertStridedSliceOp(
25552551
}
25562552
}
25572553

2554+
SmallVector<int64_t> input_pads(
2555+
input_rank * 2); // Stores pads on either side of a dimension
2556+
25582557
// Step 0: Process the begin/end masks and build the begin/sizes for the
25592558
// first slice
25602559
SmallVector<int64_t> a1_begin(input_rank), a1_size(input_rank);
@@ -2579,9 +2578,37 @@ std::optional<Value> convertStridedSliceOp(
25792578
a1_size[i] = 1;
25802579
strides[i] = 1;
25812580
}
2581+
2582+
// Note: no padding added to dynamic dimensions
2583+
auto stride_remainder = a1_size[i] % strides[i];
2584+
if (a1_size[i] > 0 && stride_remainder != 0) {
2585+
input_pads[2 * i] = 0; // No padding at beginning of dimension
2586+
auto pad_up_value = strides[i] - stride_remainder;
2587+
input_pads[2 * i + 1] = pad_up_value; // Pad end of dimension up to the
2588+
// next multiple of strides[i]
2589+
a1_size[i] += pad_up_value;
2590+
} else {
2591+
input_pads[2 * i] = 0;
2592+
input_pads[2 * i + 1] = 0;
2593+
}
2594+
}
2595+
2596+
// Step 0.5: Add tosa.Pad if required
2597+
const bool need_padding =
2598+
llvm::any_of(input_pads, [](int64_t i) { return i != 0; });
2599+
if (need_padding) {
2600+
Value a0_padding = getTosaConstShape(rewriter, op->getLoc(), input_pads);
2601+
2602+
auto a0_pad_input_op = CreateOpAndInfer<tosa::PadOp>(
2603+
rewriter, op->getLoc(),
2604+
tensorflow::GetTypeFromTFTensorShape(a1_size,
2605+
result_type.getElementType()),
2606+
input_value, a0_padding);
2607+
2608+
input_value =
2609+
a0_pad_input_op.getResult(); // overwrite input_value parameter
25822610
}
25832611

2584-
// Step 1: Slice the input array
25852612
auto a1_slice_op = CreateOpAndInfer<tosa::SliceOp>(
25862613
rewriter, op->getLoc(),
25872614
tensorflow::GetTypeFromTFTensorShape(a1_size, element_type), input_value,

tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525

2626
#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project
2727
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
28+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
2829
#include "mlir/Support/LLVM.h" // from @llvm-project
2930
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
3031
#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
@@ -123,6 +124,7 @@ DECL_CONVERT_OP(StridedSlice);
123124
DECL_CONVERT_OP(Less);
124125
DECL_CONVERT_OP(LessEqual);
125126
DECL_CONVERT_OP(Pad);
127+
DECL_CONVERT_OP(PadV2);
126128
DECL_CONVERT_OP(MirrorPad);
127129
DECL_CONVERT_OP(ResizeBilinear);
128130
DECL_CONVERT_OP(ResizeNearestNeighbor);
@@ -1770,14 +1772,46 @@ LogicalResult ConvertTFPadOp::matchAndRewrite(Operation* op,
17701772
// Not a ranked tensor output
17711773
if (!output_type) return failure();
17721774

1773-
auto pad_op = CreateOpAndInfer<tosa::PadOp>(rewriter, op->getLoc(),
1774-
output_type, tf_pad_op.getInput(),
1775-
tf_pad_op.getPaddings());
1775+
SmallVector<int64_t> padding_vals;
1776+
if (failed(getVectorFromValue64(tf_pad_op.getPaddings(), padding_vals))) {
1777+
return rewriter.notifyMatchFailure(op, "paddings is not a constant value");
1778+
}
1779+
1780+
Value padding = mlir::tosa::getTosaConstShape(rewriter, op->getLoc(), padding_vals);
1781+
1782+
auto pad_op = CreateOpAndInfer<tosa::PadOp>(
1783+
rewriter, op->getLoc(), output_type, tf_pad_op.getInput(), padding);
17761784

17771785
rewriter.replaceOp(op, {pad_op.getResult()});
17781786
return success();
17791787
}
17801788

1789+
LogicalResult ConvertTFPadV2Op::matchAndRewrite(
1790+
Operation* op, PatternRewriter& rewriter) const {
1791+
auto tf_pad_op = cast<TF::PadV2Op>(op);
1792+
1793+
RankedTensorType output_type =
1794+
tf_pad_op.getResult().getType().dyn_cast<RankedTensorType>();
1795+
if (!output_type) {
1796+
return rewriter.notifyMatchFailure(op, "output type not a ranked tensor");
1797+
}
1798+
1799+
Value input = tf_pad_op.getInput();
1800+
Value constant_value = tf_pad_op.getConstantValues();
1801+
1802+
SmallVector<int64_t> padding_vals;
1803+
if (failed(getVectorFromValue64(tf_pad_op.getPaddings(), padding_vals))) {
1804+
return rewriter.notifyMatchFailure(op, "paddings is not a constant value");
1805+
}
1806+
1807+
Value padding = mlir::tosa::getTosaConstShape(rewriter, op->getLoc(), padding_vals);
1808+
1809+
CreateReplaceOpAndInfer<tosa::PadOp>(rewriter, op, tf_pad_op.getType(), input,
1810+
padding, constant_value);
1811+
1812+
return success();
1813+
}
1814+
17811815
LogicalResult ConvertTFMirrorPadOp::matchAndRewrite(
17821816
Operation* op, PatternRewriter& rewriter) const {
17831817
auto tf_mirrorpad_op = cast<TF::MirrorPadOp>(op);
@@ -2513,6 +2547,7 @@ void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns) {
25132547
patterns.add<ConvertTFLessOp>(ctx);
25142548
patterns.add<ConvertTFLessEqualOp>(ctx);
25152549
patterns.add<ConvertTFPadOp>(ctx);
2550+
patterns.add<ConvertTFPadV2Op>(ctx);
25162551
patterns.add<ConvertTFMirrorPadOp>(ctx);
25172552
patterns.add<ConvertTFResizeBilinearOp>(ctx);
25182553
patterns.add<ConvertTFResizeNearestNeighborOp>(ctx);

tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ limitations under the License.
3434
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
3535
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
3636
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
37+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
3738
#include "mlir/IR/Block.h" // from @llvm-project
3839
#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
3940
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
@@ -3031,9 +3032,14 @@ LogicalResult ConvertTFLPadOp::matchAndRewrite(
30313032
// Not a ranked tensor output
30323033
if (!output_type) return failure();
30333034

3035+
SmallVector<int64_t> padding_vals;
3036+
if (failed(getVectorFromValue64(tfl_pad_op.getPadding(), padding_vals))) {
3037+
return rewriter.notifyMatchFailure(op, "padding is not a constant value");
3038+
}
3039+
Value padding = getTosaConstShape(rewriter, op->getLoc(), padding_vals);
3040+
30343041
auto pad_op = CreateOpAndInfer<tosa::PadOp>(
3035-
rewriter, op->getLoc(), output_type, tfl_pad_op.getInput(),
3036-
tfl_pad_op.getPadding());
3042+
rewriter, op->getLoc(), output_type, tfl_pad_op.getInput(), padding);
30373043

30383044
rewriter.replaceOp(op, {pad_op.getResult()});
30393045
return success();
@@ -3076,9 +3082,15 @@ LogicalResult ConvertTFLPadV2Op::matchAndRewrite(
30763082
auto tfl_pad_op = cast<TFL::PadV2Op>(op);
30773083

30783084
Value input = tfl_pad_op.getInput();
3079-
Value padding = tfl_pad_op.getPadding();
30803085
Value constant_value = tfl_pad_op.getConstantValues();
30813086

3087+
SmallVector<int64_t> padding_vals;
3088+
if (failed(getVectorFromValue64(tfl_pad_op.getPadding(), padding_vals))) {
3089+
return rewriter.notifyMatchFailure(op, "padding is not a constant value");
3090+
}
3091+
3092+
Value padding = getTosaConstShape(rewriter, op->getLoc(), padding_vals);
3093+
30823094
CreateReplaceOpAndInfer<tosa::PadOp>(rewriter, op, tfl_pad_op.getType(),
30833095
input, padding, constant_value);
30843096

tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,9 @@ Value getTosaConstShape(PatternRewriter& rewriter, Operation* op,
577577

578578
// Create a vector from a 32-bit value tensor. Returns the size of
579579
// the new vector or -1 on error.
580+
// Populate a int32_t vector from a val tensor
581+
// return failure if val is not a constant value
582+
// return success otherwise
580583
LogicalResult getVectorFromValue32(Value val, SmallVectorImpl<int32_t>& vec) {
581584
int i = 0;
582585

@@ -594,6 +597,26 @@ LogicalResult getVectorFromValue32(Value val, SmallVectorImpl<int32_t>& vec) {
594597
return success();
595598
}
596599

600+
// Populate a int64_t vector from a val tensor
601+
// return failure if val is not a constant value
602+
// return success otherwise
603+
LogicalResult getVectorFromValue64(Value val, SmallVectorImpl<int64_t>& vec) {
604+
int i = 0;
605+
606+
ElementsAttr elems;
607+
608+
vec.clear();
609+
610+
if (!matchPattern(val, m_Constant(&elems))) return failure();
611+
612+
for (auto idx : elems.getValues<IntegerAttr>()) {
613+
vec.push_back(static_cast<int64_t>(idx.getInt()));
614+
i++;
615+
}
616+
617+
return success();
618+
}
619+
597620
// Calculates the TOSA padding values based on TF operators padded with
598621
// SAME/VALID.
599622
//

tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,17 @@ Value getTosaConstTensorScalarInt(ImplicitLocOpBuilder& builder, Type type,
137137
Value getTosaConstShape(PatternRewriter& rewriter, Operation* op,
138138
llvm::ArrayRef<int64_t> values);
139139

140-
// Create a vector from a 32-bit value tensor. Returns vector size on success
141-
// or -1 on error.
140+
141+
// Populate a int32_t vector from a val tensor
142+
// return failure if val is not a constant value
143+
// return success otherwise
142144
LogicalResult getVectorFromValue32(Value val, SmallVectorImpl<int32_t>& vec);
143145

146+
// Populate a int64_t vector from a val tensor
147+
// return failure if val is not a constant value
148+
// return success otherwise
149+
LogicalResult getVectorFromValue64(Value val, SmallVectorImpl<int64_t>& vec);
150+
144151
// Calculates the TOSA padding values based on TF operators padded with
145152
// SAME/VALID.
146153
bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad,

0 commit comments

Comments
 (0)