@@ -802,7 +802,7 @@ class CIRYieldOpLowering
802
802
mlir::ConversionPatternRewriter &rewriter) const override {
803
803
auto *parentOp = op->getParentOp ();
804
804
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
805
- .Case <mlir::scf::IfOp>([&](auto ) {
805
+ .Case <mlir::scf::IfOp, mlir::scf::ForOp >([&](auto ) {
806
806
rewriter.replaceOpWithNewOp <mlir::scf::YieldOp>(
807
807
op, adaptor.getOperands ());
808
808
return mlir::success ();
@@ -1119,22 +1119,228 @@ class CIRPtrStrideOpLowering
1119
1119
}
1120
1120
};
1121
1121
1122
+ class SCFLoop {
1123
+ public:
1124
+ SCFLoop (mlir::cir::ForOp op, mlir::ConversionPatternRewriter *rewriter,
1125
+ const mlir::TypeConverter *converter)
1126
+ : forOp(op), rewriter(rewriter), converter(converter) {}
1127
+
1128
+ int getStep () { return step; }
1129
+ mlir::Value getLowerBound () { return lowerBound; }
1130
+ mlir::Value getUpperBound () { return upperBound; }
1131
+
1132
+ int findStepAndIV (mlir::Value &addr);
1133
+ mlir::cir::CmpOp findCmpOp ();
1134
+ mlir::Value findIVInitValue ();
1135
+ void analysis ();
1136
+
1137
+ mlir::Value plusConstant (mlir::Value V, mlir::Location loc, int addend);
1138
+ void transferToSCFForOp ();
1139
+
1140
+ private:
1141
+ mlir::cir::ForOp forOp;
1142
+ mlir::cir::CmpOp cmpOp;
1143
+ mlir::Value IVAddr, lowerBound = nullptr , upperBound = nullptr ;
1144
+ mlir::ConversionPatternRewriter *rewriter;
1145
+ const mlir::TypeConverter *converter;
1146
+ int step = 0 ;
1147
+ };
1148
+
1149
+ int SCFLoop::findStepAndIV (mlir::Value &addr) {
1150
+ auto *stepBlock =
1151
+ (forOp.maybeGetStep () ? &forOp.maybeGetStep ()->front () : nullptr );
1152
+ assert (stepBlock && " Can not find step block" );
1153
+
1154
+ int step = 0 ;
1155
+ // Try to match "IV load addr; ++IV; store IV, addr" to find step.
1156
+ for (mlir::Operation &op : *stepBlock)
1157
+ if (auto loadOp = dyn_cast<mlir::cir::LoadOp>(op)) {
1158
+ addr = loadOp.getAddr ();
1159
+ } else if (auto cop = dyn_cast<mlir::cir::ConstantOp>(op)) {
1160
+ auto attr = cop->getAttrs ().front ().getValue ();
1161
+ const auto IntAttr = attr.dyn_cast <mlir::cir::IntAttr>();
1162
+ step = IntAttr.getValue ().getSExtValue ();
1163
+ } else if (auto bop = dyn_cast<mlir::cir::BinOp>(op)) {
1164
+ if (bop.getKind () == mlir::cir::BinOpKind::Sub)
1165
+ llvm_unreachable (" Not support decrement step yet" );
1166
+ else if (bop.getKind () != mlir::cir::BinOpKind::Add)
1167
+ llvm_unreachable (" Not support the BinOp in step calculation yet" );
1168
+ } else if (auto uop = dyn_cast<mlir::cir::UnaryOp>(op)) {
1169
+ if (uop.getKind () == mlir::cir::UnaryOpKind::Inc)
1170
+ step = 1 ;
1171
+ else if (uop.getKind () == mlir::cir::UnaryOpKind::Dec)
1172
+ llvm_unreachable (" Not support decrement step yet" );
1173
+ } else if (auto storeOp = dyn_cast<mlir::cir::StoreOp>(op)) {
1174
+ assert (storeOp.getAddr () == addr && " Can't find IV when lowering ForOp" );
1175
+ }
1176
+ assert (step && " Can't find step when lowering ForOp" );
1177
+
1178
+ return step;
1179
+ }
1180
+
1181
+ static bool isIVLoad (mlir::Operation *op, mlir::Value IVAddr) {
1182
+ if (!op)
1183
+ return false ;
1184
+ if (isa<mlir::cir::LoadOp>(op)) {
1185
+ if (!op->getOperand (0 ))
1186
+ return false ;
1187
+ if (op->getOperand (0 ) == IVAddr)
1188
+ return true ;
1189
+ }
1190
+ return false ;
1191
+ }
1192
+
1193
+ mlir::cir::CmpOp SCFLoop::findCmpOp () {
1194
+ cmpOp = nullptr ;
1195
+ for (auto *user : IVAddr.getUsers ()) {
1196
+ if (user->getParentRegion () != &forOp.getCond ())
1197
+ continue ;
1198
+ if (auto loadOp = dyn_cast<mlir::cir::LoadOp>(*user)) {
1199
+ if (!loadOp->hasOneUse ())
1200
+ continue ;
1201
+ if (auto op = dyn_cast<mlir::cir::CmpOp>(*loadOp->user_begin ())) {
1202
+ cmpOp = op;
1203
+ break ;
1204
+ }
1205
+ }
1206
+ }
1207
+ if (!cmpOp)
1208
+ llvm_unreachable (" Can't find loop CmpOp" );
1209
+
1210
+ auto type = cmpOp.getLhs ().getType ();
1211
+ if (!type.isa <mlir::cir::IntType>())
1212
+ llvm_unreachable (" Non-integer type IV is not supported" );
1213
+
1214
+ auto lhsDefOp = cmpOp.getLhs ().getDefiningOp ();
1215
+ if (!lhsDefOp)
1216
+ llvm_unreachable (" Can't find IV load" );
1217
+ if (!isIVLoad (lhsDefOp, IVAddr))
1218
+ llvm_unreachable (" cmpOp LHS is not IV" );
1219
+
1220
+ if (cmpOp.getKind () != mlir::cir::CmpOpKind::le &&
1221
+ cmpOp.getKind () != mlir::cir::CmpOpKind::lt)
1222
+ llvm_unreachable (" Not support lowering other than le or lt comparison" );
1223
+
1224
+ return cmpOp;
1225
+ }
1226
+
1227
+ static int64_t getConstant (mlir::cir::ConstantOp op) {
1228
+ auto attr = op->getAttrs ().front ().getValue ();
1229
+ const auto IntAttr = attr.dyn_cast <mlir::cir::IntAttr>();
1230
+ return IntAttr.getValue ().getSExtValue ();
1231
+ }
1232
+
1233
+ mlir::Value SCFLoop::plusConstant (mlir::Value V, mlir::Location loc,
1234
+ int addend) {
1235
+ auto type = V.getType ();
1236
+ auto c1 = rewriter->create <mlir::arith::ConstantOp>(
1237
+ loc, type, mlir::IntegerAttr::get (type, addend));
1238
+ return rewriter->create <mlir::arith::AddIOp>(loc, V, c1);
1239
+ }
1240
+
1241
+ mlir::Value SCFLoop::findIVInitValue () {
1242
+ auto remapAddr = rewriter->getRemappedValue (IVAddr);
1243
+ if (!remapAddr)
1244
+ return nullptr ;
1245
+ if (!remapAddr.hasOneUse ())
1246
+ return nullptr ;
1247
+ auto memrefStore = dyn_cast<mlir::memref::StoreOp>(*remapAddr.user_begin ());
1248
+ if (!memrefStore)
1249
+ return nullptr ;
1250
+ return memrefStore->getOperand (0 );
1251
+ }
1252
+
1253
+ void SCFLoop::analysis () {
1254
+ step = findStepAndIV (IVAddr);
1255
+ cmpOp = findCmpOp ();
1256
+ auto IVInit = findIVInitValue ();
1257
+ auto IVEndBound = rewriter->getRemappedValue (cmpOp.getRhs ());
1258
+ assert (IVEndBound && " can't find IV end boundary" );
1259
+
1260
+ if (step > 0 ) {
1261
+ lowerBound = IVInit;
1262
+ if (cmpOp.getKind () == mlir::cir::CmpOpKind::lt)
1263
+ upperBound = IVEndBound;
1264
+ else if (cmpOp.getKind () == mlir::cir::CmpOpKind::le)
1265
+ upperBound = plusConstant (IVEndBound, cmpOp.getLoc (), 1 );
1266
+ }
1267
+ assert (lowerBound && " can't find loop lower bound" );
1268
+ assert (upperBound && " can't find loop upper bound" );
1269
+ }
1270
+
1271
+ static bool isInLoopBody (mlir::Operation *op) {
1272
+ mlir::Operation *parentOp = op->getParentOp ();
1273
+ if (!parentOp)
1274
+ return false ;
1275
+ if (isa<mlir::scf::ForOp>(parentOp))
1276
+ return true ;
1277
+ auto forOp = dyn_cast<mlir::cir::ForOp>(parentOp);
1278
+ if (forOp && (&forOp.getBody () == op->getParentRegion ()))
1279
+ return true ;
1280
+ return false ;
1281
+ }
1282
+
1283
+ void SCFLoop::transferToSCFForOp () {
1284
+ auto ub = getUpperBound ();
1285
+ auto lb = getLowerBound ();
1286
+ auto loc = forOp.getLoc ();
1287
+ auto type = lb.getType ();
1288
+ auto step = rewriter->create <mlir::arith::ConstantOp>(
1289
+ loc, type, mlir::IntegerAttr::get (type, getStep ()));
1290
+ auto scfForOp = rewriter->create <mlir::scf::ForOp>(loc, lb, ub, step);
1291
+ SmallVector<mlir::Value> bbArg;
1292
+ rewriter->eraseOp (&scfForOp.getBody ()->back ());
1293
+ rewriter->inlineBlockBefore (&forOp.getBody ().front (), scfForOp.getBody (),
1294
+ scfForOp.getBody ()->end (), bbArg);
1295
+ scfForOp->walk <mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1296
+ if (isa<mlir::cir::BreakOp>(op) || isa<mlir::cir::ContinueOp>(op) ||
1297
+ isa<mlir::cir::IfOp>(op))
1298
+ llvm_unreachable (
1299
+ " Not support lowering loop with break, continue or if yet" );
1300
+ // Replace the IV usage to scf loop induction variable.
1301
+ if (isIVLoad (op, IVAddr)) {
1302
+ auto newIV = scfForOp.getInductionVar ();
1303
+ op->getResult (0 ).replaceAllUsesWith (newIV);
1304
+ // Only erase the IV load in the loop body because all the operations
1305
+ // in loop step and condition regions will be erased.
1306
+ if (isInLoopBody (op))
1307
+ rewriter->eraseOp (op);
1308
+ }
1309
+ return mlir::WalkResult::advance ();
1310
+ });
1311
+ }
1312
+
1313
+ class CIRForOpLowering : public mlir ::OpConversionPattern<mlir::cir::ForOp> {
1314
+ public:
1315
+ using OpConversionPattern<mlir::cir::ForOp>::OpConversionPattern;
1316
+
1317
+ mlir::LogicalResult
1318
+ matchAndRewrite (mlir::cir::ForOp op, OpAdaptor adaptor,
1319
+ mlir::ConversionPatternRewriter &rewriter) const override {
1320
+ SCFLoop loop (op, &rewriter, getTypeConverter ());
1321
+ loop.analysis ();
1322
+ loop.transferToSCFForOp ();
1323
+ rewriter.eraseOp (op);
1324
+ return mlir::success ();
1325
+ }
1326
+ };
1327
+
1122
1328
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1123
1329
mlir::TypeConverter &converter) {
1124
1330
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
1125
1331
1126
- patterns
1127
- . add <CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering ,
1128
- CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering ,
1129
- CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering ,
1130
- CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering ,
1131
- CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering ,
1132
- CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering ,
1133
- CIRSqrtOpLowering, CIRCeilOpLowering, CIRExp2OpLowering ,
1134
- CIRExpOpLowering, CIRFAbsOpLowering, CIRFloorOpLowering ,
1135
- CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering ,
1136
- CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering>(
1137
- converter, patterns.getContext ());
1332
+ patterns. add <CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
1333
+ CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering ,
1334
+ CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering ,
1335
+ CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering ,
1336
+ CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering ,
1337
+ CIRGetGlobalOpLowering, CIRCastOpLowering ,
1338
+ CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering ,
1339
+ CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering ,
1340
+ CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering ,
1341
+ CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering ,
1342
+ CIRSinOpLowering, CIRForOpLowering>(converter,
1343
+ patterns.getContext ());
1138
1344
}
1139
1345
1140
1346
static mlir::TypeConverter prepareTypeConverter () {
0 commit comments