Skip to content

Commit eacaabb

Browse files
authored
[CIR] Add support for casting pointer-to-data-member values (#1188)
This PR adds support for base-to-derived and derived-to-base casts on pointer-to-data-member values. Related to #973.
1 parent 67bbd1e commit eacaabb

File tree

8 files changed

+323
-20
lines changed

8 files changed

+323
-20
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+52
Original file line numberDiff line numberDiff line change
@@ -3279,6 +3279,58 @@ def DerivedClassAddrOp : CIR_Op<"derived_class_addr"> {
32793279
let hasVerifier = 0;
32803280
}
32813281

3282+
//===----------------------------------------------------------------------===//
3283+
// BaseDataMemberOp & DerivedDataMemberOp
3284+
//===----------------------------------------------------------------------===//
3285+
3286+
def BaseDataMemberOp : CIR_Op<"base_data_member", [Pure]> {
3287+
let summary =
3288+
"Cast a derived class data member pointer to a base class data member "
3289+
"pointer";
3290+
let description = [{
3291+
The `cir.base_data_member` operation casts a data member pointer of type
3292+
`T Derived::*` to a data member pointer of type `T Base::*`, where `Base`
3293+
is an accessible non-ambiguous non-virtual base class of `Derived`.
3294+
3295+
The `offset` parameter gives the offset in bytes of the `Base` base class
3296+
subobject within a `Derived` object.
3297+
}];
3298+
3299+
let arguments = (ins CIR_DataMemberType:$src, IndexAttr:$offset);
3300+
let results = (outs CIR_DataMemberType:$result);
3301+
3302+
let assemblyFormat = [{
3303+
`(` $src `:` qualified(type($src)) `)`
3304+
`[` $offset `]` `->` qualified(type($result)) attr-dict
3305+
}];
3306+
3307+
let hasVerifier = 1;
3308+
}
3309+
3310+
def DerivedDataMemberOp : CIR_Op<"derived_data_member", [Pure]> {
3311+
let summary =
3312+
"Cast a base class data member pointer to a derived class data member "
3313+
"pointer";
3314+
let description = [{
3315+
The `cir.derived_data_member` operation casts a data member pointer of type
3316+
`T Base::*` to a data member pointer of type `T Derived::*`, where `Base`
3317+
is an accessible non-ambiguous non-virtual base class of `Derived`.
3318+
3319+
The `offset` parameter gives the offset in bytes of the `Base` base class
3320+
subobject within a `Derived` object.
3321+
}];
3322+
3323+
let arguments = (ins CIR_DataMemberType:$src, IndexAttr:$offset);
3324+
let results = (outs CIR_DataMemberType:$result);
3325+
3326+
let assemblyFormat = [{
3327+
`(` $src `:` qualified(type($src)) `)`
3328+
`[` $offset `]` `->` qualified(type($result)) attr-dict
3329+
}];
3330+
3331+
let hasVerifier = 1;
3332+
}
3333+
32823334
//===----------------------------------------------------------------------===//
32833335
// FuncOp
32843336
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+24-3
Original file line numberDiff line numberDiff line change
@@ -1744,9 +1744,30 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
17441744
case CK_ReinterpretMemberPointer:
17451745
llvm_unreachable("NYI");
17461746
case CK_BaseToDerivedMemberPointer:
1747-
llvm_unreachable("NYI");
1748-
case CK_DerivedToBaseMemberPointer:
1749-
llvm_unreachable("NYI");
1747+
case CK_DerivedToBaseMemberPointer: {
1748+
mlir::Value src = Visit(E);
1749+
1750+
QualType derivedTy =
1751+
Kind == CK_DerivedToBaseMemberPointer ? E->getType() : CE->getType();
1752+
const CXXRecordDecl *derivedClass = derivedTy->castAs<MemberPointerType>()
1753+
->getClass()
1754+
->getAsCXXRecordDecl();
1755+
CharUnits offset = CGF.CGM.computeNonVirtualBaseClassOffset(
1756+
derivedClass, CE->path_begin(), CE->path_end());
1757+
1758+
if (E->getType()->isMemberFunctionPointerType())
1759+
llvm_unreachable("NYI");
1760+
1761+
mlir::Location loc = CGF.getLoc(E->getExprLoc());
1762+
mlir::Type resultTy = CGF.getCIRType(DestTy);
1763+
mlir::IntegerAttr offsetAttr = Builder.getIndexAttr(offset.getQuantity());
1764+
1765+
if (Kind == CK_BaseToDerivedMemberPointer)
1766+
return Builder.create<cir::DerivedDataMemberOp>(loc, resultTy, src,
1767+
offsetAttr);
1768+
return Builder.create<cir::BaseDataMemberOp>(loc, resultTy, src,
1769+
offsetAttr);
1770+
}
17501771
case CK_ARCProduceObject:
17511772
llvm_unreachable("NYI");
17521773
case CK_ARCConsumeObject:

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,32 @@ LogicalResult cir::DynamicCastOp::verify() {
799799
return mlir::success();
800800
}
801801

802+
//===----------------------------------------------------------------------===//
803+
// BaseDataMemberOp & DerivedDataMemberOp
804+
//===----------------------------------------------------------------------===//
805+
806+
static LogicalResult verifyDataMemberCast(Operation *op, mlir::Value src,
807+
mlir::Type resultTy) {
808+
// Let the operand type be T1 C1::*, let the result type be T2 C2::*.
809+
// Verify that T1 and T2 are the same type.
810+
auto inputMemberTy =
811+
mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
812+
auto resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
813+
if (inputMemberTy != resultMemberTy)
814+
return op->emitOpError()
815+
<< "member types of the operand and the result do not match";
816+
817+
return mlir::success();
818+
}
819+
820+
LogicalResult cir::BaseDataMemberOp::verify() {
821+
return verifyDataMemberCast(getOperation(), getSrc(), getType());
822+
}
823+
824+
LogicalResult cir::DerivedDataMemberOp::verify() {
825+
return verifyDataMemberCast(getOperation(), getSrc(), getType());
826+
}
827+
802828
//===----------------------------------------------------------------------===//
803829
// ComplexCreateOp
804830
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h

+12
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ class CIRCXXABI {
8585
lowerGetRuntimeMember(cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy,
8686
mlir::Value loweredAddr, mlir::Value loweredMember,
8787
mlir::OpBuilder &builder) const = 0;
88+
89+
/// Lower the given cir.base_data_member op to a sequence of more "primitive"
90+
/// CIR operations that act on the ABI types.
91+
virtual mlir::Value lowerBaseDataMember(cir::BaseDataMemberOp op,
92+
mlir::Value loweredSrc,
93+
mlir::OpBuilder &builder) const = 0;
94+
95+
/// Lower the given cir.derived_data_member op to a sequence of more
96+
/// "primitive" CIR operations that act on the ABI types.
97+
virtual mlir::Value
98+
lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc,
99+
mlir::OpBuilder &builder) const = 0;
88100
};
89101

