Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Support lowering GlobalOp and GetGlobalOp to memref #574

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 104 additions & 4 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,95 @@ class CIRYieldOpLowering
}
};

class CIRGlobalOpLowering
: public mlir::OpConversionPattern<mlir::cir::GlobalOp> {
public:
using OpConversionPattern<mlir::cir::GlobalOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::cir::GlobalOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto moduleOp = op->getParentOfType<mlir::ModuleOp>();
if (!moduleOp)
return mlir::failure();

mlir::OpBuilder b(moduleOp.getContext());

const auto CIRSymType = op.getSymType();
auto convertedType = getTypeConverter()->convertType(CIRSymType);
if (!convertedType)
return mlir::failure();
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
if (!memrefType)
memrefType = mlir::MemRefType::get({}, convertedType);
// Add an optional alignment to the global memref.
mlir::IntegerAttr memrefAlignment =
op.getAlignment()
? mlir::IntegerAttr::get(b.getI64Type(), op.getAlignment().value())
: mlir::IntegerAttr();
// Add an optional initial value to the global memref.
mlir::Attribute initialValue = mlir::Attribute();
std::optional<mlir::Attribute> init = op.getInitialValue();
if (init.has_value()) {
if (auto constArr = init.value().dyn_cast<mlir::cir::ZeroAttr>()) {
if (memrefType.getShape().size()) {
auto rtt = mlir::RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType());
initialValue = mlir::DenseIntElementsAttr::get(rtt, 0);
} else {
auto rtt = mlir::RankedTensorType::get({}, convertedType);
initialValue = mlir::DenseIntElementsAttr::get(rtt, 0);
}
} else if (auto intAttr = init.value().dyn_cast<mlir::cir::IntAttr>()) {
auto rtt = mlir::RankedTensorType::get({}, convertedType);
initialValue = mlir::DenseIntElementsAttr::get(rtt, intAttr.getValue());
} else if (auto fltAttr = init.value().dyn_cast<mlir::cir::FPAttr>()) {
auto rtt = mlir::RankedTensorType::get({}, convertedType);
initialValue = mlir::DenseFPElementsAttr::get(rtt, fltAttr.getValue());
} else if (auto boolAttr = init.value().dyn_cast<mlir::cir::BoolAttr>()) {
auto rtt = mlir::RankedTensorType::get({}, convertedType);
initialValue =
mlir::DenseIntElementsAttr::get(rtt, (char)boolAttr.getValue());
} else
llvm_unreachable(
"GlobalOp lowering with initial value is not fully supported yet");
}

// Add symbol visibility
std::string sym_visibility = op.isPrivate() ? "private" : "public";

rewriter.replaceOpWithNewOp<mlir::memref::GlobalOp>(
op, b.getStringAttr(op.getSymName()),
/*sym_visibility=*/b.getStringAttr(sym_visibility),
/*type=*/memrefType, initialValue,
/*constant=*/op.getConstant(),
/*alignment=*/memrefAlignment);

return mlir::success();
}
};

class CIRGetGlobalOpLowering
: public mlir::OpConversionPattern<mlir::cir::GetGlobalOp> {
public:
using OpConversionPattern<mlir::cir::GetGlobalOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::GetGlobalOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
// FIXME(cir): Premature DCE to avoid lowering stuff we're not using.
// CIRGen should mitigate this and not emit the get_global.
if (op->getUses().empty()) {
rewriter.eraseOp(op);
return mlir::success();
}

auto type = getTypeConverter()->convertType(op.getType());
auto symbol = op.getName();
rewriter.replaceOpWithNewOp<mlir::memref::GetGlobalOp>(op, type, symbol);
return mlir::success();
}
};

