Skip to content

Commit

Permalink
[CIR][IR] Refactor parsing/printing of implicitly terminated regions (#…
Browse files Browse the repository at this point in the history
…310)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #314
* #313
* #312
* #311
* __->__ #310

The `shouldPrintTerm` and `checkBlockTerminator` were replaced in favor
of `omitRegionTerm` and `ensureRegionTerm` respectively. The first is
essentially the same method but simplified. The latter was refactored to
do only two things: check if the terminator omission of a region is
valid and, if so, insert the omitted terminator into the region.

The simplifications mostly leverage the fact that we only omit empty
yield values in a single-block region.
  • Loading branch information
sitio-couto authored Nov 17, 2023
1 parent ed1bd8f commit 9af56b7
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 91 deletions.
140 changes: 52 additions & 88 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,39 @@ static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
return static_cast<RetTy>(index);
}

// Check if a region's termination omission is valid and, if so, creates and
// inserts the omitted terminator into the region.
LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
SMLoc errLoc) {
Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
OpBuilder builder(parser.getBuilder().getContext());

// Region is empty or properly terminated: nothing to do.
if (region.empty() || region.back().hasTerminator())
return success();

// Check for invalid terminator omissions.
if (!region.hasOneBlock())
return parser.emitError(errLoc,
"multi-block region must not omit terminator");
if (region.back().empty())
return parser.emitError(errLoc, "empty region must not omit terminator");

// Terminator was omited correctly: recreate it.
region.back().push_back(builder.create<cir::YieldOp>(eLoc));
return success();
}

// True if the region's terminator should be omitted.
bool omitRegionTerm(mlir::Region &r) {
const auto singleNonEmptyBlock = r.hasOneBlock() && !r.front().empty();
const auto yieldsNothing = [&r]() {
YieldOp y = dyn_cast<YieldOp>(r.back().back());
return y && y.isPlain() && y.getArgs().empty();
};
return singleNonEmptyBlock && yieldsNothing();
}

//===----------------------------------------------------------------------===//
// AllocaOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -413,53 +446,6 @@ mlir::LogicalResult ThrowOp::verify() {
// IfOp
//===----------------------------------------------------------------------===//

static LogicalResult checkBlockTerminator(OpAsmParser &parser,
llvm::SMLoc parserLoc,
std::optional<Location> l, Region *r,
bool ensureTerm = true) {
mlir::Builder &builder = parser.getBuilder();
if (r->hasOneBlock()) {
if (ensureTerm) {
::mlir::impl::ensureRegionTerminator(
*r, builder, *l, [](OpBuilder &builder, Location loc) {
OperationState state(loc, YieldOp::getOperationName());
YieldOp::build(builder, state);
return Operation::create(state);
});
} else {
assert(r && "region must not be empty");
Block &block = r->back();
if (block.empty() || !block.back().hasTrait<OpTrait::IsTerminator>()) {
return parser.emitError(
parser.getCurrentLocation(),
"blocks are expected to be explicitly terminated");
}
}
return success();
}

// Empty regions don't need any handling.
auto &blocks = r->getBlocks();
if (blocks.empty())
return success();

// Test that at least one block has a yield/return/throw terminator. We can
// probably make this a bit more strict.
for (Block &block : blocks) {
if (block.empty())
continue;
auto &op = block.back();
if (op.hasTrait<mlir::OpTrait::IsTerminator>() &&
isa<YieldOp, ReturnOp, ThrowOp>(op)) {
return success();
}
}

parser.emitError(parserLoc,
"expected at least one block with cir.yield or cir.return");
return failure();
}

ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
result.regions.reserve(2);
Expand All @@ -479,17 +465,15 @@ ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseRegion(*thenRegion, /*arguments=*/{},
/*argTypes=*/{}))
return failure();
if (checkBlockTerminator(parser, parseThenLoc, result.location, thenRegion)
.failed())
if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
return failure();

// If we find an 'else' keyword, parse the 'else' region.
if (!parser.parseOptionalKeyword("else")) {
auto parseElseLoc = parser.getCurrentLocation();
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
if (checkBlockTerminator(parser, parseElseLoc, result.location, elseRegion)
.failed())
if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
return failure();
}

Expand All @@ -499,36 +483,20 @@ ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}

bool shouldPrintTerm(mlir::Region &r) {
if (!r.hasOneBlock())
return true;
auto *entryBlock = &r.front();
if (entryBlock->empty())
return false;
if (isa<ReturnOp>(entryBlock->back()))
return true;
if (isa<ThrowOp>(entryBlock->back()))
return true;
YieldOp y = dyn_cast<YieldOp>(entryBlock->back());
if (y && (!y.isPlain() || !y.getArgs().empty()))
return true;
return false;
}

