Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][IR] Implement loop's conditional operation #391

Merged
merged 2 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,25 @@ def TernaryOp : CIR_Op<"ternary",
}];
}

//===----------------------------------------------------------------------===//
// ConditionOp
//===----------------------------------------------------------------------===//

def ConditionOp : CIR_Op<"condition", [
Terminator,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface,
["getSuccessorRegions"]>
]> {
let summary = "Loop continuation condition.";
let description = [{
The `cir.condition` termintes loop's conditional regions. It takes a single
`cir.bool` operand. if the operand is true, the loop continues, otherwise
it terminates.
}];
let arguments = (ins CIR_BoolType:$condition);
let assemblyFormat = " `(` $condition `)` attr-dict ";
}

//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,11 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return create<mlir::cir::CopyOp>(dst.getLoc(), dst, src);
}

/// Create a loop condition.
mlir::cir::ConditionOp createCondition(mlir::Value condition) {
return create<mlir::cir::ConditionOp>(condition.getLoc(), condition);
}

mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
mlir::Value src, mlir::Value len) {
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);
Expand Down
32 changes: 4 additions & 28 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,26 +650,6 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
return buildCaseDefaultCascade(&S, condType, caseAttrs, os);
}

static mlir::LogicalResult buildLoopCondYield(mlir::OpBuilder &builder,
mlir::Location loc,
mlir::Value cond) {
mlir::Block *trueBB = nullptr, *falseBB = nullptr;
{
mlir::OpBuilder::InsertionGuard guard(builder);
trueBB = builder.createBlock(builder.getBlock()->getParent());
builder.create<mlir::cir::YieldOp>(loc, YieldOpKind::Continue);
}
{
mlir::OpBuilder::InsertionGuard guard(builder);
falseBB = builder.createBlock(builder.getBlock()->getParent());
builder.create<mlir::cir::YieldOp>(loc);
}

assert((trueBB && falseBB) && "expected both blocks to exist");
builder.create<mlir::cir::BrCondOp>(loc, cond, trueBB, falseBB);
return mlir::success();
}

mlir::LogicalResult
CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
ArrayRef<const Attr *> ForAttrs) {
Expand Down Expand Up @@ -703,8 +683,7 @@ CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
assert(!UnimplementedFeature::createProfileWeightsForLoop());
assert(!UnimplementedFeature::emitCondLikelihoodViaExpectIntrinsic());
mlir::Value condVal = evaluateExprAsBool(S.getCond());
if (buildLoopCondYield(b, loc, condVal).failed())
loopRes = mlir::failure();
builder.createCondition(condVal);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
Expand Down Expand Up @@ -786,8 +765,7 @@ mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
loc, boolTy,
mlir::cir::BoolAttr::get(b.getContext(), boolTy, true));
}
if (buildLoopCondYield(b, loc, condVal).failed())
loopRes = mlir::failure();
builder.createCondition(condVal);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
Expand Down Expand Up @@ -850,8 +828,7 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
// expression compares unequal to 0. The condition must be a
// scalar type.
mlir::Value condVal = evaluateExprAsBool(S.getCond());
if (buildLoopCondYield(b, loc, condVal).failed())
loopRes = mlir::failure();
builder.createCondition(condVal);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
Expand Down Expand Up @@ -910,8 +887,7 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
// expression compares unequal to 0. The condition must be a
// scalar type.
condVal = evaluateExprAsBool(S.getCond());
if (buildLoopCondYield(b, loc, condVal).failed())
loopRes = mlir::failure();
builder.createCondition(condVal);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
Expand Down
47 changes: 28 additions & 19 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,30 @@ void AllocaOp::build(::mlir::OpBuilder &odsBuilder,
odsState.addTypes(addr);
}

//===----------------------------------------------------------------------===//
// ConditionOp
//===-----------------------------------------------------------------------===//

//===----------------------------------
// BranchOpTerminatorInterface Methods

void ConditionOp::getSuccessorRegions(
ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
auto loopOp = cast<LoopOp>(getOperation()->getParentOp());

// TODO(cir): The condition value may be folded to a constant, narrowing
// down its list of possible successors.
// Condition may branch to the body or to the parent op.
regions.emplace_back(&loopOp.getBody(), loopOp.getBody().getArguments());
regions.emplace_back(loopOp->getResults());
}

MutableOperandRange
ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
// No values are yielded to the successor region.
return MutableOperandRange(getOperation(), 0, 0);
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1303,26 +1327,11 @@ void LoopOp::getSuccessorRegions(mlir::RegionBranchPoint point,
llvm::SmallVector<Region *> LoopOp::getLoopRegions() { return {&getBody()}; }

LogicalResult LoopOp::verify() {
// Cond regions should only terminate with plain 'cir.yield' or
// 'cir.yield continue'.
auto terminateError = [&]() {
return emitOpError() << "cond region must be terminated with "
"'cir.yield' or 'cir.yield continue'";
};
if (getCond().empty())
return emitOpError() << "cond region must not be empty";

auto &blocks = getCond().getBlocks();
for (Block &block : blocks) {
if (block.empty())
continue;
auto &op = block.back();
if (isa<BrCondOp>(op))
continue;
if (!isa<YieldOp>(op))
terminateError();
auto y = cast<YieldOp>(op);
if (!(y.isPlain() || y.isContinue()))
terminateError();
}
if (!llvm::isa<ConditionOp>(getCond().back().getTerminator()))
return emitOpError() << "cond region terminate with 'cir.condition'";

return success();
}
Expand Down
44 changes: 0 additions & 44 deletions clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,49 +54,6 @@ struct RemoveRedudantBranches : public OpRewritePattern<BrOp> {
}
};

