@@ -637,7 +637,7 @@ CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
637
637
template <typename T>
638
638
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade (
639
639
const T *stmt, mlir::Type condType,
640
- SmallVector<mlir::Attribute, 4 > &caseAttrs, mlir::OperationState &os ) {
640
+ SmallVector<mlir::Attribute, 4 > &caseAttrs) {
641
641
642
642
assert ((isa<CaseStmt, DefaultStmt>(stmt)) &&
643
643
" only case or default stmt go here" );
@@ -647,20 +647,18 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
647
647
// Update scope information with the current region we are
648
648
// emitting code for. This is useful to allow return blocks to be
649
649
// automatically and properly placed during cleanup.
650
- auto *region = os. addRegion ();
650
+ auto *region = currLexScope-> createSwitchRegion ();
651
651
auto *block = builder.createBlock (region);
652
652
builder.setInsertionPointToEnd (block);
653
- currLexScope->updateCurrentSwitchCaseRegion ();
654
653
655
654
auto *sub = stmt->getSubStmt ();
656
655
657
656
if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
658
657
builder.createYield (getLoc (stmt->getBeginLoc ()));
659
- res =
660
- buildDefaultStmt (*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
658
+ res = buildDefaultStmt (*dyn_cast<DefaultStmt>(sub), condType, caseAttrs);
661
659
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
662
660
builder.createYield (getLoc (stmt->getBeginLoc ()));
663
- res = buildCaseStmt (*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os );
661
+ res = buildCaseStmt (*dyn_cast<CaseStmt>(sub), condType, caseAttrs);
664
662
} else {
665
663
res = buildStmt (sub, /* useCurrentScope=*/ !isa<CompoundStmt>(sub));
666
664
}
@@ -670,27 +668,37 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
670
668
671
669
mlir::LogicalResult
672
670
CIRGenFunction::buildCaseStmt (const CaseStmt &S, mlir::Type condType,
673
- SmallVector<mlir::Attribute, 4 > &caseAttrs,
674
- mlir::OperationState &os) {
671
+ SmallVector<mlir::Attribute, 4 > &caseAttrs) {
675
672
assert ((!S.getRHS () || !S.caseStmtIsGNURange ()) &&
676
673
" case ranges not implemented" );
677
674
678
675
auto *caseStmt = foldCaseStmt (S, condType, caseAttrs);
679
- return buildCaseDefaultCascade (caseStmt, condType, caseAttrs, os );
676
+ return buildCaseDefaultCascade (caseStmt, condType, caseAttrs);
680
677
}
681
678
682
679
mlir::LogicalResult
683
680
CIRGenFunction::buildDefaultStmt (const DefaultStmt &S, mlir::Type condType,
684
- SmallVector<mlir::Attribute, 4 > &caseAttrs,
685
- mlir::OperationState &os) {
681
+ SmallVector<mlir::Attribute, 4 > &caseAttrs) {
686
682
auto ctxt = builder.getContext ();
687
683
688
684
auto defAttr = mlir::cir::CaseAttr::get (
689
685
ctxt, builder.getArrayAttr ({}),
690
686
CaseOpKindAttr::get (ctxt, mlir::cir::CaseOpKind::Default));
691
687
692
688
caseAttrs.push_back (defAttr);
693
- return buildCaseDefaultCascade (&S, condType, caseAttrs, os);
689
+ return buildCaseDefaultCascade (&S, condType, caseAttrs);
690
+ }
691
+
692
+ mlir::LogicalResult
693
+ CIRGenFunction::buildSwitchCase (const SwitchCase &S, mlir::Type condType,
694
+ SmallVector<mlir::Attribute, 4 > &caseAttrs) {
695
+ if (S.getStmtClass () == Stmt::CaseStmtClass)
696
+ return buildCaseStmt (cast<CaseStmt>(S), condType, caseAttrs);
697
+
698
+ if (S.getStmtClass () == Stmt::DefaultStmtClass)
699
+ return buildDefaultStmt (cast<DefaultStmt>(S), condType, caseAttrs);
700
+
701
+ llvm_unreachable (" expect case or default stmt" );
694
702
}
695
703
696
704
mlir::LogicalResult
@@ -953,6 +961,36 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
953
961
return mlir::success ();
954
962
}
955
963
964
+ mlir::LogicalResult CIRGenFunction::buildSwitchBody (
965
+ const Stmt *S, mlir::Type condType,
966
+ llvm::SmallVector<mlir::Attribute, 4 > &caseAttrs) {
967
+ if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
968
+ mlir::Block *lastCaseBlock = nullptr ;
969
+ auto res = mlir::success ();
970
+ for (auto *c : compoundStmt->body ()) {
971
+ if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
972
+ res = buildSwitchCase (*switchCase, condType, caseAttrs);
973
+ } else if (lastCaseBlock) {
974
+ // This means it's a random stmt following up a case, just
975
+ // emit it as part of previous known case.
976
+ mlir::OpBuilder::InsertionGuard guardCase (builder);
977
+ builder.setInsertionPointToEnd (lastCaseBlock);
978
+ res = buildStmt (c, /* useCurrentScope=*/ !isa<CompoundStmt>(c));
979
+ } else {
980
+ llvm_unreachable (" statement doesn't belong to any case region, NYI" );
981
+ }
982
+
983
+ lastCaseBlock = builder.getBlock ();
984
+
985
+ if (res.failed ())
986
+ break ;
987
+ }
988
+ return res;
989
+ }
990
+
991
+ llvm_unreachable (" switch body is not CompoundStmt, NYI" );
992
+ }
993
+
956
994
mlir::LogicalResult CIRGenFunction::buildSwitchStmt (const SwitchStmt &S) {
957
995
// TODO: LLVM codegen does some early optimization to fold the condition and
958
996
// only emit live cases. CIR should use MLIR to achieve similar things,
@@ -975,49 +1013,17 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
975
1013
// TODO: PGO and likelihood (e.g. PGO.haveRegionCounts())
976
1014
// TODO: if the switch has a condition wrapped by __builtin_unpredictable?
977
1015
978
- // FIXME: track switch to handle nested stmts.
979
1016
swop = builder.create <SwitchOp>(
980
1017
getLoc (S.getBeginLoc ()), condV,
981
1018
/* switchBuilder=*/
982
1019
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
983
- auto *cs = dyn_cast<CompoundStmt>(S.getBody ());
984
- assert (cs && " expected compound stmt" );
985
- SmallVector<mlir::Attribute, 4 > caseAttrs;
986
-
987
1020
currLexScope->setAsSwitch ();
988
- mlir::Block *lastCaseBlock = nullptr ;
989
- for (auto *c : cs->body ()) {
990
- bool caseLike = isa<CaseStmt, DefaultStmt>(c);
991
- if (!caseLike) {
992
- // This means it's a random stmt following up a case, just
993
- // emit it as part of previous known case.
994
- assert (lastCaseBlock && " expects pre-existing case block" );
995
- mlir::OpBuilder::InsertionGuard guardCase (builder);
996
- builder.setInsertionPointToEnd (lastCaseBlock);
997
- res = buildStmt (c, /* useCurrentScope=*/ !isa<CompoundStmt>(c));
998
- lastCaseBlock = builder.getBlock ();
999
- if (res.failed ())
1000
- break ;
1001
- continue ;
1002
- }
1003
-
1004
- auto *caseStmt = dyn_cast<CaseStmt>(c);
1005
-
1006
- if (caseStmt)
1007
- res = buildCaseStmt (*caseStmt, condV.getType (), caseAttrs, os);
1008
- else {
1009
- auto *defaultStmt = dyn_cast<DefaultStmt>(c);
1010
- assert (defaultStmt && " expected default stmt" );
1011
- res = buildDefaultStmt (*defaultStmt, condV.getType (), caseAttrs,
1012
- os);
1013
- }
1014
-
1015
- lastCaseBlock = builder.getBlock ();
1016
-
1017
- if (res.failed ())
1018
- break ;
1019
- }
1020
1021
1022
+ llvm::SmallVector<mlir::Attribute, 4 > caseAttrs;
1023
+
1024
+ res = buildSwitchBody (S.getBody (), condV.getType (), caseAttrs);
1025
+
1026
+ os.addRegions (currLexScope->getSwitchRegions ());
1021
1027
os.addAttribute (" cases" , builder.getArrayAttr (caseAttrs));
1022
1028
});
1023
1029
0 commit comments