Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CIRGen] Add CIRGen support for float16 and bfloat #571

Merged
merged 2 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,20 @@ def CIR_Double : CIR_FloatType<"Double", "double"> {
}];
}

def CIR_FP16 : CIR_FloatType<"FP16", "f16"> {
let summary = "CIR type that represents IEEE-754 binary16 format";
let description = [{
Floating-point type that represents the IEEE-754 binary16 format.
}];
}

def CIR_BFloat16 : CIR_FloatType<"BF16", "bf16"> {
let summary = "CIR type that represents";
let description = [{
Floating-point type that represents the bfloat16 format.
}];
}

def CIR_FP80 : CIR_FloatType<"FP80", "f80"> {
let summary = "CIR type that represents x87 80-bit floating-point format";
let description = [{
Expand Down Expand Up @@ -179,7 +193,7 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {

// Constraints

def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_LongDouble]>;
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_FP80, CIR_LongDouble]>;
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -475,7 +489,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,
def CIR_AnyType : AnyTypeOf<[
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo,
CIR_AnyFloat,
CIR_AnyFloat, CIR_FP16, CIR_BFloat16
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
4 changes: 4 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = ty.dyn_cast<mlir::cir::FP16Type>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = ty.dyn_cast<mlir::cir::BF16Type>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
return getZeroAttr(arrTy);
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
Expand Down
175 changes: 145 additions & 30 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,16 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
/// Emit a value that corresponds to null for the given type.
mlir::Value buildNullValue(QualType Ty, mlir::Location loc);

mlir::Value buildPromotedValue(mlir::Value result, QualType PromotionType) {
return Builder.createFloatingCast(result, ConvertType(PromotionType));
}

mlir::Value buildUnPromotedValue(mlir::Value result, QualType ExprType) {
return Builder.createFloatingCast(result, ConvertType(ExprType));
}

mlir::Value buildPromoted(const Expr *E, QualType PromotionType);

//===--------------------------------------------------------------------===//
// Visitor Methods
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -478,14 +488,45 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
} else if (type->isVectorType()) {
llvm_unreachable("no vector inc/dec yet");
} else if (type->isRealFloatingType()) {
auto isFloatOrDouble = type->isSpecificBuiltinType(BuiltinType::Float) ||
type->isSpecificBuiltinType(BuiltinType::Double);
assert(isFloatOrDouble && "Non-float/double NYI");
// TODO(cir): CGFPOptionsRAII
assert(!UnimplementedFeature::CGFPOptionsRAII());

// Create the inc/dec operation.
auto kind =
(isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec);
value = buildUnaryOp(E, kind, input);
if (type->isHalfType() && !CGF.getContext().getLangOpts().NativeHalfType)
llvm_unreachable("__fp16 type NYI");

if (value.getType().isa<mlir::cir::SingleType, mlir::cir::DoubleType>()) {
// Create the inc/dec operation.
// NOTE(CIR): clang calls CreateAdd but folds this to a unary op
auto kind =
(isInc ? mlir::cir::UnaryOpKind::Inc : mlir::cir::UnaryOpKind::Dec);
value = buildUnaryOp(E, kind, input);
} else {
// Remaining types are Half, Bfloat16, LongDouble, __ibm128 or
// __float128. Convert from float.

llvm::APFloat F(static_cast<float>(amount));
bool ignored;
const llvm::fltSemantics *FS;
// Don't use getFloatTypeSemantics because Half isn't
// necessarily represented using the "half" LLVM type.
if (value.getType().isa<mlir::cir::LongDoubleType>())
FS = &CGF.getTarget().getLongDoubleFormat();
else if (value.getType().isa<mlir::cir::FP16Type>())
FS = &CGF.getTarget().getHalfFormat();
else if (value.getType().isa<mlir::cir::BF16Type>())
FS = &CGF.getTarget().getBFloat16Format();
else
llvm_unreachable("fp128 / ppc_fp128 NYI");
F.convert(*FS, llvm::APFloat::rmTowardZero, &ignored);

auto loc = CGF.getLoc(E->getExprLoc());
auto amt = Builder.getConstant(
loc, mlir::cir::FPAttr::get(value.getType(), F));
value = Builder.createBinop(value, mlir::cir::BinOpKind::Add, amt);
}

if (type->isHalfType() && !CGF.getContext().getLangOpts().NativeHalfType)
llvm_unreachable("NYI");

} else if (type->isFixedPointType()) {
llvm_unreachable("no fixed point inc/dec yet");
Expand Down Expand Up @@ -549,21 +590,22 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
return Visit(E->getSubExpr()); // the actual value should be unused
return buildLoadOfLValue(E);
}
mlir::Value VisitUnaryPlus(const UnaryOperator *E) {
// NOTE(cir): QualType function parameter still not used, so don´t replicate
// it here yet.
QualType promotionTy = getPromotionType(E->getSubExpr()->getType());
mlir::Value VisitUnaryPlus(const UnaryOperator *E,
QualType PromotionType = QualType()) {
QualType promotionTy = PromotionType.isNull()
? getPromotionType(E->getSubExpr()->getType())
: PromotionType;
auto result = VisitPlus(E, promotionTy);
if (result && !promotionTy.isNull())
assert(0 && "not implemented yet");
result = buildUnPromotedValue(result, E->getType());
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Plus, result);
}

mlir::Value VisitPlus(const UnaryOperator *E, QualType PromotionType) {
// This differs from gcc, though, most likely due to a bug in gcc.
TestAndClearIgnoreResultAssign();
if (!PromotionType.isNull())
assert(0 && "scalar promotion not implemented yet");
return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);
return Visit(E->getSubExpr());
}

Expand All @@ -573,14 +615,14 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
QualType promotionTy = getPromotionType(E->getSubExpr()->getType());
auto result = VisitMinus(E, promotionTy);
if (result && !promotionTy.isNull())
assert(0 && "not implemented yet");
result = buildUnPromotedValue(result, E->getType());
return buildUnaryOp(E, mlir::cir::UnaryOpKind::Minus, result);
}

