@@ -115,14 +115,60 @@ class CIRAllocaOpLowering
115
115
}
116
116
};
117
117
118
+ // Find base and indices from memref.reinterpret_cast
119
+ // and put it into eraseList.
120
+ static bool findBaseAndIndices (mlir::Value addr, mlir::Value &base,
121
+ SmallVector<mlir::Value> &indices,
122
+ SmallVector<mlir::Operation *> &eraseList,
123
+ mlir::ConversionPatternRewriter &rewriter) {
124
+ while (mlir::Operation *addrOp = addr.getDefiningOp ()) {
125
+ if (!isa<mlir::memref::ReinterpretCastOp>(addrOp))
126
+ break ;
127
+ indices.push_back (addrOp->getOperand (1 ));
128
+ addr = addrOp->getOperand (0 );
129
+ eraseList.push_back (addrOp);
130
+ }
131
+ base = addr;
132
+ if (indices.size () == 0 )
133
+ return false ;
134
+ std::reverse (indices.begin (), indices.end ());
135
+ return true ;
136
+ }
137
+
138
+ // For memref.reinterpret_cast has multiple users, erasing the operation
139
+ // after the last load or store been generated.
140
+ static void eraseIfSafe (mlir::Value oldAddr, mlir::Value newAddr,
141
+ SmallVector<mlir::Operation *> &eraseList,
142
+ mlir::ConversionPatternRewriter &rewriter) {
143
+ unsigned oldUsedNum =
144
+ std::distance (oldAddr.getUses ().begin (), oldAddr.getUses ().end ());
145
+ unsigned newUsedNum = 0 ;
146
+ for (auto *user : newAddr.getUsers ()) {
147
+ if (isa<mlir::memref::LoadOp>(*user) || isa<mlir::memref::StoreOp>(*user))
148
+ ++newUsedNum;
149
+ }
150
+ if (oldUsedNum == newUsedNum) {
151
+ for (auto op : eraseList)
152
+ rewriter.eraseOp (op);
153
+ }
154
+ }
155
+
118
156
class CIRLoadOpLowering : public mlir ::OpConversionPattern<mlir::cir::LoadOp> {
119
157
public:
120
158
using OpConversionPattern<mlir::cir::LoadOp>::OpConversionPattern;
121
159
122
160
mlir::LogicalResult
123
161
matchAndRewrite (mlir::cir::LoadOp op, OpAdaptor adaptor,
124
162
mlir::ConversionPatternRewriter &rewriter) const override {
125
- rewriter.replaceOpWithNewOp <mlir::memref::LoadOp>(op, adaptor.getAddr ());
163
+ mlir::Value base;
164
+ SmallVector<mlir::Value> indices;
165
+ SmallVector<mlir::Operation *> eraseList;
166
+ if (findBaseAndIndices (adaptor.getAddr (), base, indices, eraseList,
167
+ rewriter)) {
168
+ rewriter.replaceOpWithNewOp <mlir::memref::LoadOp>(op, base, indices);
169
+ eraseIfSafe (op.getAddr (), base, eraseList, rewriter);
170
+ } else
171
+ rewriter.replaceOpWithNewOp <mlir::memref::LoadOp>(op, adaptor.getAddr ());
126
172
return mlir::LogicalResult::success ();
127
173
}
128
174
};
@@ -135,8 +181,17 @@ class CIRStoreOpLowering
135
181
mlir::LogicalResult
136
182
matchAndRewrite (mlir::cir::StoreOp op, OpAdaptor adaptor,
137
183
mlir::ConversionPatternRewriter &rewriter) const override {
138
- rewriter.replaceOpWithNewOp <mlir::memref::StoreOp>(op, adaptor.getValue (),
139
- adaptor.getAddr ());
184
+ mlir::Value base;
185
+ SmallVector<mlir::Value> indices;
186
+ SmallVector<mlir::Operation *> eraseList;
187
+ if (findBaseAndIndices (adaptor.getAddr (), base, indices, eraseList,
188
+ rewriter)) {
189
+ rewriter.replaceOpWithNewOp <mlir::memref::StoreOp>(op, adaptor.getValue (),
190
+ base, indices);
191
+ eraseIfSafe (op.getAddr (), base, eraseList, rewriter);
192
+ } else
193
+ rewriter.replaceOpWithNewOp <mlir::memref::StoreOp>(op, adaptor.getValue (),
194
+ adaptor.getAddr ());
140
195
return mlir::LogicalResult::success ();
141
196
}
142
197
};
@@ -747,6 +802,12 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
747
802
auto dstType = op.getResult ().getType ();
748
803
using CIR = mlir::cir::CastKind;
749
804
switch (op.getKind ()) {
805
+ case CIR::array_to_ptrdecay: {
806
+ auto newDstType = convertTy (dstType).cast <mlir::MemRefType>();
807
+ rewriter.replaceOpWithNewOp <mlir::memref::ReinterpretCastOp>(
808
+ op, newDstType, src, 0 , std::nullopt, std::nullopt);
809
+ return mlir::success ();
810
+ }
750
811
case CIR::int_to_bool: {
751
812
auto zero = rewriter.create <mlir::cir::ConstantOp>(
752
813
src.getLoc (), op.getSrc ().getType (),
@@ -838,17 +899,102 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
838
899
}
839
900
};
840
901
902
+ class CIRPtrStrideOpLowering
903
+ : public mlir::OpConversionPattern<mlir::cir::PtrStrideOp> {
904
+ public:
905
+ using mlir::OpConversionPattern<mlir::cir::PtrStrideOp>::OpConversionPattern;
906
+
907
+ // Return true if PtrStrideOp is produced by cast with array_to_ptrdecay kind
908
+ // and they are in the same block.
909
+ inline bool isCastArrayToPtrConsumer (mlir::cir::PtrStrideOp op) const {
910
+ auto defOp = op->getOperand (0 ).getDefiningOp ();
911
+ if (!defOp)
912
+ return false ;
913
+ auto castOp = dyn_cast<mlir::cir::CastOp>(defOp);
914
+ if (!castOp)
915
+ return false ;
916
+ if (castOp.getKind () != mlir::cir::CastKind::array_to_ptrdecay)
917
+ return false ;
918
+ if (!castOp->hasOneUse ())
919
+ return false ;
920
+ if (!castOp->isBeforeInBlock (op))
921
+ return false ;
922
+ return true ;
923
+ }
924
+
925
+ // Return true if all the PtrStrideOp users are load, store or cast
926
+ // with array_to_ptrdecay kind and they are in the same block.
927
+ inline bool
928
+ isLoadStoreOrCastArrayToPtrProduer (mlir::cir::PtrStrideOp op) const {
929
+ if (op.use_empty ())
930
+ return false ;
931
+ for (auto *user : op->getUsers ()) {
932
+ if (!op->isBeforeInBlock (user))
933
+ return false ;
934
+ if (isa<mlir::cir::LoadOp>(*user) || isa<mlir::cir::StoreOp>(*user))
935
+ continue ;
936
+ auto castOp = dyn_cast<mlir::cir::CastOp>(*user);
937
+ if (castOp &&
938
+ (castOp.getKind () == mlir::cir::CastKind::array_to_ptrdecay))
939
+ continue ;
940
+ return false ;
941
+ }
942
+ return true ;
943
+ }
944
+
945
+ inline mlir::Type convertTy (mlir::Type ty) const {
946
+ return getTypeConverter ()->convertType (ty);
947
+ }
948
+
949
+ // Rewrite
950
+ // %0 = cir.cast(array_to_ptrdecay, %base)
951
+ // cir.ptr_stride(%0, %stride)
952
+ // to
953
+ // memref.reinterpret_cast (%base, %stride)
954
+ //
955
+ // MemRef Dialect doesn't have GEP-like operation. memref.reinterpret_cast
956
+ // only been used to propogate %base and %stride to memref.load/store and
957
+ // should be erased after the conversion.
958
+ mlir::LogicalResult
959
+ matchAndRewrite (mlir::cir::PtrStrideOp op, OpAdaptor adaptor,
960
+ mlir::ConversionPatternRewriter &rewriter) const override {
961
+ if (!isCastArrayToPtrConsumer (op))
962
+ return mlir::failure ();
963
+ if (!isLoadStoreOrCastArrayToPtrProduer (op))
964
+ return mlir::failure ();
965
+ auto baseOp = adaptor.getBase ().getDefiningOp ();
966
+ if (!baseOp)
967
+ return mlir::failure ();
968
+ if (!isa<mlir::memref::ReinterpretCastOp>(baseOp))
969
+ return mlir::failure ();
970
+ auto base = baseOp->getOperand (0 );
971
+ auto dstType = op.getResult ().getType ();
972
+ auto newDstType = convertTy (dstType).cast <mlir::MemRefType>();
973
+ auto stride = adaptor.getStride ();
974
+ auto indexType = rewriter.getIndexType ();
975
+ // Generate casting if the stride is not index type.
976
+ if (stride.getType () != indexType)
977
+ stride = rewriter.create <mlir::arith::IndexCastOp>(op.getLoc (), indexType,
978
+ stride);
979
+ rewriter.replaceOpWithNewOp <mlir::memref::ReinterpretCastOp>(
980
+ op, newDstType, base, stride, std::nullopt, std::nullopt);
981
+ rewriter.eraseOp (baseOp);
982
+ return mlir::success ();
983
+ }
984
+ };
985
+
841
986
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
842
987
mlir::TypeConverter &converter) {
843
988
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
844
989
845
- patterns.add <CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
846
- CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
847
- CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
848
- CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
849
- CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
850
- CIRGetGlobalOpLowering, CIRCastOpLowering>(
851
- converter, patterns.getContext ());
990
+ patterns
991
+ .add <CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
992
+ CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
993
+ CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
994
+ CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
995
+ CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
996
+ CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering>(
997
+ converter, patterns.getContext ());
852
998
}
853
999
854
1000
static mlir::TypeConverter prepareTypeConverter () {
0 commit comments