Skip to content

Commit

Permalink
Remove mlir::cir::FloatType and replace it with an interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancern committed Feb 19, 2024
1 parent fa5395b commit 80aee20
Show file tree
Hide file tree
Showing 16 changed files with 158 additions and 112 deletions.
2 changes: 1 addition & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
}]>,
];
let extraClassDeclaration = [{
static FPAttr getZero(mlir::cir::FloatType type);
static FPAttr getZero(mlir::Type type);
}];
let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
Expand Down
38 changes: 1 addition & 37 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,10 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "clang/CIR/Interfaces/FPTypeInterface.h"

#include "clang/CIR/Interfaces/ASTAttrInterfaces.h"

//===----------------------------------------------------------------------===//
// CIR FloatType
//
// The base type for all floating-point types.
//===----------------------------------------------------------------------===//

namespace mlir {
namespace cir {

class SingleType;
class DoubleType;

class FloatType : public Type {
public:
using Type::Type;

// Convenience factories.
static SingleType getSingle(MLIRContext *ctx);
static DoubleType getDouble(MLIRContext *ctx);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);

/// Return the bitwidth of this float type.
unsigned getWidth() const;

/// Return the width of the mantissa of this type.
/// The width includes the integer bit.
unsigned getFPMantissaWidth() const;

/// Return the float semantics of this floating-point type.
const llvm::fltSemantics &getFloatSemantics() const;
};

} // namespace cir
} // namespace mlir

//===----------------------------------------------------------------------===//
// CIR Dialect Tablegen'd Types
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 6 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

include "clang/CIR/Dialect/IR/CIRDialect.td"
include "clang/CIR/Interfaces/ASTAttrInterfaces.td"
include "clang/CIR/Interfaces/FPTypeInterface.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"
Expand Down Expand Up @@ -102,8 +103,10 @@ def SInt64 : SInt<64>;

class CIR_FloatType<string name, string mnemonic>
: CIR_Type<name, mnemonic,
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>],
"::mlir::cir::FloatType"> {}
[
DeclareTypeInterfaceMethods<DataLayoutTypeInterface>,
DeclareTypeInterfaceMethods<FPTypeInterface>,
]> {}

def CIR_Single : CIR_FloatType<"Single", "float"> {
let summary = "CIR single-precision float type";
Expand All @@ -123,7 +126,7 @@ def CIR_Double : CIR_FloatType<"Double", "double"> {

// Constraints

def CIR_AnyFloat: Type<CPred<"$_self.isa<::mlir::cir::FloatType>()">>;
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double]>;

//===----------------------------------------------------------------------===//
// PointerType
Expand Down
9 changes: 9 additions & 0 deletions clang/include/clang/CIR/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ function(add_clang_mlir_op_interface interface)
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
endfunction()

function(add_clang_mlir_type_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRCIR${interface}IncGen)
add_dependencies(mlir-generic-headers MLIRCIR${interface}IncGen)
endfunction()

add_clang_mlir_attr_interface(ASTAttrInterfaces)
add_clang_mlir_op_interface(CIROpInterfaces)
add_clang_mlir_op_interface(CIRLoopOpInterface)
add_clang_mlir_type_interface(FPTypeInterface)
22 changes: 22 additions & 0 deletions clang/include/clang/CIR/Interfaces/FPTypeInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- FPTypeInterface.h - Interface for CIR FP types ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
//
// Defines the interface to generically handle CIR floating-point types.
//
//===----------------------------------------------------------------------===//

#ifndef CLANG_INTERFACES_CIR_FPTYPEINTERFACE_H
#define CLANG_INTERFACES_CIR_FPTYPEINTERFACE_H

#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"

/// Include the tablegen'd interface declarations.
#include "clang/CIR/Interfaces/FPTypeInterface.h.inc"

#endif // CLANG_INTERFACES_CIR_FPTYPEINTERFACE_H
52 changes: 52 additions & 0 deletions clang/include/clang/CIR/Interfaces/FPTypeInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===- FPTypeInterface.td - CIR FP Interface Definitions --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CIR_INTERFACES_FP_TYPE_INTERFACE
#define MLIR_CIR_INTERFACES_FP_TYPE_INTERFACE

include "mlir/IR/OpBase.td"

def FPTypeInterface : TypeInterface<"FPTypeInterface"> {
let description = [{
Contains helper functions to query properties about a floating-point type.
}];
let cppNamespace = "::mlir::cir";

let methods = [
InterfaceMethod<[{
Returns the bit width of this floating-point type.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getWidth",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::APFloat::semanticsSizeInBits($_type.getFloatSemantics());
}]
>,
InterfaceMethod<[{
Return the mantissa width.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getFPMantissaWidth",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::APFloat::semanticsPrecision($_type.getFloatSemantics());
}]
>,
InterfaceMethod<[{
Return the float semantics of this floating-point type.
}],
/*retTy=*/"const llvm::fltSemantics &",
/*methodName=*/"getFloatSemantics"
>,
];
}

#endif // MLIR_CIR_INTERFACES_FP_TYPE_INTERFACE
14 changes: 7 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
if (ty.isa<mlir::cir::IntType>())
return mlir::cir::IntAttr::get(ty, 0);
if (auto fltType = ty.dyn_cast<mlir::cir::FloatType>())
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
return getZeroAttr(arrTy);
Expand Down Expand Up @@ -343,13 +345,11 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
}
bool isInt(mlir::Type i) { return i.isa<mlir::cir::IntType>(); }

