Skip to content

Commit cd02ca7

Browse files
authored
[FlattenMemRef] Flatten MemRef ReshapeOp when its source memory reference is flattened in other passes (#8265)
1 parent 456053a commit cd02ca7

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

lib/Transforms/FlattenMemRefs.cpp

+24-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "circt/Support/LLVM.h"
1314
#include "circt/Transforms/Passes.h"
1415
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1516
#include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -25,6 +26,7 @@
2526
#include "mlir/Transforms/DialectConversion.h"
2627
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2728
#include "llvm/Support/FormatVariadic.h"
29+
#include "llvm/Support/LogicalResult.h"
2830
#include "llvm/Support/MathExtras.h"
2931

3032
namespace circt {
@@ -238,6 +240,27 @@ struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
238240
}
239241
};
240242

243+
struct ReshapeOpConversion : public OpConversionPattern<memref::ReshapeOp> {
244+
using OpConversionPattern::OpConversionPattern;
245+
246+
LogicalResult
247+
matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor,
248+
ConversionPatternRewriter &rewriter) const override {
249+
Value flattenedSource = rewriter.getRemappedValue(op.getSource());
250+
if (!flattenedSource)
251+
return failure();
252+
253+
auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType());
254+
if (isUniDimensional(flattenedSrcType) ||
255+
!flattenedSrcType.hasStaticShape()) {
256+
rewriter.replaceOp(op, flattenedSource);
257+
return success();
258+
}
259+
260+
return failure();
261+
}
262+
};
263+
241264
// A generic pattern which will replace an op with a new op of the same type
242265
// but using the adaptor (type converted) operands.
243266
template <typename TOp>
@@ -403,7 +426,7 @@ struct FlattenMemRefPass
403426
RewritePatternSet patterns(ctx);
404427
SetVector<StringRef> rewrittenCallees;
405428
patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
406-
GlobalOpConversion, GetGlobalOpConversion,
429+
GlobalOpConversion, GetGlobalOpConversion, ReshapeOpConversion,
407430
OperandConversionPattern<func::ReturnOp>,
408431
OperandConversionPattern<memref::DeallocOp>,
409432
CondBranchOpConversion,

test/Transforms/flatten_memref.mlir

+43
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,46 @@ module {
271271
}
272272
}
273273

274+
// -----
275+
276+
// CHECK: func.func @main(%[[VAL_0:arg0]]: memref<30xf32>) {
277+
// CHECK: %[[VAL_1:.*]] = arith.constant 30 : index
278+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
279+
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
280+
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
281+
// CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<30xf32>
282+
// CHECK: %[[VAL_6:.*]] = memref.get_global @const_1_30 : memref<2xi64>
283+
// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_2]] {
284+
// CHECK: scf.for %[[VAL_8:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_2]] {
285+
// CHECK: %[[VAL_9:.*]] = arith.constant 30 : index
286+
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_4]], %[[VAL_9]] : index
287+
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_8]] : index
288+
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref<30xf32>
289+
// CHECK: %[[VAL_13:.*]] = arith.constant 0 : index
290+
// CHECK: %[[VAL_14:.*]] = arith.shli %[[VAL_4]], %[[VAL_13]] : index
291+
// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_7]] : index
292+
// CHECK: memref.store %[[VAL_12]], %[[VAL_0]]{{\[}}%[[VAL_15]]] : memref<30xf32>
293+
// CHECK: }
294+
// CHECK: }
295+
// CHECK: return
296+
// CHECK: }
297+
298+
module {
299+
memref.global "private" constant @const_1_30 : memref<2xi64> = dense<[1, 30]>
300+
func.func @main(%arg0: memref<30x1xf32>) {
301+
%c30 = arith.constant 30 : index
302+
%c1 = arith.constant 1 : index
303+
%c2 = arith.constant 2 : index
304+
%c0 = arith.constant 0 : index
305+
%alloc = memref.alloc() : memref<2x5x3xf32>
306+
%0 = memref.get_global @const_1_30 : memref<2xi64>
307+
%reshape = memref.reshape %alloc(%0) : (memref<2x5x3xf32>, memref<2xi64>) -> memref<1x30xf32>
308+
scf.for %arg2 = %c0 to %c2 step %c1 {
309+
scf.for %arg3 = %c0 to %c30 step %c1 {
310+
%4 = memref.load %reshape[%c0, %arg3] : memref<1x30xf32>
311+
memref.store %4, %arg0[%c0, %arg2] : memref<30x1xf32>
312+
}
313+
}
314+
return
315+
}
316+
}

0 commit comments

Comments
 (0)