Skip to content

Commit 95a5485

Browse files
committed
[CIR] Lower cir.get_member to named_tuple + memref casts
Emulate the member access through memory for now.
1 parent 7fdc651 commit 95a5485

File tree

5 files changed

+123
-15
lines changed

5 files changed

+123
-15
lines changed

clang/include/clang/CIR/LowerToMLIR.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ namespace cir {
2020
void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
2121
mlir::TypeConverter &converter);
2222
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout);
23-
void runAtStartOfConvertCIRToMLIRPass(std::function<void(mlir::ConversionTarget)>);
23+
void runAtStartOfConvertCIRToMLIRPass(
24+
std::function<void(mlir::ConversionTarget)>);
2425
} // namespace cir
2526

2627
#endif // CLANG_CIR_LOWERTOMLIR_H_

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+86-9
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,17 @@
2929
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3030
#include "mlir/Dialect/Math/IR/Math.h"
3131
#include "mlir/Dialect/MemRef/IR/MemRef.h"
32+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
3233
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
34+
#include "mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
35+
#include "mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
3336
#include "mlir/Dialect/SCF/IR/SCF.h"
3437
#include "mlir/Dialect/SCF/Transforms/Passes.h"
3538
#include "mlir/Dialect/Vector/IR/VectorOps.h"
39+
#include "mlir/IR/Attributes.h"
40+
#include "mlir/IR/BuiltinAttributes.h"
3641
#include "mlir/IR/BuiltinDialect.h"
42+
#include "mlir/IR/BuiltinOps.h"
3743
#include "mlir/IR/BuiltinTypes.h"
3844
#include "mlir/IR/Operation.h"
3945
#include "mlir/IR/Region.h"
@@ -48,11 +54,13 @@
4854
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
4955
#include "mlir/Target/LLVMIR/Export.h"
5056
#include "mlir/Transforms/DialectConversion.h"
57+
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
5158
#include "clang/CIR/Dialect/IR/CIRDialect.h"
5259
#include "clang/CIR/Dialect/IR/CIRTypes.h"
5360
#include "clang/CIR/LowerToMLIR.h"
5461
#include "clang/CIR/LoweringHelpers.h"
5562
#include "clang/CIR/Passes.h"
63+
#include "llvm/ADT/ArrayRef.h"
5664
#include "llvm/ADT/STLExtras.h"
5765
#include "llvm/ADT/Sequence.h"
5866
#include "llvm/ADT/SmallVector.h"
@@ -175,7 +183,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
175183
mlir::Type mlirType =
176184
convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType());
177185

178-
// FIXME: Some types can not be converted yet (e.g. struct)
186+
// FIXME: Some types can not be converted yet
179187
if (!mlirType)
180188
return mlir::LogicalResult::failure();
181189

@@ -277,6 +285,71 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
277285
}
278286
};
279287

288+
// Lower cir.get_member
289+
//
290+
// clang-format off
291+
//
292+
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
293+
//
294+
// to something like
295+
//
296+
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
297+
// %c8 = arith.constant 8 : index
298+
// %view_1 = memref.view %1[%c8][] : memref<24xi8> to memref<f64>
299+
// clang-format on
300+
class CIRGetMemberOpLowering
301+
: public mlir::OpConversionPattern<cir::GetMemberOp> {
302+
cir::CIRDataLayout const &dataLayout;
303+
304+
public:
305+
using mlir::OpConversionPattern<cir::GetMemberOp>::OpConversionPattern;
306+
307+
CIRGetMemberOpLowering(const mlir::TypeConverter &typeConverter,
308+
mlir::MLIRContext *context,
309+
cir::CIRDataLayout const &dataLayout)
310+
: OpConversionPattern{typeConverter, context}, dataLayout{dataLayout} {}
311+
312+
mlir::LogicalResult
313+
matchAndRewrite(cir::GetMemberOp op, OpAdaptor adaptor,
314+
mlir::ConversionPatternRewriter &rewriter) const override {
315+
auto pointeeType = op.getAddrTy().getPointee();
316+
if (!mlir::isa<cir::StructType>(pointeeType))
317+
op.emitError("GetMemberOp only works on pointer to cir::StructType");
318+
auto structType = mlir::cast<cir::StructType>(pointeeType);
319+
// For now, just rely on the datalayout of the high-level type since the
320+
// datalayout of low-level type is not implemented yet. But since C++ is a
321+
// concrete datalayout, both datalayouts are the same.
322+
auto *structLayout = dataLayout.getStructLayout(structType);
323+
324+
// Get the lowered type: memref<!named_tuple.named_tuple<>>
325+
auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr().getType());
326+
// Alias the memref of struct to a memref of an i8 array of the same size.
327+
const std::array linearizedSize{
328+
static_cast<std::int64_t>(dataLayout.getTypeStoreSize(structType))};
329+
auto flattenMemRef = mlir::MemRefType::get(
330+
linearizedSize, mlir::IntegerType::get(memref.getContext(), 8));
331+
// Use a special cast because normal memref cast cannot do such an extreme
332+
// cast.
333+
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
334+
op.getLoc(), mlir::TypeRange{flattenMemRef},
335+
mlir::ValueRange{adaptor.getAddr()});
336+
337+
auto memberIndex = op.getIndex();
338+
auto namedTupleType =
339+
mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType());
340+
// The lowered type of the element to access in the named_tuple.
341+
auto loweredMemberType = namedTupleType.getType(memberIndex);
342+
auto elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
343+
auto offset = structLayout->getElementOffset(memberIndex);
344+
// Synthesize the byte access to right lowered type.
345+
auto byteShift =
346+
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), offset);
347+
rewriter.replaceOpWithNewOp<mlir::memref::ViewOp>(
348+
op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
349+
return mlir::LogicalResult::success();
350+
}
351+
};
352+
280353
class CIRCosOpLowering : public mlir::OpConversionPattern<cir::CosOp> {
281354
public:
282355
using OpConversionPattern<cir::CosOp>::OpConversionPattern;
@@ -1353,7 +1426,8 @@ class CIRPtrStrideOpLowering
13531426
};
13541427

