|
10 | 10 | //
|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
| 13 | +#include "circt/Support/LLVM.h" |
13 | 14 | #include "circt/Transforms/Passes.h"
|
14 | 15 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
15 | 16 | #include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
25 | 26 | #include "mlir/Transforms/DialectConversion.h"
|
26 | 27 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
27 | 28 | #include "llvm/Support/FormatVariadic.h"
|
| 29 | +#include "llvm/Support/LogicalResult.h" |
28 | 30 | #include "llvm/Support/MathExtras.h"
|
29 | 31 |
|
30 | 32 | namespace circt {
|
@@ -238,6 +240,27 @@ struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
|
238 | 240 | }
|
239 | 241 | };
|
240 | 242 |
|
| 243 | +struct ReshapeOpConversion : public OpConversionPattern<memref::ReshapeOp> { |
| 244 | + using OpConversionPattern::OpConversionPattern; |
| 245 | + |
| 246 | + LogicalResult |
| 247 | + matchAndRewrite(memref::ReshapeOp op, OpAdaptor adaptor, |
| 248 | + ConversionPatternRewriter &rewriter) const override { |
| 249 | + Value flattenedSource = rewriter.getRemappedValue(op.getSource()); |
| 250 | + if (!flattenedSource) |
| 251 | + return failure(); |
| 252 | + |
| 253 | + auto flattenedSrcType = cast<MemRefType>(flattenedSource.getType()); |
| 254 | + if (isUniDimensional(flattenedSrcType) || |
| 255 | + !flattenedSrcType.hasStaticShape()) { |
| 256 | + rewriter.replaceOp(op, flattenedSource); |
| 257 | + return success(); |
| 258 | + } |
| 259 | + |
| 260 | + return failure(); |
| 261 | + } |
| 262 | +}; |
| 263 | + |
241 | 264 | // A generic pattern which will replace an op with a new op of the same type
|
242 | 265 | // but using the adaptor (type converted) operands.
|
243 | 266 | template <typename TOp>
|
@@ -403,7 +426,7 @@ struct FlattenMemRefPass
|
403 | 426 | RewritePatternSet patterns(ctx);
|
404 | 427 | SetVector<StringRef> rewrittenCallees;
|
405 | 428 | patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
|
406 |
| - GlobalOpConversion, GetGlobalOpConversion, |
| 429 | + GlobalOpConversion, GetGlobalOpConversion, ReshapeOpConversion, |
407 | 430 | OperandConversionPattern<func::ReturnOp>,
|
408 | 431 | OperandConversionPattern<memref::DeallocOp>,
|
409 | 432 | CondBranchOpConversion,
|
|
0 commit comments