90102
/// Creates an Itanium-family ABI.

clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp

+46
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ class ItaniumCXXABI : public CIRCXXABI {
6565
lowerGetRuntimeMember(cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy,
6666
mlir::Value loweredAddr, mlir::Value loweredMember,
6767
mlir::OpBuilder &builder) const override;
68+
69+
mlir::Value lowerBaseDataMember(cir::BaseDataMemberOp op,
70+
mlir::Value loweredSrc,
71+
mlir::OpBuilder &builder) const override;
72+
73+
mlir::Value lowerDerivedDataMember(cir::DerivedDataMemberOp op,
74+
mlir::Value loweredSrc,
75+
mlir::OpBuilder &builder) const override;
6876
};
6977

7078
} // namespace
@@ -129,6 +137,44 @@ mlir::Operation *ItaniumCXXABI::lowerGetRuntimeMember(
129137
memberBytesPtr);
130138
}
131139

140+
static mlir::Value lowerDataMemberCast(mlir::Operation *op,
141+
mlir::Value loweredSrc,
142+
std::int64_t offset,
143+
bool isDerivedToBase,
144+
mlir::OpBuilder &builder) {
145+
if (offset == 0)
146+
return loweredSrc;
147+
148+
auto nullValue = builder.create<cir::ConstantOp>(
149+
op->getLoc(), mlir::IntegerAttr::get(loweredSrc.getType(), -1));
150+
auto isNull = builder.create<cir::CmpOp>(op->getLoc(), cir::CmpOpKind::eq,
151+
loweredSrc, nullValue);
152+
153+
auto offsetValue = builder.create<cir::ConstantOp>(
154+
op->getLoc(), mlir::IntegerAttr::get(loweredSrc.getType(), offset));
155+
auto binOpKind = isDerivedToBase ? cir::BinOpKind::Sub : cir::BinOpKind::Add;
156+
auto adjustedPtr = builder.create<cir::BinOp>(
157+
op->getLoc(), loweredSrc.getType(), binOpKind, loweredSrc, offsetValue);
158+
159+
return builder.create<cir::SelectOp>(op->getLoc(), loweredSrc.getType(),
160+
isNull, nullValue, adjustedPtr);
161+
}
162+
163+
mlir::Value ItaniumCXXABI::lowerBaseDataMember(cir::BaseDataMemberOp op,
164+
mlir::Value loweredSrc,
165+
mlir::OpBuilder &builder) const {
166+
return lowerDataMemberCast(op, loweredSrc, op.getOffset().getSExtValue(),
167+
/*isDerivedToBase=*/true, builder);
168+
}
169+
170+
mlir::Value
171+
ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op,
172+
mlir::Value loweredSrc,
173+
mlir::OpBuilder &builder) const {
174+
return lowerDataMemberCast(op, loweredSrc, op.getOffset().getSExtValue(),
175+
/*isDerivedToBase=*/false, builder);
176+
}
177+
132178
CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) {
133179
switch (LM.getCXXABIKind()) {
134180
// Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+57-17
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,24 @@ mlir::LogicalResult CIRToLLVMDerivedClassAddrOpLowering::matchAndRewrite(
914914
return mlir::success();
915915
}
916916

917+
mlir::LogicalResult CIRToLLVMBaseDataMemberOpLowering::matchAndRewrite(
918+
cir::BaseDataMemberOp op, OpAdaptor adaptor,
919+
mlir::ConversionPatternRewriter &rewriter) const {
920+
mlir::Value loweredResult =
921+
lowerMod->getCXXABI().lowerBaseDataMember(op, adaptor.getSrc(), rewriter);
922+
rewriter.replaceOp(op, loweredResult);
923+
return mlir::success();
924+
}
925+
926+
mlir::LogicalResult CIRToLLVMDerivedDataMemberOpLowering::matchAndRewrite(
927+
cir::DerivedDataMemberOp op, OpAdaptor adaptor,
928+
mlir::ConversionPatternRewriter &rewriter) const {
929+
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDerivedDataMember(
930+
op, adaptor.getSrc(), rewriter);
931+
rewriter.replaceOp(op, loweredResult);
932+
return mlir::success();
933+
}
934+
917935
static mlir::Value
918936
getValueForVTableSymbol(mlir::Operation *op,
919937
mlir::ConversionPatternRewriter &rewriter,
@@ -1518,7 +1536,13 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
15181536
mlir::ConversionPatternRewriter &rewriter) const {
15191537
mlir::Attribute attr = op.getValue();
15201538

1521-
if (mlir::isa<cir::BoolType>(op.getType())) {
1539+
if (mlir::isa<mlir::IntegerType>(op.getType())) {
1540+
// Verified cir.const operations cannot actually be of these types, but the
1541+
// lowering pass may generate temporary cir.const operations with these
1542+
// types. This is OK since MLIR allows unverified operations to be alive
1543+
// during a pass as long as they don't live past the end of the pass.
1544+
attr = op.getValue();
1545+
} else if (mlir::isa<cir::BoolType>(op.getType())) {
15221546
int value = (op.getValue() ==
15231547
cir::BoolAttr::get(getContext(),
15241548
cir::BoolType::get(getContext()), true));
@@ -2412,11 +2436,12 @@ CIRToLLVMBinOpLowering::getIntOverflowFlag(cir::BinOp op) const {
24122436
mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
24132437
cir::BinOp op, OpAdaptor adaptor,
24142438
mlir::ConversionPatternRewriter &rewriter) const {
2415-
assert((op.getLhs().getType() == op.getRhs().getType()) &&
2439+
assert((adaptor.getLhs().getType() == adaptor.getRhs().getType()) &&
24162440
"inconsistent operands' types not supported yet");
2441+
24172442
mlir::Type type = op.getRhs().getType();
2418-
assert((mlir::isa<cir::IntType, cir::CIRFPTypeInterface, cir::VectorType>(
2419-
type)) &&
2443+
assert((mlir::isa<cir::IntType, cir::CIRFPTypeInterface, cir::VectorType,
2444+
mlir::IntegerType>(type)) &&
24202445
"operand type not supported yet");
24212446

24222447
auto llvmTy = getTypeConverter()->convertType(op.getType());
@@ -2427,38 +2452,44 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
24272452

24282453
switch (op.getKind()) {
24292454
case cir::BinOpKind::Add:
2430-
if (mlir::isa<cir::IntType>(type))
2455+
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
24312456
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs,
24322457
getIntOverflowFlag(op));
24332458
else
24342459
rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmTy, lhs, rhs);
24352460
break;
24362461
case cir::BinOpKind::Sub:
2437-
if (mlir::isa<cir::IntType>(type))
2462+
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
24382463
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs,
24392464
getIntOverflowFlag(op));
24402465
else
24412466
rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, llvmTy, lhs, rhs);
24422467
break;
24432468
case cir::BinOpKind::Mul:
2444-
if (mlir::isa<cir::IntType>(type))
2469+
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
24452470
rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs,
24462471
getIntOverflowFlag(op));
24472472
else
24482473
rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, llvmTy, lhs, rhs);
24492474
break;
24502475
case cir::BinOpKind::Div:
2451-
if (auto ty = mlir::dyn_cast<cir::IntType>(type)) {
2452-
if (ty.isUnsigned())
2476+
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
2477+
auto isUnsigned = mlir::isa<cir::IntType>(type)
2478+
? mlir::cast<cir::IntType>(type).isUnsigned()
2479+
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
2480+
if (isUnsigned)
24532481
rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, llvmTy, lhs, rhs);
24542482
else
24552483
rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, llvmTy, lhs, rhs);
24562484
} else
24572485
rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, llvmTy, lhs, rhs);
24582486
break;
24592487
case cir::BinOpKind::Rem:
2460-
if (auto ty = mlir::dyn_cast<cir::IntType>(type)) {
2461-
if (ty.isUnsigned())
2488+
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
2489+
auto isUnsigned = mlir::isa<cir::IntType>(type)
2490+
? mlir::cast<cir::IntType>(type).isUnsigned()
2491+
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
2492+
if (isUnsigned)
24622493
rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, llvmTy, lhs, rhs);
24632494
else
24642495
rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, llvmTy, lhs, rhs);
@@ -2642,9 +2673,12 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
26422673
mlir::Value llResult;
26432674