13551428
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1356-
mlir::TypeConverter &converter) {
1429+
mlir::TypeConverter &converter,
1430+
cir::CIRDataLayout &cirDataLayout) {
13571431
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
13581432

13591433
patterns.add<
@@ -1372,6 +1446,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
13721446
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
13731447
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
13741448
patterns.getContext());
1449+
patterns.add<CIRGetMemberOpLowering>(converter, patterns.getContext(),
1450+
cirDataLayout);
13751451
}
13761452

13771453
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
@@ -1428,7 +1504,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14281504
curType = arrayType.getEltType();
14291505
}
14301506
auto elementType = converter.convertType(curType);
1431-
// FIXME: The element type might not be converted (e.g. struct)
1507+
// FIXME: The element type might not be converted
14321508
if (!elementType)
14331509
return nullptr;
14341510
return mlir::MemRefType::get(shape, elementType);
@@ -1470,20 +1546,21 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14701546
void ConvertCIRToMLIRPass::runOnOperation() {
14711547
auto module = getOperation();
14721548
mlir::DataLayout dataLayout{module};
1549+
cir::CIRDataLayout cirDataLayout{module};
14731550
auto converter = prepareTypeConverter(dataLayout);
14741551

14751552
mlir::RewritePatternSet patterns(&getContext());
14761553

14771554
populateCIRLoopToSCFConversionPatterns(patterns, converter);
1478-
populateCIRToMLIRConversionPatterns(patterns, converter);
1555+
populateCIRToMLIRConversionPatterns(patterns, converter, cirDataLayout);
14791556

14801557
mlir::ConversionTarget target(getContext());
14811558
target.addLegalOp<mlir::ModuleOp>();
1482-
target
1483-
.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1484-
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
1485-
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
1486-
mlir::math::MathDialect, mlir::vector::VectorDialect>();
1559+
target.addLegalDialect<mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1560+
mlir::memref::MemRefDialect, mlir::func::FuncDialect,
1561+
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
1562+
mlir::math::MathDialect, mlir::vector::VectorDialect,
1563+
mlir::named_tuple::NamedTupleDialect>();
14871564
target.addIllegalDialect<cir::CIRDialect>();
14881565

14891566
if (runAtStartHook)

clang/test/CIR/Lowering/ThroughMLIR/struct.cpp

+34-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,38 @@
33

44
struct s {
55
int a;
6-
float b;
6+
double b;
7+
char c;
78
};
8-
int main() { s v; }
9-
// CHECK: memref<!named_tuple.named_tuple<"s", [i32, f32]>>
9+
10+
int main() {
11+
s v;
12+
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>>
13+
v.a = 7;
14+
// CHECK: %[[C_7:.+]] = arith.constant 7 : i32
15+
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
16+
// CHECK: %[[OFFSET_A:.+]] = arith.constant 0 : index
17+
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<24xi8> to memref<i32>
18+
// CHECK: memref.store %[[C_7]], %[[VIEW_A]][] : memref<i32>
19+
20+
v.b = 3.;
21+
// CHECK: %[[C_3:.+]] = arith.constant 3.000000e+00 : f64
22+
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
23+
// CHECK: %[[OFFSET_B:.+]] = arith.constant 8 : index
24+
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<24xi8> to memref<f64>
25+
// CHECK: memref.store %[[C_3]], %[[VIEW_B]][] : memref<f64>
26+
27+
v.c = 'z';
28+
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
29+
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
30+
// CHECK: %[[OFFSET_C:.+]] = arith.constant 16 : index
31+
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<24xi8> to memref<i8>
32+
// memref.store %[[C_122]], %[[VIEW_C]][] : memref<i8>
33+
34+
return v.c;
35+
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
36+
// CHECK: %[[OFFSET_C_1:.+]] = arith.constant 16 : index
37+
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<24xi8> to memref<i8>
38+
// CHECK: %[[VALUE_C:.+]] = memref.load %[[VIEW_C_1]][] : memref<i8>
39+
// CHECK: %[[VALUE_RET:.+]] = arith.extsi %[[VALUE_C]] : i8 to i32
40+
}

mlir/include/mlir/Dialect/NamedTuple/IR/NamedTuple.h

-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#ifndef MLIR_DIALECT_NAMED_TUPLE_IR_NAMED_TUPLE_H
1414
#define MLIR_DIALECT_NAMED_TUPLE_IR_NAMED_TUPLE_H
1515

16-
//#include "mlir/IR/Dialect.h"
1716
#include "mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
1817
#include "mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
1918

mlir/include/mlir/InitAllDialects.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@
5959
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
6060
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
6161
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
62-
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
6362
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
63+
#include "mlir/Dialect/NamedTuple/IR/NamedTuple.h"
6464
#include "mlir/Dialect/OpenACC/OpenACC.h"
6565
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
6666
#include "mlir/Dialect/PDL/IR/PDL.h"

0 commit comments

Comments
 (0)