Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][IR] Refactor parsing/printing of implicitly terminated regions #310

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 52 additions & 88 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,39 @@ static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
return static_cast<RetTy>(index);
}

// Check if a region's termination omission is valid and, if so, creates and
// inserts the omitted terminator into the region.
LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
SMLoc errLoc) {
Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
OpBuilder builder(parser.getBuilder().getContext());

// Region is empty or properly terminated: nothing to do.
if (region.empty() || region.back().hasTerminator())
return success();

// Check for invalid terminator omissions.
if (!region.hasOneBlock())
return parser.emitError(errLoc,
"multi-block region must not omit terminator");
if (region.back().empty())
return parser.emitError(errLoc, "empty region must not omit terminator");

// Terminator was omited correctly: recreate it.
region.back().push_back(builder.create<cir::YieldOp>(eLoc));
return success();
}

// True if the region's terminator should be omitted.
bool omitRegionTerm(mlir::Region &r) {
const auto singleNonEmptyBlock = r.hasOneBlock() && !r.front().empty();
const auto yieldsNothing = [&r]() {
YieldOp y = dyn_cast<YieldOp>(r.back().back());
return y && y.isPlain() && y.getArgs().empty();
};
return singleNonEmptyBlock && yieldsNothing();
}

//===----------------------------------------------------------------------===//
// AllocaOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -413,53 +446,6 @@ mlir::LogicalResult ThrowOp::verify() {
// IfOp
//===----------------------------------------------------------------------===//

static LogicalResult checkBlockTerminator(OpAsmParser &parser,
llvm::SMLoc parserLoc,
std::optional<Location> l, Region *r,
bool ensureTerm = true) {
mlir::Builder &builder = parser.getBuilder();
if (r->hasOneBlock()) {
if (ensureTerm) {
::mlir::impl::ensureRegionTerminator(
*r, builder, *l, [](OpBuilder &builder, Location loc) {
OperationState state(loc, YieldOp::getOperationName());
YieldOp::build(builder, state);
return Operation::create(state);
});
} else {
assert(r && "region must not be empty");
Block &block = r->back();
if (block.empty() || !block.back().hasTrait<OpTrait::IsTerminator>()) {
return parser.emitError(
parser.getCurrentLocation(),
"blocks are expected to be explicitly terminated");
}
}
return success();
}

// Empty regions don't need any handling.
auto &blocks = r->getBlocks();
if (blocks.empty())
return success();

// Test that at least one block has a yield/return/throw terminator. We can
// probably make this a bit more strict.
for (Block &block : blocks) {
if (block.empty())
continue;
auto &op = block.back();
if (op.hasTrait<mlir::OpTrait::IsTerminator>() &&
isa<YieldOp, ReturnOp, ThrowOp>(op)) {
return success();
}
}

parser.emitError(parserLoc,
"expected at least one block with cir.yield or cir.return");
return failure();
}

ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
result.regions.reserve(2);
Expand All @@ -479,17 +465,15 @@ ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseRegion(*thenRegion, /*arguments=*/{},
/*argTypes=*/{}))
return failure();
if (checkBlockTerminator(parser, parseThenLoc, result.location, thenRegion)
.failed())
if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
return failure();

// If we find an 'else' keyword, parse the 'else' region.
if (!parser.parseOptionalKeyword("else")) {
auto parseElseLoc = parser.getCurrentLocation();
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
if (checkBlockTerminator(parser, parseElseLoc, result.location, elseRegion)
.failed())
if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
return failure();
}

Expand All @@ -499,36 +483,20 @@ ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}

bool shouldPrintTerm(mlir::Region &r) {
if (!r.hasOneBlock())
return true;
auto *entryBlock = &r.front();
if (entryBlock->empty())
return false;
if (isa<ReturnOp>(entryBlock->back()))
return true;
if (isa<ThrowOp>(entryBlock->back()))
return true;
YieldOp y = dyn_cast<YieldOp>(entryBlock->back());
if (y && (!y.isPlain() || !y.getArgs().empty()))
return true;
return false;
}

void cir::IfOp::print(OpAsmPrinter &p) {
p << " " << getCondition() << " ";
auto &thenRegion = this->getThenRegion();
p.printRegion(thenRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/shouldPrintTerm(thenRegion));
/*printBlockTerminators=*/!omitRegionTerm(thenRegion));

