Skip to content

Commit 5a49237

Browse files
wenpenlanza
authored andcommitted
[CIR[CIRGen][NFC] Refactor build switch op (llvm#552)
Make logic cleaner and more extensible. Separate collecting `SwitchStmt` information and building op logic into different functions. Add more UT to cover nested switch, which also worked before this pr. This pr is split from llvm#528.
1 parent 16c9b4f commit 5a49237

File tree

3 files changed

+105
-58
lines changed

3 files changed

+105
-58
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.h

+23-9
Original file line numberDiff line numberDiff line change
@@ -1085,18 +1085,23 @@ class CIRGenFunction : public CIRGenTypeCache {
10851085
template <typename T>
10861086
mlir::LogicalResult
10871087
buildCaseDefaultCascade(const T *stmt, mlir::Type condType,
1088-
SmallVector<mlir::Attribute, 4> &caseAttrs,
1089-
mlir::OperationState &os);
1088+
SmallVector<mlir::Attribute, 4> &caseAttrs);
10901089

10911090
mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S,
10921091
mlir::Type condType,
1093-
SmallVector<mlir::Attribute, 4> &caseAttrs,
1094-
mlir::OperationState &op);
1092+
SmallVector<mlir::Attribute, 4> &caseAttrs);
10951093

10961094
mlir::LogicalResult
10971095
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
1098-
SmallVector<mlir::Attribute, 4> &caseAttrs,
1099-
mlir::OperationState &op);
1096+
SmallVector<mlir::Attribute, 4> &caseAttrs);
1097+
1098+
mlir::LogicalResult
1099+
buildSwitchCase(const clang::SwitchCase &S, mlir::Type condType,
1100+
SmallVector<mlir::Attribute, 4> &caseAttrs);
1101+
1102+
mlir::LogicalResult
1103+
buildSwitchBody(const clang::Stmt *S, mlir::Type condType,
1104+
SmallVector<mlir::Attribute, 4> &caseAttrs);
11001105

11011106
mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
11021107
const CIRGenFunctionInfo &FnInfo);
@@ -1964,7 +1969,7 @@ class CIRGenFunction : public CIRGenTypeCache {
19641969
// have their own scopes but are distinct regions nonetheless.
19651970
llvm::SmallVector<mlir::Block *> RetBlocks;
19661971
llvm::SmallVector<std::optional<mlir::Location>> RetLocs;
1967-
unsigned int CurrentSwitchRegionIdx = -1;
1972+
llvm::SmallVector<std::unique_ptr<mlir::Region>> SwitchRegions;
19681973

19691974
// There's usually only one ret block per scope, but this needs to be
19701975
// get or create because of potential unreachable return statements, note
@@ -1985,16 +1990,25 @@ class CIRGenFunction : public CIRGenTypeCache {
19851990
void buildImplicitReturn();
19861991

19871992
public:
1988-
void updateCurrentSwitchCaseRegion() { CurrentSwitchRegionIdx++; }
19891993
llvm::ArrayRef<mlir::Block *> getRetBlocks() { return RetBlocks; }
19901994
llvm::ArrayRef<std::optional<mlir::Location>> getRetLocs() {
19911995
return RetLocs;
19921996
}
1997+
llvm::MutableArrayRef<std::unique_ptr<mlir::Region>> getSwitchRegions() {
1998+
assert(isSwitch() && "expected switch scope");
1999+
return SwitchRegions;
2000+
}
2001+
2002+
mlir::Region *createSwitchRegion() {
2003+
assert(isSwitch() && "expected switch scope");
2004+
SwitchRegions.push_back(std::make_unique<mlir::Region>());
2005+
return SwitchRegions.back().get();
2006+
}
19932007

19942008
mlir::Block *getOrCreateRetBlock(CIRGenFunction &CGF, mlir::Location loc) {
19952009
unsigned int regionIdx = 0;
19962010
if (isSwitch())
1997-
regionIdx = CurrentSwitchRegionIdx;
2011+
regionIdx = SwitchRegions.size() - 1;
19982012
if (regionIdx >= RetBlocks.size())
19992013
return createRetBlock(CGF, loc);
20002014
return &*RetBlocks.back();

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

+55-49
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
637637
template <typename T>
638638
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
639639
const T *stmt, mlir::Type condType,
640-
SmallVector<mlir::Attribute, 4> &caseAttrs, mlir::OperationState &os) {
640+
SmallVector<mlir::Attribute, 4> &caseAttrs) {
641641

642642
assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
643643
"only case or default stmt go here");
@@ -647,20 +647,18 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
647647
// Update scope information with the current region we are
648648
// emitting code for. This is useful to allow return blocks to be
649649
// automatically and properly placed during cleanup.
650-
auto *region = os.addRegion();
650+
auto *region = currLexScope->createSwitchRegion();
651651
auto *block = builder.createBlock(region);
652652
builder.setInsertionPointToEnd(block);
653-
currLexScope->updateCurrentSwitchCaseRegion();
654653

655654
auto *sub = stmt->getSubStmt();
656655

657656
if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
658657
builder.createYield(getLoc(stmt->getBeginLoc()));
659-
res =
660-
buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
658+
res = buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs);
661659
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
662660
builder.createYield(getLoc(stmt->getBeginLoc()));
663-
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
661+
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs);
664662
} else {
665663
res = buildStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
666664
}
@@ -670,27 +668,37 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
670668

