Skip to content

Commit 9af56b7

Browse files
authored
[CIR][IR] Refactor parsing/printing of implicitly terminated regions (#310)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #314 * #313 * #312 * #311 * __->__ #310 The `shouldPrintTerm` and `checkBlockTerminator` were replaced in favor of `omitRegionTerm` and `ensureRegionTerm` respectively. The first is essentially the same method but simplified. The latter was refactored to do only two things: check if the terminator omission of a region is valid and, if so, insert the omitted terminator into the region. The simplifications mostly leverage the fact that we only omit empty yield values in a single-block region.
1 parent ed1bd8f commit 9af56b7

File tree

2 files changed

+55
-91
lines changed

2 files changed

+55
-91
lines changed

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 52 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,39 @@ static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) {
148148
return static_cast<RetTy>(index);
149149
}
150150

151+
// Check if a region's termination omission is valid and, if so, creates and
152+
// inserts the omitted terminator into the region.
153+
LogicalResult ensureRegionTerm(OpAsmParser &parser, Region &region,
154+
SMLoc errLoc) {
155+
Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
156+
OpBuilder builder(parser.getBuilder().getContext());
157+
158+
// Region is empty or properly terminated: nothing to do.
159+
if (region.empty() || region.back().hasTerminator())
160+
return success();
161+
162+
// Check for invalid terminator omissions.
163+
if (!region.hasOneBlock())
164+
return parser.emitError(errLoc,
165+
"multi-block region must not omit terminator");
166+
if (region.back().empty())
167+
return parser.emitError(errLoc, "empty region must not omit terminator");
168+
169+
// Terminator was omited correctly: recreate it.
170+
region.back().push_back(builder.create<cir::YieldOp>(eLoc));
171+
return success();
172+
}
173+
174+
// True if the region's terminator should be omitted.
175+
bool omitRegionTerm(mlir::Region &r) {
176+
const auto singleNonEmptyBlock = r.hasOneBlock() && !r.front().empty();
177+
const auto yieldsNothing = [&r]() {
178+
YieldOp y = dyn_cast<YieldOp>(r.back().back());
179+
return y && y.isPlain() && y.getArgs().empty();
180+
};
181+
return singleNonEmptyBlock && yieldsNothing();
182+
}
183+
151184
//===----------------------------------------------------------------------===//
152185
// AllocaOp
153186
//===----------------------------------------------------------------------===//
@@ -413,53 +446,6 @@ mlir::LogicalResult ThrowOp::verify() {
413446
// IfOp
414447
//===----------------------------------------------------------------------===//
415448

416-
static LogicalResult checkBlockTerminator(OpAsmParser &parser,
417-
llvm::SMLoc parserLoc,
418-
std::optional<Location> l, Region *r,
419-
bool ensureTerm = true) {
420-
mlir::Builder &builder = parser.getBuilder();
421-
if (r->hasOneBlock()) {
422-
if (ensureTerm) {
423-
::mlir::impl::ensureRegionTerminator(
424-
*r, builder, *l, [](OpBuilder &builder, Location loc) {
425-
OperationState state(loc, YieldOp::getOperationName());
426-
YieldOp::build(builder, state);
427-
return Operation::create(state);
428-
});
429-
} else {
430-
assert(r && "region must not be empty");
431-
Block &block = r->back();
432-
if (block.empty() || !block.back().hasTrait<OpTrait::IsTerminator>()) {
433-
return parser.emitError(
434-
parser.getCurrentLocation(),
435-
"blocks are expected to be explicitly terminated");
436-
}
437-
}
438-
return success();
439-
}
440-
441-
// Empty regions don't need any handling.
442-
auto &blocks = r->getBlocks();
443-
if (blocks.empty())
444-
return success();
445-
446-
// Test that at least one block has a yield/return/throw terminator. We can
447-
// probably make this a bit more strict.
448-
for (Block &block : blocks) {
449-
if (block.empty())
450-
continue;
451-
auto &op = block.back();
452-
if (op.hasTrait<mlir::OpTrait::IsTerminator>() &&
453-
isa<YieldOp, ReturnOp, ThrowOp>(op)) {
454-
return success();
455-
}
456-
}
457-
458-
parser.emitError(parserLoc,
459-
"expected at least one block with cir.yield or cir.return");
460-
return failure();
461-
}
462-
463449
ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
464450
// Create the regions for 'then'.
465451
result.regions.reserve(2);
@@ -479,17 +465,15 @@ ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
479465
if (parser.parseRegion(*thenRegion, /*arguments=*/{},
480466
/*argTypes=*/{}))
481467
return failure();
482-
if (checkBlockTerminator(parser, parseThenLoc, result.location, thenRegion)
483-
.failed())
468+
if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed())
484469
return failure();
485470