/// Merges basic blocks of trivial conditional branches. This is useful when a
/// the condition of conditional branch is a constant and the destinations of
/// the conditional branch both have only one predecessor.
///
/// From:
/// ^bb0:
/// %0 = cir.const(#true) : !cir.bool
/// cir.brcond %0 ^bb1, ^bb2
/// ^bb1: // pred: ^bb0
/// cir.yield continue
/// ^bb2: // pred: ^bb0
/// cir.yield
///
/// To:
/// ^bb0:
/// cir.yield continue
///
struct MergeTrivialConditionalBranches : public OpRewritePattern<BrCondOp> {
using OpRewritePattern<BrCondOp>::OpRewritePattern;

LogicalResult match(BrCondOp op) const final {
return success(isa<ConstantOp>(op.getCond().getDefiningOp()) &&
op.getDestFalse()->hasOneUse() &&
op.getDestTrue()->hasOneUse());
}

/// Replace conditional branch with unconditional branch.
void rewrite(BrCondOp op, PatternRewriter &rewriter) const final {
auto constOp = llvm::cast<ConstantOp>(op.getCond().getDefiningOp());
bool cond = constOp.getValue().cast<cir::BoolAttr>().getValue();
Block *block = op.getOperation()->getBlock();

rewriter.eraseOp(op);
if (cond) {
rewriter.mergeBlocks(op.getDestTrue(), block);
rewriter.eraseBlock(op.getDestFalse());
} else {
rewriter.mergeBlocks(op.getDestFalse(), block);
rewriter.eraseBlock(op.getDestTrue());
}
}
};

struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
using OpRewritePattern<ScopeOp>::OpRewritePattern;

Expand Down Expand Up @@ -145,7 +102,6 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
RemoveRedudantBranches,
MergeTrivialConditionalBranches,
RemoveEmptyScope,
RemoveEmptySwitch
>(patterns.getContext());
Expand Down
43 changes: 12 additions & 31 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,25 +406,14 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
using mlir::OpConversionPattern<mlir::cir::LoopOp>::OpConversionPattern;
using LoopKind = mlir::cir::LoopOpKind;

mlir::LogicalResult
fetchCondRegionYields(mlir::Region &condRegion,
mlir::cir::YieldOp &yieldToBody,
mlir::cir::YieldOp &yieldToCont) const {
for (auto &bb : condRegion) {
if (auto yieldOp = dyn_cast<mlir::cir::YieldOp>(bb.getTerminator())) {
if (!yieldOp.getKind().has_value())
yieldToCont = yieldOp;
else if (yieldOp.getKind() == mlir::cir::YieldOpKind::Continue)
yieldToBody = yieldOp;
else
return mlir::failure();
}
}

// Succeed only if both yields are found.
if (!yieldToBody)
return mlir::failure();
return mlir::success();
inline void
lowerConditionOp(mlir::cir::ConditionOp op, mlir::Block *body,
mlir::Block *exit,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mlir::cir::BrCondOp>(op, op.getCondition(),
body, exit);
}

mlir::LogicalResult
Expand All @@ -438,9 +427,6 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
// Fetch required info from the condition region.
auto &condRegion = loopOp.getCond();
auto &condFrontBlock = condRegion.front();
mlir::cir::YieldOp yieldToBody, yieldToCont;
if (fetchCondRegionYields(condRegion, yieldToBody, yieldToCont).failed())
return loopOp.emitError("failed to fetch yields in cond region");

// Fetch required info from the body region.
auto &bodyRegion = loopOp.getBody();
Expand Down Expand Up @@ -472,15 +458,10 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
auto &entry = (kind != LoopKind::DoWhile ? condFrontBlock : bodyFrontBlock);
rewriter.create<mlir::cir::BrOp>(loopOp.getLoc(), &entry);

// Set loop exit point to continue block.
if (yieldToCont) {
rewriter.setInsertionPoint(yieldToCont);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToCont, continueBlock);
}

// Branch from condition to body.
rewriter.setInsertionPoint(yieldToBody);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToBody, &bodyFrontBlock);
// Branch from condition region to body or exit.
auto conditionOp =
cast<mlir::cir::ConditionOp>(condFrontBlock.getTerminator());
lowerConditionOp(conditionOp, &bodyFrontBlock, continueBlock, rewriter);

// Branch from body to condition or to step on for-loop cases.
rewriter.setInsertionPoint(bodyYield);
Expand Down
Loading