Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][ThroughMLIR] Support lowering ptrStrideOp with loadOp or storeOp to memref #585

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 156 additions & 10 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,60 @@ class CIRAllocaOpLowering
}
};

// Find base and indices from memref.reinterpret_cast
// and put it into eraseList.
static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
SmallVector<mlir::Value> &indices,
SmallVector<mlir::Operation *> &eraseList,
mlir::ConversionPatternRewriter &rewriter) {
while (mlir::Operation *addrOp = addr.getDefiningOp()) {
if (!isa<mlir::memref::ReinterpretCastOp>(addrOp))
break;
indices.push_back(addrOp->getOperand(1));
addr = addrOp->getOperand(0);
eraseList.push_back(addrOp);
}
base = addr;
if (indices.size() == 0)
return false;
std::reverse(indices.begin(), indices.end());
return true;
}

// For memref.reinterpret_cast has multiple users, erasing the operation
// after the last load or store been generated.
static void eraseIfSafe(mlir::Value oldAddr, mlir::Value newAddr,
SmallVector<mlir::Operation *> &eraseList,
mlir::ConversionPatternRewriter &rewriter) {
unsigned oldUsedNum =
std::distance(oldAddr.getUses().begin(), oldAddr.getUses().end());
unsigned newUsedNum = 0;
for (auto *user : newAddr.getUsers()) {
if (isa<mlir::memref::LoadOp>(*user) || isa<mlir::memref::StoreOp>(*user))
++newUsedNum;
}
if (oldUsedNum == newUsedNum) {
for (auto op : eraseList)
rewriter.eraseOp(op);
}
}

class CIRLoadOpLowering : public mlir::OpConversionPattern<mlir::cir::LoadOp> {
public:
using OpConversionPattern<mlir::cir::LoadOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::LoadOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, adaptor.getAddr());
mlir::Value base;
SmallVector<mlir::Value> indices;
SmallVector<mlir::Operation *> eraseList;
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
rewriter)) {
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
} else
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, adaptor.getAddr());
return mlir::LogicalResult::success();
}
};
Expand All @@ -135,8 +181,17 @@ class CIRStoreOpLowering
mlir::LogicalResult
matchAndRewrite(mlir::cir::StoreOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
adaptor.getAddr());
mlir::Value base;
SmallVector<mlir::Value> indices;
SmallVector<mlir::Operation *> eraseList;
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
rewriter)) {
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
base, indices);
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
} else
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
adaptor.getAddr());
return mlir::LogicalResult::success();
}
};
Expand Down Expand Up @@ -747,6 +802,12 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
auto dstType = op.getResult().getType();
using CIR = mlir::cir::CastKind;
switch (op.getKind()) {
case CIR::array_to_ptrdecay: {
auto newDstType = convertTy(dstType).cast<mlir::MemRefType>();
rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>(
op, newDstType, src, 0, std::nullopt, std::nullopt);
return mlir::success();
}
case CIR::int_to_bool: {
auto zero = rewriter.create<mlir::cir::ConstantOp>(
src.getLoc(), op.getSrc().getType(),
Expand Down Expand Up @@ -838,17 +899,102 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
}
};

class CIRPtrStrideOpLowering
: public mlir::OpConversionPattern<mlir::cir::PtrStrideOp> {
public:
using mlir::OpConversionPattern<mlir::cir::PtrStrideOp>::OpConversionPattern;

// Return true if PtrStrideOp is produced by cast with array_to_ptrdecay kind
// and they are in the same block.
inline bool isCastArrayToPtrConsumer(mlir::cir::PtrStrideOp op) const {
auto defOp = op->getOperand(0).getDefiningOp();
if (!defOp)
return false;
auto castOp = dyn_cast<mlir::cir::CastOp>(defOp);
if (!castOp)
return false;
if (castOp.getKind() != mlir::cir::CastKind::array_to_ptrdecay)
return false;
if (!castOp->hasOneUse())
return false;
if (!castOp->isBeforeInBlock(op))
return false;
return true;
}

// Return true if all the PtrStrideOp users are load, store or cast
// with array_to_ptrdecay kind and they are in the same block.
inline bool
isLoadStoreOrCastArrayToPtrProduer(mlir::cir::PtrStrideOp op) const {
if (op.use_empty())
return false;
for (auto *user : op->getUsers()) {
if (!op->isBeforeInBlock(user))
return false;
if (isa<mlir::cir::LoadOp>(*user) || isa<mlir::cir::StoreOp>(*user))
continue;
auto castOp = dyn_cast<mlir::cir::CastOp>(*user);
if (castOp &&
(castOp.getKind() == mlir::cir::CastKind::array_to_ptrdecay))
continue;
return false;
}
return true;
}

inline mlir::Type convertTy(mlir::Type ty) const {
return getTypeConverter()->convertType(ty);
}

// Rewrite
// %0 = cir.cast(array_to_ptrdecay, %base)
// cir.ptr_stride(%0, %stride)
// to
// memref.reinterpret_cast (%base, %stride)
//
// MemRef Dialect doesn't have GEP-like operation. memref.reinterpret_cast
// only been used to propogate %base and %stride to memref.load/store and
// should be erased after the conversion.
mlir::LogicalResult
matchAndRewrite(mlir::cir::PtrStrideOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
if (!isCastArrayToPtrConsumer(op))
return mlir::failure();
if (!isLoadStoreOrCastArrayToPtrProduer(op))
return mlir::failure();
auto baseOp = adaptor.getBase().getDefiningOp();
if (!baseOp)
return mlir::failure();
if (!isa<mlir::memref::ReinterpretCastOp>(baseOp))
return mlir::failure();
auto base = baseOp->getOperand(0);
auto dstType = op.getResult().getType();
auto newDstType = convertTy(dstType).cast<mlir::MemRefType>();
auto stride = adaptor.getStride();
auto indexType = rewriter.getIndexType();
// Generate casting if the stride is not index type.
if (stride.getType() != indexType)
stride = rewriter.create<mlir::arith::IndexCastOp>(op.getLoc(), indexType,
stride);
rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>(
op, newDstType, base, stride, std::nullopt, std::nullopt);
rewriter.eraseOp(baseOp);
return mlir::success();
}
};

void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());

