|
| 1 | +//===- SCFPrepare.cpp - pareparation work for SCF lowering ----------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "PassDetail.h" |
| 10 | +#include "mlir/IR/PatternMatch.h" |
| 11 | +#include "mlir/Support/LogicalResult.h" |
| 12 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 13 | +#include "clang/CIR/Dialect/IR/CIRDialect.h" |
| 14 | +#include "clang/CIR/Dialect/Passes.h" |
| 15 | + |
| 16 | +using namespace mlir; |
| 17 | +using namespace cir; |
| 18 | + |
| 19 | +//===----------------------------------------------------------------------===// |
| 20 | +// Rewrite patterns |
| 21 | +//===----------------------------------------------------------------------===// |
| 22 | + |
| 23 | +namespace { |
| 24 | + |
| 25 | +static Value findIVAddr(Block *step) { |
| 26 | + Value IVAddr = nullptr; |
| 27 | + for (Operation &op : *step) { |
| 28 | + if (auto loadOp = dyn_cast<LoadOp>(op)) |
| 29 | + IVAddr = loadOp.getAddr(); |
| 30 | + else if (auto storeOp = dyn_cast<StoreOp>(op)) |
| 31 | + if (IVAddr != storeOp.getAddr()) |
| 32 | + return nullptr; |
| 33 | + } |
| 34 | + return IVAddr; |
| 35 | +} |
| 36 | + |
| 37 | +static CmpOp findLoopCmpAndIV(Block *cond, Value IVAddr, Value &IV) { |
| 38 | + Operation *IVLoadOp = nullptr; |
| 39 | + for (Operation &op : *cond) { |
| 40 | + if (auto loadOp = dyn_cast<LoadOp>(op)) |
| 41 | + if (loadOp.getAddr() == IVAddr) { |
| 42 | + IVLoadOp = &op; |
| 43 | + break; |
| 44 | + } |
| 45 | + } |
| 46 | + if (!IVLoadOp) |
| 47 | + return nullptr; |
| 48 | + if (!IVLoadOp->hasOneUse()) |
| 49 | + return nullptr; |
| 50 | + IV = IVLoadOp->getResult(0); |
| 51 | + return dyn_cast<CmpOp>(*IVLoadOp->user_begin()); |
| 52 | +} |
| 53 | + |
| 54 | +// Canonicalize IV to LHS of loop comparison |
| 55 | +// For example, transfer cir.cmp(gt, %bound, %IV) to cir.cmp(lt, %IV, %bound). |
| 56 | +// So we could use RHS as boundary and use lt to determine it's an upper bound. |
| 57 | +struct canonicalizeIVtoCmpLHS : public OpRewritePattern<ForOp> { |
| 58 | + using OpRewritePattern<ForOp>::OpRewritePattern; |
| 59 | + |
| 60 | + CmpOpKind swapCmpKind(CmpOpKind kind) const { |
| 61 | + switch (kind) { |
| 62 | + case CmpOpKind::gt: |
| 63 | + return CmpOpKind::lt; |
| 64 | + case CmpOpKind::ge: |
| 65 | + return CmpOpKind::le; |
| 66 | + case CmpOpKind::lt: |
| 67 | + return CmpOpKind::gt; |
| 68 | + case CmpOpKind::le: |
| 69 | + return CmpOpKind::ge; |
| 70 | + default: |
| 71 | + break; |
| 72 | + } |
| 73 | + return kind; |
| 74 | + } |
| 75 | + |
| 76 | + void replaceWithNewCmpOp(CmpOp oldCmp, CmpOpKind newKind, Value lhs, |
| 77 | + Value rhs, PatternRewriter &rewriter) const { |
| 78 | + rewriter.setInsertionPointAfter(oldCmp.getOperation()); |
| 79 | + auto newCmp = rewriter.create<mlir::cir::CmpOp>( |
| 80 | + oldCmp.getLoc(), oldCmp.getType(), newKind, lhs, rhs); |
| 81 | + oldCmp->replaceAllUsesWith(newCmp); |
| 82 | + oldCmp->erase(); |
| 83 | + } |
| 84 | + |
| 85 | + LogicalResult matchAndRewrite(ForOp op, |
| 86 | + PatternRewriter &rewriter) const final { |
| 87 | + auto *cond = &op.getCond().front(); |
| 88 | + auto *step = (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr); |
| 89 | + if (!step) |
| 90 | + return failure(); |
| 91 | + Value IVAddr = findIVAddr(step); |
| 92 | + if (!IVAddr) |
| 93 | + return failure(); |
| 94 | + Value IV = nullptr; |
| 95 | + auto loopCmp = findLoopCmpAndIV(cond, IVAddr, IV); |
| 96 | + if (!loopCmp || !IV) |
| 97 | + return failure(); |
| 98 | + |
| 99 | + CmpOpKind cmpKind = loopCmp.getKind(); |
| 100 | + Value cmpRhs = loopCmp.getRhs(); |
| 101 | + // Canonicalize IV to LHS of loop Cmp. |
| 102 | + if (loopCmp.getLhs() != IV) { |
| 103 | + cmpKind = swapCmpKind(cmpKind); |
| 104 | + cmpRhs = loopCmp.getLhs(); |
| 105 | + replaceWithNewCmpOp(loopCmp, cmpKind, IV, cmpRhs, rewriter); |
| 106 | + return success(); |
| 107 | + } |
| 108 | + |
| 109 | + return failure(); |
| 110 | + } |
| 111 | +}; |
| 112 | + |
| 113 | +// Hoist loop invariant operations in condition block out of loop |
| 114 | +// The condition block may be generated as following which contains the |
| 115 | +// operations produced upper bound. |
| 116 | +// SCF for loop required loop boundary as input operands. So we need to |
| 117 | +// hoist the boundary operations out of loop. |
| 118 | +// |
| 119 | +// cir.for : cond { |
| 120 | +// %4 = cir.load %2 : !cir.ptr<!s32i>, !s32i |
| 121 | +// %5 = cir.const #cir.int<100> : !s32i <- upper bound |
| 122 | +// %6 = cir.cmp(lt, %4, %5) : !s32i, !s32i |
| 123 | +// %7 = cir.cast(int_to_bool, %6 : !s32i), !cir.bool |
| 124 | +// cir.condition(%7 |
| 125 | +// } body { |
| 126 | +struct hoistLoopInvariantInCondBlock : public OpRewritePattern<ForOp> { |
| 127 | + using OpRewritePattern<ForOp>::OpRewritePattern; |
| 128 | + |
| 129 | + bool isLoopInvariantLoad(Operation *op, ForOp forOp) const { |
| 130 | + auto load = dyn_cast<LoadOp>(op); |
| 131 | + if (!load) |
| 132 | + return false; |
| 133 | + |
| 134 | + auto loadAddr = load.getAddr(); |
| 135 | + auto result = |
| 136 | + forOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) { |
| 137 | + if (auto store = dyn_cast<StoreOp>(op)) { |
| 138 | + if (store.getAddr() == loadAddr) |
| 139 | + return mlir::WalkResult::interrupt(); |
| 140 | + } |
| 141 | + return mlir::WalkResult::advance(); |
| 142 | + }); |
| 143 | + |
| 144 | + if (result.wasInterrupted()) |
| 145 | + return false; |
| 146 | + |
| 147 | + return true; |
| 148 | + } |
| 149 | + |
| 150 | + LogicalResult matchAndRewrite(ForOp forOp, |
| 151 | + PatternRewriter &rewriter) const final { |
| 152 | + auto *cond = &forOp.getCond().front(); |
| 153 | + auto *step = |
| 154 | + (forOp.maybeGetStep() ? &forOp.maybeGetStep()->front() : nullptr); |
| 155 | + if (!step) |
| 156 | + return failure(); |
| 157 | + Value IVAddr = findIVAddr(step); |
| 158 | + if (!IVAddr) |
| 159 | + return failure(); |
| 160 | + Value IV = nullptr; |
| 161 | + auto loopCmp = findLoopCmpAndIV(cond, IVAddr, IV); |
| 162 | + if (!loopCmp || !IV) |
| 163 | + return failure(); |
| 164 | + |
| 165 | + Value cmpRhs = loopCmp.getRhs(); |
| 166 | + auto defOp = cmpRhs.getDefiningOp(); |
| 167 | + SmallVector<Operation *> ops; |
| 168 | + // Go through the cast if exist. |
| 169 | + if (defOp && isa<mlir::cir::CastOp>(defOp)) { |
| 170 | + ops.push_back(defOp); |
| 171 | + defOp = defOp->getOperand(0).getDefiningOp(); |
| 172 | + } |
| 173 | + if (defOp && |
| 174 | + (isa<ConstantOp>(defOp) || isLoopInvariantLoad(defOp, forOp))) { |
| 175 | + ops.push_back(defOp); |
| 176 | + for (auto op : reverse(ops)) |
| 177 | + op->moveBefore(forOp); |
| 178 | + return success(); |
| 179 | + } |
| 180 | + |
| 181 | + return failure(); |
| 182 | + } |
| 183 | +}; |
| 184 | + |
| 185 | +//===----------------------------------------------------------------------===// |
| 186 | +// SCFPreparePass |
| 187 | +//===----------------------------------------------------------------------===// |
| 188 | + |
| 189 | +struct SCFPreparePass : public SCFPrepareBase<SCFPreparePass> { |
| 190 | + using SCFPrepareBase::SCFPrepareBase; |
| 191 | + void runOnOperation() override; |
| 192 | +}; |
| 193 | + |
| 194 | +void populateSCFPreparePatterns(RewritePatternSet &patterns) { |
| 195 | + // clang-format off |
| 196 | + patterns.add< |
| 197 | + canonicalizeIVtoCmpLHS, |
| 198 | + hoistLoopInvariantInCondBlock |
| 199 | + >(patterns.getContext()); |
| 200 | + // clang-format on |
| 201 | +} |
| 202 | + |
| 203 | +void SCFPreparePass::runOnOperation() { |
| 204 | + // Collect rewrite patterns. |
| 205 | + RewritePatternSet patterns(&getContext()); |
| 206 | + populateSCFPreparePatterns(patterns); |
| 207 | + |
| 208 | + // Collect operations to apply patterns. |
| 209 | + SmallVector<Operation *, 16> ops; |
| 210 | + getOperation()->walk([&](Operation *op) { |
| 211 | + // CastOp here is to perform a manual `fold` in |
| 212 | + // applyOpPatternsAndFold |
| 213 | + if (isa<ForOp>(op)) |
| 214 | + ops.push_back(op); |
| 215 | + }); |
| 216 | + |
| 217 | + // Apply patterns. |
| 218 | + if (applyOpPatternsAndFold(ops, std::move(patterns)).failed()) |
| 219 | + signalPassFailure(); |
| 220 | +} |
| 221 | + |
| 222 | +} // namespace |
| 223 | + |
| 224 | +std::unique_ptr<Pass> mlir::createSCFPreparePass() { |
| 225 | + return std::make_unique<SCFPreparePass>(); |
| 226 | +} |
0 commit comments