Skip to content

Commit

Permalink
feat: add option to disable using affine for loops when expanding mem…
Browse files Browse the repository at this point in the history
…ref.copy

PiperOrigin-RevId: 604312248
  • Loading branch information
asraa authored and copybara-github committed Feb 5, 2024
1 parent 4a117a0 commit 53c57cf
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 36 deletions.
26 changes: 25 additions & 1 deletion include/Conversion/MemrefToArith/MemrefToArith.td
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,37 @@ def ExpandCopyPass : Pass<"expand-copy", "mlir::ModuleOp"> {
}
}
```

When `--disable-affine-loop=true` is set, then the output becomes
```
module {
func.func @memref_copy() {
%alloc = memref.alloc() : memref<2x3xi32>
%alloc_0 = memref.alloc() : memref<2x3xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%0 = affine.load %alloc[%c0, %c0] : memref<2x3xi32>
affine.store %0, %alloc_0[%c0, %c0] : memref<2x3xi32>
%1 = affine.load %alloc[%c0, %c1] : memref<2x3xi32>
affine.store %1, %alloc_0[%c0, %c1] : memref<2x3xi32>
%2 = affine.load %alloc[%c0, %c2] : memref<2x3xi32>
affine.store %2, %alloc_0[%c0, %c2] : memref<2x3xi32>
[...]
}
}
```
}];

let options = [
Option<"disableAffineLoop", "disable-affine-loop", "bool", /*default=*/"false",
"Use this to control to disable using affine loops">,
];

let dependentDialects = [
"mlir::affine::AffineDialect",
"mlir::arith::ArithDialect",
"mlir::memref::MemRefDialect",
"mlir::scf::SCFDialect",
];
}

Expand Down
96 changes: 65 additions & 31 deletions lib/Conversion/MemrefToArith/ExpandCopy.cpp
Original file line number Diff line number Diff line change
@@ -1,77 +1,111 @@
#include <memory>
#include <utility>

#include "include/Conversion/MemrefToArith/MemrefToArith.h"
#include "mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DEF_EXPANDCOPYPASS
#include "include/Conversion/MemrefToArith/MemrefToArith.h.inc"

namespace {

SmallVector<affine::AffineForOp> expandWithAffineLoops(OpBuilder& builder,
memref::CopyOp copy) {
ImplicitLocOpBuilder b(copy.getLoc(), builder);

// Create an affine for loop over the dimensions of the memref and
// explicitly copy using affine loads and stores.
MemRefType memRefType = cast<MemRefType>(copy.getSource().getType());
SmallVector<mlir::Value, 4> indices;
SmallVector<affine::AffineForOp> loops;

auto zero = b.create<arith::ConstantIndexOp>(0);
for (auto dim : memRefType.getShape()) {
if (dim == 1) {
// No need to create a loop for a one-dimensional index.
indices.push_back(zero);
continue;
}
auto loop = b.create<mlir::affine::AffineForOp>(0, dim);
b.setInsertionPointToStart(loop.getBody());
indices.push_back(loop.getInductionVar());
loops.push_back(loop);
}

auto load = b.create<mlir::affine::AffineLoadOp>(copy.getSource(), indices);
b.create<mlir::affine::AffineStoreOp>(load, copy.getTarget(), indices);
return loops;
}

} // namespace

// MemrefCopyExpansionPattern expands a `memref.copy` with explicit affine loads
// stores.
class MemrefCopyExpansionPattern final
class MemrefCopyExpansionPattern
: public mlir::OpRewritePattern<mlir::memref::CopyOp> {
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
public:
MemrefCopyExpansionPattern(mlir::MLIRContext* context,
bool disableAffineLoops)
: OpRewritePattern<memref::CopyOp>(context, /*benefit=*/3),
disableAffineLoops_(disableAffineLoops) {}

LogicalResult matchAndRewrite(memref::CopyOp copy,
PatternRewriter &rewriter) const override {
auto loc = copy.getLoc();
auto memrefType = copy.getSource().getType().cast<MemRefType>();

// Create an affine for loop over the dimensions of the memref and
// explicitly copy using affine loads and stores.
mlir::SmallVector<mlir::Value, 4> indices;
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
for (auto dim : memrefType.getShape()) {
if (dim == 1) {
// No need to create a loop for a one-dimensional index.
indices.push_back(zero);
continue;
PatternRewriter& rewriter) const override {
auto nestedLoops = expandWithAffineLoops(rewriter, copy);

if (disableAffineLoops_ && !nestedLoops.empty()) {
nestedLoops[0].getBody(0)->walk<WalkOrder::PostOrder>(
[&](affine::AffineForOp forOp) {
if (failed(loopUnrollFull(forOp))) {
return WalkResult::skip();
}
return WalkResult::advance();
});
if (failed(loopUnrollFull(nestedLoops[0]))) {
return mlir::failure();
}
auto loop = rewriter.create<mlir::affine::AffineForOp>(loc, 0, dim);
rewriter.setInsertionPointToStart(loop.getBody());
indices.push_back(loop.getInductionVar());
}

auto load = rewriter.create<mlir::affine::AffineLoadOp>(
loc, copy.getSource(), indices);
rewriter.create<mlir::affine::AffineStoreOp>(loc, load, copy.getTarget(),
indices);

rewriter.eraseOp(copy);
return mlir::success();
}

private:
bool disableAffineLoops_;
};

// ExpandCopyPass intends to remove all memref copy operations.
struct ExpandCopyPass : impl::ExpandCopyPassBase<ExpandCopyPass> {
using ExpandCopyPassBase::ExpandCopyPassBase;
void runOnOperation() override {
mlir::ConversionTarget target(getContext());
target.addIllegalOp<mlir::memref::CopyOp>();
target.addLegalDialect<mlir::arith::ArithDialect,
mlir::affine::AffineDialect>();

void runOnOperation() override {
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
mlir::RewritePatternSet patterns(&getContext());
patterns.add<MemrefCopyExpansionPattern>(&getContext());
patterns.add<MemrefCopyExpansionPattern>(&getContext(), disableAffineLoop);

(void)applyPartialConversion(getOperation(), target, std::move(patterns));
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}
};

Expand Down
9 changes: 5 additions & 4 deletions tests/memref_copy_multidim.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: heir-opt --expand-copy %s | FileCheck %s
// RUN: heir-opt --expand-copy %s | FileCheck --check-prefix=LOOP --check-prefix=CHECK %s
// RUN: heir-opt --expand-copy=disable-affine-loop=true %s | FileCheck --check-prefix=NO-LOOP --check-prefix=CHECK %s

// This verifies that --expand-copy removes memref.copy and rewrites with affine
// loads and stores.
Expand All @@ -13,9 +14,9 @@ func.func @memref_copy() -> i32 {
%alloc0 = memref.alloc() : memref<2x2xi32>
affine.store %c_42, %alloc[%c0, %c0] : memref<2x2xi32>
// CHECK-NOT: memref.copy
// CHECK: affine.for
// CHECK: affine.for
// CHECK-NEXT: affine.load {{.*}}[[MEM1]]
// LOOP-COUNT-2: affine.for
// NO-LOOP-NOT: affine.for
// CHECK: affine.load {{.*}}[[MEM1]]
// CHECK-NEXT: affine.store {{.*}}[[MEM2]]
memref.copy %alloc, %alloc0 : memref<2x2xi32> to memref<2x2xi32>
%v1 = arith.addi %c_42, %c_42 : i32
Expand Down

0 comments on commit 53c57cf

Please sign in to comment.