@@ -884,6 +884,24 @@ class CIRYieldOpLowering
884
884
}
885
885
};
886
886
887
+ class CIRBreakOpLowering
888
+ : public mlir::OpConversionPattern<mlir::cir::BreakOp> {
889
+ public:
890
+ using OpConversionPattern<mlir::cir::BreakOp>::OpConversionPattern;
891
+ mlir::LogicalResult
892
+ matchAndRewrite (mlir::cir::BreakOp op, OpAdaptor adaptor,
893
+ mlir::ConversionPatternRewriter &rewriter) const override {
894
+ auto *parentOp = op->getParentOp ();
895
+ return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
896
+ .Case <mlir::scf::IndexSwitchOp>([&](auto ) {
897
+ rewriter.replaceOpWithNewOp <mlir::scf::YieldOp>(
898
+ op, adaptor.getOperands ());
899
+ return mlir::success ();
900
+ })
901
+ .Default ([](auto ) { return mlir::failure (); });
902
+ }
903
+ };
904
+
887
905
class CIRIfOpLowering : public mlir ::OpConversionPattern<mlir::cir::IfOp> {
888
906
public:
889
907
using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
@@ -909,6 +927,62 @@ class CIRIfOpLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
909
927
}
910
928
};
911
929
930
+ class CIRSwitchOpLowering
931
+ : public mlir::OpConversionPattern<mlir::cir::SwitchOp> {
932
+ public:
933
+ using mlir::OpConversionPattern<mlir::cir::SwitchOp>::OpConversionPattern;
934
+
935
+ mlir::LogicalResult
936
+ matchAndRewrite (mlir::cir::SwitchOp op, OpAdaptor adaptor,
937
+ mlir::ConversionPatternRewriter &rewriter) const override {
938
+ llvm::SmallVector<mlir::Type> resultTypes;
939
+ if (mlir::failed (getTypeConverter ()->convertTypes (op->getResultTypes (),
940
+ resultTypes)))
941
+ return mlir::failure ();
942
+
943
+ auto caseValue = rewriter.create <mlir::arith::IndexCastOp>(
944
+ adaptor.getCondition ().getLoc (), rewriter.getIndexType (),
945
+ adaptor.getCondition ());
946
+
947
+ llvm::SmallVector<int64_t , 3 > cases;
948
+ auto caseAttrList = op.getCasesAttr ();
949
+ for (auto &caseAttr : caseAttrList) {
950
+ mlir::Attribute caseAttrValue;
951
+ caseAttr.walkImmediateSubElements (
952
+ [&caseAttrValue](mlir::Attribute subAttr) {
953
+ if (!caseAttrValue)
954
+ caseAttrValue = subAttr;
955
+ },
956
+ [](mlir::Type type) {});
957
+
958
+ mlir::cir::IntAttr cirIntAttr;
959
+ caseAttrValue.walkImmediateSubElements (
960
+ [&cirIntAttr](mlir::Attribute subAttr) {
961
+ if (!cirIntAttr)
962
+ cirIntAttr = mlir::dyn_cast_or_null<mlir::cir::IntAttr>(subAttr);
963
+ },
964
+ [](mlir::Type type) {});
965
+
966
+ if (cirIntAttr != nullptr )
967
+ cases.push_back (cirIntAttr.getSInt ());
968
+ }
969
+
970
+ auto casesRegionCount = cases.size ();
971
+
972
+ auto indexSwitchOp = rewriter.create <mlir::scf::IndexSwitchOp>(
973
+ op.getLoc (), mlir::TypeRange (resultTypes), caseValue, cases,
974
+ casesRegionCount);
975
+
976
+ for (unsigned int i = 0 ; i < op.getNumRegions (); i++) {
977
+ rewriter.inlineRegionBefore (op->getRegion (i), indexSwitchOp.getRegion (i),
978
+ indexSwitchOp.getRegion (i).end ());
979
+ }
980
+
981
+ rewriter.replaceOp (op, indexSwitchOp);
982
+ return mlir::success ();
983
+ }
984
+ };
985
+
912
986
class CIRGlobalOpLowering
913
987
: public mlir::OpConversionPattern<mlir::cir::GlobalOp> {
914
988
public:
@@ -1316,22 +1390,23 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1316
1390
mlir::TypeConverter &converter) {
1317
1391
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
1318
1392
1319
- patterns.add <
1320
- CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1321
- CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1322
- CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1323
- CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1324
- CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1325
- CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1326
- CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1327
- CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
1328
- CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
1329
- CIRSinOpLowering, CIRShiftOpLowering, CIRBitClzOpLowering,
1330
- CIRBitCtzOpLowering, CIRBitPopcountOpLowering, CIRBitClrsbOpLowering,
1331
- CIRBitFfsOpLowering, CIRBitParityOpLowering, CIRIfOpLowering,
1332
- CIRVectorCreateLowering, CIRVectorInsertLowering,
1333
- CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
1334
- patterns.getContext ());
1393
+ patterns.add <CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
1394
+ CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
1395
+ CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
1396
+ CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1397
+ CIRYieldOpLowering, CIRBreakOpLowering, CIRCosOpLowering,
1398
+ CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1399
+ CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1400
+ CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1401
+ CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
1402
+ CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
1403
+ CIRSinOpLowering, CIRShiftOpLowering, CIRBitClzOpLowering,
1404
+ CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1405
+ CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1406
+ CIRBitParityOpLowering, CIRIfOpLowering, CIRSwitchOpLowering,
1407
+ CIRVectorCreateLowering, CIRVectorInsertLowering,
1408
+ CIRVectorExtractLowering, CIRVectorCmpOpLowering>(
1409
+ converter, patterns.getContext ());
1335
1410
}
1336
1411
1337
1412
static mlir::TypeConverter prepareTypeConverter () {
0 commit comments