486471
// If we find an 'else' keyword, parse the 'else' region.
487472
if (!parser.parseOptionalKeyword("else")) {
488473
auto parseElseLoc = parser.getCurrentLocation();
489474
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
490475
return failure();
491-
if (checkBlockTerminator(parser, parseElseLoc, result.location, elseRegion)
492-
.failed())
476+
if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed())
493477
return failure();
494478
}
495479

@@ -499,36 +483,20 @@ ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) {
499483
return success();
500484
}
501485

502-
bool shouldPrintTerm(mlir::Region &r) {
503-
if (!r.hasOneBlock())
504-
return true;
505-
auto *entryBlock = &r.front();
506-
if (entryBlock->empty())
507-
return false;
508-
if (isa<ReturnOp>(entryBlock->back()))
509-
return true;
510-
if (isa<ThrowOp>(entryBlock->back()))
511-
return true;
512-
YieldOp y = dyn_cast<YieldOp>(entryBlock->back());
513-
if (y && (!y.isPlain() || !y.getArgs().empty()))
514-
return true;
515-
return false;
516-
}
517-
518486
void cir::IfOp::print(OpAsmPrinter &p) {
519487
p << " " << getCondition() << " ";
520488
auto &thenRegion = this->getThenRegion();
521489
p.printRegion(thenRegion,
522490
/*printEntryBlockArgs=*/false,
523-
/*printBlockTerminators=*/shouldPrintTerm(thenRegion));
491+
/*printBlockTerminators=*/!omitRegionTerm(thenRegion));
524492

525493
// Print the 'else' regions if it exists and has a block.
526494
auto &elseRegion = this->getElseRegion();
527495
if (!elseRegion.empty()) {
528496
p << " else ";
529497
p.printRegion(elseRegion,
530498
/*printEntryBlockArgs=*/false,
531-
/*printBlockTerminators=*/shouldPrintTerm(elseRegion));
499+
/*printBlockTerminators=*/!omitRegionTerm(elseRegion));
532500
}
533501

534502
p.printOptionalAttrDict(getOperation()->getAttrs());
@@ -611,7 +579,7 @@ ParseResult cir::ScopeOp::parse(OpAsmParser &parser, OperationState &result) {
611579
if (parser.parseRegion(*scopeRegion, /*arguments=*/{}, /*argTypes=*/{}))
612580
return failure();
613581

614-
if (checkBlockTerminator(parser, loc, result.location, scopeRegion).failed())
582+
if (ensureRegionTerm(parser, *scopeRegion, loc).failed())
615583
return failure();
616584

617585
// Parse the optional attribute list.
@@ -625,7 +593,7 @@ void cir::ScopeOp::print(OpAsmPrinter &p) {
625593
auto &scopeRegion = this->getScopeRegion();
626594
p.printRegion(scopeRegion,
627595
/*printEntryBlockArgs=*/false,
628-
/*printBlockTerminators=*/shouldPrintTerm(scopeRegion));
596+
/*printBlockTerminators=*/!omitRegionTerm(scopeRegion));
629597

630598
p.printOptionalAttrDict(getOperation()->getAttrs());
631599
}
@@ -866,10 +834,10 @@ parseSwitchOp(OpAsmParser &parser,
866834
"case region shall not be empty");
867835
}
868836

