Skip to content

Commit 43094d7

Browse files
authored
[CIR][CIRGen] Add CIRGen support for float16 and bfloat (#571)
This PR adds two new CIR floating-point types, namely `!cir.f16` and `!cir.bf16`, to represent the float16 format and bfloat format, respectively. This PR converts the clang extension type `_Float16` to `!cir.f16`, and converts the clang extension type `__bf16` type to `!cir.bf16`. The type conversion for clang extension type `__fp16` is not included in this PR since it requires additional work during CIRGen. Only CIRGen is implemented here, LLVMIR lowering / MLIR lowering should come next.
1 parent 3bad644 commit 43094d7

File tree

11 files changed

+2191
-42
lines changed

11 files changed

+2191
-42
lines changed

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+16-2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,20 @@ def CIR_Double : CIR_FloatType<"Double", "double"> {
150150
}];
151151
}
152152

153+
def CIR_FP16 : CIR_FloatType<"FP16", "f16"> {
154+
let summary = "CIR type that represents IEEE-754 binary16 format";
155+
let description = [{
156+
Floating-point type that represents the IEEE-754 binary16 format.
157+
}];
158+
}
159+
160+
def CIR_BFloat16 : CIR_FloatType<"BF16", "bf16"> {
161+
let summary = "CIR type that represents";
162+
let description = [{
163+
Floating-point type that represents the bfloat16 format.
164+
}];
165+
}
166+
153167
def CIR_FP80 : CIR_FloatType<"FP80", "f80"> {
154168
let summary = "CIR type that represents x87 80-bit floating-point format";
155169
let description = [{
@@ -179,7 +193,7 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
179193

180194
// Constraints
181195

182-
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_LongDouble]>;
196+
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_FP80, CIR_LongDouble]>;
183197
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;
184198

185199
//===----------------------------------------------------------------------===//
@@ -475,7 +489,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,
475489
def CIR_AnyType : AnyTypeOf<[
476490
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
477491
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo,
478-
CIR_AnyFloat,
492+
CIR_AnyFloat, CIR_FP16, CIR_BFloat16
479493
]>;
480494

481495
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+4
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
250250
return mlir::cir::FPAttr::getZero(fltType);
251251
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
252252
return mlir::cir::FPAttr::getZero(fltType);
253+
if (auto fltType = ty.dyn_cast<mlir::cir::FP16Type>())
254+
return mlir::cir::FPAttr::getZero(fltType);
255+
if (auto fltType = ty.dyn_cast<mlir::cir::BF16Type>())
256+
return mlir::cir::FPAttr::getZero(fltType);
253257
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
254258
return getZeroAttr(arrTy);
255259
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+145-30
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,16 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
115115
/// Emit a value that corresponds to null for the given type.
116116
mlir::Value buildNullValue(QualType Ty, mlir::Location loc);
117117

118+
mlir::Value buildPromotedValue(mlir::Value result, QualType PromotionType) {
119+
return Builder.createFloatingCast(result, ConvertType(PromotionType));
120+
}
121+
122+
mlir::Value buildUnPromotedValue(mlir::Value result, QualType ExprType) {
123+
return Builder.createFloatingCast(result, ConvertType(ExprType));
124+
}
125+
126+
mlir::Value buildPromoted(const Expr *E, QualType PromotionType);
127+
118128
//===--------------------------------------------------------------------===//
119129
// Visitor Methods
120130
//===--------------------------------------------------------------------===//
@@ -478,14 +488,45 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
478488
} else if (type->isVectorType()) {
479489
llvm_unreachable("no vector inc/dec yet");
480490
} else if (type->isRealFloatingType()) {
481-
auto isFloatOrDouble = type->isSpecificBuiltinType(BuiltinType::Float) ||
482-
type->isSpecificBuiltinType(BuiltinType::Double);
483-
assert(isFloatOrDouble && "Non-float/double NYI");
491+
// TODO(cir): CGFPOptionsRAII
492+
assert(!UnimplementedFeature::CGFPOptionsRAII());
484493

485-
// Create the inc/dec operation.
486-
auto kind =
487-
(isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec);
488-
value = buildUnaryOp(E, kind, input);
494+
if (type->isHalfType() && !CGF.getContext().getLangOpts().NativeHalfType)
495+
llvm_unreachable("__fp16 type NYI");
496+
497+
if (value.getType().isa<mlir::cir::SingleType, mlir::cir::DoubleType>()) {
498+
// Create the inc/dec operation.
499+
// NOTE(CIR): clang calls CreateAdd but folds this to a unary op
500+
auto kind =
501+
(isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec);
502+
value = buildUnaryOp(E, kind, input);
503+
} else {
504+
// Remaining types are Half, Bfloat16, LongDouble, __ibm128 or
505+
// __float128. Convert from float.
506+
507+
llvm::APFloat F(static_cast<float>(amount));
508+
bool ignored;
509+
const llvm::fltSemantics *FS;
510+
// Don't use getFloatTypeSemantics because Half isn't
511+
// necessarily represented using the "half" LLVM type.
512+
if (value.getType().isa<mlir::cir::LongDoubleType>())
513+
FS = &CGF.getTarget().getLongDoubleFormat();
514+
else if (value.getType().isa<mlir::cir::FP16Type>())
515+
FS = &CGF.getTarget().getHalfFormat();
516+
else if (value.getType().isa<mlir::cir::BF16Type>())
517+
FS = &CGF.getTarget().getBFloat16Format();
518+
else
519+
llvm_unreachable("fp128 / ppc_fp128 NYI");
520+
F.convert(*FS, llvm::APFloat::rmTowardZero, &ignored);
521+
522+
auto loc = CGF.getLoc(E->getExprLoc());
523+
auto amt = Builder.getConstant(
524+
loc, mlir::cir::FPAttr::get(value.getType(), F));
525+
value = Builder.createBinop(value, mlir::cir::BinOpKind::Add, amt);
526+
}
527+
528+
if (type->isHalfType() && !CGF.getContext().getLangOpts().NativeHalfType)
529+
llvm_unreachable("NYI");
489530

490531
} else if (type->isFixedPointType()) {
491532
llvm_unreachable("no fixed point inc/dec yet");
@@ -549,21 +590,22 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
549590
return Visit(E->getSubExpr()); // the actual value should be unused
550591
return buildLoadOfLValue(E);
551592
}
552-
mlir::Value VisitUnaryPlus(const UnaryOperator *E) {
553-
// NOTE(cir): QualType function parameter still not used, so don´t replicate
554-
// it here yet.
555-
QualType promotionTy = getPromotionType(E->getSubExpr()->getType());
593+
mlir::Value VisitUnaryPlus(const UnaryOperator *E,
594+
QualType PromotionType = QualType()) {
595+
QualType promotionTy = PromotionType.isNull()
596+
? getPromotionType(E->getSubExpr()->getType())
597+
: PromotionType;
556598
auto result = VisitPlus(E, promotionTy);
557599
if (result && !promotionTy.isNull())
558-
assert(0 && "not implemented yet");
600+
result = buildUnPromotedValue(result, E->getType());
559601
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Plus, result);
560602
}
561603

