Skip to content

Commit 6afddce

Browse files
ShivaChenlanza
authored andcommitted
[CIR] Support lowering GlobalOp and GetGlobalOp to memref (llvm#574)
This commit introduce CIRGlobalOpLowering and CIRGetGlobalOpLowering for lowering to memref.
1 parent 37d5011 commit 6afddce

File tree

2 files changed

+159
-4
lines changed

2 files changed

+159
-4
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+104-4
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,95 @@ class CIRYieldOpLowering
620620
}
621621
};
622622

623+
class CIRGlobalOpLowering
624+
: public mlir::OpConversionPattern<mlir::cir::GlobalOp> {
625+
public:
626+
using OpConversionPattern<mlir::cir::GlobalOp>::OpConversionPattern;
627+
mlir::LogicalResult
628+
matchAndRewrite(mlir::cir::GlobalOp op, OpAdaptor adaptor,
629+
mlir::ConversionPatternRewriter &rewriter) const override {
630+
auto moduleOp = op->getParentOfType<mlir::ModuleOp>();
631+
if (!moduleOp)
632+
return mlir::failure();
633+
634+
mlir::OpBuilder b(moduleOp.getContext());
635+
636+
const auto CIRSymType = op.getSymType();
637+
auto convertedType = getTypeConverter()->convertType(CIRSymType);
638+
if (!convertedType)
639+
return mlir::failure();
640+
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
641+
if (!memrefType)
642+
memrefType = mlir::MemRefType::get({}, convertedType);
643+
// Add an optional alignment to the global memref.
644+
mlir::IntegerAttr memrefAlignment =
645+
op.getAlignment()
646+
? mlir::IntegerAttr::get(b.getI64Type(), op.getAlignment().value())
647+
: mlir::IntegerAttr();
648+
// Add an optional initial value to the global memref.
649+
mlir::Attribute initialValue = mlir::Attribute();
650+
std::optional<mlir::Attribute> init = op.getInitialValue();
651+
if (init.has_value()) {
652+
if (auto constArr = init.value().dyn_cast<mlir::cir::ZeroAttr>()) {
653+
if (memrefType.getShape().size()) {
654+
auto rtt = mlir::RankedTensorType::get(memrefType.getShape(),
655+
memrefType.getElementType());
656+
initialValue = mlir::DenseIntElementsAttr::get(rtt, 0);
657+
} else {
658+
auto rtt = mlir::RankedTensorType::get({}, convertedType);
659+
initialValue = mlir::DenseIntElementsAttr::get(rtt, 0);
660+
}
661+
} else if (auto intAttr = init.value().dyn_cast<mlir::cir::IntAttr>()) {
662+
auto rtt = mlir::RankedTensorType::get({}, convertedType);
663+
initialValue = mlir::DenseIntElementsAttr::get(rtt, intAttr.getValue());
664+
} else if (auto fltAttr = init.value().dyn_cast<mlir::cir::FPAttr>()) {
665+
auto rtt = mlir::RankedTensorType::get({}, convertedType);
666+
initialValue = mlir::DenseFPElementsAttr::get(rtt, fltAttr.getValue());
667+
} else if (auto boolAttr = init.value().dyn_cast<mlir::cir::BoolAttr>()) {
668+
auto rtt = mlir::RankedTensorType::get({}, convertedType);
669+
initialValue =
670+
mlir::DenseIntElementsAttr::get(rtt, (char)boolAttr.getValue());
671+
} else
672+
llvm_unreachable(
673+
"GlobalOp lowering with initial value is not fully supported yet");
674+
}
675+
676+
// Add symbol visibility
677+
std::string sym_visibility = op.isPrivate() ? "private" : "public";
678+
679+
rewriter.replaceOpWithNewOp<mlir::memref::GlobalOp>(
680+
op, b.getStringAttr(op.getSymName()),
681+
/*sym_visibility=*/b.getStringAttr(sym_visibility),
682+
/*type=*/memrefType, initialValue,
683+
/*constant=*/op.getConstant(),
684+
/*alignment=*/memrefAlignment);
685+
686+
return mlir::success();
687+
}
688+
};
689+
690+
class CIRGetGlobalOpLowering
691+
: public mlir::OpConversionPattern<mlir::cir::GetGlobalOp> {
692+
public:
693+
using OpConversionPattern<mlir::cir::GetGlobalOp>::OpConversionPattern;
694+
695+
mlir::LogicalResult
696+
matchAndRewrite(mlir::cir::GetGlobalOp op, OpAdaptor adaptor,
697+
mlir::ConversionPatternRewriter &rewriter) const override {
698+
// FIXME(cir): Premature DCE to avoid lowering stuff we're not using.
699+
// CIRGen should mitigate this and not emit the get_global.
700+
if (op->getUses().empty()) {
701+
rewriter.eraseOp(op);
702+
return mlir::success();
703+
}
704+
705+
auto type = getTypeConverter()->convertType(op.getType());
706+
auto symbol = op.getName();
707+
rewriter.replaceOpWithNewOp<mlir::memref::GetGlobalOp>(op, type, symbol);
708+
return mlir::success();
709+
}
710+
};
711+
623712
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
624713
mlir::TypeConverter &converter) {
625714
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
@@ -628,8 +717,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
628717
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
629718
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
630719
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
631-
CIRYieldOpLowering, CIRCosOpLowering>(converter,
632-
patterns.getContext());
720+
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
721+
CIRGetGlobalOpLowering>(converter, patterns.getContext());
633722
}
634723

