@@ -393,6 +393,54 @@ class RecomposeChunkListUnpack : public OpRewritePattern<PrimListUnpackOp> {
393393 return success ();
394394 }
395395};
396+ class RecomposeRepeatInterleave : public OpRewritePattern <AtenRepeatInterleaveTensorOp> {
397+ public:
398+ using OpRewritePattern::OpRewritePattern;
399+ LogicalResult matchAndRewrite (AtenRepeatInterleaveTensorOp op,
400+ PatternRewriter &rewriter) const override {
401+ if (!op.getOutputSize ().getDefiningOp <ConstantNoneOp>())
402+ return failure ();
403+
404+ auto repeatsTy = dyn_cast<BaseTensorType>(op.getRepeats ().getType ());
405+ if (!repeatsTy || !repeatsTy.areAllSizesKnown () || repeatsTy.getSizes ().size () != 1 ) {
406+ return rewriter.notifyMatchFailure (
407+ op,
408+ " Expected 1d tensor with static shape" );
409+ }
410+ auto numElements = repeatsTy.getSizes ()[0 ];
411+
412+ auto broadcast = op.getRepeats ().getDefiningOp <AtenBroadcastToOp>();
413+ if (!broadcast){
414+ return rewriter.notifyMatchFailure (
415+ op,
416+ " Expected broadcast op defining repeat_interleave input" );
417+ }
418+
419+ auto fill = broadcast.getSelf ().getDefiningOp <AtenFillScalarOp>();
420+ if (!fill){
421+ return rewriter.notifyMatchFailure (
422+ op,
423+ " Expected fill op defining broadcast/repeat_interleave input" );
424+ }
425+
426+ int64_t fillValue;
427+ if (!matchPattern (fill.getValue (),
428+ m_TorchConstantInt (&fillValue))) {
429+ return rewriter.notifyMatchFailure (
430+ op,
431+ " Expected fill value of fill.Scalar to be an integer constant" );
432+ }
433+
434+ auto outputSize = rewriter.create <Torch::ConstantIntOp>(
435+ op->getLoc (), rewriter.getI64IntegerAttr (fillValue * numElements));
436+ rewriter.replaceOpWithNewOp <AtenRepeatInterleaveTensorOp>(op, op.getType (), op.getRepeats (), outputSize);
437+
438+ if (op.getResult ().use_empty ())
439+ rewriter.eraseOp (op);
440+ return success ();
441+ }
442+ };
443+
396444} // namespace
397445
398446namespace {
@@ -412,6 +460,7 @@ class RecomposeComplexOpsPass
412460 patterns.add <RecomposeUnbindGetItem>(context);
413461 patterns.add <RecomposeSplitTensorPrimListUnpackOp>(context);
414462 patterns.add <RecomposeChunkListUnpack>(context);
463+ patterns.add <RecomposeRepeatInterleave>(context);
415464
416465 GreedyRewriteConfig config;
417466 config.useTopDownTraversal = true ;
0 commit comments