Skip to content

Commit d7c27e7

Browse files
ShivaChenlanza
authored andcommitted
[CIR][ThroughMLIR] Support lowering ptrStrideOp with loadOp or storeOp to memref (#585)
This commit introduce CIRPtrStrideOpLowering to lower the following pattern to memref load or store. Rewrite %0 = cir.cast(array_to_ptrdecay, %base) %1 = cir.ptr_stride(%0, %index) cir.load %1 To memref.load %base[%index]
1 parent 965786d commit d7c27e7

File tree

2 files changed

+234
-10
lines changed

2 files changed

+234
-10
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+156-10
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,60 @@ class CIRAllocaOpLowering
115115
}
116116
};
117117

118+
// Find base and indices from memref.reinterpret_cast
119+
// and put it into eraseList.
120+
static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
121+
SmallVector<mlir::Value> &indices,
122+
SmallVector<mlir::Operation *> &eraseList,
123+
mlir::ConversionPatternRewriter &rewriter) {
124+
while (mlir::Operation *addrOp = addr.getDefiningOp()) {
125+
if (!isa<mlir::memref::ReinterpretCastOp>(addrOp))
126+
break;
127+
indices.push_back(addrOp->getOperand(1));
128+
addr = addrOp->getOperand(0);
129+
eraseList.push_back(addrOp);
130+
}
131+
base = addr;
132+
if (indices.size() == 0)
133+
return false;
134+
std::reverse(indices.begin(), indices.end());
135+
return true;
136+
}
137+
138+
// For memref.reinterpret_cast has multiple users, erasing the operation
139+
// after the last load or store been generated.
140+
static void eraseIfSafe(mlir::Value oldAddr, mlir::Value newAddr,
141+
SmallVector<mlir::Operation *> &eraseList,
142+
mlir::ConversionPatternRewriter &rewriter) {
143+
unsigned oldUsedNum =
144+
std::distance(oldAddr.getUses().begin(), oldAddr.getUses().end());
145+
unsigned newUsedNum = 0;
146+
for (auto *user : newAddr.getUsers()) {
147+
if (isa<mlir::memref::LoadOp>(*user) || isa<mlir::memref::StoreOp>(*user))
148+
++newUsedNum;
149+
}
150+
if (oldUsedNum == newUsedNum) {
151+
for (auto op : eraseList)
152+
rewriter.eraseOp(op);
153+
}
154+
}
155+
118156
class CIRLoadOpLowering : public mlir::OpConversionPattern<mlir::cir::LoadOp> {
119157
public:
120158
using OpConversionPattern<mlir::cir::LoadOp>::OpConversionPattern;
121159

122160
mlir::LogicalResult
123161
matchAndRewrite(mlir::cir::LoadOp op, OpAdaptor adaptor,
124162
mlir::ConversionPatternRewriter &rewriter) const override {
125-
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, adaptor.getAddr());
163+
mlir::Value base;
164+
SmallVector<mlir::Value> indices;
165+
SmallVector<mlir::Operation *> eraseList;
166+
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
167+
rewriter)) {
168+
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
169+
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
170+
} else
171+
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, adaptor.getAddr());
126172
return mlir::LogicalResult::success();
127173
}
128174
};
@@ -135,8 +181,17 @@ class CIRStoreOpLowering
135181
mlir::LogicalResult
136182
matchAndRewrite(mlir::cir::StoreOp op, OpAdaptor adaptor,
137183
mlir::ConversionPatternRewriter &rewriter) const override {
138-
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
139-
adaptor.getAddr());
184+
mlir::Value base;
185+
SmallVector<mlir::Value> indices;
186+
SmallVector<mlir::Operation *> eraseList;
187+
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
188+
rewriter)) {
189+
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
190+
base, indices);
191+
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
192+
} else
193+
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
194+
adaptor.getAddr());
140195
return mlir::LogicalResult::success();
141196
}
142197
};
@@ -747,6 +802,12 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
747802
auto dstType = op.getResult().getType();
748803
using CIR = mlir::cir::CastKind;
749804
switch (op.getKind()) {
805+
case CIR::array_to_ptrdecay: {
806+
auto newDstType = convertTy(dstType).cast<mlir::MemRefType>();
807+
rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>(
808+
op, newDstType, src, 0, std::nullopt, std::nullopt);
809+
return mlir::success();
810+
}
750811
case CIR::int_to_bool: {
751812
auto zero = rewriter.create<mlir::cir::ConstantOp>(
752813
src.getLoc(), op.getSrc().getType(),
@@ -838,17 +899,102 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
838899
}
839900
};
840901

902+
class CIRPtrStrideOpLowering
903+
: public mlir::OpConversionPattern<mlir::cir::PtrStrideOp> {
904+
public:
905+
using mlir::OpConversionPattern<mlir::cir::PtrStrideOp>::OpConversionPattern;
906+
907+
// Return true if PtrStrideOp is produced by cast with array_to_ptrdecay kind
908+
// and they are in the same block.
909+
inline bool isCastArrayToPtrConsumer(mlir::cir::PtrStrideOp op) const {
910+
auto defOp = op->getOperand(0).getDefiningOp();
911+
if (!defOp)
912+
return false;
913+
auto castOp = dyn_cast<mlir::cir::CastOp>(defOp);
914+
if (!castOp)
915+
return false;
916+
if (castOp.getKind() != mlir::cir::CastKind::array_to_ptrdecay)
917+
return false;
918+
if (!castOp->hasOneUse())
919+
return false;
920+
if (!castOp->isBeforeInBlock(op))
921+
return false;
922+
return true;
923+
}
924+
925+
// Return true if all the PtrStrideOp users are load, store or cast
926+
// with array_to_ptrdecay kind and they are in the same block.
927+
inline bool
928+
isLoadStoreOrCastArrayToPtrProduer(mlir::cir::PtrStrideOp op) const {
929+
if (op.use_empty())
930+
return false;
931+
for (auto *user : op->getUsers()) {
932+
if (!op->isBeforeInBlock(user))
933+
return false;
934+
if (isa<mlir::cir::LoadOp>(*user) || isa<mlir::cir::StoreOp>(*user))
935+
continue;
936+
auto castOp = dyn_cast<mlir::cir::CastOp>(*user);
937+
if (castOp &&
938+
(castOp.getKind() == mlir::cir::CastKind::array_to_ptrdecay))
939+
continue;
940+
return false;
941+
}
942+
return true;
943+
}
944+
945+
inline mlir::Type convertTy(mlir::Type ty) const {
946+
return getTypeConverter()->convertType(ty);
947+
}
948+
949+
// Rewrite
950+
// %0 = cir.cast(array_to_ptrdecay, %base)
951+
// cir.ptr_stride(%0, %stride)
952+
// to
953+
// memref.reinterpret_cast (%base, %stride)
954+
//
955+
// MemRef Dialect doesn't have GEP-like operation. memref.reinterpret_cast
956+
// only been used to propogate %base and %stride to memref.load/store and
957+
// should be erased after the conversion.
958+
mlir::LogicalResult
959+
matchAndRewrite(mlir::cir::PtrStrideOp op, OpAdaptor adaptor,
960+
mlir::ConversionPatternRewriter &rewriter) const override {
961+
if (!isCastArrayToPtrConsumer(op))
962+
return mlir::failure();
963+
if (!isLoadStoreOrCastArrayToPtrProduer(op))
964+
return mlir::failure();
965+
auto baseOp = adaptor.getBase().getDefiningOp();
966+
if (!baseOp)
967+
return mlir::failure();
968+
if (!isa<mlir::memref::ReinterpretCastOp>(baseOp))
969+
return mlir::failure();
970+
auto base = baseOp->getOperand(0);
971+
auto dstType = op.getResult().getType();
972+
auto newDstType = convertTy(dstType).cast<mlir::MemRefType>();
973+
auto stride = adaptor.getStride();
974+
auto indexType = rewriter.getIndexType();
975+
// Generate casting if the stride is not index type.
976+
if (stride.getType() != indexType)
977+
stride = rewriter.create<mlir::arith::IndexCastOp>(op.getLoc(), indexType,
978+
stride);
979+
rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>(
980+
op, newDstType, base, stride, std::nullopt, std::nullopt);
981+
rewriter.eraseOp(baseOp);
982+
return mlir::success();
983+
}
984+
};
985+
841986
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
842987
mlir::TypeConverter &converter) {
843988
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
844989

845-
patterns.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
846-
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
847-
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
848-
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
849-
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
850-
CIRGetGlobalOpLowering, CIRCastOpLowering>(
851-
converter, patterns.getContext());
990+
patterns
991+
.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
992+
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
993+
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
994+
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
995+
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
996+
CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering>(
997+
converter, patterns.getContext());
852998
}
853999

