Skip to content

Commit 3ea7d4c

Browse files
committed
[CIR] Add support for float16 and bfloat
This patch adds two new CIR floating-point types, namely `!cir.f16` and `!cir.bf16`, to represent the float16 format and bfloat format, respectively. CIRGen for the two new types and scalar expressions involving these two new types is also included in this patch. This patch converts the clang extension type `_Float16` to `!cir.f16`, and converts the clang extension type `__bf16` type to `!cir.bf16`.
1 parent 3bad644 commit 3ea7d4c

File tree

10 files changed

+2182
-43
lines changed

10 files changed

+2182
-43
lines changed

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

+16-2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,20 @@ def CIR_Double : CIR_FloatType<"Double", "double"> {
150150
}];
151151
}
152152

153+
def CIR_FP16 : CIR_FloatType<"FP16", "f16"> {
154+
let summary = "CIR type that represents IEEE-754 binary16 format";
155+
let description = [{
156+
Floating-point type that represents the IEEE-754 binary16 format.
157+
}];
158+
}
159+
160+
def CIR_BFloat16 : CIR_FloatType<"BF16", "bf16"> {
161+
let summary = "CIR type that represents";
162+
let description = [{
163+
Floating-point type that represents the bfloat16 format.
164+
}];
165+
}
166+
153167
def CIR_FP80 : CIR_FloatType<"FP80", "f80"> {
154168
let summary = "CIR type that represents x87 80-bit floating-point format";
155169
let description = [{
@@ -179,7 +193,7 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
179193

180194
// Constraints
181195

182-
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_LongDouble]>;
196+
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_FP80, CIR_LongDouble]>;
183197
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;
184198

185199
//===----------------------------------------------------------------------===//
@@ -475,7 +489,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,
475489
def CIR_AnyType : AnyTypeOf<[
476490
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
477491
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo,
478-
CIR_AnyFloat,
492+
CIR_AnyFloat, CIR_FP16, CIR_BFloat16
479493
]>;
480494

481495
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

+4
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
250250
return mlir::cir::FPAttr::getZero(fltType);
251251
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
252252
return mlir::cir::FPAttr::getZero(fltType);
253+
if (auto fltType = ty.dyn_cast<mlir::cir::FP16Type>())
254+
return mlir::cir::FPAttr::getZero(fltType);
255+
if (auto fltType = ty.dyn_cast<mlir::cir::BF16Type>())
256+
return mlir::cir::FPAttr::getZero(fltType);
253257
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
254258
return getZeroAttr(arrTy);
255259
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+138-31
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,16 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
115115
/// Emit a value that corresponds to null for the given type.
116116
mlir::Value buildNullValue(QualType Ty, mlir::Location loc);
117117

118+
mlir::Value buildPromotedValue(mlir::Value result, QualType PromotionType) {
119+
return Builder.createFloatingCast(result, ConvertType(PromotionType));
120+
}
121+
122+
mlir::Value buildUnPromotedValue(mlir::Value result, QualType ExprType) {
123+
return Builder.createFloatingCast(result, ConvertType(ExprType));
124+
}
125+
126+
mlir::Value buildPromoted(const Expr *E, QualType PromotionType);
127+
118128
//===--------------------------------------------------------------------===//
119129
// Visitor Methods
120130
//===--------------------------------------------------------------------===//
@@ -478,14 +488,38 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
478488
} else if (type->isVectorType()) {
479489
llvm_unreachable("no vector inc/dec yet");
480490
} else if (type->isRealFloatingType()) {
481-
auto isFloatOrDouble = type->isSpecificBuiltinType(BuiltinType::Float) ||
482-
type->isSpecificBuiltinType(BuiltinType::Double);
483-
assert(isFloatOrDouble && "Non-float/double NYI");
484-
485-
// Create the inc/dec operation.
486-
auto kind =
487-
(isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec);
488-
value = buildUnaryOp(E, kind, input);
491+
if (type->isHalfType())
492+
llvm_unreachable("__fp16 type NYI");
493+
494+
if (value.getType().isa<mlir::cir::SingleType, mlir::cir::DoubleType>()) {
495+
// Create the inc/dec operation.
496+
auto kind =
497+
(isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec);
498+
value = buildUnaryOp(E, kind, input);
499+
} else {
500+
// Remaining types are Half, Bfloat16, LongDouble, __ibm128 or
501+
// __float128. Convert from float.
502+
503+
llvm::APFloat F(static_cast<float>(amount));
504+
bool ignored;
505+
const llvm::fltSemantics *FS;
506+
// Don't use getFloatTypeSemantics because Half isn't
507+
// necessarily represented using the "half" LLVM type.
508+
if (value.getType().isa<mlir::cir::LongDoubleType>())
509+
FS = &CGF.getTarget().getLongDoubleFormat();
510+
else if (value.getType().isa<mlir::cir::FP16Type>())
511+
FS = &CGF.getTarget().getHalfFormat();
512+
else if (value.getType().isa<mlir::cir::BF16Type>())
513+
FS = &CGF.getTarget().getBFloat16Format();
514+
else
515+
llvm_unreachable("fp128 / ppc_fp128 NYI");
516+
F.convert(*FS, llvm::APFloat::rmTowardZero, &ignored);
517+
518+
auto loc = CGF.getLoc(E->getExprLoc());
519+
auto amt = Builder.getConstant(
520+
loc, mlir::cir::FPAttr::get(value.getType(), F));
521+
value = Builder.createBinop(value, mlir::cir::BinOpKind::Add, amt);
522+
}
489523

490524
} else if (type->isFixedPointType()) {
491525
llvm_unreachable("no fixed point inc/dec yet");
@@ -549,21 +583,22 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
549583
return Visit(E->getSubExpr()); // the actual value should be unused
550584
return buildLoadOfLValue(E);
551585
}
552-
mlir::Value VisitUnaryPlus(const UnaryOperator *E) {
553-
// NOTE(cir): QualType function parameter still not used, so don´t replicate
554-
// it here yet.
555-
QualType promotionTy = getPromotionType(E->getSubExpr()->getType());
586+
mlir::Value VisitUnaryPlus(const UnaryOperator *E,
587+
QualType PromotionType = QualType()) {
588+
QualType promotionTy = PromotionType.isNull()
589+
? getPromotionType(E->getSubExpr()->getType())
590+
: PromotionType;
556591
auto result = VisitPlus(E, promotionTy);
557592
if (result && !promotionTy.isNull())
558-
assert(0 && "not implemented yet");
593+
result = buildUnPromotedValue(result, E->getType());
559594
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Plus, result);
560595
}
561596

