|
1 |
| -#include <memory> |
2 | 1 | #include <utility>
|
3 | 2 |
|
4 | 3 | #include "include/Conversion/MemrefToArith/MemrefToArith.h"
|
5 | 4 | #include "mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h" // from @llvm-project
|
| 5 | +#include "mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project |
6 | 6 | #include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
|
| 7 | +#include "mlir/include/mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project |
7 | 8 | #include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project
|
8 | 9 | #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
|
9 | 10 | #include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
|
10 | 11 | #include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
|
11 | 12 | #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
|
12 | 13 | #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
|
13 | 14 | #include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project
|
| 15 | +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project |
14 | 16 | #include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
|
| 17 | +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project |
15 | 18 | #include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
|
16 | 19 | #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
|
17 | 20 | #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
|
18 | 21 | #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
|
| 22 | +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project |
19 | 23 |
|
20 | 24 | namespace mlir {
|
21 | 25 | namespace heir {
|
22 | 26 |
|
23 | 27 | #define GEN_PASS_DEF_EXPANDCOPYPASS
|
24 | 28 | #include "include/Conversion/MemrefToArith/MemrefToArith.h.inc"
|
25 | 29 |
|
| 30 | +namespace { |
| 31 | + |
| 32 | +SmallVector<affine::AffineForOp> expandWithAffineLoops(OpBuilder& builder, |
| 33 | + memref::CopyOp copy) { |
| 34 | + ImplicitLocOpBuilder b(copy.getLoc(), builder); |
| 35 | + |
| 36 | + // Create an affine for loop over the dimensions of the memref and |
| 37 | + // explicitly copy using affine loads and stores. |
| 38 | + MemRefType memRefType = cast<MemRefType>(copy.getSource().getType()); |
| 39 | + SmallVector<mlir::Value, 4> indices; |
| 40 | + SmallVector<affine::AffineForOp> loops; |
| 41 | + |
| 42 | + auto zero = b.create<arith::ConstantIndexOp>(0); |
| 43 | + for (auto dim : memRefType.getShape()) { |
| 44 | + if (dim == 1) { |
| 45 | + // No need to create a loop for a one-dimensional index. |
| 46 | + indices.push_back(zero); |
| 47 | + continue; |
| 48 | + } |
| 49 | + auto loop = b.create<mlir::affine::AffineForOp>(0, dim); |
| 50 | + b.setInsertionPointToStart(loop.getBody()); |
| 51 | + indices.push_back(loop.getInductionVar()); |
| 52 | + loops.push_back(loop); |
| 53 | + } |
| 54 | + |
| 55 | + auto load = b.create<mlir::affine::AffineLoadOp>(copy.getSource(), indices); |
| 56 | + b.create<mlir::affine::AffineStoreOp>(load, copy.getTarget(), indices); |
| 57 | + return loops; |
| 58 | +} |
| 59 | + |
| 60 | +} // namespace |
| 61 | + |
26 | 62 | // MemrefCopyExpansionPattern expands a `memref.copy` with explicit affine loads
|
27 | 63 | // stores.
|
28 |
| -class MemrefCopyExpansionPattern final |
| 64 | +class MemrefCopyExpansionPattern |
29 | 65 | : public mlir::OpRewritePattern<mlir::memref::CopyOp> {
|
30 |
| - using OpRewritePattern<memref::CopyOp>::OpRewritePattern; |
| 66 | + public: |
| 67 | + MemrefCopyExpansionPattern(mlir::MLIRContext* context, |
| 68 | + bool disableAffineLoops) |
| 69 | + : OpRewritePattern<memref::CopyOp>(context, /*benefit=*/3), |
| 70 | + disableAffineLoops_(disableAffineLoops) {} |
31 | 71 |
|
32 | 72 | LogicalResult matchAndRewrite(memref::CopyOp copy,
|
33 |
| - PatternRewriter &rewriter) const override { |
34 |
| - auto loc = copy.getLoc(); |
35 |
| - auto memrefType = copy.getSource().getType().cast<MemRefType>(); |
36 |
| - |
37 |
| - // Create an affine for loop over the dimensions of the memref and |
38 |
| - // explicitly copy using affine loads and stores. |
39 |
| - mlir::SmallVector<mlir::Value, 4> indices; |
40 |
| - auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
41 |
| - for (auto dim : memrefType.getShape()) { |
42 |
| - if (dim == 1) { |
43 |
| - // No need to create a loop for a one-dimensional index. |
44 |
| - indices.push_back(zero); |
45 |
| - continue; |
| 73 | + PatternRewriter& rewriter) const override { |
| 74 | + auto nestedLoops = expandWithAffineLoops(rewriter, copy); |
| 75 | + |
| 76 | + if (disableAffineLoops_ && !nestedLoops.empty()) { |
| 77 | + nestedLoops[0].getBody(0)->walk<WalkOrder::PostOrder>( |
| 78 | + [&](affine::AffineForOp forOp) { |
| 79 | + if (failed(loopUnrollFull(forOp))) { |
| 80 | + return WalkResult::skip(); |
| 81 | + } |
| 82 | + return WalkResult::advance(); |
| 83 | + }); |
| 84 | + if (failed(loopUnrollFull(nestedLoops[0]))) { |
| 85 | + return mlir::failure(); |
46 | 86 | }
|
47 |
| - auto loop = rewriter.create<mlir::affine::AffineForOp>(loc, 0, dim); |
48 |
| - rewriter.setInsertionPointToStart(loop.getBody()); |
49 |
| - indices.push_back(loop.getInductionVar()); |
50 | 87 | }
|
51 | 88 |
|
52 |
| - auto load = rewriter.create<mlir::affine::AffineLoadOp>( |
53 |
| - loc, copy.getSource(), indices); |
54 |
| - rewriter.create<mlir::affine::AffineStoreOp>(loc, load, copy.getTarget(), |
55 |
| - indices); |
56 |
| - |
57 | 89 | rewriter.eraseOp(copy);
|
58 | 90 | return mlir::success();
|
59 | 91 | }
|
| 92 | + |
| 93 | + private: |
| 94 | + bool disableAffineLoops_; |
60 | 95 | };
|
61 | 96 |
|
62 | 97 | // ExpandCopyPass intends to remove all memref copy operations.
|
63 | 98 | struct ExpandCopyPass : impl::ExpandCopyPassBase<ExpandCopyPass> {
|
64 | 99 | using ExpandCopyPassBase::ExpandCopyPassBase;
|
65 |
| - void runOnOperation() override { |
66 |
| - mlir::ConversionTarget target(getContext()); |
67 |
| - target.addIllegalOp<mlir::memref::CopyOp>(); |
68 |
| - target.addLegalDialect<mlir::arith::ArithDialect, |
69 |
| - mlir::affine::AffineDialect>(); |
70 | 100 |
|
| 101 | + void runOnOperation() override { |
| 102 | + GreedyRewriteConfig config; |
| 103 | + config.strictMode = GreedyRewriteStrictness::ExistingOps; |
71 | 104 | mlir::RewritePatternSet patterns(&getContext());
|
72 |
| - patterns.add<MemrefCopyExpansionPattern>(&getContext()); |
| 105 | + patterns.add<MemrefCopyExpansionPattern>(&getContext(), disableAffineLoop); |
73 | 106 |
|
74 |
| - (void)applyPartialConversion(getOperation(), target, std::move(patterns)); |
| 107 | + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), |
| 108 | + config); |
75 | 109 | }
|
76 | 110 | };
|
77 | 111 |
|
|
0 commit comments