void cir::IfOp::print(OpAsmPrinter &p) {
p << " " << getCondition() << " ";
auto &thenRegion = this->getThenRegion();
p.printRegion(thenRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/shouldPrintTerm(thenRegion));
/*printBlockTerminators=*/!omitRegionTerm(thenRegion));

// Print the 'else' regions if it exists and has a block.
auto &elseRegion = this->getElseRegion();
if (!elseRegion.empty()) {
p << " else ";
p.printRegion(elseRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/shouldPrintTerm(elseRegion));
/*printBlockTerminators=*/!omitRegionTerm(elseRegion));
}

p.printOptionalAttrDict(getOperation()->getAttrs());
Expand Down Expand Up @@ -611,7 +579,7 @@ ParseResult cir::ScopeOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseRegion(*scopeRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();

if (checkBlockTerminator(parser, loc, result.location, scopeRegion).failed())
if (ensureRegionTerm(parser, *scopeRegion, loc).failed())
return failure();

// Parse the optional attribute list.
Expand All @@ -625,7 +593,7 @@ void cir::ScopeOp::print(OpAsmPrinter &p) {
auto &scopeRegion = this->getScopeRegion();
p.printRegion(scopeRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/shouldPrintTerm(scopeRegion));
/*printBlockTerminators=*/!omitRegionTerm(scopeRegion));

p.printOptionalAttrDict(getOperation()->getAttrs());
}
Expand Down Expand Up @@ -866,10 +834,10 @@ parseSwitchOp(OpAsmParser &parser,
"case region shall not be empty");
}

if (checkBlockTerminator(parser, parserLoc, std::nullopt, &currRegion,
/*ensureTerm=*/false)
.failed())
return failure();
if (!currRegion.back().hasTerminator())
return parser.emitError(parserLoc,
"case regions must be explicitly terminated");

return success();
};

Expand Down Expand Up @@ -1134,10 +1102,10 @@ parseCatchOp(OpAsmParser &parser,
"catch region shall not be empty");
}

if (checkBlockTerminator(parser, parserLoc, std::nullopt, &currRegion,
/*ensureTerm=*/false)
.failed())
return failure();
if (!currRegion.back().hasTerminator())
return parser.emitError(
parserLoc, "blocks are expected to be explicitly terminated");

return success();
};

Expand Down Expand Up @@ -1388,9 +1356,7 @@ static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
if (ctorRegion.back().empty())
return parser.emitError(parser.getCurrentLocation(),
"ctor region shall not be empty");
if (checkBlockTerminator(parser, parseLoc,
ctorRegion.back().back().getLoc(), &ctorRegion)
.failed())
if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
return failure();
} else {
// Parse constant with initializer, examples:
Expand All @@ -1417,9 +1383,7 @@ static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
if (dtorRegion.back().empty())
return parser.emitError(parser.getCurrentLocation(),
"dtor region shall not be empty");
if (checkBlockTerminator(parser, parseLoc,
dtorRegion.back().back().getLoc(), &dtorRegion)
.failed())
if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
return failure();
}
}
Expand Down Expand Up @@ -2445,10 +2409,10 @@ LogicalResult GetMemberOp::verify() {
// these still need to be patched.
// Also we bypass the typechecking for the fields of incomplete types.
bool shouldSkipMemberTypeMismatch =
recordTy.isClass() || isIncompleteType(recordTy.getMembers()[getIndex()]);
recordTy.isClass() || isIncompleteType(recordTy.getMembers()[getIndex()]);

if (!shouldSkipMemberTypeMismatch
&& recordTy.getMembers()[getIndex()] != getResultTy().getPointee())
if (!shouldSkipMemberTypeMismatch &&
recordTy.getMembers()[getIndex()] != getResultTy().getPointee())
return emitError() << "member type mismatch";

return mlir::success();
Expand Down
6 changes: 3 additions & 3 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ cir.func @if0() {
#true = #cir.bool<true> : !cir.bool
cir.func @yield0() {
%0 = cir.const(#true) : !cir.bool
cir.if %0 { // expected-error {{custom op 'cir.if' expected at least one block with cir.yield or cir.return}}
cir.if %0 { // expected-error {{custom op 'cir.if' multi-block region must not omit terminator}}
cir.br ^a
^a:
}
Expand Down Expand Up @@ -90,10 +90,10 @@ cir.func @yieldcontinue() {
cir.func @s0() {
%1 = cir.const(#cir.int<2> : !s32i) : !s32i
cir.switch (%1 : !s32i) [
case (equal, 5) {
case (equal, 5) { // expected-error {{custom op 'cir.switch' case regions must be explicitly terminated}}
%2 = cir.const(#cir.int<3> : !s32i) : !s32i
}
] // expected-error {{blocks are expected to be explicitly terminated}}
]
cir.return
}

Expand Down

0 comments on commit 9af56b7

Please sign in to comment.