671669
mlir::LogicalResult
672670
CIRGenFunction::buildCaseStmt(const CaseStmt &S, mlir::Type condType,
673-
SmallVector<mlir::Attribute, 4> &caseAttrs,
674-
mlir::OperationState &os) {
671+
SmallVector<mlir::Attribute, 4> &caseAttrs) {
675672
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
676673
"case ranges not implemented");
677674

678675
auto *caseStmt = foldCaseStmt(S, condType, caseAttrs);
679-
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs, os);
676+
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs);
680677
}
681678

682679
mlir::LogicalResult
683680
CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
684-
SmallVector<mlir::Attribute, 4> &caseAttrs,
685-
mlir::OperationState &os) {
681+
SmallVector<mlir::Attribute, 4> &caseAttrs) {
686682
auto ctxt = builder.getContext();
687683

688684
auto defAttr = mlir::cir::CaseAttr::get(
689685
ctxt, builder.getArrayAttr({}),
690686
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));
691687

692688
caseAttrs.push_back(defAttr);
693-
return buildCaseDefaultCascade(&S, condType, caseAttrs, os);
689+
return buildCaseDefaultCascade(&S, condType, caseAttrs);
690+
}
691+
692+
mlir::LogicalResult
693+
CIRGenFunction::buildSwitchCase(const SwitchCase &S, mlir::Type condType,
694+
SmallVector<mlir::Attribute, 4> &caseAttrs) {
695+
if (S.getStmtClass() == Stmt::CaseStmtClass)
696+
return buildCaseStmt(cast<CaseStmt>(S), condType, caseAttrs);
697+
698+
if (S.getStmtClass() == Stmt::DefaultStmtClass)
699+
return buildDefaultStmt(cast<DefaultStmt>(S), condType, caseAttrs);
700+
701+
llvm_unreachable("expect case or default stmt");
694702
}
695703

696704
mlir::LogicalResult
@@ -953,6 +961,36 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
953961
return mlir::success();
954962
}
955963

964+
mlir::LogicalResult CIRGenFunction::buildSwitchBody(
965+
const Stmt *S, mlir::Type condType,
966+
llvm::SmallVector<mlir::Attribute, 4> &caseAttrs) {
967+
if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
968+
mlir::Block *lastCaseBlock = nullptr;
969+
auto res = mlir::success();
970+
for (auto *c : compoundStmt->body()) {
971+
if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
972+
res = buildSwitchCase(*switchCase, condType, caseAttrs);
973+
} else if (lastCaseBlock) {
974+
// This means it's a random stmt following up a case, just
975+
// emit it as part of previous known case.
976+
mlir::OpBuilder::InsertionGuard guardCase(builder);
977+
builder.setInsertionPointToEnd(lastCaseBlock);
978+
res = buildStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c));
979+
} else {
980+
llvm_unreachable("statement doesn't belong to any case region, NYI");
981+
}
982+
983+
lastCaseBlock = builder.getBlock();
984+
985+
if (res.failed())
986+
break;
987+
}
988+
return res;
989+
}
990+
991+
llvm_unreachable("switch body is not CompoundStmt, NYI");
992+
}
993+
956994
mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
957995
// TODO: LLVM codegen does some early optimization to fold the condition and
958996
// only emit live cases. CIR should use MLIR to achieve similar things,
@@ -975,49 +1013,17 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
9751013
// TODO: PGO and likelihood (e.g. PGO.haveRegionCounts())
9761014
// TODO: if the switch has a condition wrapped by __builtin_unpredictable?
9771015