// Print the 'else' regions if it exists and has a block.
auto &elseRegion = this->getElseRegion();
if (!elseRegion.empty()) {
p << " else ";
p.printRegion(elseRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/shouldPrintTerm(elseRegion));
/*printBlockTerminators=*/!omitRegionTerm(elseRegion));
}

p.printOptionalAttrDict(getOperation()->getAttrs());
Expand Down Expand Up @@ -611,7 +579,7 @@ ParseResult cir::ScopeOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseRegion(*scopeRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();

if (checkBlockTerminator(parser, loc, result.location, scopeRegion).failed())
if (ensureRegionTerm(parser, *scopeRegion, loc).failed())
return failure();

// Parse the optional attribute list.
Expand All @@ -625,7 +593,7 @@ void cir::ScopeOp::print(OpAsmPrinter &p) {
auto &scopeRegion = this->getScopeRegion();
p.printRegion(scopeRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/shouldPrintTerm(scopeRegion));
/*printBlockTerminators=*/!omitRegionTerm(scopeRegion));

p.printOptionalAttrDict(getOperation()->getAttrs());
}
Expand Down Expand Up @@ -866,10 +834,10 @@ parseSwitchOp(OpAsmParser &parser,
"case region shall not be empty");
}

if (checkBlockTerminator(parser, parserLoc, std::nullopt, &currRegion,
/*ensureTerm=*/false)
.failed())
return failure();
if (!currRegion.back().hasTerminator())
return parser.emitError(parserLoc,
"case regions must be explicitly terminated");

return success();
};

Expand Down Expand Up @@ -1134,10 +1102,10 @@ parseCatchOp(OpAsmParser &parser,
"catch region shall not be empty");
}

if (checkBlockTerminator(parser, parserLoc, std::nullopt, &currRegion,
/*ensureTerm=*/false)
.failed())
return failure();
if (!currRegion.back().hasTerminator())
return parser.emitError(
parserLoc, "blocks are expected to be explicitly terminated");

return success();
};

Expand Down Expand Up @@ -1388,9 +1356,7 @@ static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
if (ctorRegion.back().empty())
return parser.emitError(parser.getCurrentLocation(),
"ctor region shall not be empty");
if (checkBlockTerminator(parser, parseLoc,
ctorRegion.back().back().getLoc(), &ctorRegion)
.failed())
if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
return failure();
} else {
// Parse constant with initializer, examples:
Expand All @@ -1417,9 +1383,7 @@ static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
if (dtorRegion.back().empty())
return parser.emitError(parser.getCurrentLocation(),
"dtor region shall not be empty");
if (checkBlockTerminator(parser, parseLoc,
dtorRegion.back().back().getLoc(), &dtorRegion)
.failed())
if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
return failure();
}
}
Expand Down Expand Up @@ -2445,10 +2409,10 @@ LogicalResult GetMemberOp::verify() {
// these still need to be patched.
// Also we bypass the typechecking for the fields of incomplete types.
bool shouldSkipMemberTypeMismatch =
recordTy.isClass() || isIncompleteType(recordTy.getMembers()[getIndex()]);
recordTy.isClass() || isIncompleteType(recordTy.getMembers()[getIndex()]);

if (!shouldSkipMemberTypeMismatch
&& recordTy.getMembers()[getIndex()] != getResultTy().getPointee())
if (!shouldSkipMemberTypeMismatch &&
recordTy.getMembers()[getIndex()] != getResultTy().getPointee())
return emitError() << "member type mismatch";

return mlir::success();
Expand Down
6 changes: 3 additions & 3 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ cir.func @if0() {
#true = #cir.bool<true> : !cir.bool
cir.func @yield0() {
%0 = cir.const(#true) : !cir.bool
cir.if %0 { // expected-error {{custom op 'cir.if' expected at least one block with cir.yield or cir.return}}
cir.if %0 { // expected-error {{custom op 'cir.if' multi-block region must not omit terminator}}
cir.br ^a
^a:
}
Expand Down Expand Up @@ -90,10 +90,10 @@ cir.func @yieldcontinue() {
cir.func @s0() {
%1 = cir.const(#cir.int<2> : !s32i) : !s32i
cir.switch (%1 : !s32i) [
case (equal, 5) {
case (equal, 5) { // expected-error {{custom op 'cir.switch' case regions must be explicitly terminated}}
%2 = cir.const(#cir.int<3> : !s32i) : !s32i
}
] // expected-error {{blocks are expected to be explicitly terminated}}
]
cir.return
}

Expand Down