Skip to content

Commit 1b3db30

Browse files
ShivaChenlanza
authored andcommitted
[CIR][ThroughMLIR] Support lowering ForOp to scf (#605)
This commit introduces CIRForOpLowering for lowering to scf. The initial commit only support increment loop with lt or le comparison.
1 parent 3541487 commit 1b3db30

File tree

5 files changed

+503
-1
lines changed

5 files changed

+503
-1
lines changed

clang/include/clang/CIR/LowerToMLIR.h

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//====- LowerToMLIR.h- Lowering from CIR to MLIR --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares functions for lowering CIR modules to MLIR.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
#ifndef CLANG_CIR_LOWERTOMLIR_H
13+
#define CLANG_CIR_LOWERTOMLIR_H
14+
15+
namespace cir {
16+
17+
void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
18+
mlir::TypeConverter &converter);
19+
} // namespace cir
20+
21+
#endif // CLANG_CIR_LOWERTOMLIR_H_

clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set(LLVM_LINK_COMPONENTS
66
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
77

88
add_clang_library(clangCIRLoweringThroughMLIR
9+
LowerCIRLoopToSCF.cpp
910
LowerCIRToMLIR.cpp
1011
LowerMLIRToLLVM.cpp
1112

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
//====- LowerCIRLoopToSCF.cpp - Lowering from CIR Loop to SCF -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements lowering of CIR loop operations to SCF.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Arith/IR/Arith.h"
14+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
15+
#include "mlir/Dialect/SCF/IR/SCF.h"
16+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
17+
#include "mlir/IR/BuiltinDialect.h"
18+
#include "mlir/IR/BuiltinTypes.h"
19+
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Pass/PassManager.h"
21+
#include "mlir/Support/LogicalResult.h"
22+
#include "mlir/Transforms/DialectConversion.h"
23+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
24+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
25+
#include "clang/CIR/LowerToMLIR.h"
26+
#include "clang/CIR/Passes.h"
27+
28+
using namespace cir;
29+
using namespace llvm;
30+
31+
namespace cir {
32+
33+
class SCFLoop {
34+
public:
35+
SCFLoop(mlir::cir::ForOp op, mlir::ConversionPatternRewriter *rewriter)
36+
: forOp(op), rewriter(rewriter) {}
37+
38+
int64_t getStep() { return step; }
39+
mlir::Value getLowerBound() { return lowerBound; }
40+
mlir::Value getUpperBound() { return upperBound; }
41+
42+
int64_t findStepAndIV(mlir::Value &addr);
43+
mlir::cir::CmpOp findCmpOp();
44+
mlir::Value findIVInitValue();
45+
void analysis();
46+
47+
mlir::Value plusConstant(mlir::Value V, mlir::Location loc, int addend);
48+
void transferToSCFForOp();
49+
50+
private:
51+
mlir::cir::ForOp forOp;
52+
mlir::cir::CmpOp cmpOp;
53+
mlir::Value IVAddr, lowerBound = nullptr, upperBound = nullptr;
54+
mlir::ConversionPatternRewriter *rewriter;
55+
int64_t step = 0;
56+
};
57+
58+
static int64_t getConstant(mlir::cir::ConstantOp op) {
59+
auto attr = op->getAttrs().front().getValue();
60+
const auto IntAttr = attr.dyn_cast<mlir::cir::IntAttr>();
61+
return IntAttr.getValue().getSExtValue();
62+
}
63+
64+
int64_t SCFLoop::findStepAndIV(mlir::Value &addr) {
65+
auto *stepBlock =
66+
(forOp.maybeGetStep() ? &forOp.maybeGetStep()->front() : nullptr);
67+
assert(stepBlock && "Can not find step block");
68+
69+
int64_t step = 0;
70+
mlir::Value IV = nullptr;
71+
// Try to match "IV load addr; ++IV; store IV, addr" to find step.
72+
for (mlir::Operation &op : *stepBlock)
73+
if (auto loadOp = dyn_cast<mlir::cir::LoadOp>(op)) {
74+
addr = loadOp.getAddr();
75+
IV = loadOp.getResult();
76+
} else if (auto cop = dyn_cast<mlir::cir::ConstantOp>(op)) {
77+
if (step)
78+
llvm_unreachable(
79+
"Not support multiple constant in step calculation yet");
80+
step = getConstant(cop);
81+
} else if (auto bop = dyn_cast<mlir::cir::BinOp>(op)) {
82+
if (bop.getLhs() != IV)
83+
llvm_unreachable("Find BinOp not operate on IV");
84+
if (bop.getKind() != mlir::cir::BinOpKind::Add)
85+
llvm_unreachable(
86+
"Not support BinOp other than Add in step calculation yet");
87+
} else if (auto uop = dyn_cast<mlir::cir::UnaryOp>(op)) {
88+
if (uop.getInput() != IV)
89+
llvm_unreachable("Find UnaryOp not operate on IV");
90+
if (uop.getKind() == mlir::cir::UnaryOpKind::Inc)
91+
step = 1;
92+
else if (uop.getKind() == mlir::cir::UnaryOpKind::Dec)
93+
llvm_unreachable("Not support decrement step yet");
94+
} else if (auto storeOp = dyn_cast<mlir::cir::StoreOp>(op)) {
95+
assert(storeOp.getAddr() == addr && "Can't find IV when lowering ForOp");
96+
}
97+
assert(step && "Can't find step when lowering ForOp");
98+
99+
return step;
100+
}
101+
102+
static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
103+
if (!op)
104+
return false;
105+
if (isa<mlir::cir::LoadOp>(op)) {
106+
if (!op->getOperand(0))
107+
return false;
108+
if (op->getOperand(0) == IVAddr)
109+
return true;
110+
}
111+
return false;
112+
}
113+
114+
mlir::cir::CmpOp SCFLoop::findCmpOp() {
115+
cmpOp = nullptr;
116+
for (auto *user : IVAddr.getUsers()) {
117+
if (user->getParentRegion() != &forOp.getCond())
118+
continue;
119+
if (auto loadOp = dyn_cast<mlir::cir::LoadOp>(*user)) {
120+
if (!loadOp->hasOneUse())
121+
continue;
122+
if (auto op = dyn_cast<mlir::cir::CmpOp>(*loadOp->user_begin())) {
123+
cmpOp = op;
124+
break;
125+
}
126+
}
127+
}
128+
if (!cmpOp)
129+
llvm_unreachable("Can't find loop CmpOp");
130+
131+
auto type = cmpOp.getLhs().getType();
132+
if (!type.isa<mlir::cir::IntType>())
133+
llvm_unreachable("Non-integer type IV is not supported");
134+
135+
auto lhsDefOp = cmpOp.getLhs().getDefiningOp();
136+
if (!lhsDefOp)
137+
llvm_unreachable("Can't find IV load");
138+
if (!isIVLoad(lhsDefOp, IVAddr))
139+
llvm_unreachable("cmpOp LHS is not IV");
140+
141+
if (cmpOp.getKind() != mlir::cir::CmpOpKind::le &&
142+
cmpOp.getKind() != mlir::cir::CmpOpKind::lt)
143+
llvm_unreachable("Not support lowering other than le or lt comparison");
144+
145+
return cmpOp;
146+
}
147+
148+
mlir::Value SCFLoop::plusConstant(mlir::Value V, mlir::Location loc,
149+
int addend) {
150+
auto type = V.getType();
151+
auto c1 = rewriter->create<mlir::arith::ConstantOp>(
152+
loc, type, mlir::IntegerAttr::get(type, addend));
153+
return rewriter->create<mlir::arith::AddIOp>(loc, V, c1);
154+
}
155+
156+
// Return IV initial value by searching the store before the loop.
157+
// The operations before the loop have been transferred to MLIR.
158+
// So we need to go through getRemappedValue to find the value.
159+
mlir::Value SCFLoop::findIVInitValue() {
160+
auto remapAddr = rewriter->getRemappedValue(IVAddr);
161+
if (!remapAddr)
162+
return nullptr;
163+
if (!remapAddr.hasOneUse())
164+
return nullptr;
165+
auto memrefStore = dyn_cast<mlir::memref::StoreOp>(*remapAddr.user_begin());
166+
if (!memrefStore)
167+
return nullptr;
168+
return memrefStore->getOperand(0);
169+
}
170+
171+
void SCFLoop::analysis() {
172+
step = findStepAndIV(IVAddr);
173+
cmpOp = findCmpOp();
174+
auto IVInit = findIVInitValue();
175+
// The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare.
176+
// So we could get the value by getRemappedValue.
177+
auto IVEndBound = rewriter->getRemappedValue(cmpOp.getRhs());
178+
// If the loop end bound is not loop invariant and can't be hoisted.
179+
// The following assertion will be triggerred.
180+
assert(IVEndBound && "can't find IV end boundary");
181+
182+
if (step > 0) {
183+
lowerBound = IVInit;
184+
if (cmpOp.getKind() == mlir::cir::CmpOpKind::lt)
185+
upperBound = IVEndBound;
186+
else if (cmpOp.getKind() == mlir::cir::CmpOpKind::le)
187+
upperBound = plusConstant(IVEndBound, cmpOp.getLoc(), 1);
188+
}
189+
assert(lowerBound && "can't find loop lower bound");
190+
assert(upperBound && "can't find loop upper bound");
191+
}
192+
193+
// Return true if op operation is in the loop body.
194+
static bool isInLoopBody(mlir::Operation *op) {
195+
mlir::Operation *parentOp = op->getParentOp();
196+
if (!parentOp)
197+
return false;
198+
if (isa<mlir::scf::ForOp>(parentOp))
199+
return true;
200+
auto forOp = dyn_cast<mlir::cir::ForOp>(parentOp);
201+
if (forOp && (&forOp.getBody() == op->getParentRegion()))
202+
return true;
203+
return false;
204+
}
205+
206+
void SCFLoop::transferToSCFForOp() {
207+
auto ub = getUpperBound();
208+
auto lb = getLowerBound();
209+
auto loc = forOp.getLoc();
210+
auto type = lb.getType();
211+
auto step = rewriter->create<mlir::arith::ConstantOp>(
212+
loc, type, mlir::IntegerAttr::get(type, getStep()));
213+
auto scfForOp = rewriter->create<mlir::scf::ForOp>(loc, lb, ub, step);
214+
SmallVector<mlir::Value> bbArg;
215+
rewriter->eraseOp(&scfForOp.getBody()->back());
216+
rewriter->inlineBlockBefore(&forOp.getBody().front(), scfForOp.getBody(),
217+
scfForOp.getBody()->end(), bbArg);
218+
scfForOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
219+
if (isa<mlir::cir::BreakOp>(op) || isa<mlir::cir::ContinueOp>(op) ||
220+
isa<mlir::cir::IfOp>(op))
221+
llvm_unreachable(
222+
"Not support lowering loop with break, continue or if yet");
223+
// Replace the IV usage to scf loop induction variable.
224+
if (isIVLoad(op, IVAddr)) {
225+
auto newIV = scfForOp.getInductionVar();
226+
op->getResult(0).replaceAllUsesWith(newIV);
227+
// Only erase the IV load in the loop body because all the operations
228+
// in loop step and condition regions will be erased.
229+
if (isInLoopBody(op))
230+
rewriter->eraseOp(op);
231+
}
232+
return mlir::WalkResult::advance();
233+
});
234+
}
235+
236+
class CIRForOpLowering : public mlir::OpConversionPattern<mlir::cir::ForOp> {
237+
public:
238+
using OpConversionPattern<mlir::cir::ForOp>::OpConversionPattern;
239+
240+
mlir::LogicalResult
241+
matchAndRewrite(mlir::cir::ForOp op, OpAdaptor adaptor,
242+
mlir::ConversionPatternRewriter &rewriter) const override {
243+
SCFLoop loop(op, &rewriter);
244+
loop.analysis();
245+
loop.transferToSCFForOp();
246+
rewriter.eraseOp(op);
247+
return mlir::success();
248+
}
249+
};
250+
251+
void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
252+
mlir::TypeConverter &converter) {
253+
patterns.add<CIRForOpLowering>(converter, patterns.getContext());
254+
}
255+
256+
} // namespace cir

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "mlir/Transforms/DialectConversion.h"
4141
#include "clang/CIR/Dialect/IR/CIRDialect.h"
4242
#include "clang/CIR/Dialect/IR/CIRTypes.h"
43+
#include "clang/CIR/LowerToMLIR.h"
4344
#include "clang/CIR/Passes.h"
4445
#include "llvm/ADT/Sequence.h"
4546
#include "llvm/ADT/TypeSwitch.h"
@@ -802,7 +803,7 @@ class CIRYieldOpLowering
802803
mlir::ConversionPatternRewriter &rewriter) const override {
803804
auto *parentOp = op->getParentOp();
804805
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
805-
.Case<mlir::scf::IfOp>([&](auto) {
806+
.Case<mlir::scf::IfOp, mlir::scf::ForOp>([&](auto) {
806807
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(
807808
op, adaptor.getOperands());
808809
return mlir::success();
@@ -1199,6 +1200,7 @@ void ConvertCIRToMLIRPass::runOnOperation() {
11991200

12001201
mlir::RewritePatternSet patterns(&getContext());
12011202

1203+
populateCIRLoopToSCFConversionPatterns(patterns, converter);
12021204
populateCIRToMLIRConversionPatterns(patterns, converter);
12031205

12041206
mlir::ConversionTarget target(getContext());

0 commit comments

Comments
 (0)