29
29
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
30
30
#include " mlir/Dialect/Math/IR/Math.h"
31
31
#include " mlir/Dialect/MemRef/IR/MemRef.h"
32
+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
32
33
#include " mlir/Dialect/NamedTuple/IR/NamedTuple.h"
34
+ #include " mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h"
35
+ #include " mlir/Dialect/NamedTuple/IR/NamedTupleTypes.h"
33
36
#include " mlir/Dialect/SCF/IR/SCF.h"
34
37
#include " mlir/Dialect/SCF/Transforms/Passes.h"
35
38
#include " mlir/Dialect/Vector/IR/VectorOps.h"
39
+ #include " mlir/IR/Attributes.h"
40
+ #include " mlir/IR/BuiltinAttributes.h"
36
41
#include " mlir/IR/BuiltinDialect.h"
42
+ #include " mlir/IR/BuiltinOps.h"
37
43
#include " mlir/IR/BuiltinTypes.h"
38
44
#include " mlir/IR/Operation.h"
39
45
#include " mlir/IR/Region.h"
48
54
#include " mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
49
55
#include " mlir/Target/LLVMIR/Export.h"
50
56
#include " mlir/Transforms/DialectConversion.h"
57
+ #include " clang/CIR/Dialect/IR/CIRDataLayout.h"
51
58
#include " clang/CIR/Dialect/IR/CIRDialect.h"
52
59
#include " clang/CIR/Dialect/IR/CIRTypes.h"
53
60
#include " clang/CIR/LowerToMLIR.h"
54
61
#include " clang/CIR/LoweringHelpers.h"
55
62
#include " clang/CIR/Passes.h"
63
+ #include " llvm/ADT/ArrayRef.h"
56
64
#include " llvm/ADT/STLExtras.h"
57
65
#include " llvm/ADT/Sequence.h"
58
66
#include " llvm/ADT/SmallVector.h"
@@ -175,7 +183,7 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
175
183
mlir::Type mlirType =
176
184
convertTypeForMemory (*getTypeConverter (), adaptor.getAllocaType ());
177
185
178
- // FIXME: Some types can not be converted yet (e.g. struct)
186
+ // FIXME: Some types can not be converted yet
179
187
if (!mlirType)
180
188
return mlir::LogicalResult::failure ();
181
189
@@ -277,6 +285,71 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
277
285
}
278
286
};
279
287
288
+ // Lower cir.get_member
289
+ //
290
+ // clang-format off
291
+ //
292
+ // %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
293
+ //
294
+ // to something like
295
+ //
296
+ // %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
297
+ // %c8 = arith.constant 8 : index
298
+ // %view_1 = memref.view %1[%c8][] : memref<24xi8> to memref<f64>
299
+ // clang-format on
300
+ class CIRGetMemberOpLowering
301
+ : public mlir::OpConversionPattern<cir::GetMemberOp> {
302
+ cir::CIRDataLayout const &dataLayout;
303
+
304
+ public:
305
+ using mlir::OpConversionPattern<cir::GetMemberOp>::OpConversionPattern;
306
+
307
+ CIRGetMemberOpLowering (const mlir::TypeConverter &typeConverter,
308
+ mlir::MLIRContext *context,
309
+ cir::CIRDataLayout const &dataLayout)
310
+ : OpConversionPattern{typeConverter, context}, dataLayout{dataLayout} {}
311
+
312
+ mlir::LogicalResult
313
+ matchAndRewrite (cir::GetMemberOp op, OpAdaptor adaptor,
314
+ mlir::ConversionPatternRewriter &rewriter) const override {
315
+ auto pointeeType = op.getAddrTy ().getPointee ();
316
+ if (!mlir::isa<cir::StructType>(pointeeType))
317
+ op.emitError (" GetMemberOp only works on pointer to cir::StructType" );
318
+ auto structType = mlir::cast<cir::StructType>(pointeeType);
319
+ // For now, just rely on the datalayout of the high-level type since the
320
+ // datalayout of low-level type is not implemented yet. But since C++ is a
321
+ // concrete datalayout, both datalayouts are the same.
322
+ auto *structLayout = dataLayout.getStructLayout (structType);
323
+
324
+ // Get the lowered type: memref<!named_tuple.named_tuple<>>
325
+ auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr ().getType ());
326
+ // Alias the memref of struct to a memref of an i8 array of the same size.
327
+ const std::array linearizedSize{
328
+ static_cast <std::int64_t >(dataLayout.getTypeStoreSize (structType))};
329
+ auto flattenMemRef = mlir::MemRefType::get (
330
+ linearizedSize, mlir::IntegerType::get (memref.getContext (), 8 ));
331
+ // Use a special cast because normal memref cast cannot do such an extreme
332
+ // cast.
333
+ auto bytesMemRef = rewriter.create <mlir::named_tuple::CastOp>(
334
+ op.getLoc (), mlir::TypeRange{flattenMemRef},
335
+ mlir::ValueRange{adaptor.getAddr ()});
336
+
337
+ auto memberIndex = op.getIndex ();
338
+ auto namedTupleType =
339
+ mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType ());
340
+ // The lowered type of the element to access in the named_tuple.
341
+ auto loweredMemberType = namedTupleType.getType (memberIndex);
342
+ auto elementMemRefTy = mlir::MemRefType::get ({}, loweredMemberType);
343
+ auto offset = structLayout->getElementOffset (memberIndex);
344
+ // Synthesize the byte access to right lowered type.
345
+ auto byteShift =
346
+ rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), offset);
347
+ rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
348
+ op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
349
+ return mlir::LogicalResult::success ();
350
+ }
351
+ };
352
+
280
353
class CIRCosOpLowering : public mlir ::OpConversionPattern<cir::CosOp> {
281
354
public:
282
355
using OpConversionPattern<cir::CosOp>::OpConversionPattern;
@@ -1353,7 +1426,8 @@ class CIRPtrStrideOpLowering
1353
1426
};
1354
1427
1355
1428
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1356
- mlir::TypeConverter &converter) {
1429
+ mlir::TypeConverter &converter,
1430
+ cir::CIRDataLayout &cirDataLayout) {
1357
1431
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
1358
1432
1359
1433
patterns.add <
@@ -1372,6 +1446,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1372
1446
CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1373
1447
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
1374
1448
patterns.getContext ());
1449
+ patterns.add <CIRGetMemberOpLowering>(converter, patterns.getContext (),
1450
+ cirDataLayout);
1375
1451
}
1376
1452
1377
1453
mlir::TypeConverter prepareTypeConverter (mlir::DataLayout &dataLayout) {
@@ -1428,7 +1504,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1428
1504
curType = arrayType.getEltType ();
1429
1505
}
1430
1506
auto elementType = converter.convertType (curType);
1431
- // FIXME: The element type might not be converted (e.g. struct)
1507
+ // FIXME: The element type might not be converted
1432
1508
if (!elementType)
1433
1509
return nullptr ;
1434
1510
return mlir::MemRefType::get (shape, elementType);
@@ -1470,20 +1546,21 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1470
1546
void ConvertCIRToMLIRPass::runOnOperation () {
1471
1547
auto module = getOperation ();
1472
1548
mlir::DataLayout dataLayout{module};
1549
+ cir::CIRDataLayout cirDataLayout{module};
1473
1550
auto converter = prepareTypeConverter (dataLayout);
1474
1551
1475
1552
mlir::RewritePatternSet patterns (&getContext ());
1476
1553
1477
1554
populateCIRLoopToSCFConversionPatterns (patterns, converter);
1478
- populateCIRToMLIRConversionPatterns (patterns, converter);
1555
+ populateCIRToMLIRConversionPatterns (patterns, converter, cirDataLayout );
1479
1556
1480
1557
mlir::ConversionTarget target (getContext ());
1481
1558
target.addLegalOp <mlir::ModuleOp>();
1482
- target
1483
- . addLegalDialect < mlir::affine::AffineDialect , mlir::arith::ArithDialect ,
1484
- mlir::memref::MemRefDialect , mlir::func::FuncDialect ,
1485
- mlir::scf::SCFDialect , mlir::cf::ControlFlowDialect ,
1486
- mlir::math::MathDialect, mlir::vector::VectorDialect >();
1559
+ target. addLegalDialect <mlir::affine::AffineDialect, mlir::arith::ArithDialect,
1560
+ mlir::memref::MemRefDialect , mlir::func::FuncDialect ,
1561
+ mlir::scf::SCFDialect , mlir::cf::ControlFlowDialect ,
1562
+ mlir::math::MathDialect , mlir::vector::VectorDialect ,
1563
+ mlir::named_tuple::NamedTupleDialect >();
1487
1564
target.addIllegalDialect <cir::CIRDialect>();
1488
1565
1489
1566
if (runAtStartHook)
0 commit comments