869-
if (checkBlockTerminator(parser, parserLoc, std::nullopt, &currRegion,
870-
/*ensureTerm=*/false)
871-
.failed())
872-
return failure();
837+
if (!currRegion.back().hasTerminator())
838+
return parser.emitError(parserLoc,
839+
"case regions must be explicitly terminated");
840+
873841
return success();
874842
};
875843

@@ -1134,10 +1102,10 @@ parseCatchOp(OpAsmParser &parser,
11341102
"catch region shall not be empty");
11351103
}
11361104

1137-
if (checkBlockTerminator(parser, parserLoc, std::nullopt, &currRegion,
1138-
/*ensureTerm=*/false)
1139-
.failed())
1140-
return failure();
1105+
if (!currRegion.back().hasTerminator())
1106+
return parser.emitError(
1107+
parserLoc, "blocks are expected to be explicitly terminated");
1108+
11411109
return success();
11421110
};
11431111

@@ -1388,9 +1356,7 @@ static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
13881356
if (ctorRegion.back().empty())
13891357
return parser.emitError(parser.getCurrentLocation(),
13901358
"ctor region shall not be empty");
1391-
if (checkBlockTerminator(parser, parseLoc,
1392-
ctorRegion.back().back().getLoc(), &ctorRegion)
1393-
.failed())
1359+
if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed())
13941360
return failure();
13951361
} else {
13961362
// Parse constant with initializer, examples:
@@ -1417,9 +1383,7 @@ static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser,
14171383
if (dtorRegion.back().empty())
14181384
return parser.emitError(parser.getCurrentLocation(),
14191385
"dtor region shall not be empty");
1420-
if (checkBlockTerminator(parser, parseLoc,
1421-
dtorRegion.back().back().getLoc(), &dtorRegion)
1422-
.failed())
1386+
if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed())
14231387
return failure();
14241388
}
14251389
}
@@ -2445,10 +2409,10 @@ LogicalResult GetMemberOp::verify() {
24452409
// these still need to be patched.
24462410
// Also we bypass the typechecking for the fields of incomplete types.
24472411
bool shouldSkipMemberTypeMismatch =
2448-
recordTy.isClass() || isIncompleteType(recordTy.getMembers()[getIndex()]);
2412+
recordTy.isClass() || isIncompleteType(recordTy.getMembers()[getIndex()]);
24492413

2450-
if (!shouldSkipMemberTypeMismatch
2451-
&& recordTy.getMembers()[getIndex()] != getResultTy().getPointee())
2414+
if (!shouldSkipMemberTypeMismatch &&
2415+
recordTy.getMembers()[getIndex()] != getResultTy().getPointee())
24522416
return emitError() << "member type mismatch";
24532417

24542418
return mlir::success();

clang/test/CIR/IR/invalid.cir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ cir.func @if0() {
4141
#true = #cir.bool<true> : !cir.bool
4242
cir.func @yield0() {
4343
%0 = cir.const(#true) : !cir.bool
44-
cir.if %0 { // expected-error {{custom op 'cir.if' expected at least one block with cir.yield or cir.return}}
44+
cir.if %0 { // expected-error {{custom op 'cir.if' multi-block region must not omit terminator}}
4545
cir.br ^a
4646
^a:
4747
}
@@ -90,10 +90,10 @@ cir.func @yieldcontinue() {
9090
cir.func @s0() {
9191
%1 = cir.const(#cir.int<2> : !s32i) : !s32i
9292
cir.switch (%1 : !s32i) [
93-
case (equal, 5) {
93+
case (equal, 5) { // expected-error {{custom op 'cir.switch' case regions must be explicitly terminated}}
9494
%2 = cir.const(#cir.int<3> : !s32i) : !s32i
9595
}
96-
] // expected-error {{blocks are expected to be explicitly terminated}}
96+
]
9797
cir.return
9898
}
9999

0 commit comments

Comments
 (0)