void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
Expand All @@ -628,8 +717,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRCosOpLowering>(converter,
patterns.getContext());
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
CIRGetGlobalOpLowering>(converter, patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand All @@ -639,6 +728,8 @@ static mlir::TypeConverter prepareTypeConverter() {
// FIXME: The pointee type might not be converted (e.g. struct)
if (!ty)
return nullptr;
if (isa<mlir::cir::ArrayType>(type.getPointee()))
return ty;
return mlir::MemRefType::get({}, ty);
});
converter.addConversion(
Expand Down Expand Up @@ -669,8 +760,17 @@ static mlir::TypeConverter prepareTypeConverter() {
return converter.convertType(type.getUnderlying());
});
converter.addConversion([&](mlir::cir::ArrayType type) -> mlir::Type {
auto elementType = converter.convertType(type.getEltType());
return mlir::MemRefType::get(type.getSize(), elementType);
SmallVector<int64_t> shape;
mlir::Type curType = type;
while (auto arrayType = dyn_cast<mlir::cir::ArrayType>(curType)) {
shape.push_back(arrayType.getSize());
curType = arrayType.getEltType();
}
auto elementType = converter.convertType(curType);
// FIXME: The element type might not be converted (e.g. struct)
if (!elementType)
return nullptr;
return mlir::MemRefType::get(shape, elementType);
});

return converter;
Expand Down
55 changes: 55 additions & 0 deletions clang/test/CIR/Lowering/ThroughMLIR/global.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM

!u32i = !cir.int<u, 32>
module {
cir.global external @i = #cir.int<2> : !u32i
cir.global external @f = #cir.fp<3.000000e+00> : !cir.float
cir.global external @b = #cir.bool<true> : !cir.bool
cir.global "private" external @a : !cir.array<!u32i x 100>
cir.global external @aa = #cir.zero : !cir.array<!cir.array<!u32i x 256> x 256>

cir.func @get_global_int_value() -> !u32i {
%0 = cir.get_global @i : cir.ptr <!u32i>
%1 = cir.load %0 : cir.ptr <!u32i>, !u32i
cir.return %1 : !u32i
}
cir.func @get_global_float_value() -> !cir.float {
%0 = cir.get_global @f : cir.ptr <!cir.float>
%1 = cir.load %0 : cir.ptr <!cir.float>, !cir.float
cir.return %1 : !cir.float
}
cir.func @get_global_bool_value() -> !cir.bool {
%0 = cir.get_global @b : cir.ptr <!cir.bool>
%1 = cir.load %0 : cir.ptr <!cir.bool>, !cir.bool
cir.return %1 : !cir.bool
}
cir.func @get_global_array_pointer() -> !cir.ptr<!cir.array<!u32i x 100>> {
%0 = cir.get_global @a : cir.ptr <!cir.array<!u32i x 100>>
cir.return %0 : !cir.ptr<!cir.array<!u32i x 100>>
}
cir.func @get_global_multi_array_pointer() -> !cir.ptr<!cir.array<!cir.array<!u32i x 256> x 256>> {
%0 = cir.get_global @aa : cir.ptr <!cir.array<!cir.array<!u32i x 256> x 256>>
cir.return %0 : !cir.ptr<!cir.array<!cir.array<!u32i x 256> x 256>>
}
}

// MLIR: memref.global "public" @i : memref<i32> = dense<2>
// MLIR: memref.global "public" @f : memref<f32> = dense<3.000000e+00>
// MLIR: memref.global "public" @b : memref<i8> = dense<1>
// MLIR: memref.global "private" @a : memref<100xi32>
// MLIR: memref.global "public" @aa : memref<256x256xi32> = dense<0>
// MLIR: memref.get_global @i : memref<i32>
// MLIR: memref.get_global @f : memref<f32>
// MLIR: memref.get_global @b : memref<i8>
// MLIR: memref.get_global @a : memref<100xi32>
// MLIR: memref.get_global @aa : memref<256x256xi32>

// LLVM: @i = global i32 2
// LLVM: @f = global float 3.000000e+00
// LLVM: @b = global i8 1
// LLVM: @a = private global [100 x i32] undef
// LLVM: @aa = global [256 x [256 x i32]] zeroinitializer
// LLVM: load i32, ptr @i
// LLVM: load float, ptr @f
// LLVM: load i8, ptr @b
Loading