562604
mlir::Value VisitPlus(const UnaryOperator *E, QualType PromotionType) {
563605
// This differs from gcc, though, most likely due to a bug in gcc.
564606
TestAndClearIgnoreResultAssign();
565607
if (!PromotionType.isNull())
566-
assert(0 && "scalar promotion not implemented yet");
608+
return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
567609
return Visit(E->getSubExpr());
568610
}
569611

@@ -573,14 +615,14 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
573615
QualType promotionTy = getPromotionType(E->getSubExpr()->getType());
574616
auto result = VisitMinus(E, promotionTy);
575617
if (result && !promotionTy.isNull())
576-
assert(0 && "not implemented yet");
618+
result = buildUnPromotedValue(result, E->getType());
577619
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Minus, result);
578620
}
579621

580622
mlir::Value VisitMinus(const UnaryOperator *E, QualType PromotionType) {
581623
TestAndClearIgnoreResultAssign();
582624
if (!PromotionType.isNull())
583-
assert(0 && "scalar promotion not implemented yet");
625+
return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
584626

585627
// NOTE: LLVM codegen will lower this directly to either a FNeg
586628
// or a Sub instruction. In CIR this will be handled later in LowerToLLVM.
@@ -752,18 +794,23 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
752794
QualType DstType, mlir::Type SrcTy,
753795
mlir::Type DstTy, ScalarConversionOpts Opts);
754796

