Skip to content

Commit 80b3643

Browse files
committed
[CIR][ThroughMLIR] Support lowering ForOp to scf
This commit introduces CIRForOpLowering for lowering to scf. The initial commit only support increment loop with lt or le comparison.
1 parent 43094d7 commit 80b3643

File tree

2 files changed

+439
-13
lines changed

2 files changed

+439
-13
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+219-13
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ class CIRYieldOpLowering
802802
mlir::ConversionPatternRewriter &rewriter) const override {
803803
auto *parentOp = op->getParentOp();
804804
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
805-
.Case<mlir::scf::IfOp>([&](auto) {
805+
.Case<mlir::scf::IfOp, mlir::scf::ForOp>([&](auto) {
806806
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(
807807
op, adaptor.getOperands());
808808
return mlir::success();
@@ -1119,22 +1119,228 @@ class CIRPtrStrideOpLowering
11191119
}
11201120
};
11211121

1122+
class SCFLoop {
1123+
public:
1124+
SCFLoop(mlir::cir::ForOp op, mlir::ConversionPatternRewriter *rewriter,
1125+
const mlir::TypeConverter *converter)
1126+
: forOp(op), rewriter(rewriter), converter(converter) {}
1127+
1128+
int getStep() { return step; }
1129+
mlir::Value getLowerBound() { return lowerBound; }
1130+
mlir::Value getUpperBound() { return upperBound; }
1131+
1132+
int findStepAndIV(mlir::Value &addr);
1133+
mlir::cir::CmpOp findCmpOp();
1134+
mlir::Value findIVInitValue();
1135+
void analysis();
1136+
1137+
mlir::Value plusConstant(mlir::Value V, mlir::Location loc, int addend);
1138+
void transferToSCFForOp();
1139+
1140+
private:
1141+
mlir::cir::ForOp forOp;
1142+
mlir::cir::CmpOp cmpOp;
1143+
mlir::Value IVAddr, lowerBound = nullptr, upperBound = nullptr;
1144+
mlir::ConversionPatternRewriter *rewriter;
1145+
const mlir::TypeConverter *converter;
1146+
int step = 0;
1147+
};
1148+
1149+
int SCFLoop::findStepAndIV(mlir::Value &addr) {
1150+
auto *stepBlock =
1151+
(forOp.maybeGetStep() ? &forOp.maybeGetStep()->front() : nullptr);
1152+
assert(stepBlock && "Can not find step block");
1153+
1154+
int step = 0;
1155+
// Try to match "IV load addr; ++IV; store IV, addr" to find step.
1156+
for (mlir::Operation &op : *stepBlock)
1157+
if (auto loadOp = dyn_cast<mlir::cir::LoadOp>(op)) {
1158+
addr = loadOp.getAddr();
1159+
} else if (auto cop = dyn_cast<mlir::cir::ConstantOp>(op)) {
1160+
auto attr = cop->getAttrs().front().getValue();
1161+
const auto IntAttr = attr.dyn_cast<mlir::cir::IntAttr>();
1162+
step = IntAttr.getValue().getSExtValue();
1163+
} else if (auto bop = dyn_cast<mlir::cir::BinOp>(op)) {
1164+
if (bop.getKind() == mlir::cir::BinOpKind::Sub)
1165+
llvm_unreachable("Not support decrement step yet");
1166+
else if (bop.getKind() != mlir::cir::BinOpKind::Add)
1167+
llvm_unreachable("Not support the BinOp in step calculation yet");
1168+
} else if (auto uop = dyn_cast<mlir::cir::UnaryOp>(op)) {
1169+
if (uop.getKind() == mlir::cir::UnaryOpKind::Inc)
1170+
step = 1;
1171+
else if (uop.getKind() == mlir::cir::UnaryOpKind::Dec)
1172+
llvm_unreachable("Not support decrement step yet");
1173+
} else if (auto storeOp = dyn_cast<mlir::cir::StoreOp>(op)) {
1174+
assert(storeOp.getAddr() == addr && "Can't find IV when lowering ForOp");
1175+
}
1176+
assert(step && "Can't find step when lowering ForOp");
1177+
1178+
return step;
1179+
}
1180+
1181+
static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
1182+
if (!op)
1183+
return false;
1184+
if (isa<mlir::cir::LoadOp>(op)) {
1185+
if (!op->getOperand(0))
1186+
return false;
1187+
if (op->getOperand(0) == IVAddr)
1188+
return true;
1189+
}
1190+
return false;
1191+
}
1192+
1193+
mlir::cir::CmpOp SCFLoop::findCmpOp() {
1194+
cmpOp = nullptr;
1195+
for (auto *user : IVAddr.getUsers()) {
1196+
if (user->getParentRegion() != &forOp.getCond())
1197+
continue;
1198+
if (auto loadOp = dyn_cast<mlir::cir::LoadOp>(*user)) {
1199+
if (!loadOp->hasOneUse())
1200+
continue;
1201+
if (auto op = dyn_cast<mlir::cir::CmpOp>(*loadOp->user_begin())) {
1202+
cmpOp = op;
1203+
break;
1204+
}
1205+
}
1206+
}
1207+
if (!cmpOp)
1208+
llvm_unreachable("Can't find loop CmpOp");
1209+
1210+
auto type = cmpOp.getLhs().getType();
1211+
if (!type.isa<mlir::cir::IntType>())
1212+
llvm_unreachable("Non-integer type IV is not supported");
1213+
1214+
auto lhsDefOp = cmpOp.getLhs().getDefiningOp();
1215+
if (!lhsDefOp)
1216+
llvm_unreachable("Can't find IV load");
1217+
if (!isIVLoad(lhsDefOp, IVAddr))
1218+
llvm_unreachable("cmpOp LHS is not IV");
1219+
1220+
if (cmpOp.getKind() != mlir::cir::CmpOpKind::le &&
1221+
cmpOp.getKind() != mlir::cir::CmpOpKind::lt)
1222+
llvm_unreachable("Not support lowering other than le or lt comparison");
1223+
1224+
return cmpOp;
1225+
}
1226+
1227+
static int64_t getConstant(mlir::cir::ConstantOp op) {
1228+
auto attr = op->getAttrs().front().getValue();
1229+
const auto IntAttr = attr.dyn_cast<mlir::cir::IntAttr>();
1230+
return IntAttr.getValue().getSExtValue();
1231+
}
1232+
1233+
mlir::Value SCFLoop::plusConstant(mlir::Value V, mlir::Location loc,
1234+
int addend) {
1235+
auto type = V.getType();
1236+
auto c1 = rewriter->create<mlir::arith::ConstantOp>(
1237+
loc, type, mlir::IntegerAttr::get(type, addend));
1238+
return rewriter->create<mlir::arith::AddIOp>(loc, V, c1);
1239+
}
1240+
1241+
mlir::Value SCFLoop::findIVInitValue() {
1242+
auto remapAddr = rewriter->getRemappedValue(IVAddr);
1243+
if (!remapAddr)
1244+
return nullptr;
1245+
if (!remapAddr.hasOneUse())
1246+
return nullptr;
1247+
auto memrefStore = dyn_cast<mlir::memref::StoreOp>(*remapAddr.user_begin());
1248+
if (!memrefStore)
1249+
return nullptr;
1250+
return memrefStore->getOperand(0);
1251+
}
1252+
1253+
void SCFLoop::analysis() {
1254+
step = findStepAndIV(IVAddr);
1255+
cmpOp = findCmpOp();
1256+
auto IVInit = findIVInitValue();
1257+
auto IVEndBound = rewriter->getRemappedValue(cmpOp.getRhs());
1258+
assert(IVEndBound && "can't find IV end boundary");
1259+
1260+
if (step > 0) {
1261+
lowerBound = IVInit;
1262+
if (cmpOp.getKind() == mlir::cir::CmpOpKind::lt)
1263+
upperBound = IVEndBound;
1264+
else if (cmpOp.getKind() == mlir::cir::CmpOpKind::le)
1265+
upperBound = plusConstant(IVEndBound, cmpOp.getLoc(), 1);
1266+
}
1267+
assert(lowerBound && "can't find loop lower bound");
1268+
assert(upperBound && "can't find loop upper bound");
1269+
}
1270+
1271+
static bool isInLoopBody(mlir::Operation *op) {
1272+
mlir::Operation *parentOp = op->getParentOp();
1273+
if (!parentOp)
1274+
return false;
1275+
if (isa<mlir::scf::ForOp>(parentOp))
1276+
return true;
1277+
auto forOp = dyn_cast<mlir::cir::ForOp>(parentOp);
1278+
if (forOp && (&forOp.getBody() == op->getParentRegion()))
1279+
return true;
1280+
return false;
1281+
}
1282+
1283+
void SCFLoop::transferToSCFForOp() {
1284+
auto ub = getUpperBound();
1285+
auto lb = getLowerBound();
1286+
auto loc = forOp.getLoc();
1287+
auto type = lb.getType();
1288+
auto step = rewriter->create<mlir::arith::ConstantOp>(
1289+
loc, type, mlir::IntegerAttr::get(type, getStep()));
1290+
auto scfForOp = rewriter->create<mlir::scf::ForOp>(loc, lb, ub, step);
1291+
SmallVector<mlir::Value> bbArg;
1292+
rewriter->eraseOp(&scfForOp.getBody()->back());
1293+
rewriter->inlineBlockBefore(&forOp.getBody().front(), scfForOp.getBody(),
1294+
scfForOp.getBody()->end(), bbArg);
1295+
scfForOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
1296+
if (isa<mlir::cir::BreakOp>(op) || isa<mlir::cir::ContinueOp>(op) ||
1297+
isa<mlir::cir::IfOp>(op))
1298+
llvm_unreachable(
1299+
"Not support lowering loop with break, continue or if yet");
1300+
// Replace the IV usage to scf loop induction variable.
1301+
if (isIVLoad(op, IVAddr)) {
1302+
auto newIV = scfForOp.getInductionVar();
1303+
op->getResult(0).replaceAllUsesWith(newIV);
1304+
// Only erase the IV load in the loop body because all the operations
1305+
// in loop step and condition regions will be erased.
1306+
if (isInLoopBody(op))
1307+
rewriter->eraseOp(op);
1308+
}
1309+
return mlir::WalkResult::advance();
1310+
});
1311+
}
1312+
1313+
class CIRForOpLowering : public mlir::OpConversionPattern<mlir::cir::ForOp> {
1314+
public:
1315+
using OpConversionPattern<mlir::cir::ForOp>::OpConversionPattern;
1316+
1317+
mlir::LogicalResult
1318+
matchAndRewrite(mlir::cir::ForOp op, OpAdaptor adaptor,
1319+
mlir::ConversionPatternRewriter &rewriter) const override {
1320+
SCFLoop loop(op, &rewriter, getTypeConverter());
1321+
loop.analysis();
1322+
loop.transferToSCFForOp();
1323+
rewriter.eraseOp(op);
1324+
return mlir::success();
1325+
}
1326+
};
1327+
11221328
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
11231329
mlir::TypeConverter &converter) {
11241330
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
11251331

1126-
patterns
1127-
.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
1128-
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
1129-
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
1130-
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1131-
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1132-
CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
1133-
CIRSqrtOpLowering, CIRCeilOpLowering, CIRExp2OpLowering,
1134-
CIRExpOpLowering, CIRFAbsOpLowering, CIRFloorOpLowering,
1135-
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1136-
CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering>(
1137-
converter, patterns.getContext());
1332+
patterns.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
1333+
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
1334+
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
1335+
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
1336+
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
1337+
CIRGetGlobalOpLowering, CIRCastOpLowering,
1338+
CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1339+
CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1340+
CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
1341+
CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
1342+
CIRSinOpLowering, CIRForOpLowering>(converter,
1343+
patterns.getContext());
11381344
}
11391345

11401346
static mlir::TypeConverter prepareTypeConverter() {

0 commit comments

Comments
 (0)