562597
mlir::Value VisitPlus(const UnaryOperator *E, QualType PromotionType) {
563598
// This differs from gcc, though, most likely due to a bug in gcc.
564599
TestAndClearIgnoreResultAssign();
565600
if (!PromotionType.isNull())
566-
assert(0 && "scalar promotion not implemented yet");
601+
return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
567602
return Visit(E->getSubExpr());
568603
}
569604

@@ -573,14 +608,14 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
573608
QualType promotionTy = getPromotionType(E->getSubExpr()->getType());
574609
auto result = VisitMinus(E, promotionTy);
575610
if (result && !promotionTy.isNull())
576-
assert(0 && "not implemented yet");
611+
result = buildUnPromotedValue(result, E->getType());
577612
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Minus, result);
578613
}
579614

580615
mlir::Value VisitMinus(const UnaryOperator *E, QualType PromotionType) {
581616
TestAndClearIgnoreResultAssign();
582617
if (!PromotionType.isNull())
583-
assert(0 && "scalar promotion not implemented yet");
618+
return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
584619

585620
// NOTE: LLVM codegen will lower this directly to either a FNeg
586621
// or a Sub instruction. In CIR this will be handled later in LowerToLLVM.
@@ -752,13 +787,17 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
752787
QualType DstType, mlir::Type SrcTy,
753788
mlir::Type DstTy, ScalarConversionOpts Opts);
754789

