Skip to content

Commit

Permalink
[CIR][Transforms][NFC] Refactor MergeCleanups pass (llvm#384)
Browse files Browse the repository at this point in the history
Breaks the pass into smaller more manageable rewrites.
  • Loading branch information
sitio-couto authored and lanza committed Oct 1, 2024
1 parent eb1f98f commit 7a2be67
Showing 1 changed file with 123 additions and 203 deletions.
326 changes: 123 additions & 203 deletions clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,248 +7,168 @@
//===----------------------------------------------------------------------===//

#include "PassDetail.h"

#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"

using namespace mlir;
using namespace cir;

namespace {

template <typename ScopeLikeOpTy>
struct SimplifyRetYieldBlocks : public mlir::OpRewritePattern<ScopeLikeOpTy> {
using OpRewritePattern<ScopeLikeOpTy>::OpRewritePattern;
mlir::LogicalResult replaceScopeLikeOp(PatternRewriter &rewriter,
ScopeLikeOpTy scopeLikeOp) const;

SimplifyRetYieldBlocks(mlir::MLIRContext *context)
: OpRewritePattern<ScopeLikeOpTy>(context, /*benefit=*/1) {}

mlir::LogicalResult
checkAndRewriteRegion(mlir::Region &r,
mlir::PatternRewriter &rewriter) const {
auto &blocks = r.getBlocks();

if (blocks.size() <= 1)
return failure();

// Rewrite something like this:
//
// cir.if %2 {
// %3 = cir.const(3 : i32) : i32
// cir.br ^bb1
// ^bb1: // pred: ^bb0
// cir.return %3 : i32
// }
//
// to this:
//
// cir.if %2 {
// %3 = cir.const(3 : i32) : i32
// cir.return %3 : i32
// }
//
SmallPtrSet<mlir::Block *, 4> candidateBlocks;
for (Block &block : blocks) {
if (block.isEntryBlock())
continue;

auto yieldVars = block.getOps<cir::YieldOp>();
for (cir::YieldOp yield : yieldVars)
candidateBlocks.insert(yield.getOperation()->getBlock());
//===----------------------------------------------------------------------===//
// Rewrite patterns
//===----------------------------------------------------------------------===//

auto retVars = block.getOps<cir::ReturnOp>();
for (cir::ReturnOp ret : retVars)
candidateBlocks.insert(ret.getOperation()->getBlock());
}
namespace {

auto changed = mlir::failure();
for (auto *mergeSource : candidateBlocks) {
if (!(mergeSource->hasNoSuccessors() && mergeSource->hasOneUse()))
continue;
auto *mergeDest = mergeSource->getSinglePredecessor();
if (!mergeDest || mergeDest->getNumSuccessors() != 1)
continue;
rewriter.eraseOp(mergeDest->getTerminator());
rewriter.mergeBlocks(mergeSource, mergeDest);
changed = mlir::success();
/// Removes branches between two blocks if it is the only branch.
///
/// From:
/// ^bb0:
/// cir.br ^bb1
/// ^bb1: // pred: ^bb0
/// cir.return
///
/// To:
/// ^bb0:
/// cir.return
struct RemoveRedudantBranches : public OpRewritePattern<BrOp> {
using OpRewritePattern<BrOp>::OpRewritePattern;

LogicalResult matchAndRewrite(BrOp op,
PatternRewriter &rewriter) const final {
Block *block = op.getOperation()->getBlock();
Block *dest = op.getDest();

// Single edge between blocks: merge it.
if (block->getNumSuccessors() == 1 &&
dest->getSinglePredecessor() == block) {
rewriter.eraseOp(op);
rewriter.mergeBlocks(dest, block);
return success();
}

return changed;
return failure();
}
};

mlir::LogicalResult
checkAndRewriteLoopCond(mlir::Region &condRegion,
mlir::PatternRewriter &rewriter) const {
SmallVector<Operation *> opsToSimplify;
condRegion.walk([&](Operation *op) {
if (isa<cir::BrCondOp>(op))
opsToSimplify.push_back(op);
});

// Blocks should only contain one "yield" operation.
auto trivialYield = [&](Block *b) {
if (&b->front() != &b->back())
return false;
return isa<YieldOp>(b->getTerminator());
};

if (opsToSimplify.size() != 1)
return failure();
BrCondOp brCondOp = cast<cir::BrCondOp>(opsToSimplify[0]);
/// Merges basic blocks of trivial conditional branches. This is useful when a
/// the condition of conditional branch is a constant and the destinations of
/// the conditional branch both have only one predecessor.
///
/// From:
/// ^bb0:
/// %0 = cir.const(#true) : !cir.bool
/// cir.brcond %0 ^bb1, ^bb2
/// ^bb1: // pred: ^bb0
/// cir.yield continue
/// ^bb2: // pred: ^bb0
/// cir.yield
///
/// To:
/// ^bb0:
/// cir.yield continue
///
struct MergeTrivialConditionalBranches : public OpRewritePattern<BrCondOp> {
using OpRewritePattern<BrCondOp>::OpRewritePattern;

LogicalResult match(BrCondOp op) const final {
return success(isa<ConstantOp>(op.getCond().getDefiningOp()) &&
op.getDestFalse()->hasOneUse() &&
op.getDestTrue()->hasOneUse());
}

// TODO: leverage SCCP to get improved results.
auto cstOp = dyn_cast<cir::ConstantOp>(brCondOp.getCond().getDefiningOp());
if (!cstOp || !cstOp.getValue().isa<mlir::cir::BoolAttr>() ||
!trivialYield(brCondOp.getDestTrue()) ||
!trivialYield(brCondOp.getDestFalse()))
return failure();
/// Replace conditional branch with unconditional branch.
void rewrite(BrCondOp op, PatternRewriter &rewriter) const final {
auto constOp = llvm::cast<ConstantOp>(op.getCond().getDefiningOp());
bool cond = constOp.getValue().cast<cir::BoolAttr>().getValue();
Block *block = op.getOperation()->getBlock();

// If the condition is constant, no need to use brcond, just yield
// properly, "yield" for false and "yield continue" for true.
auto boolAttr = cstOp.getValue().cast<mlir::cir::BoolAttr>();
auto *falseBlock = brCondOp.getDestFalse();
auto *trueBlock = brCondOp.getDestTrue();
auto *currBlock = brCondOp.getOperation()->getBlock();
if (boolAttr.getValue()) {
rewriter.eraseOp(opsToSimplify[0]);
rewriter.mergeBlocks(trueBlock, currBlock);
falseBlock->erase();
rewriter.eraseOp(op);
if (cond) {
rewriter.mergeBlocks(op.getDestTrue(), block);
rewriter.eraseBlock(op.getDestFalse());
} else {
rewriter.eraseOp(opsToSimplify[0]);
rewriter.mergeBlocks(falseBlock, currBlock);
trueBlock->erase();
rewriter.mergeBlocks(op.getDestFalse(), block);
rewriter.eraseBlock(op.getDestTrue());
}
if (cstOp.use_empty())
rewriter.eraseOp(cstOp);
return success();
}

mlir::LogicalResult
matchAndRewrite(ScopeLikeOpTy op,
mlir::PatternRewriter &rewriter) const override {
return replaceScopeLikeOp(rewriter, op);
}
};

// Specialize the template to account for the different build signatures for
// IfOp, ScopeOp, FuncOp, SwitchOp, LoopOp.
template <>
mlir::LogicalResult
SimplifyRetYieldBlocks<IfOp>::replaceScopeLikeOp(PatternRewriter &rewriter,
IfOp ifOp) const {
auto regionChanged = mlir::failure();
if (checkAndRewriteRegion(ifOp.getThenRegion(), rewriter).succeeded())
regionChanged = mlir::success();
if (checkAndRewriteRegion(ifOp.getElseRegion(), rewriter).succeeded())
regionChanged = mlir::success();
return regionChanged;
}
struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
using OpRewritePattern<ScopeOp>::OpRewritePattern;

template <>
mlir::LogicalResult
SimplifyRetYieldBlocks<ScopeOp>::replaceScopeLikeOp(PatternRewriter &rewriter,
ScopeOp scopeOp) const {
// Scope region empty: just remove scope.
if (scopeOp.getRegion().empty()) {
rewriter.eraseOp(scopeOp);
return mlir::success();
LogicalResult match(ScopeOp op) const final {
return success(op.getRegion().empty() ||
(op.getRegion().getBlocks().size() == 1 &&
op.getRegion().front().empty()));
}

// Scope region non-empty: clean it up.
if (checkAndRewriteRegion(scopeOp.getRegion(), rewriter).succeeded())
return mlir::success();

return mlir::failure();
}

template <>
mlir::LogicalResult SimplifyRetYieldBlocks<cir::FuncOp>::replaceScopeLikeOp(
PatternRewriter &rewriter, cir::FuncOp funcOp) const {
auto regionChanged = mlir::failure();
if (checkAndRewriteRegion(funcOp.getRegion(), rewriter).succeeded())
regionChanged = mlir::success();
return regionChanged;
}
void rewrite(ScopeOp op, PatternRewriter &rewriter) const final {
rewriter.eraseOp(op);
}
};

template <>
mlir::LogicalResult SimplifyRetYieldBlocks<cir::SwitchOp>::replaceScopeLikeOp(
PatternRewriter &rewriter, cir::SwitchOp switchOp) const {
auto regionChanged = mlir::failure();
struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
using OpRewritePattern<SwitchOp>::OpRewritePattern;

// Empty switch statement: just remove it.
if (!switchOp.getCases().has_value() || switchOp.getCases()->empty()) {
rewriter.eraseOp(switchOp);
return mlir::success();
LogicalResult match(SwitchOp op) const final {
return success(op.getRegions().empty());
}

// Non-empty switch statement: clean it up.
for (auto &r : switchOp.getRegions()) {
if (checkAndRewriteRegion(r, rewriter).succeeded())
regionChanged = mlir::success();
void rewrite(SwitchOp op, PatternRewriter &rewriter) const final {
rewriter.eraseOp(op);
}
return regionChanged;
}

template <>
mlir::LogicalResult SimplifyRetYieldBlocks<cir::LoopOp>::replaceScopeLikeOp(
PatternRewriter &rewriter, cir::LoopOp loopOp) const {
auto regionChanged = mlir::failure();
if (checkAndRewriteRegion(loopOp.getBody(), rewriter).succeeded())
regionChanged = mlir::success();
if (checkAndRewriteLoopCond(loopOp.getCond(), rewriter).succeeded())
regionChanged = mlir::success();
return regionChanged;
}
};

void getMergeCleanupsPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyRetYieldBlocks<IfOp>, SimplifyRetYieldBlocks<ScopeOp>,
SimplifyRetYieldBlocks<cir::FuncOp>,
SimplifyRetYieldBlocks<cir::SwitchOp>,
SimplifyRetYieldBlocks<cir::LoopOp>>(context);
}
//===----------------------------------------------------------------------===//
// MergeCleanupsPass
//===----------------------------------------------------------------------===//

struct MergeCleanupsPass : public MergeCleanupsBase<MergeCleanupsPass> {
MergeCleanupsPass() = default;
using MergeCleanupsBase::MergeCleanupsBase;

// The same operation rewriting done here could have been performed
// by CanonicalizerPass (adding hasCanonicalizer for target Ops and
// implementing the same from above in CIRDialects.cpp). However, it's
// currently too aggressive for static analysis purposes, since it might
// remove things where a diagnostic can be generated.
//
// FIXME: perhaps we can add one more mode to GreedyRewriteConfig to
// disable this behavior.
void runOnOperation() override;
};

// The same operation rewriting done here could have been performed
// by CanonicalizerPass (adding hasCanonicalizer for target Ops and implementing
// the same from above in CIRDialects.cpp). However, it's currently too
// aggressive for static analysis purposes, since it might remove things where
// a diagnostic can be generated.
//
// FIXME: perhaps we can add one more mode to GreedyRewriteConfig to
// disable this behavior.
void MergeCleanupsPass::runOnOperation() {
auto op = getOperation();
mlir::RewritePatternSet patterns(&getContext());
getMergeCleanupsPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
RemoveRedudantBranches,
MergeTrivialConditionalBranches,
RemoveEmptyScope,
RemoveEmptySwitch
>(patterns.getContext());
// clang-format on
}

SmallVector<Operation *> opsToSimplify;
op->walk([&](Operation *op) {
if (isa<cir::IfOp, cir::ScopeOp, cir::FuncOp, cir::SwitchOp, cir::LoopOp>(
op))
opsToSimplify.push_back(op);
void MergeCleanupsPass::runOnOperation() {
// Collect rewrite patterns.
RewritePatternSet patterns(&getContext());
populateMergeCleanupPatterns(patterns);

// Collect operations to apply patterns.
SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp>(op))
ops.push_back(op);
});

for (auto *o : opsToSimplify) {
bool erase = false;
(void)applyOpPatternsAndFold(o, frozenPatterns, GreedyRewriteConfig(),
&erase);
}
// Apply patterns.
if (applyOpPatternsAndFold(ops, std::move(patterns)).failed())
signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> mlir::createMergeCleanupsPass() {
Expand Down

0 comments on commit 7a2be67

Please sign in to comment.