Skip to content

Commit

Permalink
Revert removing builtin floating point support
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancern committed Jan 25, 2024
1 parent 70223d6 commit f582e0c
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 78 deletions.
3 changes: 2 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def CIR_Double : CIR_FloatType<"Double", "double"> {

// Constraints

def CIR_AnyFloat: Type<CPred<"$_self.isa<::mlir::cir::FloatType>()">>;
def CIR_AnyFloat: Type<
CPred<"$_self.isa<::mlir::FloatType, ::mlir::cir::FloatType>()">>;

//===----------------------------------------------------------------------===//
// PointerType
Expand Down
16 changes: 12 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
if (ty.isa<mlir::cir::IntType>())
return mlir::cir::IntAttr::get(ty, 0);
if (ty.isa<mlir::FloatType>())
return mlir::FloatAttr::get(ty, 0.0);
if (auto fltType = ty.dyn_cast<mlir::cir::FloatType>())
return mlir::cir::FloatAttr::getZero(fltType);
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
Expand All @@ -250,12 +252,18 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
if (const auto intVal = attr.dyn_cast<mlir::cir::IntAttr>())
return intVal.isNullValue();

if (const auto fpVal = attr.dyn_cast<mlir::cir::FloatAttr>()) {
if (attr.isa<mlir::FloatAttr, mlir::cir::FloatAttr>()) {
auto fpVal = [&attr] {
if (auto fpAttr = attr.dyn_cast<mlir::cir::FloatAttr>())
return fpAttr.getValue();
return attr.cast<mlir::FloatAttr>().getValue();
}();

bool ignored;
llvm::APFloat FV(+0.0);
FV.convert(fpVal.getValue().getSemantics(),
llvm::APFloat::rmNearestTiesToEven, &ignored);
return FV.bitwiseIsEqual(fpVal.getValue());
FV.convert(fpVal.getSemantics(), llvm::APFloat::rmNearestTiesToEven,
&ignored);
return FV.bitwiseIsEqual(fpVal);
}

if (const auto structVal = attr.dyn_cast<mlir::cir::ConstStructAttr>()) {
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,9 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
assert(0 && "not implemented");
else {
mlir::Type ty = CGM.getCIRType(DestType);
return CGM.getBuilder().getAttr<mlir::cir::FloatAttr>(ty, Init);
if (ty.isa<mlir::cir::FloatType>())
return CGM.getBuilder().getAttr<mlir::cir::FloatAttr>(ty, Init);
return builder.getFloatAttr(ty, Init);
}
}
case APValue::Array: {
Expand Down
16 changes: 10 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
}
mlir::Value VisitFloatingLiteral(const FloatingLiteral *E) {
mlir::Type Ty = CGF.getCIRType(E->getType());
if (Ty.isa<mlir::cir::FloatType>())
return Builder.create<mlir::cir::ConstantOp>(
CGF.getLoc(E->getExprLoc()), Ty,
Builder.getAttr<mlir::cir::FloatAttr>(Ty, E->getValue()));
return Builder.create<mlir::cir::ConstantOp>(
CGF.getLoc(E->getExprLoc()), Ty,
Builder.getAttr<mlir::cir::FloatAttr>(Ty, E->getValue()));
Builder.getFloatAttr(Ty, E->getValue()));
}
mlir::Value VisitCharacterLiteral(const CharacterLiteral *E) {
mlir::Type Ty = CGF.getCIRType(E->getType());
Expand Down Expand Up @@ -1200,7 +1204,7 @@ mlir::Value ScalarExprEmitter::buildSub(const BinOpInfo &Ops) {
llvm_unreachable("NYI");

assert(!UnimplementedFeature::cirVectorType());
if (Ops.LHS.getType().isa<mlir::cir::FloatType>()) {
if (Ops.LHS.getType().isa<mlir::FloatType, mlir::cir::FloatType>()) {
CIRGenFunction::CIRGenFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return Builder.createFSub(Ops.LHS, Ops.RHS);
}
Expand Down Expand Up @@ -1668,20 +1672,20 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
llvm_unreachable("NYI: signed bool");
if (CGF.getBuilder().isInt(DstTy)) {
CastKind = mlir::cir::CastKind::bool_to_int;
} else if (DstTy.isa<mlir::cir::FloatType>()) {
} else if (DstTy.isa<mlir::FloatType, mlir::cir::FloatType>()) {
CastKind = mlir::cir::CastKind::bool_to_float;
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
} else if (CGF.getBuilder().isInt(SrcTy)) {
if (CGF.getBuilder().isInt(DstTy)) {
CastKind = mlir::cir::CastKind::integral;
} else if (DstTy.isa<mlir::cir::FloatType>()) {
} else if (DstTy.isa<mlir::FloatType, mlir::cir::FloatType>()) {
CastKind = mlir::cir::CastKind::int_to_float;
} else {
llvm_unreachable("Internal error: Cast to unexpected type");
}
} else if (SrcTy.isa<mlir::cir::FloatType>()) {
} else if (SrcTy.isa<mlir::FloatType, mlir::cir::FloatType>()) {
if (CGF.getBuilder().isInt(DstTy)) {
// If we can't recognize overflow as undefined behavior, assume that
// overflow saturates. This protects against normal optimizations if we
Expand All @@ -1691,7 +1695,7 @@ mlir::Value ScalarExprEmitter::buildScalarCast(
if (Builder.getIsFPConstrained())
llvm_unreachable("NYI");
CastKind = mlir::cir::CastKind::float_to_int;
} else if (DstTy.isa<mlir::cir::FloatType>()) {
} else if (DstTy.isa<mlir::FloatType, mlir::cir::FloatType>()) {
// TODO: split this to createFPExt/createFPTrunc
return Builder.createFloatingCast(Src, DstTy);
} else {
Expand Down
15 changes: 8 additions & 7 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
return success();
}

if (attrType.isa<mlir::cir::IntAttr, mlir::cir::FloatAttr>()) {
if (attrType
.isa<mlir::cir::IntAttr, mlir::FloatAttr, mlir::cir::FloatAttr>()) {
auto at = attrType.cast<TypedAttr>();
if (at.getType() != opType) {
return op->emitOpError("result type (")
Expand Down Expand Up @@ -422,13 +423,13 @@ LogicalResult CastOp::verify() {
return success();
}
case cir::CastKind::floating: {
if (!srcType.dyn_cast<mlir::cir::FloatType>() ||
!resType.dyn_cast<mlir::cir::FloatType>())
if (!srcType.isa<mlir::FloatType, mlir::cir::FloatType>() ||
!resType.isa<mlir::FloatType, mlir::cir::FloatType>())
return emitOpError() << "requries floating for source and result";
return success();
}
case cir::CastKind::float_to_int: {
if (!srcType.dyn_cast<mlir::cir::FloatType>())
if (!srcType.isa<mlir::FloatType, mlir::cir::FloatType>())
return emitOpError() << "requires floating for source";
if (!resType.dyn_cast<mlir::cir::IntType>())
return emitOpError() << "requires !IntegerType for result";
Expand All @@ -449,7 +450,7 @@ LogicalResult CastOp::verify() {
return success();
}
case cir::CastKind::float_to_bool: {
if (!srcType.isa<mlir::cir::FloatType>())
if (!srcType.isa<mlir::FloatType, mlir::cir::FloatType>())
return emitOpError() << "requires float for source";
if (!resType.isa<mlir::cir::BoolType>())
return emitOpError() << "requires !cir.bool for result";
Expand All @@ -465,14 +466,14 @@ LogicalResult CastOp::verify() {
case cir::CastKind::int_to_float: {
if (!srcType.isa<mlir::cir::IntType>())
return emitOpError() << "requires !cir.int for source";
if (!resType.isa<mlir::cir::FloatType>())
if (!resType.isa<mlir::FloatType, mlir::cir::FloatType>())
return emitOpError() << "requires !cir.float for result";
return success();
}
case cir::CastKind::bool_to_float: {
if (!srcType.isa<mlir::cir::BoolType>())
return emitOpError() << "requires !cir.bool for source";
if (!resType.isa<mlir::cir::FloatType>())
if (!resType.isa<mlir::FloatType, mlir::cir::FloatType>())
return emitOpError() << "requires !cir.float for result";
return success();
}
Expand Down
72 changes: 47 additions & 25 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::cir::ConstPtrAttr ptrAttr,
}

/// FloatAttr visitor.
inline mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::FloatAttr fltAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
}

inline mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::cir::FloatAttr fltAttr,
mlir::ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -304,6 +313,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
const mlir::TypeConverter *converter) {
if (const auto intAttr = attr.dyn_cast<mlir::cir::IntAttr>())
return lowerCirAttrAsValue(parentOp, intAttr, rewriter, converter);
if (const auto fltAttr = attr.dyn_cast<mlir::FloatAttr>())
return lowerCirAttrAsValue(parentOp, fltAttr, rewriter, converter);
if (const auto fltAttr = attr.dyn_cast<mlir::cir::FloatAttr>())
return lowerCirAttrAsValue(parentOp, fltAttr, rewriter, converter);
if (const auto ptrAttr = attr.dyn_cast<mlir::cir::ConstPtrAttr>())
Expand Down Expand Up @@ -573,23 +584,33 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
break;
}
case mlir::cir::CastKind::floating: {
auto dstTy = castOp.getResult().getType().cast<mlir::cir::FloatType>();
auto dstTy = castOp.getResult().getType();
auto srcTy = castOp.getSrc().getType();
auto llvmSrcVal = adaptor.getOperands().front();
auto llvmDstTy =
getTypeConverter()->convertType(dstTy).cast<mlir::FloatType>();

if (auto fpSrcTy = srcTy.dyn_cast<mlir::cir::FloatType>()) {
if (fpSrcTy.getWidth() > dstTy.getWidth())
rewriter.replaceOpWithNewOp<mlir::LLVM::FPTruncOp>(castOp, llvmDstTy,
llvmSrcVal);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FPExtOp>(castOp, llvmDstTy,
llvmSrcVal);
return mlir::success();
}
if (!dstTy.isa<mlir::FloatType, mlir::cir::FloatType>() ||
!srcTy.isa<mlir::FloatType, mlir::cir::FloatType>())
return castOp.emitError()
<< "NYI cast from " << srcTy << " to " << dstTy;

return castOp.emitError() << "NYI cast from " << srcTy << " to " << dstTy;
auto llvmSrcVal = adaptor.getOperands().front();
auto llvmDstTy = dstTy.dyn_cast<mlir::FloatType>();
if (!llvmDstTy)
llvmDstTy =
getTypeConverter()->convertType(dstTy).cast<mlir::FloatType>();

auto getFloatWidth = [](mlir::Type ty) -> unsigned {
if (auto fltTy = ty.dyn_cast<mlir::FloatType>())
return fltTy.getWidth();
return ty.cast<mlir::cir::FloatType>().getWidth();
};

if (getFloatWidth(srcTy) > getFloatWidth(dstTy))
rewriter.replaceOpWithNewOp<mlir::LLVM::FPTruncOp>(castOp, llvmDstTy,
llvmSrcVal);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FPExtOp>(castOp, llvmDstTy,
llvmSrcVal);
return mlir::success();
}
case mlir::cir::CastKind::int_to_ptr: {
auto dstTy = castOp.getType().cast<mlir::cir::PointerType>();
Expand Down Expand Up @@ -993,13 +1014,12 @@ lowerConstArrayAttr(mlir::cir::ConstArrayAttr constArr,
if (type.isa<mlir::cir::IntType>())
return convertToDenseElementsAttr<mlir::cir::IntAttr, mlir::APInt>(
constArr, dims, converter->convertType(type));
if (auto fltTy = type.dyn_cast<mlir::cir::FloatType>()) {
// TODO(lancern): convert all elements in the constant array to proper
// format.

if (type.isa<mlir::FloatType>())
return convertToDenseElementsAttr<mlir::FloatAttr, mlir::APFloat>(
constArr, dims, converter->convertType(type));
if (type.isa<mlir::cir::FloatType>())
return convertToDenseElementsAttr<mlir::cir::FloatAttr, mlir::APFloat>(
constArr, dims, converter->convertType(type));
}

return std::nullopt;
}
Expand All @@ -1025,10 +1045,9 @@ class CIRConstantLowering
attr = rewriter.getIntegerAttr(
typeConverter->convertType(op.getType()),
op.getValue().cast<mlir::cir::IntAttr>().getValue());
} else if (op.getType().isa<mlir::FloatType>()) {
attr = op.getValue();
} else if (op.getType().isa<mlir::cir::FloatType>()) {
// TODO(cir): ppcfp128 format floating-point type is lowered to the
// llvm.ppc_fp128 type, which is not supported by mlir::FloatAttr. How to
// lower a constant op with a value of such type to LLVMIR?
attr = rewriter.getFloatAttr(
typeConverter->convertType(op.getType()),
op.getValue().cast<mlir::cir::FloatAttr>().getValue());
Expand Down Expand Up @@ -1478,6 +1497,9 @@ class CIRGlobalOpLowering
<< constArr.getElts();
return mlir::failure();
}
} else if (llvm::isa<mlir::FloatAttr>(init.value())) {
// Nothing to do since LLVM already supports these types as
// initializers.
} else if (auto fltAttr = init.value().dyn_cast<mlir::cir::FloatAttr>()) {
// Initializer is a constant floating-point number: convert to MLIR
// builtin constant.
Expand Down Expand Up @@ -1622,7 +1644,7 @@ class CIRUnaryOpLowering
}

// Floating point unary operations: + - ++ --
if (elementType.isa<mlir::cir::FloatType>()) {
if (elementType.isa<mlir::FloatType, mlir::cir::FloatType>()) {
switch (op.getKind()) {
case mlir::cir::UnaryOpKind::Inc: {
assert(!IsVector && "++ not allowed on vector types");
Expand Down Expand Up @@ -1701,7 +1723,7 @@ class CIRBinOpLowering : public mlir::OpConversionPattern<mlir::cir::BinOp> {
assert((op.getLhs().getType() == op.getRhs().getType()) &&
"inconsistent operands' types not supported yet");
mlir::Type type = op.getRhs().getType();
assert((type.isa<mlir::cir::IntType, mlir::cir::FloatType,
assert((type.isa<mlir::cir::IntType, mlir::FloatType, mlir::cir::FloatType,
mlir::cir::VectorType>()) &&
"operand type not supported yet");

Expand Down Expand Up @@ -1922,7 +1944,7 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
auto kind = convertToICmpPredicate(cmpOp.getKind(), /* isSigned=*/false);
llResult = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else if (type.isa<mlir::cir::FloatType>()) {
} else if (type.isa<mlir::FloatType, mlir::cir::FloatType>()) {
auto kind = convertToFCmpPredicate(cmpOp.getKind());
llResult = rewriter.create<mlir::LLVM::FCmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
Expand Down
Loading

0 comments on commit f582e0c

Please sign in to comment.