@@ -558,6 +558,45 @@ struct ConvertStablehloFloatDivideOp
558558 }
559559};
560560
561+ struct ConvertStablehloDynamicSliceOp
562+ : public OpRewritePattern<stablehlo::DynamicSliceOp> {
563+ using OpRewritePattern<stablehlo::DynamicSliceOp>::OpRewritePattern;
564+
565+ LogicalResult matchAndRewrite (stablehlo::DynamicSliceOp op,
566+ PatternRewriter& rewriter) const override {
567+ auto operandType = dyn_cast<RankedTensorType>(op.getOperand ().getType ());
568+ if (!operandType) {
569+ return rewriter.notifyMatchFailure (op, " expected ranked tensor type" );
570+ }
571+
572+ if (operandType.getRank () < 1 ) {
573+ return rewriter.notifyMatchFailure (
574+ op, " tosa.slice requires input tensor of at least rank 1" );
575+ }
576+
577+ SmallVector<int64_t > startIndices;
578+ for (Value startIndex : op.getStartIndices ()) {
579+ DenseIntElementsAttr startAttr;
580+ if (auto constOp = startIndex.getDefiningOp <stablehlo::ConstantOp>()) {
581+ startAttr = dyn_cast<DenseIntElementsAttr>(constOp.getValue ());
582+ }
583+
584+ if (!startAttr || startAttr.getNumElements () != 1 ) {
585+ return rewriter.notifyMatchFailure (
586+ op, " tosa.slice requires constant start indices" );
587+ }
588+
589+ startIndices.push_back ((*startAttr.value_begin <APInt>()).getSExtValue ());
590+ }
591+
592+ rewriter.replaceOpWithNewOp <tosa::SliceOp>(
593+ op, op.getType (), op.getOperand (),
594+ getTosaConstShape (rewriter, op.getLoc (), startIndices),
595+ getTosaConstShape (rewriter, op.getLoc (), op.getSliceSizes ()));
596+ return success ();
597+ }
598+ };
599+
561600LogicalResult StablehloLegalizeToTosaPass::initialize (MLIRContext* ctx) {
562601 RewritePatternSet patternList (ctx);
563602 populateGeneratedPDLLPatterns (patternList);
@@ -573,6 +612,8 @@ LogicalResult StablehloLegalizeToTosaPass::initialize(MLIRContext* ctx) {
573612 patternList.addWithLabel <ConvertStablehloReduceOp>({" StablehloReduce" }, ctx);
574613 patternList.addWithLabel <ConvertStablehloReturnOp>({" StablehloReturn" }, ctx);
575614 patternList.addWithLabel <ConvertStablehloSliceOp>({" StablehloSlice" }, ctx);
615+ patternList.addWithLabel <ConvertStablehloDynamicSliceOp>(
616+ {" StablehloDynamicSlice" }, ctx);
576617 patternList.addWithLabel <ConvertStablehloTransposeOp>({" StablehloTranspose" },
577618 ctx);
578619 patternList.addWithLabel <ConvertStablehloWhileOp>({" StablehloWhile" }, ctx);
0 commit comments