Skip to content

Commit e3bc1b3

Browse files
committed
[CIR] Support lowering CastOp to arith
This commit introduce CIRCastOpLowering for lowering to arith.
1 parent 97b7280 commit e3bc1b3

File tree

2 files changed

+243
-1
lines changed

2 files changed

+243
-1
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+141-1
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,145 @@ class CIRGetGlobalOpLowering
709709
}
710710
};
711711

712+
class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
713+
public:
714+
using OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern;
715+
716+
inline mlir::Type convertTy(mlir::Type ty) const {
717+
return getTypeConverter()->convertType(ty);
718+
}
719+
720+
/// If the given type is a vector type, return the vector's element type.
721+
/// Otherwise return the given type unchanged.
722+
inline mlir::Type elementTypeIfVector(mlir::Type type) const {
723+
if (auto VecType = type.dyn_cast<mlir::cir::VectorType>()) {
724+
return VecType.getEltType();
725+
}
726+
return type;
727+
}
728+
729+
mlir::LogicalResult
730+
matchAndRewrite(mlir::cir::CastOp op, OpAdaptor adaptor,
731+
mlir::ConversionPatternRewriter &rewriter) const override {
732+
auto src = adaptor.getSrc();
733+
auto dstType = op.getResult().getType();
734+
using CIR = mlir::cir::CastKind;
735+
switch (op.getKind()) {
736+
case CIR::int_to_bool: {
737+
auto zero = rewriter.create<mlir::cir::ConstantOp>(
738+
src.getLoc(), op.getSrc().getType(),
739+
mlir::cir::IntAttr::get(op.getSrc().getType(), 0));
740+
rewriter.replaceOpWithNewOp<mlir::cir::CmpOp>(
741+
op, mlir::cir::BoolType::get(getContext()), mlir::cir::CmpOpKind::ne,
742+
op.getSrc(), zero);
743+
return mlir::success();
744+
}
745+
case CIR::integral: {
746+
auto srcType = op.getSrc().getType();
747+
auto newDstType = convertTy(dstType);
748+
mlir::cir::IntType srcIntType =
749+
elementTypeIfVector(srcType).cast<mlir::cir::IntType>();
750+
mlir::cir::IntType dstIntType =
751+
elementTypeIfVector(dstType).cast<mlir::cir::IntType>();
752+
753+
if (dstIntType.getWidth() < srcIntType.getWidth()) {
754+
// Bigger to smaller. Truncate.
755+
rewriter.replaceOpWithNewOp<mlir::arith::TruncIOp>(op, newDstType, src);
756+
} else if (dstIntType.getWidth() > srcIntType.getWidth()) {
757+
// Smaller to bigger. Zero extend or sign extend based on signedness.
758+
if (srcIntType.isUnsigned())
759+
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, newDstType,
760+
src);
761+
else
762+
rewriter.replaceOpWithNewOp<mlir::arith::ExtSIOp>(op, newDstType,
763+
src);
764+
} else {
765+
// Same size. Signedness changes doesn't matter. Do nothing.
766+
rewriter.replaceOp(op, src);
767+
}
768+
return mlir::success();
769+
}
770+
case CIR::floating: {
771+
auto newDstType = convertTy(dstType);
772+
auto srcTy = elementTypeIfVector(op.getSrc().getType());
773+
auto dstTy = elementTypeIfVector(op.getResult().getType());
774+
775+
if (!dstTy.isa<mlir::cir::CIRFPTypeInterface>() ||
776+
!srcTy.isa<mlir::cir::CIRFPTypeInterface>())
777+
return op.emitError() << "NYI cast from " << srcTy << " to " << dstTy;
778+
779+
auto getFloatWidth = [](mlir::Type ty) -> unsigned {
780+
return ty.cast<mlir::cir::CIRFPTypeInterface>().getWidth();
781+
};
782+
783+
if (getFloatWidth(srcTy) > getFloatWidth(dstTy))
784+
rewriter.replaceOpWithNewOp<mlir::arith::TruncFOp>(op, newDstType, src);
785+
else
786+
rewriter.replaceOpWithNewOp<mlir::arith::ExtFOp>(op, newDstType, src);
787+
return mlir::success();
788+
}
789+
case CIR::float_to_bool: {
790+
auto dstTy = op.getType().cast<mlir::cir::BoolType>();
791+
auto newDstType = convertTy(dstTy);
792+
auto kind = mlir::arith::CmpFPredicate::UNE;
793+
794+
// Check if float is not equal to zero.
795+
auto zeroFloat = rewriter.create<mlir::arith::ConstantOp>(
796+
op.getLoc(), src.getType(), mlir::FloatAttr::get(src.getType(), 0.0));
797+
798+
// Extend comparison result to either bool (C++) or int (C).
799+
mlir::Value cmpResult = rewriter.create<mlir::arith::CmpFOp>(
800+
op.getLoc(), kind, src, zeroFloat);
801+
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, newDstType,
802+
cmpResult);
803+
return mlir::success();
804+
}
805+
case CIR::bool_to_int: {
806+
auto newSrcTy = src.getType().cast<mlir::IntegerType>();
807+
auto dstTy = op.getType().cast<mlir::cir::IntType>();
808+
auto newDstType = convertTy(dstTy).cast<mlir::IntegerType>();
809+
if (newSrcTy.getWidth() == newDstType.getWidth())
810+
rewriter.replaceOpWithNewOp<mlir::arith::BitcastOp>(op, newDstType,
811+
src);
812+
else
813+
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, newDstType, src);
814+
return mlir::success();
815+
}
816+
case CIR::bool_to_float: {
817+
auto dstTy = op.getType();
818+
auto newDstType = convertTy(dstTy);
819+
rewriter.replaceOpWithNewOp<mlir::arith::UIToFPOp>(op, newDstType, src);
820+
return mlir::success();
821+
}
822+
case CIR::int_to_float: {
823+
auto dstTy = op.getType();
824+
auto newDstType = convertTy(dstTy);
825+
if (elementTypeIfVector(op.getSrc().getType())
826+
.cast<mlir::cir::IntType>()
827+
.isSigned())
828+
rewriter.replaceOpWithNewOp<mlir::arith::SIToFPOp>(op, newDstType, src);
829+
else
830+
rewriter.replaceOpWithNewOp<mlir::arith::UIToFPOp>(op, newDstType, src);
831+
return mlir::success();
832+
}
833+
case CIR::float_to_int: {
834+
auto dstTy = op.getType();
835+
auto newDstType = convertTy(dstTy);
836+
if (elementTypeIfVector(op.getResult().getType())
837+
.cast<mlir::cir::IntType>()
838+
.isSigned())
839+
rewriter.replaceOpWithNewOp<mlir::arith::FPToSIOp>(op, newDstType, src);
840+
else
841+
rewriter.replaceOpWithNewOp<mlir::arith::FPToUIOp>(op, newDstType, src);
842+
return mlir::success();
843+
}
844+
default:
845+
break;
846+
}
847+
return mlir::failure();
848+
}
849+
};
850+
712851
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
713852
mlir::TypeConverter &converter) {
714853
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
@@ -718,7 +857,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
718857
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
719858
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
720859
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
721-
CIRGetGlobalOpLowering>(converter, patterns.getContext());
860+
CIRGetGlobalOpLowering, CIRCastOpLowering>(
861+
converter, patterns.getContext());
722862
}
723863

