Skip to content

Commit 80aee20

Browse files
committed
Remove mlir::cir::FloatType and replace it with an interface
1 parent fa5395b commit 80aee20

File tree

16 files changed

+158
-112
lines changed

16 files changed

+158
-112
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
220220
}]>,
221221
];
222222
let extraClassDeclaration = [{
223-
static FPAttr getZero(mlir::cir::FloatType type);
223+
static FPAttr getZero(mlir::Type type);
224224
}];
225225
let genVerifyDecl = 1;
226226
let hasCustomAssemblyFormat = 1;

clang/include/clang/CIR/Dialect/IR/CIRTypes.h

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,10 @@
1616
#include "mlir/IR/BuiltinAttributes.h"
1717
#include "mlir/IR/Types.h"
1818
#include "mlir/Interfaces/DataLayoutInterfaces.h"
19+
#include "clang/CIR/Interfaces/FPTypeInterface.h"
1920

2021
#include "clang/CIR/Interfaces/ASTAttrInterfaces.h"
2122

22-
//===----------------------------------------------------------------------===//
23-
// CIR FloatType
24-
//
25-
// The base type for all floating-point types.
26-
//===----------------------------------------------------------------------===//
27-
28-
namespace mlir {
29-
namespace cir {
30-
31-
class SingleType;
32-
class DoubleType;
33-
34-
class FloatType : public Type {
35-
public:
36-
using Type::Type;
37-
38-
// Convenience factories.
39-
static SingleType getSingle(MLIRContext *ctx);
40-
static DoubleType getDouble(MLIRContext *ctx);
41-
42-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
43-
static bool classof(Type type);
44-
45-
/// Return the bitwidth of this float type.
46-
unsigned getWidth() const;
47-
48-
/// Return the width of the mantissa of this type.
49-
/// The width includes the integer bit.
50-
unsigned getFPMantissaWidth() const;
51-
52-
/// Return the float semantics of this floating-point type.
53-
const llvm::fltSemantics &getFloatSemantics() const;
54-
};
55-
56-
} // namespace cir
57-
} // namespace mlir
58-
5923
//===----------------------------------------------------------------------===//
6024
// CIR Dialect Tablegen'd Types
6125
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
include "clang/CIR/Dialect/IR/CIRDialect.td"
1717
include "clang/CIR/Interfaces/ASTAttrInterfaces.td"
18+
include "clang/CIR/Interfaces/FPTypeInterface.td"
1819
include "mlir/Interfaces/DataLayoutInterfaces.td"
1920
include "mlir/IR/AttrTypeBase.td"
2021
include "mlir/IR/EnumAttr.td"
@@ -102,8 +103,10 @@ def SInt64 : SInt<64>;
102103

103104
class CIR_FloatType<string name, string mnemonic>
104105
: CIR_Type<name, mnemonic,
105-
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>],
106-
"::mlir::cir::FloatType"> {}
106+
[
107+
DeclareTypeInterfaceMethods<DataLayoutTypeInterface>,
108+
DeclareTypeInterfaceMethods<FPTypeInterface>,
109+
]> {}
107110