26442675
// Lower to LLVM comparison op.
2645-
if (auto intTy = mlir::dyn_cast<cir::IntType>(type)) {
2646-
auto kind =
2647-
convertCmpKindToICmpPredicate(cmpOp.getKind(), intTy.isSigned());
2676+
// if (auto intTy = mlir::dyn_cast<cir::IntType>(type)) {
2677+
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
2678+
auto isSigned = mlir::isa<cir::IntType>(type)
2679+
? mlir::cast<cir::IntType>(type).isSigned()
2680+
: mlir::cast<mlir::IntegerType>(type).isSigned();
2681+
auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned);
26482682
llResult = rewriter.create<mlir::LLVM::ICmpOp>(
26492683
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
26502684
} else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
@@ -3847,9 +3881,15 @@ void populateCIRToLLVMConversionPatterns(
38473881
patterns.add<CIRToLLVMAllocaOpLowering>(converter, dataLayout,
38483882
stringGlobalsMap, argStringGlobalsMap,
38493883
argsVarMap, patterns.getContext());
3850-
patterns.add<CIRToLLVMConstantOpLowering, CIRToLLVMGlobalOpLowering,
3851-
CIRToLLVMGetRuntimeMemberOpLowering>(
3852-
converter, patterns.getContext(), lowerModule);
3884+
patterns.add<
3885+
// clang-format off
3886+
CIRToLLVMBaseDataMemberOpLowering,
3887+
CIRToLLVMConstantOpLowering,
3888+
CIRToLLVMDerivedDataMemberOpLowering,
3889+
CIRToLLVMGetRuntimeMemberOpLowering,
3890+
CIRToLLVMGlobalOpLowering
3891+
// clang-format on
3892+
>(converter, patterns.getContext(), lowerModule);
38533893
patterns.add<
38543894
// clang-format off
38553895
CIRToLLVMAbsOpLowering,

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

+30
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,36 @@ class CIRToLLVMDerivedClassAddrOpLowering
165165
mlir::ConversionPatternRewriter &) const override;
166166
};
167167

