Skip to content

Commit 53c57cf

Browse files
asraacopybara-github
authored andcommitted
feat: add option to disable using affine for loops when expanding memref.copy
PiperOrigin-RevId: 604312248
1 parent 4a117a0 commit 53c57cf

File tree

3 files changed

+95
-36
lines changed

3 files changed

+95
-36
lines changed

include/Conversion/MemrefToArith/MemrefToArith.td

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,37 @@ def ExpandCopyPass : Pass<"expand-copy", "mlir::ModuleOp"> {
9191
}
9292
}
9393
```
94+
95+
When `--disable-affine-loop=true` is set, then the output becomes
96+
```
97+
module {
98+
func.func @memref_copy() {
99+
%alloc = memref.alloc() : memref<2x3xi32>
100+
%alloc_0 = memref.alloc() : memref<2x3xi32>
101+
%c0 = arith.constant 0 : index
102+
%c1 = arith.constant 1 : index
103+
%c2 = arith.constant 2 : index
104+
%0 = affine.load %alloc[%c0, %c0] : memref<2x3xi32>
105+
affine.store %0, %alloc_0[%c0, %c0] : memref<2x3xi32>
106+
%1 = affine.load %alloc[%c0, %c1] : memref<2x3xi32>
107+
affine.store %1, %alloc_0[%c0, %c1] : memref<2x3xi32>
108+
%2 = affine.load %alloc[%c0, %c2] : memref<2x3xi32>
109+
affine.store %2, %alloc_0[%c0, %c2] : memref<2x3xi32>
110+
[...]
111+
}
112+
}
113+
```
94114
}];
95115

116+
let options = [
117+
Option<"disableAffineLoop", "disable-affine-loop", "bool", /*default=*/"false",
118+
"Use this to control to disable using affine loops">,
119+
];
120+
96121
let dependentDialects = [
97122
"mlir::affine::AffineDialect",
98123
"mlir::arith::ArithDialect",
99124
"mlir::memref::MemRefDialect",
100-
"mlir::scf::SCFDialect",
101125
];
102126
}
103127

lib/Conversion/MemrefToArith/ExpandCopy.cpp

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,111 @@
1-
#include <memory>
21
#include <utility>
32

43
#include "include/Conversion/MemrefToArith/MemrefToArith.h"
54
#include "mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h" // from @llvm-project
5+
#include "mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project
66
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
7+
#include "mlir/include/mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project
78
#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project
89
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
910
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
1011
#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
1112
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
1213
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
1314
#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project
15+
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
1416
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
17+
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
1518
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
1619
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
1720
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
1821
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
22+
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
1923

2024
namespace mlir {
2125
namespace heir {
2226

2327
#define GEN_PASS_DEF_EXPANDCOPYPASS
2428
#include "include/Conversion/MemrefToArith/MemrefToArith.h.inc"
2529

30+
namespace {
31+
32+
SmallVector<affine::AffineForOp> expandWithAffineLoops(OpBuilder& builder,
33+
memref::CopyOp copy) {
34+
ImplicitLocOpBuilder b(copy.getLoc(), builder);
35+
36+
// Create an affine for loop over the dimensions of the memref and
37+
// explicitly copy using affine loads and stores.
38+
MemRefType memRefType = cast<MemRefType>(copy.getSource().getType());
39+
SmallVector<mlir::Value, 4> indices;
40+
SmallVector<affine::AffineForOp> loops;
41+
42+
auto zero = b.create<arith::ConstantIndexOp>(0);
43+
for (auto dim : memRefType.getShape()) {
44+
if (dim == 1) {
45+
// No need to create a loop for a one-dimensional index.
46+
indices.push_back(zero);
47+
continue;
48+
}
49+
auto loop = b.create<mlir::affine::AffineForOp>(0, dim);
50+
b.setInsertionPointToStart(loop.getBody());
51+
indices.push_back(loop.getInductionVar());
52+
loops.push_back(loop);
53+
}
54+
55+
auto load = b.create<mlir::affine::AffineLoadOp>(copy.getSource(), indices);
56+
b.create<mlir::affine::AffineStoreOp>(load, copy.getTarget(), indices);
57+
return loops;
58+
}
59+
60+
} // namespace
61+
2662
// MemrefCopyExpansionPattern expands a `memref.copy` with explicit affine loads
2763
// stores.
28-
class MemrefCopyExpansionPattern final
64+
class MemrefCopyExpansionPattern
2965
: public mlir::OpRewritePattern<mlir::memref::CopyOp> {
30-
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
66+
public:
67+
MemrefCopyExpansionPattern(mlir::MLIRContext* context,
68+
bool disableAffineLoops)
69+
: OpRewritePattern<memref::CopyOp>(context, /*benefit=*/3),
70+
disableAffineLoops_(disableAffineLoops) {}
3171

3272
LogicalResult matchAndRewrite(memref::CopyOp copy,
33-
PatternRewriter &rewriter) const override {
34-
auto loc = copy.getLoc();
35-
auto memrefType = copy.getSource().getType().cast<MemRefType>();
36-
37-
// Create an affine for loop over the dimensions of the memref and
38-
// explicitly copy using affine loads and stores.
39-
mlir::SmallVector<mlir::Value, 4> indices;
40-
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
41-
for (auto dim : memrefType.getShape()) {
42-
if (dim == 1) {
43-
// No need to create a loop for a one-dimensional index.
44-
indices.push_back(zero);
45-
continue;
73+
PatternRewriter& rewriter) const override {
74+
auto nestedLoops = expandWithAffineLoops(rewriter, copy);
75+
76+
if (disableAffineLoops_ && !nestedLoops.empty()) {
77+
nestedLoops[0].getBody(0)->walk<WalkOrder::PostOrder>(
78+
[&](affine::AffineForOp forOp) {
79+
if (failed(loopUnrollFull(forOp))) {
80+
return WalkResult::skip();
81+
}
82+
return WalkResult::advance();
83+
});
84+
if (failed(loopUnrollFull(nestedLoops[0]))) {
85+
return mlir::failure();
4686
}
47-
auto loop = rewriter.create<mlir::affine::AffineForOp>(loc, 0, dim);
48-
rewriter.setInsertionPointToStart(loop.getBody());
49-
indices.push_back(loop.getInductionVar());
5087
}
5188

52-
auto load = rewriter.create<mlir::affine::AffineLoadOp>(
53-
loc, copy.getSource(), indices);
54-
rewriter.create<mlir::affine::AffineStoreOp>(loc, load, copy.getTarget(),
55-
indices);
56-
5789
rewriter.eraseOp(copy);
5890
return mlir::success();
5991
}
92+
93+
private:
94+
bool disableAffineLoops_;
6095
};
6196

6297
// ExpandCopyPass intends to remove all memref copy operations.
6398
struct ExpandCopyPass : impl::ExpandCopyPassBase<ExpandCopyPass> {
6499
using ExpandCopyPassBase::ExpandCopyPassBase;
65-
void runOnOperation() override {
66-
mlir::ConversionTarget target(getContext());
67-
target.addIllegalOp<mlir::memref::CopyOp>();
68-
target.addLegalDialect<mlir::arith::ArithDialect,
69-
mlir::affine::AffineDialect>();
70100

101+
void runOnOperation() override {
102+
GreedyRewriteConfig config;
103+
config.strictMode = GreedyRewriteStrictness::ExistingOps;
71104
mlir::RewritePatternSet patterns(&getContext());
72-
patterns.add<MemrefCopyExpansionPattern>(&getContext());
105+
patterns.add<MemrefCopyExpansionPattern>(&getContext(), disableAffineLoop);
73106

74-
(void)applyPartialConversion(getOperation(), target, std::move(patterns));
107+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
108+
config);
75109
}
76110
};
77111

tests/memref_copy_multidim.mlir

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: heir-opt --expand-copy %s | FileCheck %s
1+
// RUN: heir-opt --expand-copy %s | FileCheck --check-prefix=LOOP --check-prefix=CHECK %s
2+
// RUN: heir-opt --expand-copy=disable-affine-loop=true %s | FileCheck --check-prefix=NO-LOOP --check-prefix=CHECK %s
23

34
// This verifies that --expand-copy removes memref.copy and rewrites with affine
45
// loads and stores.
@@ -13,9 +14,9 @@ func.func @memref_copy() -> i32 {
1314
%alloc0 = memref.alloc() : memref<2x2xi32>
1415
affine.store %c_42, %alloc[%c0, %c0] : memref<2x2xi32>
1516
// CHECK-NOT: memref.copy
16-
// CHECK: affine.for
17-
// CHECK: affine.for
18-
// CHECK-NEXT: affine.load {{.*}}[[MEM1]]
17+
// LOOP-COUNT-2: affine.for
18+
// NO-LOOP-NOT: affine.for
19+
// CHECK: affine.load {{.*}}[[MEM1]]
1920
// CHECK-NEXT: affine.store {{.*}}[[MEM2]]
2021
memref.copy %alloc, %alloc0 : memref<2x2xi32> to memref<2x2xi32>
2122
%v1 = arith.addi %c_42, %c_42 : i32

0 commit comments

Comments
 (0)