724864
static mlir::TypeConverter prepareTypeConverter() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR
2+
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
3+
4+
!s32i = !cir.int<s, 32>
5+
!s16i = !cir.int<s, 16>
6+
!u32i = !cir.int<u, 32>
7+
!u16i = !cir.int<u, 16>
8+
!u8i = !cir.int<u, 8>
9+
module {
10+
cir.func @cast_int_to_bool(%i : !u32i) -> !cir.bool {
11+
%1 = cir.cast(int_to_bool, %i : !u32i), !cir.bool
12+
cir.return %1 : !cir.bool
13+
}
14+
cir.func @cast_integral_trunc(%i : !u32i) -> !u16i {
15+
%1 = cir.cast(integral, %i : !u32i), !u16i
16+
cir.return %1 : !u16i
17+
}
18+
cir.func @cast_integral_extu(%i : !u16i) -> !u32i {
19+
%1 = cir.cast(integral, %i : !u16i), !u32i
20+
cir.return %1 : !u32i
21+
}
22+
cir.func @cast_integral_exts(%i : !s16i) -> !s32i {
23+
%1 = cir.cast(integral, %i : !s16i), !s32i
24+
cir.return %1 : !s32i
25+
}
26+
cir.func @cast_integral_same_size(%i : !u32i) -> !s32i {
27+
%1 = cir.cast(integral, %i : !u32i), !s32i
28+
cir.return %1 : !s32i
29+
}
30+
cir.func @cast_floating_trunc(%d : !cir.double) -> !cir.float {
31+
%1 = cir.cast(floating, %d : !cir.double), !cir.float
32+
cir.return %1 : !cir.float
33+
}
34+
cir.func @cast_floating_extf(%f : !cir.float) -> !cir.double {
35+
%1 = cir.cast(floating, %f : !cir.float), !cir.double
36+
cir.return %1 : !cir.double
37+
}
38+
cir.func @cast_float_to_bool(%f : !cir.float) -> !cir.bool {
39+
%1 = cir.cast(float_to_bool, %f : !cir.float), !cir.bool
40+
cir.return %1 : !cir.bool
41+
}
42+
cir.func @cast_bool_to_int8(%b : !cir.bool) -> !u8i {
43+
%1 = cir.cast(bool_to_int, %b : !cir.bool), !u8i
44+
cir.return %1 : !u8i
45+
}
46+
cir.func @cast_bool_to_int(%b : !cir.bool) -> !u32i {
47+
%1 = cir.cast(bool_to_int, %b : !cir.bool), !u32i
48+
cir.return %1 : !u32i
49+
}
50+
cir.func @cast_bool_to_float(%b : !cir.bool) -> !cir.float {
51+
%1 = cir.cast(bool_to_float, %b : !cir.bool), !cir.float
52+
cir.return %1 : !cir.float
53+
}
54+
cir.func @cast_signed_int_to_float(%i : !s32i) -> !cir.float {
55+
%1 = cir.cast(int_to_float, %i : !s32i), !cir.float
56+
cir.return %1 : !cir.float
57+
}
58+
cir.func @cast_unsigned_int_to_float(%i : !u32i) -> !cir.float {
59+
%1 = cir.cast(int_to_float, %i : !u32i), !cir.float
60+
cir.return %1 : !cir.float
61+
}
62+
cir.func @cast_float_to_int_signed(%f : !cir.float) -> !s32i {
63+
%1 = cir.cast(float_to_int, %f : !cir.float), !s32i
64+
cir.return %1 : !s32i
65+
}
66+
cir.func @cast_float_to_int_unsigned(%f : !cir.float) -> !u32i {
67+
%1 = cir.cast(float_to_int, %f : !cir.float), !u32i
68+
cir.return %1 : !u32i
69+
}
70+
}
71+
72+
// MLIR: arith.cmpi ne, %arg0, %c0_i32
73+
// MLIR: arith.trunci %arg0 : i32 to i16
74+
// MLIR: arith.extui %arg0 : i16 to i32
75+
// MLIR: arith.extsi %arg0 : i16 to i32
76+
// MLIR: return %arg0 : i32
77+
// MLIR: arith.truncf %arg0 : f64 to f32
78+
// MLIR: arith.extf %arg0 : f32 to f64
79+
// MLIR: arith.cmpf une, %arg0, %cst : f32
80+
// MLIR: arith.bitcast %arg0 : i8 to i8
81+
// MLIR: arith.extui %arg0 : i8 to i32
82+
// MLIR: arith.uitofp %arg0 : i8 to f32
83+
// MLIR: arith.sitofp %arg0 : i32 to f32
84+
// MLIR: arith.uitofp %arg0 : i32 to f32
85+
// MLIR: arith.fptosi %arg0 : f32 to i32
86+
// MLIR: arith.fptoui %arg0 : f32 to i32
87+
88+
// LLVM: icmp ne i32 %0, 0
89+
// LLVM: trunc i32 %0 to i16
90+
// LLVM: zext i16 %0 to i32
91+
// LLVM: sext i16 %0 to i32
92+
// LLVM: ret i32 %0
93+
// LLVM: fptrunc double %0 to float
94+
// LLVM: fpext float %0 to double
95+
// LLVM: fcmp une float %0, 0.000000e+00
96+
// LLVM: ret i8 %0
97+
// LLVM: zext i8 %0 to i32
98+
// LLVM: uitofp i8 %0 to float
99+
// LLVM: sitofp i32 %0 to float
100+
// LLVM: uitofp i32 %0 to float
101+
// LLVM: fptosi float %0 to i32
102+
// LLVM: fptoui float %0 to i32

0 commit comments

Comments
 (0)