patterns.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
CIRGetGlobalOpLowering, CIRCastOpLowering>(
converter, patterns.getContext());
patterns
.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering>(
converter, patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down
78 changes: 78 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/ptrstride.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM

!s32i = !cir.int<s, 32>
module {
cir.global "private" external @a : !cir.array<!s32i x 100>
cir.global "private" external @aa : !cir.array<!cir.array<!s32i x 100> x 100>

// int get_1d_array_value() { return a[1]; }
// MLIR-LABEL: func.func @get_1d_array_value() -> i32
// LLVM-LABEL: define i32 @get_1d_array_value()
cir.func @get_1d_array_value() -> !s32i {
// MLIR-NEXT: %[[BASE:.*]] = memref.get_global @a : memref<100xi32>
// MLIR-NEXT: %[[ONE:.*]] = arith.constant 1 : i32
// MLIR-NEXT: %[[INDEX:.*]] = arith.index_cast %[[ONE]] : i32 to index
// MLIR-NEXT: %[[VALUE:.*]] = memref.load %[[BASE]][%[[INDEX]]] : memref<100xi32>

// LLVM-NEXT: load i32, ptr getelementptr (i32, ptr @a, i64 1)

%1 = cir.get_global @a : !cir.ptr<!cir.array<!s32i x 100>>
%2 = cir.const #cir.int<1> : !s32i
%3 = cir.cast(array_to_ptrdecay, %1 : !cir.ptr<!cir.array<!s32i x 100>>), !cir.ptr<!s32i>
%4 = cir.ptr_stride(%3 : !cir.ptr<!s32i>, %2 : !s32i), !cir.ptr<!s32i>
%5 = cir.load %4 : !cir.ptr<!s32i>, !s32i
cir.return %5 : !s32i
}

// int get_2d_array_value() { return aa[1][2]; }
// MLIR-LABEL: func.func @get_2d_array_value() -> i32
// LLVM-LABEL: define i32 @get_2d_array_value()
cir.func @get_2d_array_value() -> !s32i {
// MLIR-NEXT: %[[BASE:.*]] = memref.get_global @aa : memref<100x100xi32>
// MLIR-NEXT: %[[ONE:.*]] = arith.constant 1 : i32
// MLIR-NEXT: %[[INDEX1:.*]] = arith.index_cast %[[ONE]] : i32 to index
// MLIR-NEXT: %[[TWO:.*]] = arith.constant 2 : i32
// MLIR-NEXT: %[[INDEX2:.*]] = arith.index_cast %[[TWO]] : i32 to index
// MLIR-NEXT: %[[VALUE:.*]] = memref.load %[[BASE]][%[[INDEX1]], %[[INDEX2]]] : memref<100x100xi32>

// LLVM-NEXT: load i32, ptr getelementptr (i32, ptr @aa, i64 102)

%1 = cir.get_global @aa : !cir.ptr<!cir.array<!cir.array<!s32i x 100> x 100>>
%2 = cir.const #cir.int<1> : !s32i
%3 = cir.cast(array_to_ptrdecay, %1 : !cir.ptr<!cir.array<!cir.array<!s32i x 100> x 100>>), !cir.ptr<!cir.array<!s32i x 100>>
%4 = cir.ptr_stride(%3 : !cir.ptr<!cir.array<!s32i x 100>>, %2 : !s32i), !cir.ptr<!cir.array<!s32i x 100>>
%5 = cir.const #cir.int<2> : !s32i
%6 = cir.cast(array_to_ptrdecay, %4 : !cir.ptr<!cir.array<!s32i x 100>>), !cir.ptr<!s32i>
%7 = cir.ptr_stride(%6 : !cir.ptr<!s32i>, %5 : !s32i), !cir.ptr<!s32i>
%8 = cir.load %7 : !cir.ptr<!s32i>, !s32i
cir.return %8 : !s32i
}

// void inc_1d_array_value() { a[1] += 2; }
// MLIR-LABEL: func.func @inc_1d_array_value()
// LLVM-LABEL: define void @inc_1d_array_value()
cir.func @inc_1d_array_value() {
// MLIR-NEXT: %[[TWO:.*]] = arith.constant 2 : i32
// MLIR-NEXT: %[[BASE:.*]] = memref.get_global @a : memref<100xi32>
// MLIR-NEXT: %[[ONE:.*]] = arith.constant 1 : i32
// MLIR-NEXT: %[[INDEX:.*]] = arith.index_cast %[[ONE]] : i32 to index
// MLIR-NEXT: %[[VALUE:.*]] = memref.load %[[BASE]][%[[INDEX]]] : memref<100xi32>
// MLIR-NEXT: %[[VALUE_INC:.*]] = arith.addi %[[VALUE]], %[[TWO]] : i32
// MLIR-NEXT: memref.store %[[VALUE_INC]], %[[BASE]][%[[INDEX]]] : memref<100xi32>

// LLVM-NEXT: %[[VALUE:.*]] = load i32, ptr getelementptr (i32, ptr @a, i64 1)
// LLVM-NEXT: %[[VALUE_INC:.*]] = add i32 %[[VALUE]], 2
// LLVM-NEXT: store i32 %[[VALUE_INC]], ptr getelementptr (i32, ptr @a, i64 1)

%0 = cir.const #cir.int<2> : !s32i
%1 = cir.get_global @a : !cir.ptr<!cir.array<!s32i x 100>>
%2 = cir.const #cir.int<1> : !s32i
%3 = cir.cast(array_to_ptrdecay, %1 : !cir.ptr<!cir.array<!s32i x 100>>), !cir.ptr<!s32i>
%4 = cir.ptr_stride(%3 : !cir.ptr<!s32i>, %2 : !s32i), !cir.ptr<!s32i>
%5 = cir.load %4 : !cir.ptr<!s32i>, !s32i
%6 = cir.binop(add, %5, %0) : !s32i
cir.store %6, %4 : !s32i, !cir.ptr<!s32i>
cir.return
}
}
Loading