635724
static mlir::TypeConverter prepareTypeConverter() {
@@ -639,6 +728,8 @@ static mlir::TypeConverter prepareTypeConverter() {
639728
// FIXME: The pointee type might not be converted (e.g. struct)
640729
if (!ty)
641730
return nullptr;
731+
if (isa<mlir::cir::ArrayType>(type.getPointee()))
732+
return ty;
642733
return mlir::MemRefType::get({}, ty);
643734
});
644735
converter.addConversion(
@@ -669,8 +760,17 @@ static mlir::TypeConverter prepareTypeConverter() {
669760
return converter.convertType(type.getUnderlying());
670761
});
671762
converter.addConversion([&](mlir::cir::ArrayType type) -> mlir::Type {
672-
auto elementType = converter.convertType(type.getEltType());
673-
return mlir::MemRefType::get(type.getSize(), elementType);
763+
SmallVector<int64_t> shape;
764+
mlir::Type curType = type;
765+
while (auto arrayType = dyn_cast<mlir::cir::ArrayType>(curType)) {
766+
shape.push_back(arrayType.getSize());
767+
curType = arrayType.getEltType();
768+
}
769+
auto elementType = converter.convertType(curType);
770+
// FIXME: The element type might not be converted (e.g. struct)
771+
if (!elementType)
772+
return nullptr;
773+
return mlir::MemRefType::get(shape, elementType);
674774
});
675775

676776
return converter;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR
2+
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
3+
4+
!u32i = !cir.int<u, 32>
5+
module {
6+
cir.global external @i = #cir.int<2> : !u32i
7+
cir.global external @f = #cir.fp<3.000000e+00> : !cir.float
8+
cir.global external @b = #cir.bool<true> : !cir.bool
9+
cir.global "private" external @a : !cir.array<!u32i x 100>
10+
cir.global external @aa = #cir.zero : !cir.array<!cir.array<!u32i x 256> x 256>
11+
12+
cir.func @get_global_int_value() -> !u32i {
13+
%0 = cir.get_global @i : cir.ptr <!u32i>
14+
%1 = cir.load %0 : cir.ptr <!u32i>, !u32i
15+
cir.return %1 : !u32i
16+
}
17+
cir.func @get_global_float_value() -> !cir.float {
18+
%0 = cir.get_global @f : cir.ptr <!cir.float>
19+
%1 = cir.load %0 : cir.ptr <!cir.float>, !cir.float
20+
cir.return %1 : !cir.float
21+
}
22+
cir.func @get_global_bool_value() -> !cir.bool {
23+
%0 = cir.get_global @b : cir.ptr <!cir.bool>
24+
%1 = cir.load %0 : cir.ptr <!cir.bool>, !cir.bool
25+
cir.return %1 : !cir.bool
26+
}
27+
cir.func @get_global_array_pointer() -> !cir.ptr<!cir.array<!u32i x 100>> {
28+
%0 = cir.get_global @a : cir.ptr <!cir.array<!u32i x 100>>
29+
cir.return %0 : !cir.ptr<!cir.array<!u32i x 100>>
30+
}
31+
cir.func @get_global_multi_array_pointer() -> !cir.ptr<!cir.array<!cir.array<!u32i x 256> x 256>> {
32+
%0 = cir.get_global @aa : cir.ptr <!cir.array<!cir.array<!u32i x 256> x 256>>
33+
cir.return %0 : !cir.ptr<!cir.array<!cir.array<!u32i x 256> x 256>>
34+
}
35+
}
36+
37+
// MLIR: memref.global "public" @i : memref<i32> = dense<2>
38+
// MLIR: memref.global "public" @f : memref<f32> = dense<3.000000e+00>
39+
// MLIR: memref.global "public" @b : memref<i8> = dense<1>
40+
// MLIR: memref.global "private" @a : memref<100xi32>
41+
// MLIR: memref.global "public" @aa : memref<256x256xi32> = dense<0>
42+
// MLIR: memref.get_global @i : memref<i32>
43+
// MLIR: memref.get_global @f : memref<f32>
44+
// MLIR: memref.get_global @b : memref<i8>
45+
// MLIR: memref.get_global @a : memref<100xi32>
46+
// MLIR: memref.get_global @aa : memref<256x256xi32>
47+
48+
// LLVM: @i = global i32 2
49+
// LLVM: @f = global float 3.000000e+00
50+
// LLVM: @b = global i8 1
51+
// LLVM: @a = private global [100 x i32] undef
52+
// LLVM: @aa = global [256 x [256 x i32]] zeroinitializer
53+
// LLVM: load i32, ptr @i
54+
// LLVM: load float, ptr @f
55+
// LLVM: load i8, ptr @b

0 commit comments

Comments
 (0)