168+
class CIRToLLVMBaseDataMemberOpLowering
169+
: public mlir::OpConversionPattern<cir::BaseDataMemberOp> {
170+
cir::LowerModule *lowerMod;
171+
172+
public:
173+
CIRToLLVMBaseDataMemberOpLowering(const mlir::TypeConverter &typeConverter,
174+
mlir::MLIRContext *context,
175+
cir::LowerModule *lowerModule)
176+
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {}
177+
178+
mlir::LogicalResult
179+
matchAndRewrite(cir::BaseDataMemberOp op, OpAdaptor,
180+
mlir::ConversionPatternRewriter &) const override;
181+
};
182+
183+
class CIRToLLVMDerivedDataMemberOpLowering
184+
: public mlir::OpConversionPattern<cir::DerivedDataMemberOp> {
185+
cir::LowerModule *lowerMod;
186+
187+
public:
188+
CIRToLLVMDerivedDataMemberOpLowering(const mlir::TypeConverter &typeConverter,
189+
mlir::MLIRContext *context,
190+
cir::LowerModule *lowerModule)
191+
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {}
192+
193+
mlir::LogicalResult
194+
matchAndRewrite(cir::DerivedDataMemberOp op, OpAdaptor,
195+
mlir::ConversionPatternRewriter &) const override;
196+
};
197+
168198
class CIRToLLVMVTTAddrPointOpLowering
169199
: public mlir::OpConversionPattern<cir::VTTAddrPointOp> {
170200
public:

0 commit comments

Comments
 (0)