Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] introduce CIR floating-point types #385

Merged
merged 4 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,33 @@ def IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// FPAttr
//===----------------------------------------------------------------------===//

def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
let summary = "An attribute containing a floating-point value";
let description = [{
An fp attribute is a literal attribute that represents a floating-point
value of the specified floating-point type.
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APFloat":$value);
let builders = [
AttrBuilderWithInferredContext<(ins "Type":$type,
"const APFloat &":$value), [{
return $_get(type.getContext(), type, value);
}]>,
];
let extraClassDeclaration = [{
static FPAttr getZero(mlir::Type type);
}];
let genVerifyDecl = 1;

let assemblyFormat = [{
`<` custom<FloatLiteral>($value, ref($type)) `>`
}];
}

//===----------------------------------------------------------------------===//
// ConstPointerAttr
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2615,8 +2615,8 @@ def IterEndOp : CIR_Op<"iterator_end"> {

class UnaryFPToFPBuiltinOp<string mnemonic>
: CIR_Op<mnemonic, [Pure, SameOperandsAndResultType]> {
let arguments = (ins AnyFloat:$src);
let results = (outs AnyFloat:$result);
let arguments = (ins CIR_AnyFloat:$src);
let results = (outs CIR_AnyFloat:$result);
let summary = "libc builtin equivalent ignoring "
"floating point exceptions and errno";
let assemblyFormat = "$src `:` type($src) attr-dict";
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h"

#include "clang/CIR/Interfaces/ASTAttrInterfaces.h"

Expand Down
40 changes: 37 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@

include "clang/CIR/Dialect/IR/CIRDialect.td"
include "clang/CIR/Interfaces/ASTAttrInterfaces.td"
include "clang/CIR/Interfaces/CIRFPTypeInterface.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

//===----------------------------------------------------------------------===//
// CIR Types
//===----------------------------------------------------------------------===//

class CIR_Type<string name, string typeMnemonic, list<Trait> traits = []> :
TypeDef<CIR_Dialect, name, traits> {
class CIR_Type<string name, string typeMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: TypeDef<CIR_Dialect, name, traits, baseCppClass> {
let mnemonic = typeMnemonic;
}

Expand Down Expand Up @@ -94,6 +97,37 @@ def SInt16 : SInt<16>;
def SInt32 : SInt<32>;
def SInt64 : SInt<64>;

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//

class CIR_FloatType<string name, string mnemonic>
: CIR_Type<name, mnemonic,
[
DeclareTypeInterfaceMethods<DataLayoutTypeInterface>,
DeclareTypeInterfaceMethods<CIRFPTypeInterface>,
]> {}

def CIR_Single : CIR_FloatType<"Single", "float"> {
let summary = "CIR single-precision float type";
let description = [{
Floating-point type that represents the `float` type in C/C++. Its
underlying floating-point format is the IEEE-754 binary32 format.
}];
}

def CIR_Double : CIR_FloatType<"Double", "double"> {
let summary = "CIR double-precision float type";
let description = [{
Floating-point type that represents the `double` type in C/C++. Its
underlying floating-point format is the IEEE-754 binar64 format.
}];
}

// Constraints

def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double]>;

//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -318,7 +352,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,

def CIR_AnyType : AnyTypeOf<[
CIR_IntType, CIR_PointerType, CIR_BoolType, CIR_ArrayType, CIR_VectorType,
CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo, AnyFloat,
CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo, CIR_AnyFloat,
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
22 changes: 22 additions & 0 deletions clang/include/clang/CIR/Interfaces/CIRFPTypeInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- CIRFPTypeInterface.h - Interface for CIR FP types -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
//
// Defines the interface to generically handle CIR floating-point types.
//
//===----------------------------------------------------------------------===//

#ifndef CLANG_INTERFACES_CIR_CIR_FPTYPEINTERFACE_H
#define CLANG_INTERFACES_CIR_CIR_FPTYPEINTERFACE_H

#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"

/// Include the tablegen'd interface declarations.
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h.inc"

#endif // CLANG_INTERFACES_CIR_CIR_FPTYPEINTERFACE_H
52 changes: 52 additions & 0 deletions clang/include/clang/CIR/Interfaces/CIRFPTypeInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===- CIRFPTypeInterface.td - CIR FP Interface Definitions -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CIR_INTERFACES_CIR_FP_TYPE_INTERFACE
#define MLIR_CIR_INTERFACES_CIR_FP_TYPE_INTERFACE

include "mlir/IR/OpBase.td"

def CIRFPTypeInterface : TypeInterface<"CIRFPTypeInterface"> {
let description = [{
Contains helper functions to query properties about a floating-point type.
}];
let cppNamespace = "::mlir::cir";

let methods = [
InterfaceMethod<[{
Returns the bit width of this floating-point type.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getWidth",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::APFloat::semanticsSizeInBits($_type.getFloatSemantics());
}]
>,
InterfaceMethod<[{
Return the mantissa width.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getFPMantissaWidth",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::APFloat::semanticsPrecision($_type.getFloatSemantics());
}]
>,
InterfaceMethod<[{
Return the float semantics of this floating-point type.
}],
/*retTy=*/"const llvm::fltSemantics &",
/*methodName=*/"getFloatSemantics"
>,
];
}

#endif // MLIR_CIR_INTERFACES_CIR_FP_TYPE_INTERFACE
9 changes: 9 additions & 0 deletions clang/include/clang/CIR/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ function(add_clang_mlir_op_interface interface)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

function(add_clang_mlir_type_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIR${interface}IncGen)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

add_clang_mlir_attr_interface(ASTAttrInterfaces)
add_clang_mlir_op_interface(CIROpInterfaces)
add_clang_mlir_op_interface(CIRLoopOpInterface)
add_clang_mlir_type_interface(CIRFPTypeInterface)
33 changes: 17 additions & 16 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,10 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
if (ty.isa<mlir::cir::IntType>())
return mlir::cir::IntAttr::get(ty, 0);
if (ty.isa<mlir::FloatType>())
return mlir::FloatAttr::get(ty, 0.0);
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
return getZeroAttr(arrTy);
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
Expand Down Expand Up @@ -256,12 +258,13 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
if (const auto boolVal = attr.dyn_cast<mlir::cir::BoolAttr>())
return !boolVal.getValue();

if (const auto fpVal = attr.dyn_cast<mlir::FloatAttr>()) {
if (auto fpAttr = attr.dyn_cast<mlir::cir::FPAttr>()) {
auto fpVal = fpAttr.getValue();
bool ignored;
llvm::APFloat FV(+0.0);
FV.convert(fpVal.getValue().getSemantics(),
llvm::APFloat::rmNearestTiesToEven, &ignored);
return FV.bitwiseIsEqual(fpVal.getValue());
FV.convert(fpVal.getSemantics(), llvm::APFloat::rmNearestTiesToEven,
&ignored);
return FV.bitwiseIsEqual(fpVal);
}

if (const auto structVal = attr.dyn_cast<mlir::cir::ConstStructAttr>()) {
Expand Down Expand Up @@ -348,23 +351,21 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
}
bool isInt(mlir::Type i) { return i.isa<mlir::cir::IntType>(); }

mlir::FloatType getLongDouble80BitsTy() const {
return typeCache.LongDouble80BitsTy;
}
mlir::Type getLongDouble80BitsTy() const { llvm_unreachable("NYI"); }

/// Get the proper floating point type for the given semantics.
mlir::FloatType getFloatTyForFormat(const llvm::fltSemantics &format,
bool useNativeHalf) const {
mlir::Type getFloatTyForFormat(const llvm::fltSemantics &format,
bool useNativeHalf) const {
if (&format == &llvm::APFloat::IEEEhalf()) {
llvm_unreachable("IEEEhalf float format is NYI");
}

if (&format == &llvm::APFloat::BFloat())
llvm_unreachable("BFloat float format is NYI");
if (&format == &llvm::APFloat::IEEEsingle())
llvm_unreachable("IEEEsingle float format is NYI");
return typeCache.FloatTy;
if (&format == &llvm::APFloat::IEEEdouble())
llvm_unreachable("IEEEdouble float format is NYI");
return typeCache.DoubleTy;
if (&format == &llvm::APFloat::IEEEquad())
llvm_unreachable("IEEEquad float format is NYI");
if (&format == &llvm::APFloat::PPCDoubleDouble())
Expand Down Expand Up @@ -491,9 +492,9 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
}

bool isSized(mlir::Type ty) {
if (ty.isIntOrFloat() ||
ty.isa<mlir::cir::PointerType, mlir::cir::StructType,
mlir::cir::ArrayType, mlir::cir::BoolType, mlir::cir::IntType>())
if (ty.isa<mlir::cir::PointerType, mlir::cir::StructType,
mlir::cir::ArrayType, mlir::cir::BoolType, mlir::cir::IntType,
mlir::cir::CIRFPTypeInterface>())
return true;
assert(0 && "Unimplemented size for type");
return false;
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,9 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
assert(0 && "not implemented");
else {
mlir::Type ty = CGM.getCIRType(DestType);
return builder.getFloatAttr(ty, Init);
assert(ty.isa<mlir::cir::CIRFPTypeInterface>() &&
"expected floating-point type");
return CGM.getBuilder().getAttr<mlir::cir::FPAttr>(ty, Init);
}
}
case APValue::Array: {
Expand Down
14 changes: 8 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
}
mlir::Value VisitFloatingLiteral(const FloatingLiteral *E) {
mlir::Type Ty = CGF.getCIRType(E->getType());
assert(Ty.isa<mlir::cir::CIRFPTypeInterface>() &&
"expect floating-point type");
return Builder.create<mlir::cir::ConstantOp>(
CGF.getLoc(E->getExprLoc()), Ty,
Builder.getFloatAttr(Ty, E->getValue()));
Builder.getAttr<mlir::cir::FPAttr>(Ty, E->getValue()));
}
mlir::Value VisitCharacterLiteral(const CharacterLiteral *E) {
mlir::Type Ty = CGF.getCIRType(E->getType());
Expand Down Expand Up @@ -1227,7 +1229,7 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
llvm_unreachable("NYI");

assert(!UnimplementedFeature::cirVectorType());
if (Ops.LHS.getType().isa<mlir::FloatType>()) {
if (Ops.LHS.getType().isa<mlir::cir::CIRFPTypeInterface>()) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFSub(Ops.LHS, Ops.RHS);
}
Expand Down Expand Up @@ -1701,20 +1703,20 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
llvm_unreachable("NYI: signed bool");
if (CGF.getBuilder().isInt(DstTy)) {
CastKind = mlir::cir::CastKind::bool_to_int;
} else if (DstTy.isa<mlir::FloatType>()) {
} else if (DstTy.isa<mlir::cir::CIRFPTypeInterface>()) {
CastKind = mlir::cir::CastKind::bool_to_float;
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
} else if (CGF.getBuilder().isInt(SrcTy)) {
if (CGF.getBuilder().isInt(DstTy)) {
CastKind = mlir::cir::CastKind::integral;
} else if (DstTy.isa<mlir::FloatType>()) {
} else if (DstTy.isa<mlir::cir::CIRFPTypeInterface>()) {
CastKind = mlir::cir::CastKind::int_to_float;
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
} else if (SrcTy.isa<mlir::FloatType>()) {
} else if (SrcTy.isa<mlir::cir::CIRFPTypeInterface>()) {
if (CGF.getBuilder().isInt(DstTy)) {
// If we can't recognize overflow as undefined behavior, assume that
// overflow saturates. This protects against normal optimizations if we
Expand All @@ -1724,7 +1726,7 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
if (Builder.getIsFPConstrained())
llvm_unreachable("NYI");
CastKind = mlir::cir::CastKind::float_to_int;
} else if (DstTy.isa<mlir::FloatType>()) {
} else if (DstTy.isa<mlir::cir::CIRFPTypeInterface>()) {
// TODO: split this to createFPExt/createFPTrunc
return Builder.createFloatingCast(Src, DstTy);
} else {
Expand Down
5 changes: 2 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,10 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,

// TODO: HalfTy
// TODO: BFloatTy
FloatTy = builder.getF32Type();
DoubleTy = builder.getF64Type();
FloatTy = ::mlir::cir::SingleType::get(builder.getContext());
DoubleTy = ::mlir::cir::DoubleType::get(builder.getContext());
// TODO(cir): perhaps we should abstract long double variations into a custom
// cir.long_double type. Said type would also hold the semantics for lowering.
LongDouble80BitsTy = builder.getF80Type();

// TODO: PointerWidthInBits
PointerAlignInBytes =
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenTypeCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ struct CIRGenTypeCache {
// mlir::Type HalfTy, BFloatTy;
// TODO(cir): perhaps we should abstract long double variations into a custom
// cir.long_double type. Said type would also hold the semantics for lowering.
mlir::FloatType FloatTy, DoubleTy, LongDouble80BitsTy;
mlir::cir::SingleType FloatTy;
mlir::cir::DoubleType DoubleTy;

/// int
mlir::Type UIntTy;
Expand Down
Loading
Loading