mlir::cir::FloatType getLongDouble80BitsTy() const {
llvm_unreachable("NYI");
}
mlir::Type getLongDouble80BitsTy() const { llvm_unreachable("NYI"); }

/// Get the proper floating point type for the given semantics.
mlir::cir::FloatType getFloatTyForFormat(const llvm::fltSemantics &format,
bool useNativeHalf) const {
mlir::Type getFloatTyForFormat(const llvm::fltSemantics &format,
bool useNativeHalf) const {
if (&format == &llvm::APFloat::IEEEhalf()) {
llvm_unreachable("IEEEhalf float format is NYI");
}
Expand Down Expand Up @@ -472,7 +472,7 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
bool isSized(mlir::Type ty) {
if (ty.isa<mlir::cir::PointerType, mlir::cir::StructType,
mlir::cir::ArrayType, mlir::cir::BoolType, mlir::cir::IntType,
mlir::cir::FloatType>())
mlir::cir::FPTypeInterface>())
return true;
assert(0 && "Unimplemented size for type");
return false;
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,8 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
assert(0 && "not implemented");
else {
mlir::Type ty = CGM.getCIRType(DestType);
assert(ty.isa<mlir::cir::FloatType>() && "expected floating-point type");
assert(ty.isa<mlir::cir::FPTypeInterface>() &&
"expected floating-point type");
return CGM.getBuilder().getAttr<mlir::cir::FPAttr>(ty, Init);
}
}
Expand Down
13 changes: 7 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
}
mlir::Value VisitFloatingLiteral(const FloatingLiteral *E) {
mlir::Type Ty = CGF.getCIRType(E->getType());
assert(Ty.isa<mlir::cir::FloatType>() && "expect floating-point type");
assert(Ty.isa<mlir::cir::FPTypeInterface>() &&
"expect floating-point type");
return Builder.create<mlir::cir::ConstantOp>(
CGF.getLoc(E->getExprLoc()), Ty,
Builder.getAttr<mlir::cir::FPAttr>(Ty, E->getValue()));
Expand Down Expand Up @@ -1202,7 +1203,7 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
llvm_unreachable("NYI");

assert(!UnimplementedFeature::cirVectorType());
if (Ops.LHS.getType().isa<mlir::cir::FloatType>()) {
if (Ops.LHS.getType().isa<mlir::cir::FPTypeInterface>()) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFSub(Ops.LHS, Ops.RHS);
}
Expand Down Expand Up @@ -1670,20 +1671,20 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
llvm_unreachable("NYI: signed bool");
if (CGF.getBuilder().isInt(DstTy)) {
CastKind = mlir::cir::CastKind::bool_to_int;
} else if (DstTy.isa<mlir::cir::FloatType>()) {
} else if (DstTy.isa<mlir::cir::FPTypeInterface>()) {
CastKind = mlir::cir::CastKind::bool_to_float;
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
} else if (CGF.getBuilder().isInt(SrcTy)) {
if (CGF.getBuilder().isInt(DstTy)) {
CastKind = mlir::cir::CastKind::integral;
} else if (DstTy.isa<mlir::cir::FloatType>()) {
} else if (DstTy.isa<mlir::cir::FPTypeInterface>()) {
CastKind = mlir::cir::CastKind::int_to_float;
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
} else if (SrcTy.isa<mlir::cir::FloatType>()) {
} else if (SrcTy.isa<mlir::cir::FPTypeInterface>()) {
if (CGF.getBuilder().isInt(DstTy)) {
// If we can't recognize overflow as undefined behavior, assume that
// overflow saturates. This protects against normal optimizations if we
Expand All @@ -1693,7 +1694,7 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
if (Builder.getIsFPConstrained())
llvm_unreachable("NYI");
CastKind = mlir::cir::CastKind::float_to_int;
} else if (DstTy.isa<mlir::cir::FloatType>()) {
} else if (DstTy.isa<mlir::cir::FPTypeInterface>()) {
// TODO: split this to createFPExt/createFPTrunc
return Builder.createFloatingCast(Src, DstTy);
} else {
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,

// TODO: HalfTy
// TODO: BFloatTy
FloatTy = ::mlir::cir::FloatType::getSingle(builder.getContext());
DoubleTy = ::mlir::cir::FloatType::getDouble(builder.getContext());
FloatTy = ::mlir::cir::SingleType::get(builder.getContext());
DoubleTy = ::mlir::cir::DoubleType::get(builder.getContext());
// TODO(cir): perhaps we should abstract long double variations into a custom
// cir.long_double type. Said type would also hold the semantics for lowering.

Expand Down
21 changes: 11 additions & 10 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Attribute cir::FPAttr::parse(AsmParser &parser, Type odsType) {
double value;

if (!odsType.isa<cir::FloatType>())
auto odsTypeFpInterface = odsType.dyn_cast<cir::FPTypeInterface>();
if (!odsTypeFpInterface)
return {};
auto ty = odsType.cast<cir::FloatType>();

if (parser.parseLess())
return {};
Expand All @@ -318,28 +318,29 @@ Attribute cir::FPAttr::parse(AsmParser &parser, Type odsType) {

auto losesInfo = false;
APFloat convertedValue{value};
convertedValue.convert(ty.getFloatSemantics(), llvm::RoundingMode::TowardZero,
&losesInfo);
convertedValue.convert(odsTypeFpInterface.getFloatSemantics(),
llvm::RoundingMode::TowardZero, &losesInfo);

return cir::FPAttr::get(ty, convertedValue);
return cir::FPAttr::get(odsType, convertedValue);
}

void cir::FPAttr::print(AsmPrinter &printer) const {
printer << '<' << getValue() << '>';
}

cir::FPAttr cir::FPAttr::getZero(mlir::cir::FloatType type) {
return get(type, APFloat::getZero(type.getFloatSemantics()));
cir::FPAttr cir::FPAttr::getZero(mlir::Type type) {
return get(type, APFloat::getZero(
type.cast<cir::FPTypeInterface>().getFloatSemantics()));
}

LogicalResult cir::FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, APFloat value) {
auto fltType = type.dyn_cast<cir::FloatType>();
if (!fltType) {
auto fltTypeInterface = type.dyn_cast<cir::FPTypeInterface>();
if (!fltTypeInterface) {
emitError() << "expected floating-point type";
return failure();
}
if (APFloat::SemanticsToEnum(fltType.getFloatSemantics()) !=
if (APFloat::SemanticsToEnum(fltTypeInterface.getFloatSemantics()) !=
APFloat::SemanticsToEnum(value.getSemantics())) {
emitError() << "floating-point semantics mismatch";
return failure();
Expand Down
12 changes: 6 additions & 6 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,13 @@ LogicalResult CastOp::verify() {
return success();
}
case cir::CastKind::floating: {
if (!srcType.isa<mlir::cir::FloatType>() ||
!resType.isa<mlir::cir::FloatType>())
if (!srcType.isa<mlir::cir::FPTypeInterface>() ||
!resType.isa<mlir::cir::FPTypeInterface>())
return emitOpError() << "requries floating for source and result";
return success();
}
case cir::CastKind::float_to_int: {
if (!srcType.isa<mlir::cir::FloatType>())
if (!srcType.isa<mlir::cir::FPTypeInterface>())
return emitOpError() << "requires floating for source";
if (!resType.dyn_cast<mlir::cir::IntType>())
return emitOpError() << "requires !IntegerType for result";
Expand All @@ -450,7 +450,7 @@ LogicalResult CastOp::verify() {
return success();
}
case cir::CastKind::float_to_bool: {
if (!srcType.isa<mlir::cir::FloatType>())
if (!srcType.isa<mlir::cir::FPTypeInterface>())
return emitOpError() << "requires float for source";
if (!resType.isa<mlir::cir::BoolType>())
return emitOpError() << "requires !cir.bool for result";
Expand All @@ -466,14 +466,14 @@ LogicalResult CastOp::verify() {
case cir::CastKind::int_to_float: {
if (!srcType.isa<mlir::cir::IntType>())
return emitOpError() << "requires !cir.int for source";
if (!resType.isa<mlir::cir::FloatType>())
if (!resType.isa<mlir::cir::FPTypeInterface>())
return emitOpError() << "requires !cir.float for result";
return success();
}
case cir::CastKind::bool_to_float: {
if (!srcType.isa<mlir::cir::BoolType>())
return emitOpError() << "requires !cir.bool for source";
if (!resType.isa<mlir::cir::FloatType>())
if (!resType.isa<mlir::cir::FPTypeInterface>())
return emitOpError() << "requires !cir.float for result";
return success();
}
Expand Down
Loading

0 comments on commit 80aee20

Please sign in to comment.