978-
// FIXME: track switch to handle nested stmts.
9791016
swop = builder.create<SwitchOp>(
9801017
getLoc(S.getBeginLoc()), condV,
9811018
/*switchBuilder=*/
9821019
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
983-
auto *cs = dyn_cast<CompoundStmt>(S.getBody());
984-
assert(cs && "expected compound stmt");
985-
SmallVector<mlir::Attribute, 4> caseAttrs;
986-
9871020
currLexScope->setAsSwitch();
988-
mlir::Block *lastCaseBlock = nullptr;
989-
for (auto *c : cs->body()) {
990-
bool caseLike = isa<CaseStmt, DefaultStmt>(c);
991-
if (!caseLike) {
992-
// This means it's a random stmt following up a case, just
993-
// emit it as part of previous known case.
994-
assert(lastCaseBlock && "expects pre-existing case block");
995-
mlir::OpBuilder::InsertionGuard guardCase(builder);
996-
builder.setInsertionPointToEnd(lastCaseBlock);
997-
res = buildStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c));
998-
lastCaseBlock = builder.getBlock();
999-
if (res.failed())
1000-
break;
1001-
continue;
1002-
}
1003-
1004-
auto *caseStmt = dyn_cast<CaseStmt>(c);
1005-
1006-
if (caseStmt)
1007-
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttrs, os);
1008-
else {
1009-
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
1010-
assert(defaultStmt && "expected default stmt");
1011-
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttrs,
1012-
os);
1013-
}
1014-
1015-
lastCaseBlock = builder.getBlock();
1016-
1017-
if (res.failed())
1018-
break;
1019-
}
10201021

1022+
llvm::SmallVector<mlir::Attribute, 4> caseAttrs;
1023+
1024+
res = buildSwitchBody(S.getBody(), condV.getType(), caseAttrs);
1025+
1026+
os.addRegions(currLexScope->getSwitchRegions());
10211027
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
10221028
});
10231029

clang/test/CIR/CodeGen/switch.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ void sw12(int a) {
266266
break;
267267
}
268268
}
269+
269270
// CHECK: cir.func @_Z4sw12i
270271
// CHECK: cir.scope {
271272
// CHECK: cir.switch
@@ -275,6 +276,32 @@ void sw12(int a) {
275276
// CHECK-NEXT: cir.break
276277
// CHECK-NEXT: }
277278

279+
void sw13(int a, int b) {
280+
switch (a) {
281+
case 1:
282+
switch (b) {
283+
case 2:
284+
break;
285+
}
286+
}
287+
}
288+
289+
// CHECK: cir.func @_Z4sw13ii
290+
// CHECK: cir.scope {
291+
// CHECK: cir.switch
292+
// CHECK-NEXT: case (equal, 1) {
293+
// CHECK-NEXT: cir.scope {
294+
// CHECK: cir.switch
295+
// CHECK-NEXT: case (equal, 2) {
296+
// CHECK-NEXT: cir.break
297+
// CHECK-NEXT: }
298+
// CHECK-NEXT: ]
299+
// CHECK-NEXT: }
300+
// CHECK-NEXT: cir.yield
301+
// CHECK-NEXT: }
302+
// CHECK: }
303+
// CHECK: cir.return
304+
278305
void fallthrough(int x) {
279306
switch (x) {
280307
case 1:

0 commit comments

Comments
 (0)