108111
def CIR_Single : CIR_FloatType<"Single", "float"> {
109112
let summary = "CIR single-precision float type";
@@ -123,7 +126,7 @@ def CIR_Double : CIR_FloatType<"Double", "double"> {
123126

124127
// Constraints
125128

126-
def CIR_AnyFloat: Type<CPred<"$_self.isa<::mlir::cir::FloatType>()">>;
129+
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double]>;
127130

128131
//===----------------------------------------------------------------------===//
129132
// PointerType

clang/include/clang/CIR/Interfaces/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ function(add_clang_mlir_op_interface interface)
2020
add_dependencies(mlir-generic-headers MLIR${interface}IncGen)
2121
endfunction()
2222

23+
function(add_clang_mlir_type_interface interface)
24+
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
25+
mlir_tablegen(${interface}.h.inc -gen-type-interface-decls)
26+
mlir_tablegen(${interface}.cpp.inc -gen-type-interface-defs)
27+
add_public_tablegen_target(MLIRCIR${interface}IncGen)
28+
add_dependencies(mlir-generic-headers MLIRCIR${interface}IncGen)
29+
endfunction()
30+
2331
add_clang_mlir_attr_interface(ASTAttrInterfaces)
2432
add_clang_mlir_op_interface(CIROpInterfaces)
2533
add_clang_mlir_op_interface(CIRLoopOpInterface)
34+
add_clang_mlir_type_interface(FPTypeInterface)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- FPTypeInterface.h - Interface for CIR FP types ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
//
9+
// Defines the interface to generically handle CIR floating-point types.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef CLANG_INTERFACES_CIR_FPTYPEINTERFACE_H
14+
#define CLANG_INTERFACES_CIR_FPTYPEINTERFACE_H
15+
16+
#include "mlir/IR/Types.h"
17+
#include "llvm/ADT/APFloat.h"
18+
19+
/// Include the tablegen'd interface declarations.
20+
#include "clang/CIR/Interfaces/FPTypeInterface.h.inc"
21+
22+
#endif // CLANG_INTERFACES_CIR_FPTYPEINTERFACE_H
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//===- FPTypeInterface.td - CIR FP Interface Definitions --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CIR_INTERFACES_FP_TYPE_INTERFACE
10+
#define MLIR_CIR_INTERFACES_FP_TYPE_INTERFACE
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def FPTypeInterface : TypeInterface<"FPTypeInterface"> {
15+
let description = [{
16+
Contains helper functions to query properties about a floating-point type.
17+
}];
18+
let cppNamespace = "::mlir::cir";
19+
20+
let methods = [
21+
InterfaceMethod<[{
22+
Returns the bit width of this floating-point type.
23+
}],
24+
/*retTy=*/"unsigned",
25+
/*methodName=*/"getWidth",
26+
/*args=*/(ins),
27+
/*methodBody=*/"",
28+
/*defaultImplementation=*/[{
29+
return llvm::APFloat::semanticsSizeInBits($_type.getFloatSemantics());
30+
}]
31+
>,
32+
InterfaceMethod<[{
33+
Return the mantissa width.
34+
}],
35+
/*retTy=*/"unsigned",
36+
/*methodName=*/"getFPMantissaWidth",
37+
/*args=*/(ins),
38+
/*methodBody=*/"",
39+
/*defaultImplementation=*/[{
40+
return llvm::APFloat::semanticsPrecision($_type.getFloatSemantics());
41+
}]
42+
>,
43+
InterfaceMethod<[{
44+
Return the float semantics of this floating-point type.
45+
}],
46+
/*retTy=*/"const llvm::fltSemantics &",
47+
/*methodName=*/"getFloatSemantics"
48+
>,
49+
];
50+
}
51+
52+
#endif // MLIR_CIR_INTERFACES_FP_TYPE_INTERFACE

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,9 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
224224
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
225225
if (ty.isa<mlir::cir::IntType>())
226226
return mlir::cir::IntAttr::get(ty, 0);
227-
if (auto fltType = ty.dyn_cast<mlir::cir::FloatType>())
227+
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
228+
return mlir::cir::FPAttr::getZero(fltType);
229+
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
228230
return mlir::cir::FPAttr::getZero(fltType);
229231
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
230232
return getZeroAttr(arrTy);
@@ -343,13 +345,11 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
343345
}
344346
bool isInt(mlir::Type i) { return i.isa<mlir::cir::IntType>(); }
345347

346-
mlir::cir::FloatType getLongDouble80BitsTy() const {
347-
llvm_unreachable("NYI");
348-
}
348+
mlir::Type getLongDouble80BitsTy() const { llvm_unreachable("NYI"); }
349349