755-
BinOpInfo buildBinOps(const BinaryOperator *E) {
790+
BinOpInfo buildBinOps(const BinaryOperator *E,
791+
QualType PromotionType = QualType()) {
756792
BinOpInfo Result;
757-
Result.LHS = Visit(E->getLHS());
758-
Result.RHS = Visit(E->getRHS());
759-
Result.FullType = E->getType();
760-
Result.CompType = E->getType();
761-
if (auto VecType = dyn_cast_or_null<VectorType>(E->getType())) {
793+
Result.LHS = CGF.buildPromotedScalarExpr(E->getLHS(), PromotionType);
794+
Result.RHS = CGF.buildPromotedScalarExpr(E->getRHS(), PromotionType);
795+
if (!PromotionType.isNull())
796+
Result.FullType = PromotionType;
797+
else
798+
Result.FullType = E->getType();
799+
Result.CompType = Result.FullType;
800+
if (const auto *VecType = dyn_cast_or_null<VectorType>(Result.FullType)) {
762801
Result.CompType = VecType->getElementType();
763802
}
764803
Result.Opcode = E->getOpcode();
@@ -793,15 +832,22 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
793832
if (auto *CT = Ty->getAs<ComplexType>()) {
794833
llvm_unreachable("NYI");
795834
}
796-
if (Ty.UseExcessPrecision(CGF.getContext()))
797-
llvm_unreachable("NYI");
835+
if (Ty.UseExcessPrecision(CGF.getContext())) {
836+
if (auto *VT = Ty->getAs<VectorType>())
837+
llvm_unreachable("NYI");
838+
return CGF.getContext().FloatTy;
839+
}
798840
return QualType();
799841
}
800842

801843
// Binary operators and binary compound assignment operators.
802844
#define HANDLEBINOP(OP) \
803845
mlir::Value VisitBin##OP(const BinaryOperator *E) { \
804-
return build##OP(buildBinOps(E)); \
846+
QualType promotionTy = getPromotionType(E->getType()); \
847+
auto result = build##OP(buildBinOps(E, promotionTy)); \
848+
if (result && !promotionTy.isNull()) \
849+
result = buildUnPromotedValue(result, E->getType()); \
850+
return result; \
805851
} \
806852
mlir::Value VisitBin##OP##Assign(const CompoundAssignOperator *E) { \
807853
return buildCompoundAssign(E, &ScalarExprEmitter::build##OP); \
@@ -1053,6 +1099,13 @@ mlir::Value CIRGenFunction::buildScalarExpr(const Expr *E) {
10531099
return ScalarExprEmitter(*this, builder).Visit(const_cast<Expr *>(E));
10541100
}
10551101

1102+
mlir::Value CIRGenFunction::buildPromotedScalarExpr(const Expr *E,
1103+
QualType PromotionType) {
1104+
if (!PromotionType.isNull())
1105+
return ScalarExprEmitter(*this, builder).buildPromoted(E, PromotionType);
1106+
return ScalarExprEmitter(*this, builder).Visit(const_cast<Expr *>(E));
1107+
}
1108+
10561109
[[maybe_unused]] static bool MustVisitNullValue(const Expr *E) {
10571110
// If a null pointer expression's type is the C++0x nullptr_t, then
10581111
// it's not necessarily a simple constant and it must be evaluated
@@ -1885,8 +1938,20 @@ LValue ScalarExprEmitter::buildCompoundAssignLValue(
18851938

18861939
// Emit the RHS first. __block variables need to have the rhs evaluated
18871940
// first, plus this should improve codegen a little.
1888-
OpInfo.RHS = Visit(E->getRHS());
1889-
OpInfo.FullType = E->getComputationResultType();
1941+
1942+
QualType PromotionTypeCR = getPromotionType(E->getComputationResultType());
1943+
if (PromotionTypeCR.isNull())
1944+
PromotionTypeCR = E->getComputationResultType();
1945+
1946+
QualType PromotionTypeLHS = getPromotionType(E->getComputationLHSType());
1947+
QualType PromotionTypeRHS = getPromotionType(E->getRHS()->getType());
1948+
1949+
if (!PromotionTypeRHS.isNull())
1950+
OpInfo.RHS = CGF.buildPromotedScalarExpr(E->getRHS(), PromotionTypeRHS);
1951+
else
1952+
OpInfo.RHS = Visit(E->getRHS());
1953+
1954+
OpInfo.FullType = PromotionTypeCR;
18901955
OpInfo.CompType = OpInfo.FullType;
18911956
if (auto VecType = dyn_cast_or_null<VectorType>(OpInfo.FullType)) {
18921957
OpInfo.CompType = VecType->getElementType();
@@ -1908,16 +1973,20 @@ LValue ScalarExprEmitter::buildCompoundAssignLValue(
19081973
CIRGenFunction::SourceLocRAIIObject sourceloc{
19091974
CGF, CGF.getLoc(E->getSourceRange())};
19101975
SourceLocation Loc = E->getExprLoc();
1911-
OpInfo.LHS =
1912-
buildScalarConversion(OpInfo.LHS, LHSTy, E->getComputationLHSType(), Loc);
1976+
if (!PromotionTypeLHS.isNull())
1977+
OpInfo.LHS = buildScalarConversion(OpInfo.LHS, LHSTy, PromotionTypeLHS,
1978+
E->getExprLoc());
1979+
else
1980+
OpInfo.LHS = buildScalarConversion(OpInfo.LHS, LHSTy,
1981+
E->getComputationLHSType(), Loc);
19131982

19141983
// Expand the binary operator.
19151984
Result = (this->*Func)(OpInfo);
19161985

19171986
// Convert the result back to the LHS type,
19181987
// potentially with Implicit Conversion sanitizer check.
1919-
Result = buildScalarConversion(Result, E->getComputationResultType(), LHSTy,
1920-
Loc, ScalarConversionOpts(CGF.SanOpts));
1988+
Result = buildScalarConversion(Result, PromotionTypeCR, LHSTy, Loc,
1989+
ScalarConversionOpts(CGF.SanOpts));
19211990

19221991
// Store the result value into the LHS lvalue. Bit-fields are handled
19231992
// specially because the result is altered by the store, i.e., [C99 6.5.16p1]
@@ -1938,6 +2007,44 @@ mlir::Value ScalarExprEmitter::buildNullValue(QualType Ty, mlir::Location loc) {
19382007
return CGF.buildFromMemory(CGF.CGM.buildNullConstant(Ty, loc), Ty);
19392008
}
19402009

2010+
mlir::Value ScalarExprEmitter::buildPromoted(const Expr *E,
2011+
QualType PromotionType) {
2012+
E = E->IgnoreParens();
2013+
if (const auto *BO = dyn_cast<BinaryOperator>(E)) {
2014+
switch (BO->getOpcode()) {
2015+
#define HANDLE_BINOP(OP) \
2016+
case BO_##OP: \
2017+
return build##OP(buildBinOps(BO, PromotionType));
2018+
HANDLE_BINOP(Add)
2019+
HANDLE_BINOP(Sub)
2020+
HANDLE_BINOP(Mul)
2021+
HANDLE_BINOP(Div)
2022+
#undef HANDLE_BINOP
2023+
default:
2024+
break;
2025+
}
2026+
} else if (const auto *UO = dyn_cast<UnaryOperator>(E)) {
2027+
switch (UO->getOpcode()) {
2028+
case UO_Imag:
2029+
case UO_Real:
2030+
llvm_unreachable("NYI");
2031+
case UO_Minus:
2032+
return VisitMinus(UO, PromotionType);
2033+
case UO_Plus:
2034+
return VisitPlus(UO, PromotionType);
2035+
default:
2036+
break;
2037+
}
2038+
}
2039+
auto result = Visit(const_cast<Expr *>(E));
2040+
if (result) {
2041+
if (!PromotionType.isNull())
2042+
return buildPromotedValue(result, PromotionType);
2043+
return buildUnPromotedValue(result, E->getType());
2044+
}
2045+
return result;
2046+
}
2047+
19412048
mlir::Value ScalarExprEmitter::buildCompoundAssign(
19422049
const CompoundAssignOperator *E,
19432050
mlir::Value (ScalarExprEmitter::*Func)(const BinOpInfo &)) {

clang/lib/CIR/CodeGen/CIRGenFunction.h

+3
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,9 @@ class CIRGenFunction : public CIRGenTypeCache {
10861086
mlir::Value buildScalarExpr(const clang::Expr *E);
10871087
mlir::Value buildScalarConstant(const ConstantEmission &Constant, Expr *E);
10881088

1089+
mlir::Value buildPromotedScalarExpr(const clang::Expr *E,
1090+
QualType PromotionType);
1091+
10891092
mlir::Type getCIRType(const clang::QualType &type);
10901093

10911094
const CaseStmt *foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,

clang/lib/CIR/CodeGen/CIRGenModule.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,11 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,
131131
// Initialize CIR pointer types cache.
132132
VoidPtrTy = ::mlir::cir::PointerType::get(builder.getContext(), VoidTy);
133133

134-
// TODO: HalfTy
135-
// TODO: BFloatTy
134+
FP16Ty = ::mlir::cir::FP16Type::get(builder.getContext());
135+
BFloat16Ty = ::mlir::cir::BF16Type::get(builder.getContext());
136136
FloatTy = ::mlir::cir::SingleType::get(builder.getContext());
137137
DoubleTy = ::mlir::cir::DoubleType::get(builder.getContext());
138138
FP80Ty = ::mlir::cir::FP80Type::get(builder.getContext());
139-
// TODO(cir): perhaps we should abstract long double variations into a custom
140-
// cir.long_double type. Said type would also hold the semantics for lowering.
141139

142140
// TODO: PointerWidthInBits
143141
PointerAlignInBytes =

clang/lib/CIR/CodeGen/CIRGenTypeCache.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@ struct CIRGenTypeCache {
3434
mlir::cir::IntType SInt8Ty, SInt16Ty, SInt32Ty, SInt64Ty;
3535
// usigned char, unsigned, unsigned short, unsigned long
3636
mlir::cir::IntType UInt8Ty, UInt16Ty, UInt32Ty, UInt64Ty;
37-
/// half, bfloat, float, double
38-
// mlir::Type HalfTy, BFloatTy;
39-
// TODO(cir): perhaps we should abstract long double variations into a custom
40-
// cir.long_double type. Said type would also hold the semantics for lowering.
37+
/// half, bfloat, float, double, fp80
38+
mlir::cir::FP16Type FP16Ty;
39+
mlir::cir::BF16Type BFloat16Ty;
4140
mlir::cir::SingleType FloatTy;
4241
mlir::cir::DoubleType DoubleTy;
4342
mlir::cir::FP80Type FP80Ty;

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -464,14 +464,14 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) {
464464
break;
465465

466466
case BuiltinType::Float16:
467-
ResultType = Builder.getF16Type();
467+
ResultType = CGM.FP16Ty;
468468
break;
469469
case BuiltinType::Half:
470470
// Should be the same as above?
471471
assert(0 && "not implemented");
472472
break;
473473
case BuiltinType::BFloat16:
474-
ResultType = Builder.getBF16Type();
474+
ResultType = CGM.BFloat16Ty;
475475
break;
476476
case BuiltinType::Float:
477477
ResultType = CGM.FloatTy;

0 commit comments

Comments
 (0)