From 3aeb3e0e9868eb4dec6acde2abc79de35251b68c Mon Sep 17 00:00:00 2001 From: Vinicius Couto Espindola <34522047+sitio-couto@users.noreply.github.com> Date: Mon, 8 Jan 2024 21:54:39 -0300 Subject: [PATCH] [CIR][Transforms][NFC] Refactor MergeCleanups pass (#384) Breaks the pass into smaller more manageable rewrites. --- .../CIR/Dialect/Transforms/MergeCleanups.cpp | 326 +++++++----------- 1 file changed, 123 insertions(+), 203 deletions(-) diff --git a/clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp b/clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp index f295361140a9..822ce6f4bb2c 100644 --- a/clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp +++ b/clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp @@ -7,248 +7,168 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" - -#include "clang/CIR/Dialect/IR/CIRDialect.h" -#include "clang/CIR/Dialect/Passes.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - -#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "clang/CIR/Dialect/IR/CIRDialect.h" +#include "clang/CIR/Dialect/Passes.h" using namespace mlir; using namespace cir; -namespace { - -template -struct SimplifyRetYieldBlocks : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - mlir::LogicalResult replaceScopeLikeOp(PatternRewriter &rewriter, - ScopeLikeOpTy scopeLikeOp) const; - - SimplifyRetYieldBlocks(mlir::MLIRContext *context) - : OpRewritePattern(context, /*benefit=*/1) {} - - mlir::LogicalResult - checkAndRewriteRegion(mlir::Region &r, - mlir::PatternRewriter &rewriter) const { - auto &blocks = r.getBlocks(); - - if (blocks.size() <= 1) - return failure(); - - // Rewrite something like this: - // - // cir.if %2 { - // %3 = cir.const(3 : i32) : i32 - // cir.br ^bb1 - // ^bb1: // pred: ^bb0 - // cir.return %3 : i32 - // } - // - // to this: - // - // cir.if %2 { - // %3 = cir.const(3 : i32) : i32 - // cir.return %3 : i32 - // } - // - SmallPtrSet candidateBlocks; - for (Block &block : blocks) { - if (block.isEntryBlock()) - continue; - - auto yieldVars = block.getOps(); - for (cir::YieldOp yield : yieldVars) - candidateBlocks.insert(yield.getOperation()->getBlock()); +//===----------------------------------------------------------------------===// +// Rewrite patterns +//===----------------------------------------------------------------------===// - auto retVars = block.getOps(); - for (cir::ReturnOp ret : retVars) - candidateBlocks.insert(ret.getOperation()->getBlock()); - } +namespace { - auto changed = mlir::failure(); - for (auto *mergeSource : candidateBlocks) { - if (!(mergeSource->hasNoSuccessors() && mergeSource->hasOneUse())) - continue; - auto *mergeDest = mergeSource->getSinglePredecessor(); - if (!mergeDest || mergeDest->getNumSuccessors() != 1) - continue; - rewriter.eraseOp(mergeDest->getTerminator()); - rewriter.mergeBlocks(mergeSource, mergeDest); - changed = mlir::success(); +/// Removes branches between two blocks if it is the only branch. +/// +/// From: +/// ^bb0: +/// cir.br ^bb1 +/// ^bb1: // pred: ^bb0 +/// cir.return +/// +/// To: +/// ^bb0: +/// cir.return +struct RemoveRedudantBranches : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BrOp op, + PatternRewriter &rewriter) const final { + Block *block = op.getOperation()->getBlock(); + Block *dest = op.getDest(); + + // Single edge between blocks: merge it. + if (block->getNumSuccessors() == 1 && + dest->getSinglePredecessor() == block) { + rewriter.eraseOp(op); + rewriter.mergeBlocks(dest, block); + return success(); } - return changed; + return failure(); } +}; - mlir::LogicalResult - checkAndRewriteLoopCond(mlir::Region &condRegion, - mlir::PatternRewriter &rewriter) const { - SmallVector opsToSimplify; - condRegion.walk([&](Operation *op) { - if (isa(op)) - opsToSimplify.push_back(op); - }); - - // Blocks should only contain one "yield" operation. - auto trivialYield = [&](Block *b) { - if (&b->front() != &b->back()) - return false; - return isa(b->getTerminator()); - }; - - if (opsToSimplify.size() != 1) - return failure(); - BrCondOp brCondOp = cast(opsToSimplify[0]); +/// Merges basic blocks of trivial conditional branches. This is useful when a +/// the condition of conditional branch is a constant and the destinations of +/// the conditional branch both have only one predecessor. +/// +/// From: +/// ^bb0: +/// %0 = cir.const(#true) : !cir.bool +/// cir.brcond %0 ^bb1, ^bb2 +/// ^bb1: // pred: ^bb0 +/// cir.yield continue +/// ^bb2: // pred: ^bb0 +/// cir.yield +/// +/// To: +/// ^bb0: +/// cir.yield continue +/// +struct MergeTrivialConditionalBranches : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(BrCondOp op) const final { + return success(isa(op.getCond().getDefiningOp()) && + op.getDestFalse()->hasOneUse() && + op.getDestTrue()->hasOneUse()); + } - // TODO: leverage SCCP to get improved results. - auto cstOp = dyn_cast(brCondOp.getCond().getDefiningOp()); - if (!cstOp || !cstOp.getValue().isa() || - !trivialYield(brCondOp.getDestTrue()) || - !trivialYield(brCondOp.getDestFalse())) - return failure(); + /// Replace conditional branch with unconditional branch. + void rewrite(BrCondOp op, PatternRewriter &rewriter) const final { + auto constOp = llvm::cast(op.getCond().getDefiningOp()); + bool cond = constOp.getValue().cast().getValue(); + Block *block = op.getOperation()->getBlock(); - // If the condition is constant, no need to use brcond, just yield - // properly, "yield" for false and "yield continue" for true. - auto boolAttr = cstOp.getValue().cast(); - auto *falseBlock = brCondOp.getDestFalse(); - auto *trueBlock = brCondOp.getDestTrue(); - auto *currBlock = brCondOp.getOperation()->getBlock(); - if (boolAttr.getValue()) { - rewriter.eraseOp(opsToSimplify[0]); - rewriter.mergeBlocks(trueBlock, currBlock); - falseBlock->erase(); + rewriter.eraseOp(op); + if (cond) { + rewriter.mergeBlocks(op.getDestTrue(), block); + rewriter.eraseBlock(op.getDestFalse()); } else { - rewriter.eraseOp(opsToSimplify[0]); - rewriter.mergeBlocks(falseBlock, currBlock); - trueBlock->erase(); + rewriter.mergeBlocks(op.getDestFalse(), block); + rewriter.eraseBlock(op.getDestTrue()); } - if (cstOp.use_empty()) - rewriter.eraseOp(cstOp); - return success(); - } - - mlir::LogicalResult - matchAndRewrite(ScopeLikeOpTy op, - mlir::PatternRewriter &rewriter) const override { - return replaceScopeLikeOp(rewriter, op); } }; -// Specialize the template to account for the different build signatures for -// IfOp, ScopeOp, FuncOp, SwitchOp, LoopOp. -template <> -mlir::LogicalResult -SimplifyRetYieldBlocks::replaceScopeLikeOp(PatternRewriter &rewriter, - IfOp ifOp) const { - auto regionChanged = mlir::failure(); - if (checkAndRewriteRegion(ifOp.getThenRegion(), rewriter).succeeded()) - regionChanged = mlir::success(); - if (checkAndRewriteRegion(ifOp.getElseRegion(), rewriter).succeeded()) - regionChanged = mlir::success(); - return regionChanged; -} +struct RemoveEmptyScope : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -template <> -mlir::LogicalResult -SimplifyRetYieldBlocks::replaceScopeLikeOp(PatternRewriter &rewriter, - ScopeOp scopeOp) const { - // Scope region empty: just remove scope. - if (scopeOp.getRegion().empty()) { - rewriter.eraseOp(scopeOp); - return mlir::success(); + LogicalResult match(ScopeOp op) const final { + return success(op.getRegion().empty() || + (op.getRegion().getBlocks().size() == 1 && + op.getRegion().front().empty())); } - // Scope region non-empty: clean it up. - if (checkAndRewriteRegion(scopeOp.getRegion(), rewriter).succeeded()) - return mlir::success(); - - return mlir::failure(); -} - -template <> -mlir::LogicalResult SimplifyRetYieldBlocks::replaceScopeLikeOp( - PatternRewriter &rewriter, cir::FuncOp funcOp) const { - auto regionChanged = mlir::failure(); - if (checkAndRewriteRegion(funcOp.getRegion(), rewriter).succeeded()) - regionChanged = mlir::success(); - return regionChanged; -} + void rewrite(ScopeOp op, PatternRewriter &rewriter) const final { + rewriter.eraseOp(op); + } +}; -template <> -mlir::LogicalResult SimplifyRetYieldBlocks::replaceScopeLikeOp( - PatternRewriter &rewriter, cir::SwitchOp switchOp) const { - auto regionChanged = mlir::failure(); +struct RemoveEmptySwitch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - // Empty switch statement: just remove it. - if (!switchOp.getCases().has_value() || switchOp.getCases()->empty()) { - rewriter.eraseOp(switchOp); - return mlir::success(); + LogicalResult match(SwitchOp op) const final { + return success(op.getRegions().empty()); } - // Non-empty switch statement: clean it up. - for (auto &r : switchOp.getRegions()) { - if (checkAndRewriteRegion(r, rewriter).succeeded()) - regionChanged = mlir::success(); + void rewrite(SwitchOp op, PatternRewriter &rewriter) const final { + rewriter.eraseOp(op); } - return regionChanged; -} - -template <> -mlir::LogicalResult SimplifyRetYieldBlocks::replaceScopeLikeOp( - PatternRewriter &rewriter, cir::LoopOp loopOp) const { - auto regionChanged = mlir::failure(); - if (checkAndRewriteRegion(loopOp.getBody(), rewriter).succeeded()) - regionChanged = mlir::success(); - if (checkAndRewriteLoopCond(loopOp.getCond(), rewriter).succeeded()) - regionChanged = mlir::success(); - return regionChanged; -} +}; -void getMergeCleanupsPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add, SimplifyRetYieldBlocks, - SimplifyRetYieldBlocks, - SimplifyRetYieldBlocks, - SimplifyRetYieldBlocks>(context); -} +//===----------------------------------------------------------------------===// +// MergeCleanupsPass +//===----------------------------------------------------------------------===// struct MergeCleanupsPass : public MergeCleanupsBase { - MergeCleanupsPass() = default; + using MergeCleanupsBase::MergeCleanupsBase; + + // The same operation rewriting done here could have been performed + // by CanonicalizerPass (adding hasCanonicalizer for target Ops and + // implementing the same from above in CIRDialects.cpp). However, it's + // currently too aggressive for static analysis purposes, since it might + // remove things where a diagnostic can be generated. + // + // FIXME: perhaps we can add one more mode to GreedyRewriteConfig to + // disable this behavior. void runOnOperation() override; }; -// The same operation rewriting done here could have been performed -// by CanonicalizerPass (adding hasCanonicalizer for target Ops and implementing -// the same from above in CIRDialects.cpp). However, it's currently too -// aggressive for static analysis purposes, since it might remove things where -// a diagnostic can be generated. -// -// FIXME: perhaps we can add one more mode to GreedyRewriteConfig to -// disable this behavior. -void MergeCleanupsPass::runOnOperation() { - auto op = getOperation(); - mlir::RewritePatternSet patterns(&getContext()); - getMergeCleanupsPatterns(patterns, &getContext()); - FrozenRewritePatternSet frozenPatterns(std::move(patterns)); +void populateMergeCleanupPatterns(RewritePatternSet &patterns) { + // clang-format off + patterns.add< + RemoveRedudantBranches, + MergeTrivialConditionalBranches, + RemoveEmptyScope, + RemoveEmptySwitch + >(patterns.getContext()); + // clang-format on +} - SmallVector opsToSimplify; - op->walk([&](Operation *op) { - if (isa( - op)) - opsToSimplify.push_back(op); +void MergeCleanupsPass::runOnOperation() { + // Collect rewrite patterns. + RewritePatternSet patterns(&getContext()); + populateMergeCleanupPatterns(patterns); + + // Collect operations to apply patterns. + SmallVector ops; + getOperation()->walk([&](Operation *op) { + if (isa(op)) + ops.push_back(op); }); - for (auto *o : opsToSimplify) { - bool erase = false; - (void)applyOpPatternsAndFold(o, frozenPatterns, GreedyRewriteConfig(), - &erase); - } + // Apply patterns. + if (applyOpPatternsAndFold(ops, std::move(patterns)).failed()) + signalPassFailure(); } + } // namespace std::unique_ptr mlir::createMergeCleanupsPass() {