mlir::Value VisitMinus(const UnaryOperator *E, QualType PromotionType) {
TestAndClearIgnoreResultAssign();
if (!PromotionType.isNull())
assert(0 && "scalar promotion not implemented yet");
return CGF.buildPromotedScalarExpr(E->getSubExpr(), PromotionType);

// NOTE: LLVM codegen will lower this directly to either a FNeg
// or a Sub instruction. In CIR this will be handled later in LowerToLLVM.
Expand Down Expand Up @@ -752,18 +794,23 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
QualType DstType, mlir::Type SrcTy,
mlir::Type DstTy, ScalarConversionOpts Opts);

BinOpInfo buildBinOps(const BinaryOperator *E) {
BinOpInfo buildBinOps(const BinaryOperator *E,
QualType PromotionType = QualType()) {
BinOpInfo Result;
Result.LHS = Visit(E->getLHS());
Result.RHS = Visit(E->getRHS());
Result.FullType = E->getType();
Result.CompType = E->getType();
if (auto VecType = dyn_cast_or_null<VectorType>(E->getType())) {
Result.LHS = CGF.buildPromotedScalarExpr(E->getLHS(), PromotionType);
Result.RHS = CGF.buildPromotedScalarExpr(E->getRHS(), PromotionType);
if (!PromotionType.isNull())
Result.FullType = PromotionType;
else
Result.FullType = E->getType();
Result.CompType = Result.FullType;
if (const auto *VecType = dyn_cast_or_null<VectorType>(Result.FullType)) {
Result.CompType = VecType->getElementType();
}
Result.Opcode = E->getOpcode();
Result.Loc = E->getSourceRange();
// TODO: Result.FPFeatures
assert(!UnimplementedFeature::getFPFeaturesInEffect());
Result.E = E;
return Result;
}
Expand Down Expand Up @@ -793,15 +840,22 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
if (auto *CT = Ty->getAs<ComplexType>()) {
llvm_unreachable("NYI");
}
if (Ty.UseExcessPrecision(CGF.getContext()))
llvm_unreachable("NYI");
if (Ty.UseExcessPrecision(CGF.getContext())) {
if (auto *VT = Ty->getAs<VectorType>())
llvm_unreachable("NYI");
return CGF.getContext().FloatTy;
}
return QualType();
}

// Binary operators and binary compound assignment operators.
#define HANDLEBINOP(OP) \
mlir::Value VisitBin##OP(const BinaryOperator *E) { \
return build##OP(buildBinOps(E)); \
QualType promotionTy = getPromotionType(E->getType()); \
auto result = build##OP(buildBinOps(E, promotionTy)); \
if (result && !promotionTy.isNull()) \
result = buildUnPromotedValue(result, E->getType()); \
return result; \
} \
mlir::Value VisitBin##OP##Assign(const CompoundAssignOperator *E) { \
return buildCompoundAssign(E, &ScalarExprEmitter::build##OP); \
Expand Down Expand Up @@ -1053,6 +1107,13 @@ mlir::Value CIRGenFunction::buildScalarExpr(const Expr *E) {
return ScalarExprEmitter(*this, builder).Visit(const_cast<Expr *>(E));
}

mlir::Value CIRGenFunction::buildPromotedScalarExpr(const Expr *E,
QualType PromotionType) {
if (!PromotionType.isNull())
return ScalarExprEmitter(*this, builder).buildPromoted(E, PromotionType);
return ScalarExprEmitter(*this, builder).Visit(const_cast<Expr *>(E));
}

[[maybe_unused]] static bool MustVisitNullValue(const Expr *E) {
// If a null pointer expression's type is the C++0x nullptr_t, then
// it's not necessarily a simple constant and it must be evaluated
Expand Down Expand Up @@ -1885,8 +1946,20 @@ LValue ScalarExprEmitter::buildCompoundAssignLValue(

// Emit the RHS first. __block variables need to have the rhs evaluated
// first, plus this should improve codegen a little.
OpInfo.RHS = Visit(E->getRHS());
OpInfo.FullType = E->getComputationResultType();

QualType PromotionTypeCR = getPromotionType(E->getComputationResultType());
if (PromotionTypeCR.isNull())
PromotionTypeCR = E->getComputationResultType();

QualType PromotionTypeLHS = getPromotionType(E->getComputationLHSType());
QualType PromotionTypeRHS = getPromotionType(E->getRHS()->getType());

if (!PromotionTypeRHS.isNull())
OpInfo.RHS = CGF.buildPromotedScalarExpr(E->getRHS(), PromotionTypeRHS);
else
OpInfo.RHS = Visit(E->getRHS());

OpInfo.FullType = PromotionTypeCR;
OpInfo.CompType = OpInfo.FullType;
if (auto VecType = dyn_cast_or_null<VectorType>(OpInfo.FullType)) {
OpInfo.CompType = VecType->getElementType();
Expand All @@ -1908,16 +1981,20 @@ LValue ScalarExprEmitter::buildCompoundAssignLValue(
CIRGenFunction::SourceLocRAIIObject sourceloc{
CGF, CGF.getLoc(E->getSourceRange())};
SourceLocation Loc = E->getExprLoc();
OpInfo.LHS =
buildScalarConversion(OpInfo.LHS, LHSTy, E->getComputationLHSType(), Loc);
if (!PromotionTypeLHS.isNull())
OpInfo.LHS = buildScalarConversion(OpInfo.LHS, LHSTy, PromotionTypeLHS,
E->getExprLoc());
else
OpInfo.LHS = buildScalarConversion(OpInfo.LHS, LHSTy,
E->getComputationLHSType(), Loc);

// Expand the binary operator.
Result = (this->*Func)(OpInfo);

// Convert the result back to the LHS type,
// potentially with Implicit Conversion sanitizer check.
Result = buildScalarConversion(Result, E->getComputationResultType(), LHSTy,
Loc, ScalarConversionOpts(CGF.SanOpts));
Result = buildScalarConversion(Result, PromotionTypeCR, LHSTy, Loc,
ScalarConversionOpts(CGF.SanOpts));

// Store the result value into the LHS lvalue. Bit-fields are handled
// specially because the result is altered by the store, i.e., [C99 6.5.16p1]
Expand All @@ -1938,6 +2015,44 @@ mlir::Value ScalarExprEmitter::buildNullValue(QualType Ty, mlir::Location loc) {
return CGF.buildFromMemory(CGF.CGM.buildNullConstant(Ty, loc), Ty);
}

mlir::Value ScalarExprEmitter::buildPromoted(const Expr *E,
QualType PromotionType) {
E = E->IgnoreParens();
if (const auto *BO = dyn_cast<BinaryOperator>(E)) {
switch (BO->getOpcode()) {
#define HANDLE_BINOP(OP) \
case BO_##OP: \
return build##OP(buildBinOps(BO, PromotionType));
HANDLE_BINOP(Add)
HANDLE_BINOP(Sub)
HANDLE_BINOP(Mul)
HANDLE_BINOP(Div)
#undef HANDLE_BINOP
default:
break;
}
} else if (const auto *UO = dyn_cast<UnaryOperator>(E)) {
switch (UO->getOpcode()) {
case UO_Imag:
case UO_Real:
llvm_unreachable("NYI");
case UO_Minus:
return VisitMinus(UO, PromotionType);
case UO_Plus:
return VisitPlus(UO, PromotionType);
default:
break;
}
}
auto result = Visit(const_cast<Expr *>(E));
if (result) {
if (!PromotionType.isNull())
return buildPromotedValue(result, PromotionType);
return buildUnPromotedValue(result, E->getType());
}
return result;
}

mlir::Value ScalarExprEmitter::buildCompoundAssign(
const CompoundAssignOperator *E,
mlir::Value (ScalarExprEmitter::*Func)(const BinOpInfo &)) {
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,9 @@ class CIRGenFunction : public CIRGenTypeCache {
mlir::Value buildScalarExpr(const clang::Expr *E);
mlir::Value buildScalarConstant(const ConstantEmission &Constant, Expr *E);

mlir::Value buildPromotedScalarExpr(const clang::Expr *E,
QualType PromotionType);

mlir::Type getCIRType(const clang::QualType &type);

const CaseStmt *foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
Expand Down
6 changes: 2 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,11 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,
// Initialize CIR pointer types cache.
VoidPtrTy = ::mlir::cir::PointerType::get(builder.getContext(), VoidTy);

// TODO: HalfTy
// TODO: BFloatTy
FP16Ty = ::mlir::cir::FP16Type::get(builder.getContext());
BFloat16Ty = ::mlir::cir::BF16Type::get(builder.getContext());
FloatTy = ::mlir::cir::SingleType::get(builder.getContext());
DoubleTy = ::mlir::cir::DoubleType::get(builder.getContext());
FP80Ty = ::mlir::cir::FP80Type::get(builder.getContext());
// TODO(cir): perhaps we should abstract long double variations into a custom
// cir.long_double type. Said type would also hold the semantics for lowering.

// TODO: PointerWidthInBits
PointerAlignInBytes =
Expand Down
7 changes: 3 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenTypeCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ struct CIRGenTypeCache {
mlir::cir::IntType SInt8Ty, SInt16Ty, SInt32Ty, SInt64Ty;
// usigned char, unsigned, unsigned short, unsigned long
mlir::cir::IntType UInt8Ty, UInt16Ty, UInt32Ty, UInt64Ty;
/// half, bfloat, float, double
// mlir::Type HalfTy, BFloatTy;
// TODO(cir): perhaps we should abstract long double variations into a custom
// cir.long_double type. Said type would also hold the semantics for lowering.
/// half, bfloat, float, double, fp80
mlir::cir::FP16Type FP16Ty;
mlir::cir::BF16Type BFloat16Ty;
mlir::cir::SingleType FloatTy;
mlir::cir::DoubleType DoubleTy;
mlir::cir::FP80Type FP80Ty;
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,14 +464,14 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) {
break;

case BuiltinType::Float16:
ResultType = Builder.getF16Type();
ResultType = CGM.FP16Ty;
break;
case BuiltinType::Half:
// Should be the same as above?
assert(0 && "not implemented");
break;
case BuiltinType::BFloat16:
ResultType = Builder.getBF16Type();
ResultType = CGM.BFloat16Ty;
break;
case BuiltinType::Float:
ResultType = CGM.FloatTy;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ struct UnimplementedFeature {
static bool shouldEmitLifetimeMarkers() { return false; }
static bool peepholeProtection() { return false; }
static bool CGCapturedStmtInfo() { return false; }
static bool CGFPOptionsRAII() { return false; }
static bool getFPFeaturesInEffect() { return false; }
static bool cxxABI() { return false; }
static bool openCL() { return false; }
static bool CUDA() { return false; }
Expand Down
Loading
Loading