Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][Transforms][NFC] Refactor MergeCleanups pass #384

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 123 additions & 203 deletions clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,248 +7,168 @@
//===----------------------------------------------------------------------===//

#include "PassDetail.h"

#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"

using namespace mlir;
using namespace cir;

namespace {

template <typename ScopeLikeOpTy>
struct SimplifyRetYieldBlocks : public mlir::OpRewritePattern<ScopeLikeOpTy> {
using OpRewritePattern<ScopeLikeOpTy>::OpRewritePattern;
mlir::LogicalResult replaceScopeLikeOp(PatternRewriter &rewriter,
ScopeLikeOpTy scopeLikeOp) const;

SimplifyRetYieldBlocks(mlir::MLIRContext *context)
: OpRewritePattern<ScopeLikeOpTy>(context, /*benefit=*/1) {}

mlir::LogicalResult
checkAndRewriteRegion(mlir::Region &r,
mlir::PatternRewriter &rewriter) const {
auto &blocks = r.getBlocks();

if (blocks.size() <= 1)
return failure();

// Rewrite something like this:
//
// cir.if %2 {
// %3 = cir.const(3 : i32) : i32
// cir.br ^bb1
// ^bb1: // pred: ^bb0
// cir.return %3 : i32
// }
//
// to this:
//
// cir.if %2 {
// %3 = cir.const(3 : i32) : i32
// cir.return %3 : i32
// }
//
SmallPtrSet<mlir::Block *, 4> candidateBlocks;
for (Block &block : blocks) {
if (block.isEntryBlock())
continue;

auto yieldVars = block.getOps<cir::YieldOp>();
for (cir::YieldOp yield : yieldVars)
candidateBlocks.insert(yield.getOperation()->getBlock());
//===----------------------------------------------------------------------===//
// Rewrite patterns
//===----------------------------------------------------------------------===//

auto retVars = block.getOps<cir::ReturnOp>();
for (cir::ReturnOp ret : retVars)
candidateBlocks.insert(ret.getOperation()->getBlock());
}
namespace {

auto changed = mlir::failure();
for (auto *mergeSource : candidateBlocks) {
if (!(mergeSource->hasNoSuccessors() && mergeSource->hasOneUse()))
continue;
auto *mergeDest = mergeSource->getSinglePredecessor();
if (!mergeDest || mergeDest->getNumSuccessors() != 1)
continue;
rewriter.eraseOp(mergeDest->getTerminator());
rewriter.mergeBlocks(mergeSource, mergeDest);
changed = mlir::success();
/// Removes branches between two blocks if it is the only branch.
///
/// From:
/// ^bb0:
/// cir.br ^bb1
/// ^bb1: // pred: ^bb0
/// cir.return
///
/// To:
/// ^bb0:
/// cir.return
struct RemoveRedudantBranches : public OpRewritePattern<BrOp> {
using OpRewritePattern<BrOp>::OpRewritePattern;

LogicalResult matchAndRewrite(BrOp op,
PatternRewriter &rewriter) const final {
Block *block = op.getOperation()->getBlock();
Block *dest = op.getDest();

// Single edge between blocks: merge it.
if (block->getNumSuccessors() == 1 &&
dest->getSinglePredecessor() == block) {
rewriter.eraseOp(op);
rewriter.mergeBlocks(dest, block);
return success();
}

return changed;
return failure();
}
};

mlir::LogicalResult
checkAndRewriteLoopCond(mlir::Region &condRegion,
mlir::PatternRewriter &rewriter) const {
SmallVector<Operation *> opsToSimplify;
condRegion.walk([&](Operation *op) {
if (isa<cir::BrCondOp>(op))
opsToSimplify.push_back(op);
});

// Blocks should only contain one "yield" operation.
auto trivialYield = [&](Block *b) {
if (&b->front() != &b->back())
return false;
return isa<YieldOp>(b->getTerminator());
};

if (opsToSimplify.size() != 1)
return failure();
BrCondOp brCondOp = cast<cir::BrCondOp>(opsToSimplify[0]);
/// Merges basic blocks of trivial conditional branches. This is useful when a
/// the condition of conditional branch is a constant and the destinations of
/// the conditional branch both have only one predecessor.
///
/// From:
/// ^bb0:
/// %0 = cir.const(#true) : !cir.bool
/// cir.brcond %0 ^bb1, ^bb2
/// ^bb1: // pred: ^bb0
/// cir.yield continue
/// ^bb2: // pred: ^bb0
/// cir.yield
///
/// To:
/// ^bb0:
/// cir.yield continue
///
struct MergeTrivialConditionalBranches : public OpRewritePattern<BrCondOp> {
using OpRewritePattern<BrCondOp>::OpRewritePattern;

LogicalResult match(BrCondOp op) const final {
return success(isa<ConstantOp>(op.getCond().getDefiningOp()) &&
op.getDestFalse()->hasOneUse() &&
op.getDestTrue()->hasOneUse());
}

// TODO: leverage SCCP to get improved results.
auto cstOp = dyn_cast<cir::ConstantOp>(brCondOp.getCond().getDefiningOp());
if (!cstOp || !cstOp.getValue().isa<mlir::cir::BoolAttr>() ||
!trivialYield(brCondOp.getDestTrue()) ||
!trivialYield(brCondOp.getDestFalse()))
return failure();
/// Replace conditional branch with unconditional branch.
void rewrite(BrCondOp op, PatternRewriter &rewriter) const final {
auto constOp = llvm::cast<ConstantOp>(op.getCond().getDefiningOp());
bool cond = constOp.getValue().cast<cir::BoolAttr>().getValue();
Block *block = op.getOperation()->getBlock();

// If the condition is constant, no need to use brcond, just yield
// properly, "yield" for false and "yield continue" for true.
auto boolAttr = cstOp.getValue().cast<mlir::cir::BoolAttr>();
auto *falseBlock = brCondOp.getDestFalse();
auto *trueBlock = brCondOp.getDestTrue();
auto *currBlock = brCondOp.getOperation()->getBlock();
if (boolAttr.getValue()) {
rewriter.eraseOp(opsToSimplify[0]);
rewriter.mergeBlocks(trueBlock, currBlock);
falseBlock->erase();
rewriter.eraseOp(op);
if (cond) {
rewriter.mergeBlocks(op.getDestTrue(), block);
rewriter.eraseBlock(op.getDestFalse());
} else {
rewriter.eraseOp(opsToSimplify[0]);
rewriter.mergeBlocks(falseBlock, currBlock);
trueBlock->erase();
rewriter.mergeBlocks(op.getDestFalse(), block);
rewriter.eraseBlock(op.getDestTrue());
}
if (cstOp.use_empty())
rewriter.eraseOp(cstOp);
return success();
}

mlir::LogicalResult
matchAndRewrite(ScopeLikeOpTy op,
mlir::PatternRewriter &rewriter) const override {
return replaceScopeLikeOp(rewriter, op);
}
};

// Specialize the template to account for the different build signatures for
// IfOp, ScopeOp, FuncOp, SwitchOp, LoopOp.
template <>
mlir::LogicalResult
SimplifyRetYieldBlocks<IfOp>::replaceScopeLikeOp(PatternRewriter &rewriter,
IfOp ifOp) const {
auto regionChanged = mlir::failure();
if (checkAndRewriteRegion(ifOp.getThenRegion(), rewriter).succeeded())
regionChanged = mlir::success();
if (checkAndRewriteRegion(ifOp.getElseRegion(), rewriter).succeeded())
regionChanged = mlir::success();
return regionChanged;
}
struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
using OpRewritePattern<ScopeOp>::OpRewritePattern;

template <>
mlir::LogicalResult
SimplifyRetYieldBlocks<ScopeOp>::replaceScopeLikeOp(PatternRewriter &rewriter,
ScopeOp scopeOp) const {
// Scope region empty: just remove scope.
if (scopeOp.getRegion().empty()) {
rewriter.eraseOp(scopeOp);
return mlir::success();
LogicalResult match(ScopeOp op) const final {
return success(op.getRegion().empty() ||
(op.getRegion().getBlocks().size() == 1 &&
op.getRegion().front().empty()));
}

// Scope region non-empty: clean it up.
if (checkAndRewriteRegion(scopeOp.getRegion(), rewriter).succeeded())
return mlir::success();

return mlir::failure();
}

template <>
mlir::LogicalResult SimplifyRetYieldBlocks<cir::FuncOp>::replaceScopeLikeOp(
PatternRewriter &rewriter, cir::FuncOp funcOp) const {
auto regionChanged = mlir::failure();
if (checkAndRewriteRegion(funcOp.getRegion(), rewriter).succeeded())
regionChanged = mlir::success();
return regionChanged;
}
void rewrite(ScopeOp op, PatternRewriter &rewriter) const final {
rewriter.eraseOp(op);
}
};

template <>
mlir::LogicalResult SimplifyRetYieldBlocks<cir::SwitchOp>::replaceScopeLikeOp(
PatternRewriter &rewriter, cir::SwitchOp switchOp) const {
auto regionChanged = mlir::failure();
struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
using OpRewritePattern<SwitchOp>::OpRewritePattern;

// Empty switch statement: just remove it.
if (!switchOp.getCases().has_value() || switchOp.getCases()->empty()) {
rewriter.eraseOp(switchOp);
return mlir::success();
LogicalResult match(SwitchOp op) const final {
return success(op.getRegions().empty());
}

// Non-empty switch statement: clean it up.
for (auto &r : switchOp.getRegions()) {
if (checkAndRewriteRegion(r, rewriter).succeeded())
regionChanged = mlir::success();
void rewrite(SwitchOp op, PatternRewriter &rewriter) const final {
rewriter.eraseOp(op);
}
return regionChanged;
}

template <>
mlir::LogicalResult SimplifyRetYieldBlocks<cir::LoopOp>::replaceScopeLikeOp(
PatternRewriter &rewriter, cir::LoopOp loopOp) const {
auto regionChanged = mlir::failure();
if (checkAndRewriteRegion(loopOp.getBody(), rewriter).succeeded())
regionChanged = mlir::success();
if (checkAndRewriteLoopCond(loopOp.getCond(), rewriter).succeeded())
regionChanged = mlir::success();
return regionChanged;
}
};

void getMergeCleanupsPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyRetYieldBlocks<IfOp>, SimplifyRetYieldBlocks<ScopeOp>,
SimplifyRetYieldBlocks<cir::FuncOp>,
SimplifyRetYieldBlocks<cir::SwitchOp>,
SimplifyRetYieldBlocks<cir::LoopOp>>(context);
}
//===----------------------------------------------------------------------===//
// MergeCleanupsPass
//===----------------------------------------------------------------------===//

struct MergeCleanupsPass : public MergeCleanupsBase<MergeCleanupsPass> {
MergeCleanupsPass() = default;
using MergeCleanupsBase::MergeCleanupsBase;

// The same operation rewriting done here could have been performed
// by CanonicalizerPass (adding hasCanonicalizer for target Ops and
// implementing the same from above in CIRDialects.cpp). However, it's
// currently too aggressive for static analysis purposes, since it might
// remove things where a diagnostic can be generated.
//
// FIXME: perhaps we can add one more mode to GreedyRewriteConfig to
// disable this behavior.
void runOnOperation() override;
};

// The same operation rewriting done here could have been performed
// by CanonicalizerPass (adding hasCanonicalizer for target Ops and implementing
// the same from above in CIRDialects.cpp). However, it's currently too
// aggressive for static analysis purposes, since it might remove things where
// a diagnostic can be generated.
//
// FIXME: perhaps we can add one more mode to GreedyRewriteConfig to
// disable this behavior.
void MergeCleanupsPass::runOnOperation() {
auto op = getOperation();
mlir::RewritePatternSet patterns(&getContext());
getMergeCleanupsPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
RemoveRedudantBranches,
MergeTrivialConditionalBranches,
RemoveEmptyScope,
RemoveEmptySwitch
>(patterns.getContext());
// clang-format on
}

SmallVector<Operation *> opsToSimplify;
op->walk([&](Operation *op) {
if (isa<cir::IfOp, cir::ScopeOp, cir::FuncOp, cir::SwitchOp, cir::LoopOp>(
op))
opsToSimplify.push_back(op);
void MergeCleanupsPass::runOnOperation() {
// Collect rewrite patterns.
RewritePatternSet patterns(&getContext());
populateMergeCleanupPatterns(patterns);

// Collect operations to apply patterns.
SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp>(op))
ops.push_back(op);
});

for (auto *o : opsToSimplify) {
bool erase = false;
(void)applyOpPatternsAndFold(o, frozenPatterns, GreedyRewriteConfig(),
&erase);
}
// Apply patterns.
if (applyOpPatternsAndFold(ops, std::move(patterns)).failed())
signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> mlir::createMergeCleanupsPass() {
Expand Down