350350
/// Get the proper floating point type for the given semantics.
351-
mlir::cir::FloatType getFloatTyForFormat(const llvm::fltSemantics &format,
352-
bool useNativeHalf) const {
351+
mlir::Type getFloatTyForFormat(const llvm::fltSemantics &format,
352+
bool useNativeHalf) const {
353353
if (&format == &llvm::APFloat::IEEEhalf()) {
354354
llvm_unreachable("IEEEhalf float format is NYI");
355355
}
@@ -472,7 +472,7 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
472472
bool isSized(mlir::Type ty) {
473473
if (ty.isa<mlir::cir::PointerType, mlir::cir::StructType,
474474
mlir::cir::ArrayType, mlir::cir::BoolType, mlir::cir::IntType,
475-
mlir::cir::FloatType>())
475+
mlir::cir::FPTypeInterface>())
476476
return true;
477477
assert(0 && "Unimplemented size for type");
478478
return false;

clang/lib/CIR/CodeGen/CIRGenExprConst.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1479,7 +1479,8 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
14791479
assert(0 && "not implemented");
14801480
else {
14811481
mlir::Type ty = CGM.getCIRType(DestType);
1482-
assert(ty.isa<mlir::cir::FloatType>() && "expected floating-point type");
1482+
assert(ty.isa<mlir::cir::FPTypeInterface>() &&
1483+
"expected floating-point type");
14831484
return CGM.getBuilder().getAttr<mlir::cir::FPAttr>(ty, Init);
14841485
}
14851486
}

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
165165
}
166166
mlir::Value VisitFloatingLiteral(const FloatingLiteral *E) {
167167
mlir::Type Ty = CGF.getCIRType(E->getType());
168-
assert(Ty.isa<mlir::cir::FloatType>() && "expect floating-point type");
168+
assert(Ty.isa<mlir::cir::FPTypeInterface>() &&
169+
"expect floating-point type");
169170
return Builder.create<mlir::cir::ConstantOp>(
170171
CGF.getLoc(E->getExprLoc()), Ty,
171172
Builder.getAttr<mlir::cir::FPAttr>(Ty, E->getValue()));
@@ -1202,7 +1203,7 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
12021203
llvm_unreachable("NYI");
12031204

12041205
assert(!UnimplementedFeature::cirVectorType());
1205-
if (Ops.LHS.getType().isa<mlir::cir::FloatType>()) {
1206+
if (Ops.LHS.getType().isa<mlir::cir::FPTypeInterface>()) {
12061207
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
12071208
return Builder.createFSub(Ops.LHS, Ops.RHS);
12081209
}
@@ -1670,20 +1671,20 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
16701671
llvm_unreachable("NYI: signed bool");
16711672
if (CGF.getBuilder().isInt(DstTy)) {
16721673
CastKind = mlir::cir::CastKind::bool_to_int;
1673-
} else if (DstTy.isa<mlir::cir::FloatType>()) {
1674+
} else if (DstTy.isa<mlir::cir::FPTypeInterface>()) {
16741675
CastKind = mlir::cir::CastKind::bool_to_float;
16751676
} else {
16761677
llvm_unreachable("Internal error: Cast to unexpected type");
16771678
}
16781679
} else if (CGF.getBuilder().isInt(SrcTy)) {
16791680
if (CGF.getBuilder().isInt(DstTy)) {
16801681
CastKind = mlir::cir::CastKind::integral;
1681-
} else if (DstTy.isa<mlir::cir::FloatType>()) {
1682+
} else if (DstTy.isa<mlir::cir::FPTypeInterface>()) {
16821683
CastKind = mlir::cir::CastKind::int_to_float;
16831684
} else {
16841685
llvm_unreachable("Internal error: Cast to unexpected type");
16851686
}
1686-
} else if (SrcTy.isa<mlir::cir::FloatType>()) {
1687+
} else if (SrcTy.isa<mlir::cir::FPTypeInterface>()) {
16871688
if (CGF.getBuilder().isInt(DstTy)) {
16881689
// If we can't recognize overflow as undefined behavior, assume that
16891690
// overflow saturates. This protects against normal optimizations if we
@@ -1693,7 +1694,7 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
16931694
if (Builder.getIsFPConstrained())
16941695
llvm_unreachable("NYI");
16951696
CastKind = mlir::cir::CastKind::float_to_int;
1696-
} else if (DstTy.isa<mlir::cir::FloatType>()) {
1697+
} else if (DstTy.isa<mlir::cir::FPTypeInterface>()) {
16971698
// TODO: split this to createFPExt/createFPTrunc
16981699
return Builder.createFloatingCast(Src, DstTy);
16991700
} else {

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &context,
133133

134134
// TODO: HalfTy
135135
// TODO: BFloatTy
136-
FloatTy = ::mlir::cir::FloatType::getSingle(builder.getContext());
137-
DoubleTy = ::mlir::cir::FloatType::getDouble(builder.getContext());
136+
FloatTy = ::mlir::cir::SingleType::get(builder.getContext());
137+
DoubleTy = ::mlir::cir::DoubleType::get(builder.getContext());
138138
// TODO(cir): perhaps we should abstract long double variations into a custom
139139
// cir.long_double type. Said type would also hold the semantics for lowering.
140140

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,9 @@ LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
302302
Attribute cir::FPAttr::parse(AsmParser &parser, Type odsType) {
303303
double value;
304304

305-
if (!odsType.isa<cir::FloatType>())
305+
auto odsTypeFpInterface = odsType.dyn_cast<cir::FPTypeInterface>();
306+
if (!odsTypeFpInterface)
306307
return {};
307-
auto ty = odsType.cast<cir::FloatType>();
308308

309309
if (parser.parseLess())
310310
return {};
@@ -318,28 +318,29 @@ Attribute cir::FPAttr::parse(AsmParser &parser, Type odsType) {
318318

319319
auto losesInfo = false;
320320
APFloat convertedValue{value};
321-
convertedValue.convert(ty.getFloatSemantics(), llvm::RoundingMode::TowardZero,
322-
&losesInfo);
321+
convertedValue.convert(odsTypeFpInterface.getFloatSemantics(),
322+
llvm::RoundingMode::TowardZero, &losesInfo);
323323

324-
return cir::FPAttr::get(ty, convertedValue);
324+
return cir::FPAttr::get(odsType, convertedValue);
325325
}
326326

327327
void cir::FPAttr::print(AsmPrinter &printer) const {
328328
printer << '<' << getValue() << '>';
329329
}
330330

331-
cir::FPAttr cir::FPAttr::getZero(mlir::cir::FloatType type) {
332-
return get(type, APFloat::getZero(type.getFloatSemantics()));
331+
cir::FPAttr cir::FPAttr::getZero(mlir::Type type) {
332+
return get(type, APFloat::getZero(
333+
type.cast<cir::FPTypeInterface>().getFloatSemantics()));
333334
}
334335

335336
LogicalResult cir::FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
336337
Type type, APFloat value) {
337-
auto fltType = type.dyn_cast<cir::FloatType>();
338-
if (!fltType) {
338+
auto fltTypeInterface = type.dyn_cast<cir::FPTypeInterface>();
339+
if (!fltTypeInterface) {
339340
emitError() << "expected floating-point type";
340341
return failure();
341342
}
342-
if (APFloat::SemanticsToEnum(fltType.getFloatSemantics()) !=
343+
if (APFloat::SemanticsToEnum(fltTypeInterface.getFloatSemantics()) !=
343344
APFloat::SemanticsToEnum(value.getSemantics())) {
344345
emitError() << "floating-point semantics mismatch";
345346
return failure();

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -423,13 +423,13 @@ LogicalResult CastOp::verify() {
423423
return success();
424424
}
425425
case cir::CastKind::floating: {
426-
if (!srcType.isa<mlir::cir::FloatType>() ||
427-
!resType.isa<mlir::cir::FloatType>())
426+
if (!srcType.isa<mlir::cir::FPTypeInterface>() ||
427+
!resType.isa<mlir::cir::FPTypeInterface>())
428428
return emitOpError() << "requries floating for source and result";
429429
return success();
430430
}
431431
case cir::CastKind::float_to_int: {
432-
if (!srcType.isa<mlir::cir::FloatType>())
432+
if (!srcType.isa<mlir::cir::FPTypeInterface>())
433433
return emitOpError() << "requires floating for source";
434434
if (!resType.dyn_cast<mlir::cir::IntType>())
435435
return emitOpError() << "requires !IntegerType for result";
@@ -450,7 +450,7 @@ LogicalResult CastOp::verify() {
450450
return success();
451451
}
452452
case cir::CastKind::float_to_bool: {
453-
if (!srcType.isa<mlir::cir::FloatType>())
453+
if (!srcType.isa<mlir::cir::FPTypeInterface>())
454454
return emitOpError() << "requires float for source";
455455
if (!resType.isa<mlir::cir::BoolType>())
456456
return emitOpError() << "requires !cir.bool for result";
@@ -466,14 +466,14 @@ LogicalResult CastOp::verify() {
466466
case cir::CastKind::int_to_float: {
467467
if (!srcType.isa<mlir::cir::IntType>())
468468
return emitOpError() << "requires !cir.int for source";
469-
if (!resType.isa<mlir::cir::FloatType>())
469+
if (!resType.isa<mlir::cir::FPTypeInterface>())
470470
return emitOpError() << "requires !cir.float for result";
471471
return success();
472472
}
473473
case cir::CastKind::bool_to_float: {
474474
if (!srcType.isa<mlir::cir::BoolType>())
475475
return emitOpError() << "requires !cir.bool for source";
476-
if (!resType.isa<mlir::cir::FloatType>())
476+
if (!resType.isa<mlir::cir::FPTypeInterface>())
477477
return emitOpError() << "requires !cir.float for result";
478478
return success();
479479
}

0 commit comments

Comments
 (0)