diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 35a194df22f1..29c2cde0b384 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -148,6 +148,39 @@ static RetTy parseOptionalCIRKeyword(AsmParser &parser, EnumTy defaultValue) { return static_cast(index); } +// Check if a region's termination omission is valid and, if so, creates and +// inserts the omitted terminator into the region. +LogicalResult ensureRegionTerm(OpAsmParser &parser, Region ®ion, + SMLoc errLoc) { + Location eLoc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + OpBuilder builder(parser.getBuilder().getContext()); + + // Region is empty or properly terminated: nothing to do. + if (region.empty() || region.back().hasTerminator()) + return success(); + + // Check for invalid terminator omissions. + if (!region.hasOneBlock()) + return parser.emitError(errLoc, + "multi-block region must not omit terminator"); + if (region.back().empty()) + return parser.emitError(errLoc, "empty region must not omit terminator"); + + // Terminator was omited correctly: recreate it. + region.back().push_back(builder.create(eLoc)); + return success(); +} + +// True if the region's terminator should be omitted. +bool omitRegionTerm(mlir::Region &r) { + const auto singleNonEmptyBlock = r.hasOneBlock() && !r.front().empty(); + const auto yieldsNothing = [&r]() { + YieldOp y = dyn_cast(r.back().back()); + return y && y.isPlain() && y.getArgs().empty(); + }; + return singleNonEmptyBlock && yieldsNothing(); +} + //===----------------------------------------------------------------------===// // AllocaOp //===----------------------------------------------------------------------===// @@ -413,53 +446,6 @@ mlir::LogicalResult ThrowOp::verify() { // IfOp //===----------------------------------------------------------------------===// -static LogicalResult checkBlockTerminator(OpAsmParser &parser, - llvm::SMLoc parserLoc, - std::optional l, Region *r, - bool ensureTerm = true) { - mlir::Builder &builder = parser.getBuilder(); - if (r->hasOneBlock()) { - if (ensureTerm) { - ::mlir::impl::ensureRegionTerminator( - *r, builder, *l, [](OpBuilder &builder, Location loc) { - OperationState state(loc, YieldOp::getOperationName()); - YieldOp::build(builder, state); - return Operation::create(state); - }); - } else { - assert(r && "region must not be empty"); - Block &block = r->back(); - if (block.empty() || !block.back().hasTrait()) { - return parser.emitError( - parser.getCurrentLocation(), - "blocks are expected to be explicitly terminated"); - } - } - return success(); - } - - // Empty regions don't need any handling. - auto &blocks = r->getBlocks(); - if (blocks.empty()) - return success(); - - // Test that at least one block has a yield/return/throw terminator. We can - // probably make this a bit more strict. - for (Block &block : blocks) { - if (block.empty()) - continue; - auto &op = block.back(); - if (op.hasTrait() && - isa(op)) { - return success(); - } - } - - parser.emitError(parserLoc, - "expected at least one block with cir.yield or cir.return"); - return failure(); -} - ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions for 'then'. result.regions.reserve(2); @@ -479,8 +465,7 @@ ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); - if (checkBlockTerminator(parser, parseThenLoc, result.location, thenRegion) - .failed()) + if (ensureRegionTerm(parser, *thenRegion, parseThenLoc).failed()) return failure(); // If we find an 'else' keyword, parse the 'else' region. @@ -488,8 +473,7 @@ ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) { auto parseElseLoc = parser.getCurrentLocation(); if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); - if (checkBlockTerminator(parser, parseElseLoc, result.location, elseRegion) - .failed()) + if (ensureRegionTerm(parser, *elseRegion, parseElseLoc).failed()) return failure(); } @@ -499,28 +483,12 @@ ParseResult cir::IfOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -bool shouldPrintTerm(mlir::Region &r) { - if (!r.hasOneBlock()) - return true; - auto *entryBlock = &r.front(); - if (entryBlock->empty()) - return false; - if (isa(entryBlock->back())) - return true; - if (isa(entryBlock->back())) - return true; - YieldOp y = dyn_cast(entryBlock->back()); - if (y && (!y.isPlain() || !y.getArgs().empty())) - return true; - return false; -} - void cir::IfOp::print(OpAsmPrinter &p) { p << " " << getCondition() << " "; auto &thenRegion = this->getThenRegion(); p.printRegion(thenRegion, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/shouldPrintTerm(thenRegion)); + /*printBlockTerminators=*/!omitRegionTerm(thenRegion)); // Print the 'else' regions if it exists and has a block. auto &elseRegion = this->getElseRegion(); @@ -528,7 +496,7 @@ void cir::IfOp::print(OpAsmPrinter &p) { p << " else "; p.printRegion(elseRegion, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/shouldPrintTerm(elseRegion)); + /*printBlockTerminators=*/!omitRegionTerm(elseRegion)); } p.printOptionalAttrDict(getOperation()->getAttrs()); @@ -611,7 +579,7 @@ ParseResult cir::ScopeOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseRegion(*scopeRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); - if (checkBlockTerminator(parser, loc, result.location, scopeRegion).failed()) + if (ensureRegionTerm(parser, *scopeRegion, loc).failed()) return failure(); // Parse the optional attribute list. @@ -625,7 +593,7 @@ void cir::ScopeOp::print(OpAsmPrinter &p) { auto &scopeRegion = this->getScopeRegion(); p.printRegion(scopeRegion, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/shouldPrintTerm(scopeRegion)); + /*printBlockTerminators=*/!omitRegionTerm(scopeRegion)); p.printOptionalAttrDict(getOperation()->getAttrs()); } @@ -866,10 +834,10 @@ parseSwitchOp(OpAsmParser &parser, "case region shall not be empty"); } - if (checkBlockTerminator(parser, parserLoc, std::nullopt, &currRegion, - /*ensureTerm=*/false) - .failed()) - return failure(); + if (!currRegion.back().hasTerminator()) + return parser.emitError(parserLoc, + "case regions must be explicitly terminated"); + return success(); }; @@ -1134,10 +1102,10 @@ parseCatchOp(OpAsmParser &parser, "catch region shall not be empty"); } - if (checkBlockTerminator(parser, parserLoc, std::nullopt, &currRegion, - /*ensureTerm=*/false) - .failed()) - return failure(); + if (!currRegion.back().hasTerminator()) + return parser.emitError( + parserLoc, "blocks are expected to be explicitly terminated"); + return success(); }; @@ -1388,9 +1356,7 @@ static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, if (ctorRegion.back().empty()) return parser.emitError(parser.getCurrentLocation(), "ctor region shall not be empty"); - if (checkBlockTerminator(parser, parseLoc, - ctorRegion.back().back().getLoc(), &ctorRegion) - .failed()) + if (ensureRegionTerm(parser, ctorRegion, parseLoc).failed()) return failure(); } else { // Parse constant with initializer, examples: @@ -1417,9 +1383,7 @@ static ParseResult parseGlobalOpTypeAndInitialValue(OpAsmParser &parser, if (dtorRegion.back().empty()) return parser.emitError(parser.getCurrentLocation(), "dtor region shall not be empty"); - if (checkBlockTerminator(parser, parseLoc, - dtorRegion.back().back().getLoc(), &dtorRegion) - .failed()) + if (ensureRegionTerm(parser, dtorRegion, parseLoc).failed()) return failure(); } } @@ -2445,10 +2409,10 @@ LogicalResult GetMemberOp::verify() { // these still need to be patched. // Also we bypass the typechecking for the fields of incomplete types. bool shouldSkipMemberTypeMismatch = - recordTy.isClass() || isIncompleteType(recordTy.getMembers()[getIndex()]); + recordTy.isClass() || isIncompleteType(recordTy.getMembers()[getIndex()]); - if (!shouldSkipMemberTypeMismatch - && recordTy.getMembers()[getIndex()] != getResultTy().getPointee()) + if (!shouldSkipMemberTypeMismatch && + recordTy.getMembers()[getIndex()] != getResultTy().getPointee()) return emitError() << "member type mismatch"; return mlir::success(); diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index e40d2d4aab96..ca14e8530336 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -41,7 +41,7 @@ cir.func @if0() { #true = #cir.bool : !cir.bool cir.func @yield0() { %0 = cir.const(#true) : !cir.bool - cir.if %0 { // expected-error {{custom op 'cir.if' expected at least one block with cir.yield or cir.return}} + cir.if %0 { // expected-error {{custom op 'cir.if' multi-block region must not omit terminator}} cir.br ^a ^a: } @@ -90,10 +90,10 @@ cir.func @yieldcontinue() { cir.func @s0() { %1 = cir.const(#cir.int<2> : !s32i) : !s32i cir.switch (%1 : !s32i) [ - case (equal, 5) { + case (equal, 5) { // expected-error {{custom op 'cir.switch' case regions must be explicitly terminated}} %2 = cir.const(#cir.int<3> : !s32i) : !s32i } - ] // expected-error {{blocks are expected to be explicitly terminated}} + ] cir.return }