Skip to content

Commit 4b6d89a

Browse files
authored
Derive output_size of repeat_interleave when inputs are broadcast(fill(x)) (#109)
1 parent 198c510 commit 4b6d89a

4 files changed

Lines changed: 82 additions & 0 deletions

File tree

e2e_testing/xfail_sets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"Conv1dNoPaddingTransposeModule_basic",
1919
"Conv1dNoPaddingGroupModule_basic",
2020
"RepeatInterleaveStaticModule_basic",
21+
"RepeatInterleaveFillModule_basic",
2122
# tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
2223
"IndexPutImpl2DNoneIndexBroadcastStaticModule_basic"
2324
}
@@ -277,6 +278,7 @@
277278
"ScatterValueIntModule_basic",
278279
# ERROR: Unsupported: dynamic shape operator: aten.repeat_interleave.Tensor
279280
"RepeatInterleaveModule_basic",
281+
"RepeatInterleaveFillModule_basic",
280282

281283
# failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal
282284
"Conv1dNoPaddingModule_basic",
@@ -1226,6 +1228,7 @@
12261228
"ChunkListUnpack_Module_basic",
12271229
"ChunkListUnpackUneven_Module_basic",
12281230
"RepeatInterleaveStaticModule_basic",
1231+
"RepeatInterleaveFillModule_basic",
12291232
}
12301233

12311234
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
@@ -1490,5 +1493,6 @@
14901493
"ScatterValueFloatModule_basic",
14911494
"ScatterValueIntModule_basic",
14921495
"RepeatInterleaveModule_basic",
1496+
"RepeatInterleaveFillModule_basic",
14931497
"Im2ColModule_basic",
14941498
}

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,4 +491,11 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
491491
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
492492
return backendLegalOpsSet.contains(opName);
493493
});
494+
495+
// TODO: We need this for TOSA; other backends might be fine with this op
496+
// having a dynamic sized output tensor.
497+
target.addDynamicallyLegalOp<AtenRepeatInterleaveTensorOp>(
498+
[](AtenRepeatInterleaveTensorOp op) {
499+
return op.getOutputSize().getDefiningOp<ConstantIntOp>();
500+
});
494501
}

lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,54 @@ class RecomposeChunkListUnpack : public OpRewritePattern<PrimListUnpackOp> {
393393
return success();
394394
}
395395
};
396+
class RecomposeRepeatInterleave : public OpRewritePattern<AtenRepeatInterleaveTensorOp> {
397+
public:
398+
using OpRewritePattern::OpRewritePattern;
399+
LogicalResult matchAndRewrite(AtenRepeatInterleaveTensorOp op,
400+
PatternRewriter &rewriter) const override {
401+
if (!op.getOutputSize().getDefiningOp<ConstantNoneOp>())
402+
return failure();
403+
404+
auto repeatsTy = dyn_cast<BaseTensorType>(op.getRepeats().getType());
405+
if (!repeatsTy || !repeatsTy.areAllSizesKnown() || repeatsTy.getSizes().size() != 1) {
406+
return rewriter.notifyMatchFailure(
407+
op,
408+
"Expected 1d tensor with static shape");
409+
}
410+
auto numElements = repeatsTy.getSizes()[0];
411+
412+
auto broadcast = op.getRepeats().getDefiningOp<AtenBroadcastToOp>();
413+
if (!broadcast){
414+
return rewriter.notifyMatchFailure(
415+
op,
416+
"Expected broadcast op defining repeat_interleave input");
417+
}
418+
419+
auto fill = broadcast.getSelf().getDefiningOp<AtenFillScalarOp>();
420+
if (!fill){
421+
return rewriter.notifyMatchFailure(
422+
op,
423+
"Expected fill op defining broadcast/repeat_interleave input");
424+
}
425+
426+
int64_t fillValue;
427+
if (!matchPattern(fill.getValue(),
428+
m_TorchConstantInt(&fillValue))) {
429+
return rewriter.notifyMatchFailure(
430+
op,
431+
"Expected fill value of fill.Scalar to be an integer constant");
432+
}
433+
434+
auto outputSize = rewriter.create<Torch::ConstantIntOp>(
435+
op->getLoc(), rewriter.getI64IntegerAttr(fillValue * numElements));
436+
rewriter.replaceOpWithNewOp<AtenRepeatInterleaveTensorOp>(op, op.getType(), op.getRepeats(), outputSize);
437+
438+
if (op.getResult().use_empty())
439+
rewriter.eraseOp(op);
440+
return success();
441+
}
442+
};
443+
396444
} // namespace
397445

398446
namespace {
@@ -412,6 +460,7 @@ class RecomposeComplexOpsPass
412460
patterns.add<RecomposeUnbindGetItem>(context);
413461
patterns.add<RecomposeSplitTensorPrimListUnpackOp>(context);
414462
patterns.add<RecomposeChunkListUnpack>(context);
463+
patterns.add<RecomposeRepeatInterleave>(context);
415464

416465
GreedyRewriteConfig config;
417466
config.useTopDownTraversal = true;

python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,28 @@ def forward(self, x):
14801480
def RepeatInterleaveModule_basic(module, tu: TestUtils):
14811481
module.forward(torch.tensor([3, 1, 2, 4], dtype=torch.int))
14821482

1483+
# ==============================================================================
1484+
class RepeatInterleaveFillModule(torch.nn.Module):
1485+
1486+
def __init__(self):
1487+
super().__init__()
1488+
1489+
@export
1490+
@annotate_args([
1491+
None,
1492+
([1], torch.int, True),
1493+
])
1494+
def forward(self, x):
1495+
x = torch.ops.aten.fill_(x, 2)
1496+
x = torch.ops.aten.expand(x, [16])
1497+
return torch.ops.aten.repeat_interleave(x)
1498+
1499+
1500+
@register_test_case(module_factory=lambda: RepeatInterleaveFillModule())
1501+
def RepeatInterleaveFillModule_basic(module, tu: TestUtils):
1502+
module.forward(torch.tensor([1], dtype=torch.int))
1503+
1504+
14831505
# ==============================================================================
14841506

14851507
class RepeatInterleaveStaticModule(torch.nn.Module):

0 commit comments

Comments
 (0)