Skip to content

Commit 0766c90

Browse files
committed
[CIR] Add -cir-mlir-scf-prepare to simplify lowering to SCF
This commit introduces SCFPreparePass to 1) Canonicalize IV to LHS of loop comparison For example, transfer cir.cmp(gt, %bound, %IV) to cir.cmp(lt, %IV, %bound). So we could use RHS as boundary and use lt to determine it's an upper bound. 2) Hoist loop invariant operations in condition block out of loop The condition block may be generated as following which contains the operations produced upper bound. SCF for loop required loop boundary as input operands. So we need to hoist the boundary operations out of loop. cir.for : cond { %4 = cir.load %2 : !cir.ptr<!s32i>, !s32i %5 = cir.const #cir.int<100> : !s32i <- upper bound %6 = cir.cmp(lt, %4, %5) : !s32i, !s32i %7 = cir.cast(int_to_bool, %6 : !s32i), !cir.bool cir.condition(%7 } body {
1 parent 5730174 commit 0766c90

File tree

9 files changed

+400
-17
lines changed

9 files changed

+400
-17
lines changed

clang/include/clang/CIR/CIRToCIRPasses.h

+7-8
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,13 @@ class ModuleOp;
2828
namespace cir {
2929

3030
// Run set of cleanup/prepare/etc passes CIR <-> CIR.
31-
mlir::LogicalResult
32-
runCIRToCIRPasses(mlir::ModuleOp theModule, mlir::MLIRContext *mlirCtx,
33-
clang::ASTContext &astCtx, bool enableVerifier,
34-
bool enableLifetime, llvm::StringRef lifetimeOpts,
35-
bool enableIdiomRecognizer,
36-
llvm::StringRef idiomRecognizerOpts, bool enableLibOpt,
37-
llvm::StringRef libOptOpts,
38-
std::string &passOptParsingFailure, bool flattenCIR);
31+
mlir::LogicalResult runCIRToCIRPasses(
32+
mlir::ModuleOp theModule, mlir::MLIRContext *mlirCtx,
33+
clang::ASTContext &astCtx, bool enableVerifier, bool enableLifetime,
34+
llvm::StringRef lifetimeOpts, bool enableIdiomRecognizer,
35+
llvm::StringRef idiomRecognizerOpts, bool enableLibOpt,
36+
llvm::StringRef libOptOpts, std::string &passOptParsingFailure,
37+
bool flattenCIR, bool emitMLIR);
3938

4039
} // namespace cir
4140

clang/include/clang/CIR/Dialect/Passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ std::unique_ptr<Pass> createLifetimeCheckPass(ArrayRef<StringRef> remark,
2828
clang::ASTContext *astCtx);
2929
std::unique_ptr<Pass> createMergeCleanupsPass();
3030
std::unique_ptr<Pass> createDropASTPass();
31+
std::unique_ptr<Pass> createSCFPreparePass();
3132
std::unique_ptr<Pass> createLoweringPreparePass();
3233
std::unique_ptr<Pass> createLoweringPreparePass(clang::ASTContext *astCtx);
3334
std::unique_ptr<Pass> createIdiomRecognizerPass();

clang/include/clang/CIR/Dialect/Passes.td

+10
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ def LoweringPrepare : Pass<"cir-lowering-prepare"> {
7575
let dependentDialects = ["cir::CIRDialect"];
7676
}
7777

