Skip to content

Commit ffc3648

Browse files
committed
Convert stablehlo.dynamic_slice into tosa.slice when indexes are const
1 parent 4a6ea0d commit ffc3648

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

stablehlo/conversions/tosa/tests/unary.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,28 @@ func.func @slice_rank_seven(%arg : tensor<2x3x4x5x6x7x8xf32>) -> tensor<1x2x3x4x
114114
return %0 : tensor<1x2x3x4x5x6x7xf32>
115115
}
116116

117+
// CHECK-LABEL: @dynamic_slice_constant_start
118+
func.func @dynamic_slice_constant_start(%arg : tensor<10x10xf32>) -> tensor<2x3xf32> {
119+
// CHECK-DAG: %[[START:.*]] = tosa.const_shape {values = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
120+
// CHECK-DAG: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
121+
// CHECK: tosa.slice %arg0, %[[START]], %[[SIZE]]
122+
%start0 = "stablehlo.constant"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
123+
%start1 = "stablehlo.constant"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
124+
%0 = "stablehlo.dynamic_slice"(%arg, %start0, %start1) {
125+
slice_sizes = array<i64: 2, 3>
126+
} : (tensor<10x10xf32>, tensor<i32>, tensor<i32>) -> tensor<2x3xf32>
127+
return %0 : tensor<2x3xf32>
128+
}
129+
130+
// CHECK-LABEL: @dynamic_slice_runtime_start
131+
func.func @dynamic_slice_runtime_start(%arg0 : tensor<10x10xf32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<2x3xf32> {
132+
// CHECK: stablehlo.dynamic_slice
133+
%0 = "stablehlo.dynamic_slice"(%arg0, %arg1, %arg2) {
134+
slice_sizes = array<i64: 2, 3>
135+
} : (tensor<10x10xf32>, tensor<i32>, tensor<i32>) -> tensor<2x3xf32>
136+
return %0 : tensor<2x3xf32>
137+
}
138+
117139
// CHECK-LABEL: @tanh
118140
func.func @tanh(%arg : tensor<10xf32>) -> tensor<10xf32> {
119141
// CHECK: tosa.tanh

stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,45 @@ struct ConvertStablehloFloatDivideOp
558558
}
559559
};
560560

561+
struct ConvertStablehloDynamicSliceOp
562+
: public OpRewritePattern<stablehlo::DynamicSliceOp> {
563+
using OpRewritePattern<stablehlo::DynamicSliceOp>::OpRewritePattern;
564+
565+
LogicalResult matchAndRewrite(stablehlo::DynamicSliceOp op,
566+
PatternRewriter& rewriter) const override {
567+
auto operandType = dyn_cast<RankedTensorType>(op.getOperand().getType());
568+
if (!operandType) {
569+
return rewriter.notifyMatchFailure(op, "expected ranked tensor type");
570+
}
571+
572+
if (operandType.getRank() < 1) {
573+
return rewriter.notifyMatchFailure(
574+
op, "tosa.slice requires input tensor of at least rank 1");
575+
}
576+
577+
SmallVector<int64_t> startIndices;
578+
for (Value startIndex : op.getStartIndices()) {
579+
DenseIntElementsAttr startAttr;
580+
if (auto constOp = startIndex.getDefiningOp<stablehlo::ConstantOp>()) {
581+
startAttr = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
582+
}
583+
584+
if (!startAttr || startAttr.getNumElements() != 1) {
585+
return rewriter.notifyMatchFailure(
586+
op, "tosa.slice requires constant start indices");
587+
}
588+
589+
startIndices.push_back((*startAttr.value_begin<APInt>()).getSExtValue());
590+
}
591+
592+
rewriter.replaceOpWithNewOp<tosa::SliceOp>(
593+
op, op.getType(), op.getOperand(),
594+
getTosaConstShape(rewriter, op.getLoc(), startIndices),
595+
getTosaConstShape(rewriter, op.getLoc(), op.getSliceSizes()));
596+
return success();
597+
}
598+
};
599+
561600
LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
562601
RewritePatternSet patternList(ctx);
563602
populateGeneratedPDLLPatterns(patternList);
@@ -573,6 +612,8 @@ LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
573612
patternList.addWithLabel<ConvertStablehloReduceOp>({"StablehloReduce"}, ctx);
574613
patternList.addWithLabel<ConvertStablehloReturnOp>({"StablehloReturn"}, ctx);
575614
patternList.addWithLabel<ConvertStablehloSliceOp>({"StablehloSlice"}, ctx);
615+
patternList.addWithLabel<ConvertStablehloDynamicSliceOp>(
616+
{"StablehloDynamicSlice"}, ctx);
576617
patternList.addWithLabel<ConvertStablehloTransposeOp>({"StablehloTranspose"},
577618
ctx);
578619
patternList.addWithLabel<ConvertStablehloWhileOp>({"StablehloWhile"}, ctx);

0 commit comments

Comments
 (0)