|
| 1 | +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" |
| 2 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 3 | +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 4 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 5 | +#include "mlir/Pass/Pass.h" |
| 6 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 7 | + |
| 8 | +using namespace mlir; |
| 9 | + |
| 10 | +/// Matches scf -> cf loop pattern: |
| 11 | +/// %cmp = arith.cmpi slt, %lhs, %const : i32 |
| 12 | +/// cf.cond_br %cmp, ^loop, ^exit |
| 13 | +/// |
| 14 | +/// and annotates the `cf.cond_br` with loop iteration count. |
| 15 | + |
| 16 | +namespace { |
| 17 | + |
| 18 | +struct AnnotateLoopCondBrPattern : public OpRewritePattern<cf::CondBranchOp> { |
| 19 | + using OpRewritePattern<cf::CondBranchOp>::OpRewritePattern; |
| 20 | + |
| 21 | + LogicalResult matchAndRewrite(cf::CondBranchOp condBrOp, |
| 22 | + PatternRewriter &rewriter) const override { |
| 23 | + |
| 24 | + auto cmpOp = condBrOp.getCondition().getDefiningOp<arith::CmpIOp>(); |
| 25 | + if (!cmpOp) |
| 26 | + return failure(); |
| 27 | + |
| 28 | + if (cmpOp.getPredicate() != arith::CmpIPredicate::slt) |
| 29 | + return failure(); |
| 30 | + |
| 31 | + auto rhsOp = cmpOp.getRhs().getDefiningOp<arith::ConstantOp>(); |
| 32 | + if (!rhsOp) |
| 33 | + return failure(); |
| 34 | + |
| 35 | + auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(rhsOp.getValue()); |
| 36 | + if (!intAttr) |
| 37 | + return failure(); |
| 38 | + |
| 39 | + int64_t loopCount = intAttr.getInt(); |
| 40 | + |
| 41 | + rewriter.setInsertionPointAfter(condBrOp); |
| 42 | + condBrOp->setAttr("loop_count", rewriter.getI64IntegerAttr(loopCount)); |
| 43 | + |
| 44 | + return success(); |
| 45 | + } |
| 46 | +}; |
| 47 | + |
| 48 | +/// A pass that runs on a function and looks for any cf.cond_br that |
| 49 | +/// implements a loop condition (in the pattern above), then annotates |
| 50 | +/// it with `loop_count`. |
| 51 | +struct AnnotateLoopCountPass |
| 52 | + : public PassWrapper<AnnotateLoopCountPass, OperationPass<func::FuncOp>> { |
| 53 | + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateLoopCountPass) |
| 54 | + |
| 55 | + StringRef getArgument() const final { return "annotate-loop-count"; } |
| 56 | + StringRef getDescription() const final { |
| 57 | + return "Annotate cf.cond_br ops with loop iteration counts."; |
| 58 | + } |
| 59 | + |
| 60 | + void runOnOperation() override { |
| 61 | + auto function = getOperation(); |
| 62 | + |
| 63 | + RewritePatternSet patterns(&getContext()); |
| 64 | + patterns.add<AnnotateLoopCondBrPattern>(&getContext()); |
| 65 | + |
| 66 | + if (failed(applyPatternsAndFoldGreedily(function.getBody(), |
| 67 | + std::move(patterns)))) { |
| 68 | + signalPassFailure(); |
| 69 | + } |
| 70 | + } |
| 71 | +}; |
| 72 | + |
| 73 | +} // namespace |
| 74 | + |
| 75 | +std::unique_ptr<Pass> createAnnotateLoopCountPass() { |
| 76 | + return std::make_unique<AnnotateLoopCountPass>(); |
| 77 | +} |
0 commit comments