Skip to content

Commit 198ce40

Browse files
author
Kaifeng Lin
committed
[CIR][ThroughMLIR] Support lowering SwitchOp without fallthrough to scf
1 parent 8311717 commit 198ce40

File tree

2 files changed

+147
-16
lines changed

2 files changed

+147
-16
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+91-16
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,24 @@ class CIRYieldOpLowering
884884
}
885885
};
886886

887+
class CIRBreakOpLowering
888+
: public mlir::OpConversionPattern<mlir::cir::BreakOp> {
889+
public:
890+
using OpConversionPattern<mlir::cir::BreakOp>::OpConversionPattern;
891+
mlir::LogicalResult
892+
matchAndRewrite(mlir::cir::BreakOp op, OpAdaptor adaptor,
893+
mlir::ConversionPatternRewriter &rewriter) const override {
894+
auto *parentOp = op->getParentOp();
895+
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
896+
.Case<mlir::scf::IndexSwitchOp>([&](auto) {
897+
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(
898+
op, adaptor.getOperands());
899+
return mlir::success();
900+
})
901+
.Default([](auto) { return mlir::failure(); });
902+
}
903+
};
904+
887905
class CIRIfOpLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
888906
public:
889907
using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
@@ -909,6 +927,62 @@ class CIRIfOpLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
909927
}
910928
};
911929

930+
class CIRSwitchOpLowering
931+
: public mlir::OpConversionPattern<mlir::cir::SwitchOp> {
932+
public:
933+
using mlir::OpConversionPattern<mlir::cir::SwitchOp>::OpConversionPattern;
934+
935+
mlir::LogicalResult
936+
matchAndRewrite(mlir::cir::SwitchOp op, OpAdaptor adaptor,
937+
mlir::ConversionPatternRewriter &rewriter) const override {
938+
llvm::SmallVector<mlir::Type> resultTypes;
939+
if (mlir::failed(getTypeConverter()->convertTypes(op->getResultTypes(),
940+
resultTypes)))
941+
return mlir::failure();
942+
943+
auto caseValue = rewriter.create<mlir::arith::IndexCastOp>(
944+
adaptor.getCondition().getLoc(), rewriter.getIndexType(),
945+
adaptor.getCondition());
946+
947+
llvm::SmallVector<int64_t, 3> cases;
948+
auto caseAttrList = op.getCasesAttr();
949+
for (auto &caseAttr : caseAttrList) {
950+
mlir::Attribute caseAttrValue;
951+
caseAttr.walkImmediateSubElements(
952+
[&caseAttrValue](mlir::Attribute subAttr) {
953+
if (!caseAttrValue)
954+
caseAttrValue = subAttr;
955+
},
956+
[](mlir::Type type) {});
957+
958+
mlir::cir::IntAttr cirIntAttr;
959+
caseAttrValue.walkImmediateSubElements(
960+
[&cirIntAttr](mlir::Attribute subAttr) {
961+
if (!cirIntAttr)
962+
cirIntAttr = mlir::dyn_cast_or_null<mlir::cir::IntAttr>(subAttr);
963+
},
964+
[](mlir::Type type) {});
965+
966+
if (cirIntAttr != nullptr)
967+
cases.push_back(cirIntAttr.getSInt());
968+
}
969+
970+
auto casesRegionCount = cases.size();
971+
972+
auto indexSwitchOp = rewriter.create<mlir::scf::IndexSwitchOp>(
973+
op.getLoc(), mlir::TypeRange(resultTypes), caseValue, cases,
974+
casesRegionCount);
975+
976+
for (unsigned int i = 0; i < op.getNumRegions(); i++) {
977+
rewriter.inlineRegionBefore(op->getRegion(i), indexSwitchOp.getRegion(i),
978+
indexSwitchOp.getRegion(i).end());
979+
}
980+
981+
rewriter.replaceOp(op, indexSwitchOp);
982+
return mlir::success();
983+
}
984+
};
985+
912986
class CIRGlobalOpLowering
913987
: public mlir::OpConversionPattern<mlir::cir::GlobalOp> {
914988
public:
@@ -1316,22 +1390,23 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13161390
mlir::TypeConverter &converter) {
13171391
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
13181392

1319-
patterns.add<
1320-
CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1321-
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1322-
CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1323-
CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1324-
CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1325-
CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1326-
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1327-
CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
1328-
CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
1329-
CIRSinOpLowering, CIRShiftOpLowering, CIRBitClzOpLowering,
1330-
CIRBitCtzOpLowering, CIRBitPopcountOpLowering, CIRBitClrsbOpLowering,
1331-
CIRBitFfsOpLowering, CIRBitParityOpLowering, CIRIfOpLowering,
1332-
CIRVectorCreateLowering, CIRVectorInsertLowering,
1333-
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
1334-
patterns.getContext());
1393+
patterns.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
1394+
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
1395+
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
1396+
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1397+
CIRYieldOpLowering, CIRBreakOpLowering, CIRCosOpLowering,
1398+
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1399+
CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1400+
CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1401+
CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
1402+
CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
1403+
CIRSinOpLowering, CIRShiftOpLowering, CIRBitClzOpLowering,
1404+
CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1405+
CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1406+
CIRBitParityOpLowering, CIRIfOpLowering, CIRSwitchOpLowering,
1407+
CIRVectorCreateLowering, CIRVectorInsertLowering,
1408+
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(
1409+
converter, patterns.getContext());
13351410
}
13361411

13371412
static mlir::TypeConverter prepareTypeConverter() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
int switch_test(int cond) {
5+
6+
// CHECK: %alloca = memref.alloca() {alignment = 4 : i64} : memref<i32>
7+
// CHECK: %alloca_0 = memref.alloca() {alignment = 4 : i64} : memref<i32>
8+
// CHECK: %alloca_1 = memref.alloca() {alignment = 4 : i64} : memref<i32>
9+
10+
// CHECK: memref.store %arg0, %alloca[] : memref<i32>
11+
12+
int ret;
13+
14+
// CHECK: memref.alloca_scope {
15+
16+
// CHECK: %2 = memref.load %alloca[] : memref<i32>
17+
// CHECK: %3 = arith.index_cast %2 : i32 to index
18+
19+
switch (cond) {
20+
21+
// CHECK: scf.index_switch %3
22+
23+
case 0: ret = 10; break;
24+
25+
// CHECK: case 0 {
26+
// CHECK: %c100_i32 = arith.constant 100 : i32
27+
// CHECK: memref.store %c100_i32, %alloca_1[] : memref<i32>
28+
// CHECK: scf.yield
29+
// CHECK: }
30+
31+
case 1: ret = 100; break;
32+
33+
// CHECK: case 1 {
34+
// CHECK: %c1000_i32 = arith.constant 1000 : i32
35+
// CHECK: memref.store %c1000_i32, %alloca_1[] : memref<i32>
36+
// CHECK: scf.yield
37+
// CHECK: }
38+
39+
default: ret = 1000; break;
40+
41+
// CHECK: default {
42+
// CHECK: %c10_i32 = arith.constant 10 : i32
43+
// CHECK: memref.store %c10_i32, %alloca_1[] : memref<i32>
44+
// CHECK: }
45+
46+
}
47+
48+
return ret;
49+
50+
// CHECK: }
51+
52+
// CHECK: %0 = memref.load %alloca_1[] : memref<i32>
53+
// CHECK: memref.store %0, %alloca_0[] : memref<i32>
54+
// CHECK: %1 = memref.load %alloca_0[] : memref<i32>
55+
// CHECK: return %1 : i32
56+
}

0 commit comments

Comments
 (0)