755-
BinOpInfo buildBinOps(const BinaryOperator *E) {
797+
BinOpInfo buildBinOps(const BinaryOperator *E,
798+
QualType PromotionType = QualType()) {
756799
BinOpInfo Result;
757-
Result.LHS = Visit(E->getLHS());
758-
Result.RHS = Visit(E->getRHS());
759-
Result.FullType = E->getType();
760-
Result.CompType = E->getType();
761-
if (auto VecType = dyn_cast_or_null<VectorType>(E->getType())) {
800+
Result.LHS = CGF.buildPromotedScalarExpr(E->getLHS(), PromotionType);
801+
Result.RHS = CGF.buildPromotedScalarExpr(E->getRHS(), PromotionType);
802+
if (!PromotionType.isNull())
803+
Result.FullType = PromotionType;
804+
else
805+
Result.FullType = E->getType();
806+
Result.CompType = Result.FullType;
807+
if (const auto *VecType = dyn_cast_or_null<VectorType>(Result.FullType)) {
762808
Result.CompType = VecType->getElementType();
763809
}
764810
Result.Opcode = E->getOpcode();
765811
Result.Loc = E->getSourceRange();
766812
// TODO: Result.FPFeatures
813+
assert(!UnimplementedFeature::getFPFeaturesInEffect());
767814
Result.E = E;
768815
return Result;
769816
}
@@ -793,15 +840,22 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
793840
if (auto *CT = Ty->getAs<ComplexType>()) {
794841
llvm_unreachable("NYI");
795842
}
796-
if (Ty.UseExcessPrecision(CGF.getContext()))
797-
llvm_unreachable("NYI");
843+
if (Ty.UseExcessPrecision(CGF.getContext())) {
844+
if (auto *VT = Ty->getAs<VectorType>())
845+
llvm_unreachable("NYI");
846+
return CGF.getContext().FloatTy;
847+
}
798848
return QualType();
799849
}
800850

801851
// Binary operators and binary compound assignment operators.
802852
#define HANDLEBINOP(OP) \
803853
mlir::Value VisitBin##OP(const BinaryOperator *E) { \
804-
return build##OP(buildBinOps(E)); \
854+
QualType promotionTy = getPromotionType(E->getType()); \
855+
auto result = build##OP(buildBinOps(E, promotionTy)); \
856+
if (result && !promotionTy.isNull()) \
857+
result = buildUnPromotedValue(result, E->getType()); \
858+
return result; \
805859
} \
806860
mlir::Value VisitBin##OP##Assign(const CompoundAssignOperator *E) { \
807861
return buildCompoundAssign(E, &ScalarExprEmitter::build##OP); \
@@ -1053,6 +1107,13 @@ mlir::Value CIRGenFunction::buildScalarExpr(const Expr *E) {
10531107
return ScalarExprEmitter(*this, builder).Visit(const_cast<Expr *>(E));
10541108
}
10551109

1110+
mlir::Value CIRGenFunction::buildPromotedScalarExpr(const Expr *E,
1111+
QualType PromotionType) {
1112+
if (!PromotionType.isNull())
1113+
return ScalarExprEmitter(*this, builder).buildPromoted(E, PromotionType);
1114+
return ScalarExprEmitter(*this, builder).Visit(const_cast<Expr *>(E));
1115+
}
1116+
10561117
[[maybe_unused]] static bool MustVisitNullValue(const Expr *E) {
10571118
// If a null pointer expression's type is the C++0x nullptr_t, then
10581119
// it's not necessarily a simple constant and it must be evaluated
@@ -1885,8 +1946,20 @@ LValue ScalarExprEmitter::buildCompoundAssignLValue(
18851946

18861947
// Emit the RHS first. __block variables need to have the rhs evaluated
18871948
// first, plus this should improve codegen a little.
1888-
OpInfo.RHS = Visit(E->getRHS());
1889-
OpInfo.FullType = E->getComputationResultType();
1949+
1950+
QualType PromotionTypeCR = getPromotionType(E->getComputationResultType());
1951+
if (PromotionTypeCR.isNull())
1952+
PromotionTypeCR = E->getComputationResultType();
1953+
1954+
QualType PromotionTypeLHS = getPromotionType(E->getComputationLHSType());
1955+
QualType PromotionTypeRHS = getPromotionType(E->getRHS()->getType());
1956+
1957+
if (!PromotionTypeRHS.isNull())
1958+
OpInfo.RHS = CGF.buildPromotedScalarExpr(E->getRHS(), PromotionTypeRHS);
1959+
else
1960+
OpInfo.RHS = Visit(E->getRHS());
1961+
1962+
OpInfo.FullType = PromotionTypeCR;
18901963
OpInfo.CompType = OpInfo.FullType;
18911964
if (auto VecType = dyn_cast_or_null<VectorType>(OpInfo.FullType)) {
18921965
OpInfo.CompType = VecType->getElementType();
@@ -1908,16 +1981,20 @@ LValue ScalarExprEmitter::buildCompoundAssignLValue(
19081981
CIRGenFunction::SourceLocRAIIObject sourceloc{
19091982
CGF, CGF.getLoc(E->getSourceRange())};
19101983
SourceLocation Loc = E->getExprLoc();
1911-
OpInfo.LHS =
1912-
buildScalarConversion(OpInfo.LHS, LHSTy, E->getComputationLHSType(), Loc);
1984+
if (!PromotionTypeLHS.isNull())
1985+
OpInfo.LHS = buildScalarConversion(OpInfo.LHS, LHSTy, PromotionTypeLHS,
1986+
E->getExprLoc());
1987+
else
1988+
OpInfo.LHS = buildScalarConversion(OpInfo.LHS, LHSTy,
1989+
E->getComputationLHSType(), Loc);
19131990

19141991
// Expand the binary operator.
19151992
Result = (this->*Func)(OpInfo);
19161993

19171994
// Convert the result back to the LHS type,
19181995
// potentially with Implicit Conversion sanitizer check.
1919-
Result = buildScalarConversion(Result, E->getComputationResultType(), LHSTy,
1920-
Loc, ScalarConversionOpts(CGF.SanOpts));
1996+
Result = buildScalarConversion(Result, PromotionTypeCR, LHSTy, Loc,
1997+
ScalarConversionOpts(CGF.SanOpts));
19211998

19221999
// Store the result value into the LHS lvalue. Bit-fields are handled
19232000
// specially because the result is altered by the store, i.e., [C99 6.5.16p1]
@@ -1938,6 +2015,44 @@ mlir::Value ScalarExprEmitter::buildNullValue(QualType Ty, mlir::Location loc) {
19382015
return CGF.buildFromMemory(CGF.CGM.buildNullConstant(Ty, loc), Ty);
19392016
}
19402017

2018+
mlir::Value ScalarExprEmitter::buildPromoted(const Expr *E,
2019+
QualType PromotionType) {
2020+
E = E->IgnoreParens();
2021+
if (const auto *BO = dyn_cast<BinaryOperator>(E)) {
2022+
switch (BO->getOpcode()) {
2023+
#define HANDLE_BINOP(OP) \
2024+
case BO_##OP: \
2025+
return build##OP(buildBinOps(BO, PromotionType));
2026+
HANDLE_BINOP(Add)
2027+
HANDLE_BINOP(Sub)
2028+
HANDLE_BINOP(Mul)
2029+
HANDLE_BINOP(Div)
2030+
#undef HANDLE_BINOP
2031+
default:
2032+
break;
2033+
}
2034+
} else if (const auto *UO = dyn_cast<UnaryOperator>(E)) {
2035+
switch (UO->getOpcode()) {
2036+
case UO_Imag:
2037+
case UO_Real:
2038+
llvm_unreachable("NYI");
2039+
case UO_Minus:
2040+
return VisitMinus(UO, PromotionType);
2041+
case UO_Plus:
2042+
return VisitPlus(UO, PromotionType);
2043+
default:
2044+
break;
2045+
}
2046+
}
2047+
auto result = Visit(const_cast<Expr *>(E));
2048+
if (result) {
2049+
if (!PromotionType.isNull())
2050+
return buildPromotedValue(result, PromotionType);
2051+
return buildUnPromotedValue(result, E->getType());
2052+
}
2053+
return result;
2054+
}
2055+
19412056
mlir::Value ScalarExprEmitter::buildCompoundAssign(
19422057
const CompoundAssignOperator *E,
19432058
mlir::Value (ScalarExprEmitter::*Func)(const BinOpInfo &)) {

clang/lib/CIR/CodeGen/CIRGenFunction.h

+3
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,9 @@ class CIRGenFunction : public CIRGenTypeCache {
10861086
mlir::Value buildScalarExpr(const clang::Expr *E);
10871087
mlir::Value buildScalarConstant(const ConstantEmission &Constant, Expr *E);
10881088

1089+
mlir::Value buildPromotedScalarExpr(const clang::Expr *E,
1090+
QualType PromotionType);
1091+
10891092
mlir::Type getCIRType(const clang::QualType &type);
10901093

10911094
const CaseStmt *foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,

clang/lib/CIR/CodeGen/CIRGenModule.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,11 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,
131131
// Initialize CIR pointer types cache.
132132
VoidPtrTy = ::mlir::cir::PointerType::get(builder.getContext(), VoidTy);
133133

134-
// TODO: HalfTy
135-
// TODO: BFloatTy
134+
FP16Ty = ::mlir::cir::FP16Type::get(builder.getContext());
135+
BFloat16Ty = ::mlir::cir::BF16Type::get(builder.getContext());
136136
FloatTy = ::mlir::cir::SingleType::get(builder.getContext());
137137
DoubleTy = ::mlir::cir::DoubleType::get(builder.getContext());
138138
FP80Ty = ::mlir::cir::FP80Type::get(builder.getContext());
139-
// TODO(cir): perhaps we should abstract long double variations into a custom
140-
// cir.long_double type. Said type would also hold the semantics for lowering.
141139

142140
// TODO: PointerWidthInBits
143141
PointerAlignInBytes =

clang/lib/CIR/CodeGen/CIRGenTypeCache.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@ struct CIRGenTypeCache {
3434
mlir::cir::IntType SInt8Ty, SInt16Ty, SInt32Ty, SInt64Ty;
3535
// usigned char, unsigned, unsigned short, unsigned long
3636
mlir::cir::IntType UInt8Ty, UInt16Ty, UInt32Ty, UInt64Ty;
37-
/// half, bfloat, float, double
38-
// mlir::Type HalfTy, BFloatTy;
39-
// TODO(cir): perhaps we should abstract long double variations into a custom
40-
// cir.long_double type. Said type would also hold the semantics for lowering.
37+
/// half, bfloat, float, double, fp80
38+
mlir::cir::FP16Type FP16Ty;
39+
mlir::cir::BF16Type BFloat16Ty;
4140
mlir::cir::SingleType FloatTy;
4241
mlir::cir::DoubleType DoubleTy;
4342
mlir::cir::FP80Type FP80Ty;

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -464,14 +464,14 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) {
464464
break;
465465

466466
case BuiltinType::Float16:
467-
ResultType = Builder.getF16Type();
467+
ResultType = CGM.FP16Ty;
468468
break;
469469
case BuiltinType::Half:
470470
// Should be the same as above?
471471
assert(0 && "not implemented");
472472
break;
473473
case BuiltinType::BFloat16:
474-
ResultType = Builder.getBF16Type();
474+
ResultType = CGM.BFloat16Ty;
475475
break;
476476
case BuiltinType::Float:
477477
ResultType = CGM.FloatTy;

clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h

+2
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ struct UnimplementedFeature {
130130
static bool shouldEmitLifetimeMarkers() { return false; }
131131
static bool peepholeProtection() { return false; }
132132
static bool CGCapturedStmtInfo() { return false; }
133+
static bool CGFPOptionsRAII() { return false; }
134+
static bool getFPFeaturesInEffect() { return false; }
133135
static bool cxxABI() { return false; }
134136
static bool openCL() { return false; }
135137
static bool CUDA() { return false; }

0 commit comments

Comments
 (0)