78+
def SCFPrepare : Pass<"cir-mlir-scf-prepare"> {
79+
let summary = "Preparation work before lowering to SCF dialect";
80+
let description = [{
81+
This pass does preparation work for SCF lowering. For example, it may
82+
hoist the loop invariant or canonicalize the loop comparison.
83+
}];
84+
let constructor = "mlir::createSCFPreparePass()";
85+
let dependentDialects = ["cir::CIRDialect"];
86+
}
87+
7888
def FlattenCFG : Pass<"cir-flatten-cfg"> {
7989
let summary = "Produces flatten cfg";
8090
let description = [{

clang/lib/CIR/CodeGen/CIRPasses.cpp

+10-8
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
#include "mlir/Support/LogicalResult.h"
2020

2121
namespace cir {
22-
mlir::LogicalResult
23-
runCIRToCIRPasses(mlir::ModuleOp theModule, mlir::MLIRContext *mlirCtx,
24-
clang::ASTContext &astCtx, bool enableVerifier,
25-
bool enableLifetime, llvm::StringRef lifetimeOpts,
26-
bool enableIdiomRecognizer,
27-
llvm::StringRef idiomRecognizerOpts, bool enableLibOpt,
28-
llvm::StringRef libOptOpts,
29-
std::string &passOptParsingFailure, bool flattenCIR) {
22+
mlir::LogicalResult runCIRToCIRPasses(
23+
mlir::ModuleOp theModule, mlir::MLIRContext *mlirCtx,
24+
clang::ASTContext &astCtx, bool enableVerifier, bool enableLifetime,
25+
llvm::StringRef lifetimeOpts, bool enableIdiomRecognizer,
26+
llvm::StringRef idiomRecognizerOpts, bool enableLibOpt,
27+
llvm::StringRef libOptOpts, std::string &passOptParsingFailure,
28+
bool flattenCIR, bool emitMLIR) {
3029
mlir::PassManager pm(mlirCtx);
3130
pm.addPass(mlir::createMergeCleanupsPass());
3231

@@ -68,6 +67,9 @@ runCIRToCIRPasses(mlir::ModuleOp theModule, mlir::MLIRContext *mlirCtx,
6867
if (flattenCIR)
6968
mlir::populateCIRPreLoweringPasses(pm);
7069

70+
if (emitMLIR)
71+
pm.addPass(mlir::createSCFPreparePass());
72+
7173
// FIXME: once CIRCodenAction fixes emission other than CIR we
7274
// need to run this right before dialect emission.
7375
pm.addPass(mlir::createDropASTPass());

clang/lib/CIR/Dialect/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_clang_library(MLIRCIRTransforms
99
StdHelpers.cpp
1010
FlattenCFG.cpp
1111
GotoSolver.cpp
12+
SCFPrepare.cpp
1213

1314
DEPENDS
1415
MLIRCIRPassIncGen
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
//===- SCFPrepare.cpp - pareparation work for SCF lowering ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "PassDetail.h"
10+
#include "mlir/IR/PatternMatch.h"
11+
#include "mlir/Support/LogicalResult.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
14+
#include "clang/CIR/Dialect/Passes.h"
15+
16+
using namespace mlir;
17+
using namespace cir;
18+
19+
//===----------------------------------------------------------------------===//
20+
// Rewrite patterns
21+
//===----------------------------------------------------------------------===//
22+
23+
namespace {
24+
25+
static Value findIVAddr(Block *step) {
26+
Value IVAddr = nullptr;
27+
for (Operation &op : *step) {
28+
if (auto loadOp = dyn_cast<LoadOp>(op))
29+
IVAddr = loadOp.getAddr();
30+
else if (auto storeOp = dyn_cast<StoreOp>(op))
31+
if (IVAddr != storeOp.getAddr())
32+
return nullptr;
33+
}
34+
return IVAddr;
35+
}
36+
37+
static CmpOp findLoopCmpAndIV(Block *cond, Value IVAddr, Value &IV) {
38+
Operation *IVLoadOp = nullptr;
39+
for (Operation &op : *cond) {
40+
if (auto loadOp = dyn_cast<LoadOp>(op))
41+
if (loadOp.getAddr() == IVAddr) {
42+
IVLoadOp = &op;
43+
break;
44+
}
45+
}
46+
if (!IVLoadOp)
47+
return nullptr;
48+
if (!IVLoadOp->hasOneUse())
49+
return nullptr;
50+
IV = IVLoadOp->getResult(0);
51+
return dyn_cast<CmpOp>(*IVLoadOp->user_begin());
52+
}
53+
54+
// Canonicalize IV to LHS of loop comparison
55+
// For example, transfer cir.cmp(gt, %bound, %IV) to cir.cmp(lt, %IV, %bound).
56+
// So we could use RHS as boundary and use lt to determine it's an upper bound.
57+
struct canonicalizeIVtoCmpLHS : public OpRewritePattern<ForOp> {
58+
using OpRewritePattern<ForOp>::OpRewritePattern;
59+
60+
CmpOpKind swapCmpKind(CmpOpKind kind) const {
61+
switch (kind) {
62+
case CmpOpKind::gt:
63+
return CmpOpKind::lt;
64+
case CmpOpKind::ge:
65+
return CmpOpKind::le;
66+
case CmpOpKind::lt:
67+
return CmpOpKind::gt;
68+
case CmpOpKind::le:
69+
return CmpOpKind::ge;
70+
default:
71+
break;
72+
}
73+
return kind;
74+
}
75+
76+
void replaceWithNewCmpOp(CmpOp oldCmp, CmpOpKind newKind, Value lhs,
77+
Value rhs, PatternRewriter &rewriter) const {
78+
rewriter.setInsertionPointAfter(oldCmp.getOperation());
79+
auto newCmp = rewriter.create<mlir::cir::CmpOp>(
80+
oldCmp.getLoc(), oldCmp.getType(), newKind, lhs, rhs);
81+
oldCmp->replaceAllUsesWith(newCmp);
82+
oldCmp->erase();
83+
}
84+
85+
LogicalResult matchAndRewrite(ForOp op,
86+
PatternRewriter &rewriter) const final {
87+
auto *cond = &op.getCond().front();
88+
auto *step = (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);
89+
if (!step)
90+
return failure();
91+
Value IVAddr = findIVAddr(step);
92+
if (!IVAddr)
93+
return failure();
94+
Value IV = nullptr;
95+
auto loopCmp = findLoopCmpAndIV(cond, IVAddr, IV);
96+
if (!loopCmp || !IV)
97+
return failure();
98+
99+
CmpOpKind cmpKind = loopCmp.getKind();
100+
Value cmpRhs = loopCmp.getRhs();
101+
// Canonicalize IV to LHS of loop Cmp.
102+
if (loopCmp.getLhs() != IV) {
103+
cmpKind = swapCmpKind(cmpKind);
104+
cmpRhs = loopCmp.getLhs();
105+
replaceWithNewCmpOp(loopCmp, cmpKind, IV, cmpRhs, rewriter);
106+
return success();
107+
}
108+
109+
return failure();
110+
}
111+
};
112+
113+
// Hoist loop invariant operations in condition block out of loop
114+
// The condition block may be generated as following which contains the
115+
// operations produced upper bound.
116+
// SCF for loop required loop boundary as input operands. So we need to
117+
// hoist the boundary operations out of loop.
118+
//
119+
// cir.for : cond {
120+
// %4 = cir.load %2 : !cir.ptr<!s32i>, !s32i
121+
// %5 = cir.const #cir.int<100> : !s32i <- upper bound
122+
// %6 = cir.cmp(lt, %4, %5) : !s32i, !s32i
123+
// %7 = cir.cast(int_to_bool, %6 : !s32i), !cir.bool
124+
// cir.condition(%7
125+
// } body {
126+
struct hoistLoopInvariantInCondBlock : public OpRewritePattern<ForOp> {
127+
using OpRewritePattern<ForOp>::OpRewritePattern;
128+
129+
bool isLoopInvariantLoad(Operation *op, ForOp forOp) const {
130+
auto load = dyn_cast<LoadOp>(op);
131+
if (!load)
132+
return false;
133+
134+
auto loadAddr = load.getAddr();
135+
auto result =
136+
forOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
137+
if (auto store = dyn_cast<StoreOp>(op)) {
138+
if (store.getAddr() == loadAddr)
139+
return mlir::WalkResult::interrupt();
140+
}
141+
return mlir::WalkResult::advance();
142+
});
143+
144+
if (result.wasInterrupted())
145+
return false;
146+
147+
return true;
148+
}
149+
150+
LogicalResult matchAndRewrite(ForOp forOp,
151+
PatternRewriter &rewriter) const final {
152+
auto *cond = &forOp.getCond().front();
153+
auto *step =
154+
(forOp.maybeGetStep() ? &forOp.maybeGetStep()->front() : nullptr);
155+
if (!step)
156+
return failure();
157+
Value IVAddr = findIVAddr(step);
158+
if (!IVAddr)
159+
return failure();
160+
Value IV = nullptr;
161+
auto loopCmp = findLoopCmpAndIV(cond, IVAddr, IV);
162+
if (!loopCmp || !IV)
163+
return failure();
164+
165+
Value cmpRhs = loopCmp.getRhs();
166+
auto defOp = cmpRhs.getDefiningOp();
167+
SmallVector<Operation *> ops;
168+
// Go through the cast if exist.
169+
if (defOp && isa<mlir::cir::CastOp>(defOp)) {
170+
ops.push_back(defOp);
171+
defOp = defOp->getOperand(0).getDefiningOp();
172+
}
173+
if (defOp &&
174+
(isa<ConstantOp>(defOp) || isLoopInvariantLoad(defOp, forOp))) {
175+
ops.push_back(defOp);
176+
for (auto op : reverse(ops))
177+
op->moveBefore(forOp);
178+
return success();
179+
}
180+
181+
return failure();
182+
}
183+
};
184+
185+
//===----------------------------------------------------------------------===//
186+
// SCFPreparePass
187+
//===----------------------------------------------------------------------===//
188+
189+
struct SCFPreparePass : public SCFPrepareBase<SCFPreparePass> {
190+
using SCFPrepareBase::SCFPrepareBase;
191+
void runOnOperation() override;
192+
};
193+
194+
void populateSCFPreparePatterns(RewritePatternSet &patterns) {
195+
// clang-format off
196+
patterns.add<
197+
canonicalizeIVtoCmpLHS,
198+
hoistLoopInvariantInCondBlock
199+
>(patterns.getContext());
200+
// clang-format on
201+
}
202+
203+
void SCFPreparePass::runOnOperation() {
204+
// Collect rewrite patterns.
205+
RewritePatternSet patterns(&getContext());
206+
populateSCFPreparePatterns(patterns);
207+
208+
// Collect operations to apply patterns.
209+
SmallVector<Operation *, 16> ops;
210+
getOperation()->walk([&](Operation *op) {
211+
// CastOp here is to perform a manual `fold` in
212+
// applyOpPatternsAndFold
213+
if (isa<ForOp>(op))
214+
ops.push_back(op);
215+
});
216+
217+
// Apply patterns.
218+
if (applyOpPatternsAndFold(ops, std::move(patterns)).failed())
219+
signalPassFailure();
220+
}
221+
222+
} // namespace
223+
224+
std::unique_ptr<Pass> mlir::createSCFPreparePass() {
225+
return std::make_unique<SCFPreparePass>();
226+
}

clang/lib/CIR/FrontendAction/CIRGenAction.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ class CIRGenConsumer : public clang::ASTConsumer {
186186
feOptions.ClangIRLifetimeCheck, lifetimeOpts,
187187
feOptions.ClangIRIdiomRecognizer, idiomRecognizerOpts,
188188
feOptions.ClangIRLibOpt, libOptOpts, passOptParsingFailure,
189-
action == CIRGenAction::OutputType::EmitCIRFlat)
189+
action == CIRGenAction::OutputType::EmitCIRFlat,
190+
action == CIRGenAction::OutputType::EmitMLIR)
190191
.failed()) {
191192
if (!passOptParsingFailure.empty())
192193
diagnosticsEngine.Report(diag::err_drv_cir_pass_opt_parsing)

0 commit comments

Comments
 (0)