8541000
static mlir::TypeConverter prepareTypeConverter() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR
2+
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
3+
4+
!s32i = !cir.int<s, 32>
5+
module {
6+
cir.global "private" external @a : !cir.array<!s32i x 100>
7+
cir.global "private" external @aa : !cir.array<!cir.array<!s32i x 100> x 100>
8+
9+
// int get_1d_array_value() { return a[1]; }
10+
// MLIR-LABEL: func.func @get_1d_array_value() -> i32
11+
// LLVM-LABEL: define i32 @get_1d_array_value()
12+
cir.func @get_1d_array_value() -> !s32i {
13+
// MLIR-NEXT: %[[BASE:.*]] = memref.get_global @a : memref<100xi32>
14+
// MLIR-NEXT: %[[ONE:.*]] = arith.constant 1 : i32
15+
// MLIR-NEXT: %[[INDEX:.*]] = arith.index_cast %[[ONE]] : i32 to index
16+
// MLIR-NEXT: %[[VALUE:.*]] = memref.load %[[BASE]][%[[INDEX]]] : memref<100xi32>
17+
18+
// LLVM-NEXT: load i32, ptr getelementptr (i32, ptr @a, i64 1)
19+
20+
%1 = cir.get_global @a : !cir.ptr<!cir.array<!s32i x 100>>
21+
%2 = cir.const #cir.int<1> : !s32i
22+
%3 = cir.cast(array_to_ptrdecay, %1 : !cir.ptr<!cir.array<!s32i x 100>>), !cir.ptr<!s32i>
23+
%4 = cir.ptr_stride(%3 : !cir.ptr<!s32i>, %2 : !s32i), !cir.ptr<!s32i>
24+
%5 = cir.load %4 : !cir.ptr<!s32i>, !s32i
25+
cir.return %5 : !s32i
26+
}
27+
28+
// int get_2d_array_value() { return aa[1][2]; }
29+
// MLIR-LABEL: func.func @get_2d_array_value() -> i32
30+
// LLVM-LABEL: define i32 @get_2d_array_value()
31+
cir.func @get_2d_array_value() -> !s32i {
32+
// MLIR-NEXT: %[[BASE:.*]] = memref.get_global @aa : memref<100x100xi32>
33+
// MLIR-NEXT: %[[ONE:.*]] = arith.constant 1 : i32
34+
// MLIR-NEXT: %[[INDEX1:.*]] = arith.index_cast %[[ONE]] : i32 to index
35+
// MLIR-NEXT: %[[TWO:.*]] = arith.constant 2 : i32
36+
// MLIR-NEXT: %[[INDEX2:.*]] = arith.index_cast %[[TWO]] : i32 to index
37+
// MLIR-NEXT: %[[VALUE:.*]] = memref.load %[[BASE]][%[[INDEX1]], %[[INDEX2]]] : memref<100x100xi32>
38+
39+
// LLVM-NEXT: load i32, ptr getelementptr (i32, ptr @aa, i64 102)
40+
41+
%1 = cir.get_global @aa : !cir.ptr<!cir.array<!cir.array<!s32i x 100> x 100>>
42+
%2 = cir.const #cir.int<1> : !s32i
43+
%3 = cir.cast(array_to_ptrdecay, %1 : !cir.ptr<!cir.array<!cir.array<!s32i x 100> x 100>>), !cir.ptr<!cir.array<!s32i x 100>>
44+
%4 = cir.ptr_stride(%3 : !cir.ptr<!cir.array<!s32i x 100>>, %2 : !s32i), !cir.ptr<!cir.array<!s32i x 100>>
45+
%5 = cir.const #cir.int<2> : !s32i
46+
%6 = cir.cast(array_to_ptrdecay, %4 : !cir.ptr<!cir.array<!s32i x 100>>), !cir.ptr<!s32i>
47+
%7 = cir.ptr_stride(%6 : !cir.ptr<!s32i>, %5 : !s32i), !cir.ptr<!s32i>
48+
%8 = cir.load %7 : !cir.ptr<!s32i>, !s32i
49+
cir.return %8 : !s32i
50+
}
51+
52+
// void inc_1d_array_value() { a[1] += 2; }
53+
// MLIR-LABEL: func.func @inc_1d_array_value()
54+
// LLVM-LABEL: define void @inc_1d_array_value()
55+
cir.func @inc_1d_array_value() {
56+
// MLIR-NEXT: %[[TWO:.*]] = arith.constant 2 : i32
57+
// MLIR-NEXT: %[[BASE:.*]] = memref.get_global @a : memref<100xi32>
58+
// MLIR-NEXT: %[[ONE:.*]] = arith.constant 1 : i32
59+
// MLIR-NEXT: %[[INDEX:.*]] = arith.index_cast %[[ONE]] : i32 to index
60+
// MLIR-NEXT: %[[VALUE:.*]] = memref.load %[[BASE]][%[[INDEX]]] : memref<100xi32>
61+
// MLIR-NEXT: %[[VALUE_INC:.*]] = arith.addi %[[VALUE]], %[[TWO]] : i32
62+
// MLIR-NEXT: memref.store %[[VALUE_INC]], %[[BASE]][%[[INDEX]]] : memref<100xi32>
63+
64+
// LLVM-NEXT: %[[VALUE:.*]] = load i32, ptr getelementptr (i32, ptr @a, i64 1)
65+
// LLVM-NEXT: %[[VALUE_INC:.*]] = add i32 %[[VALUE]], 2
66+
// LLVM-NEXT: store i32 %[[VALUE_INC]], ptr getelementptr (i32, ptr @a, i64 1)
67+
68+
%0 = cir.const #cir.int<2> : !s32i
69+
%1 = cir.get_global @a : !cir.ptr<!cir.array<!s32i x 100>>
70+
%2 = cir.const #cir.int<1> : !s32i
71+
%3 = cir.cast(array_to_ptrdecay, %1 : !cir.ptr<!cir.array<!s32i x 100>>), !cir.ptr<!s32i>
72+
%4 = cir.ptr_stride(%3 : !cir.ptr<!s32i>, %2 : !s32i), !cir.ptr<!s32i>
73+
%5 = cir.load %4 : !cir.ptr<!s32i>, !s32i
74+
%6 = cir.binop(add, %5, %0) : !s32i
75+
cir.store %6, %4 : !s32i, !cir.ptr<!s32i>
76+
cir.return
77+
}
78+
}

0 commit comments

Comments
 (0)