Skip to content

Commit 1efee91

Browse files
authored
[CIR][ThroughMLIR] Support lowering CastOp to arith (#577)
This commit introduce CIRCastOpLowering for lowering to arith.
1 parent b361bbe commit 1efee91

File tree

2 files changed

+278
-1
lines changed

2 files changed

+278
-1
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+131-1
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,135 @@ class CIRGetGlobalOpLowering
709709
}
710710
};
711711

712+
static mlir::Value createIntCast(mlir::ConversionPatternRewriter &rewriter,
713+
mlir::Value src, mlir::Type dstTy,
714+
bool isSigned = false) {
715+
auto srcTy = src.getType();
716+
assert(isa<mlir::IntegerType>(srcTy));
717+
assert(isa<mlir::IntegerType>(dstTy));
718+
719+
auto srcWidth = srcTy.cast<mlir::IntegerType>().getWidth();
720+
auto dstWidth = dstTy.cast<mlir::IntegerType>().getWidth();
721+
auto loc = src.getLoc();
722+
723+
if (dstWidth > srcWidth && isSigned)
724+
return rewriter.create<mlir::arith::ExtSIOp>(loc, dstTy, src);
725+
else if (dstWidth > srcWidth)
726+
return rewriter.create<mlir::arith::ExtUIOp>(loc, dstTy, src);
727+
else if (dstWidth < srcWidth)
728+
return rewriter.create<mlir::arith::TruncIOp>(loc, dstTy, src);
729+
else
730+
return rewriter.create<mlir::arith::BitcastOp>(loc, dstTy, src);
731+
}
732+
733+
class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
734+
public:
735+
using OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern;
736+
737+
inline mlir::Type convertTy(mlir::Type ty) const {
738+
return getTypeConverter()->convertType(ty);
739+
}
740+
741+
mlir::LogicalResult
742+
matchAndRewrite(mlir::cir::CastOp op, OpAdaptor adaptor,
743+
mlir::ConversionPatternRewriter &rewriter) const override {
744+
if (isa<mlir::cir::VectorType>(op.getSrc().getType()))
745+
llvm_unreachable("CastOp lowering for vector type is not supported yet");
746+
auto src = adaptor.getSrc();
747+
auto dstType = op.getResult().getType();
748+
using CIR = mlir::cir::CastKind;
749+
switch (op.getKind()) {
750+
case CIR::int_to_bool: {
751+
auto zero = rewriter.create<mlir::cir::ConstantOp>(
752+
src.getLoc(), op.getSrc().getType(),
753+
mlir::cir::IntAttr::get(op.getSrc().getType(), 0));
754+
rewriter.replaceOpWithNewOp<mlir::cir::CmpOp>(
755+
op, mlir::cir::BoolType::get(getContext()), mlir::cir::CmpOpKind::ne,
756+
op.getSrc(), zero);
757+
return mlir::success();
758+
}
759+
case CIR::integral: {
760+
auto newDstType = convertTy(dstType);
761+
auto srcType = op.getSrc().getType();
762+
mlir::cir::IntType srcIntType = srcType.cast<mlir::cir::IntType>();
763+
auto newOp =
764+
createIntCast(rewriter, src, newDstType, srcIntType.isSigned());
765+
rewriter.replaceOp(op, newOp);
766+
return mlir::success();
767+
}
768+
case CIR::floating: {
769+
auto newDstType = convertTy(dstType);
770+
auto srcTy = op.getSrc().getType();
771+
auto dstTy = op.getResult().getType();
772+
773+
if (!dstTy.isa<mlir::cir::CIRFPTypeInterface>() ||
774+
!srcTy.isa<mlir::cir::CIRFPTypeInterface>())
775+
return op.emitError() << "NYI cast from " << srcTy << " to " << dstTy;
776+
777+
auto getFloatWidth = [](mlir::Type ty) -> unsigned {
778+
return ty.cast<mlir::cir::CIRFPTypeInterface>().getWidth();
779+
};
780+
781+
if (getFloatWidth(srcTy) > getFloatWidth(dstTy))
782+
rewriter.replaceOpWithNewOp<mlir::arith::TruncFOp>(op, newDstType, src);
783+
else
784+
rewriter.replaceOpWithNewOp<mlir::arith::ExtFOp>(op, newDstType, src);
785+
return mlir::success();
786+
}
787+
case CIR::float_to_bool: {
788+
auto dstTy = op.getType().cast<mlir::cir::BoolType>();
789+
auto newDstType = convertTy(dstTy);
790+
auto kind = mlir::arith::CmpFPredicate::UNE;
791+
792+
// Check if float is not equal to zero.
793+
auto zeroFloat = rewriter.create<mlir::arith::ConstantOp>(
794+
op.getLoc(), src.getType(), mlir::FloatAttr::get(src.getType(), 0.0));
795+
796+
// Extend comparison result to either bool (C++) or int (C).
797+
mlir::Value cmpResult = rewriter.create<mlir::arith::CmpFOp>(
798+
op.getLoc(), kind, src, zeroFloat);
799+
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, newDstType,
800+
cmpResult);
801+
return mlir::success();
802+
}
803+
case CIR::bool_to_int: {
804+
auto dstTy = op.getType().cast<mlir::cir::IntType>();
805+
auto newDstType = convertTy(dstTy).cast<mlir::IntegerType>();
806+
auto newOp = createIntCast(rewriter, src, newDstType);
807+
rewriter.replaceOp(op, newOp);
808+
return mlir::success();
809+
}
810+
case CIR::bool_to_float: {
811+
auto dstTy = op.getType();
812+
auto newDstType = convertTy(dstTy);
813+
rewriter.replaceOpWithNewOp<mlir::arith::UIToFPOp>(op, newDstType, src);
814+
return mlir::success();
815+
}
816+
case CIR::int_to_float: {
817+
auto dstTy = op.getType();
818+
auto newDstType = convertTy(dstTy);
819+
if (op.getSrc().getType().cast<mlir::cir::IntType>().isSigned())
820+
rewriter.replaceOpWithNewOp<mlir::arith::SIToFPOp>(op, newDstType, src);
821+
else
822+
rewriter.replaceOpWithNewOp<mlir::arith::UIToFPOp>(op, newDstType, src);
823+
return mlir::success();
824+
}
825+
case CIR::float_to_int: {
826+
auto dstTy = op.getType();
827+
auto newDstType = convertTy(dstTy);
828+
if (op.getResult().getType().cast<mlir::cir::IntType>().isSigned())
829+
rewriter.replaceOpWithNewOp<mlir::arith::FPToSIOp>(op, newDstType, src);
830+
else
831+
rewriter.replaceOpWithNewOp<mlir::arith::FPToUIOp>(op, newDstType, src);
832+
return mlir::success();
833+
}
834+
default:
835+
break;
836+
}
837+
return mlir::failure();
838+
}
839+
};
840+
712841
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
713842
mlir::TypeConverter &converter) {
714843
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
@@ -718,7 +847,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
718847
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
719848
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
720849
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
721-
CIRGetGlobalOpLowering>(converter, patterns.getContext());
850+
CIRGetGlobalOpLowering, CIRCastOpLowering>(
851+
converter, patterns.getContext());
722852
}
723853

724854
static mlir::TypeConverter prepareTypeConverter() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR
2+
// RUN: cir-opt %s -cir-to-mlir -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
3+
4+
!s32i = !cir.int<s, 32>
5+
!s16i = !cir.int<s, 16>
6+
!u32i = !cir.int<u, 32>
7+
!u16i = !cir.int<u, 16>
8+
!u8i = !cir.int<u, 8>
9+
module {
10+
// MLIR-LABEL: func.func @cast_int_to_bool(%arg0: i32) -> i8
11+
// LLVM-LABEL: define i8 @cast_int_to_bool(i32 %0)
12+
cir.func @cast_int_to_bool(%i : !u32i) -> !cir.bool {
13+
// MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0 : i32
14+
// MLIR-NEXT: arith.cmpi ne, %arg0, %[[ZERO]]
15+
// LLVM-NEXT: icmp ne i32 %0, 0
16+
17+
%1 = cir.cast(int_to_bool, %i : !u32i), !cir.bool
18+
cir.return %1 : !cir.bool
19+
}
20+
// MLIR-LABEL: func.func @cast_integral_trunc(%arg0: i32) -> i16
21+
// LLVM-LABEL: define i16 @cast_integral_trunc(i32 %0)
22+
cir.func @cast_integral_trunc(%i : !u32i) -> !u16i {
23+
// MLIR-NEXT: arith.trunci %arg0 : i32 to i16
24+
// LLVM-NEXT: trunc i32 %0 to i16
25+
26+
%1 = cir.cast(integral, %i : !u32i), !u16i
27+
cir.return %1 : !u16i
28+
}
29+
// MLIR-LABEL: func.func @cast_integral_extu(%arg0: i16) -> i32
30+
// LLVM-LABEL: define i32 @cast_integral_extu(i16 %0)
31+
cir.func @cast_integral_extu(%i : !u16i) -> !u32i {
32+
// MLIR-NEXT: arith.extui %arg0 : i16 to i32
33+
// LLVM-NEXT: zext i16 %0 to i32
34+
35+
%1 = cir.cast(integral, %i : !u16i), !u32i
36+
cir.return %1 : !u32i
37+
}
38+
// MLIR-LABEL: func.func @cast_integral_exts(%arg0: i16) -> i32
39+
// LLVM-LABEL: define i32 @cast_integral_exts(i16 %0)
40+
cir.func @cast_integral_exts(%i : !s16i) -> !s32i {
41+
// MLIR-NEXT: arith.extsi %arg0 : i16 to i32
42+
// LLVM-NEXT: sext i16 %0 to i32
43+
44+
%1 = cir.cast(integral, %i : !s16i), !s32i
45+
cir.return %1 : !s32i
46+
}
47+
// MLIR-LABEL: func.func @cast_integral_same_size(%arg0: i32) -> i32
48+
// LLVM-LABEL: define i32 @cast_integral_same_size(i32 %0)
49+
cir.func @cast_integral_same_size(%i : !u32i) -> !s32i {
50+
// MLIR-NEXT: %0 = arith.bitcast %arg0 : i32 to i32
51+
// LLVM-NEXT: ret i32 %0
52+
53+
%1 = cir.cast(integral, %i : !u32i), !s32i
54+
cir.return %1 : !s32i
55+
}
56+
// MLIR-LABEL: func.func @cast_floating_trunc(%arg0: f64) -> f32
57+
// LLVM-LABEL: define float @cast_floating_trunc(double %0)
58+
cir.func @cast_floating_trunc(%d : !cir.double) -> !cir.float {
59+
// MLIR-NEXT: arith.truncf %arg0 : f64 to f32
60+
// LLVM-NEXT: fptrunc double %0 to float
61+
62+
%1 = cir.cast(floating, %d : !cir.double), !cir.float
63+
cir.return %1 : !cir.float
64+
}
65+
// MLIR-LABEL: func.func @cast_floating_extf(%arg0: f32) -> f64
66+
// LLVM-LABEL: define double @cast_floating_extf(float %0)
67+
cir.func @cast_floating_extf(%f : !cir.float) -> !cir.double {
68+
// MLIR-NEXT: arith.extf %arg0 : f32 to f64
69+
// LLVM-NEXT: fpext float %0 to double
70+
71+
%1 = cir.cast(floating, %f : !cir.float), !cir.double
72+
cir.return %1 : !cir.double
73+
}
74+
// MLIR-LABEL: func.func @cast_float_to_bool(%arg0: f32) -> i8
75+
// LLVM-LABEL: define i8 @cast_float_to_bool(float %0)
76+
cir.func @cast_float_to_bool(%f : !cir.float) -> !cir.bool {
77+
// MLIR-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
78+
// MLIR-NEXT: arith.cmpf une, %arg0, %[[ZERO]] : f32
79+
// LLVM-NEXT: fcmp une float %0, 0.000000e+00
80+
81+
%1 = cir.cast(float_to_bool, %f : !cir.float), !cir.bool
82+
cir.return %1 : !cir.bool
83+
}
84+
// MLIR-LABEL: func.func @cast_bool_to_int8(%arg0: i8) -> i8
85+
// LLVM-LABEL: define i8 @cast_bool_to_int8(i8 %0)
86+
cir.func @cast_bool_to_int8(%b : !cir.bool) -> !u8i {
87+
// MLIR-NEXT: arith.bitcast %arg0 : i8 to i8
88+
// LLVM-NEXT: ret i8 %0
89+
90+
%1 = cir.cast(bool_to_int, %b : !cir.bool), !u8i
91+
cir.return %1 : !u8i
92+
}
93+
// MLIR-LABEL: func.func @cast_bool_to_int(%arg0: i8) -> i32
94+
// LLVM-LABEL: define i32 @cast_bool_to_int(i8 %0)
95+
cir.func @cast_bool_to_int(%b : !cir.bool) -> !u32i {
96+
// MLIR-NEXT: arith.extui %arg0 : i8 to i32
97+
// LLVM-NEXT: zext i8 %0 to i32
98+
99+
%1 = cir.cast(bool_to_int, %b : !cir.bool), !u32i
100+
cir.return %1 : !u32i
101+
}
102+
// MLIR-LABEL: func.func @cast_bool_to_float(%arg0: i8) -> f32
103+
// LLVM-LABEL: define float @cast_bool_to_float(i8 %0)
104+
cir.func @cast_bool_to_float(%b : !cir.bool) -> !cir.float {
105+
// MLIR-NEXT: arith.uitofp %arg0 : i8 to f32
106+
// LLVM-NEXT: uitofp i8 %0 to float
107+
108+
%1 = cir.cast(bool_to_float, %b : !cir.bool), !cir.float
109+
cir.return %1 : !cir.float
110+
}
111+
// MLIR-LABEL: func.func @cast_signed_int_to_float(%arg0: i32) -> f32
112+
// LLVM-LABEL: define float @cast_signed_int_to_float(i32 %0)
113+
cir.func @cast_signed_int_to_float(%i : !s32i) -> !cir.float {
114+
// MLIR-NEXT: arith.sitofp %arg0 : i32 to f32
115+
// LLVM-NEXT: sitofp i32 %0 to float
116+
117+
%1 = cir.cast(int_to_float, %i : !s32i), !cir.float
118+
cir.return %1 : !cir.float
119+
}
120+
// MLIR-LABEL: func.func @cast_unsigned_int_to_float(%arg0: i32) -> f32
121+
// LLVM-LABEL: define float @cast_unsigned_int_to_float(i32 %0)
122+
cir.func @cast_unsigned_int_to_float(%i : !u32i) -> !cir.float {
123+
// MLIR-NEXT: arith.uitofp %arg0 : i32 to f32
124+
// LLVM-NEXT: uitofp i32 %0 to float
125+
126+
%1 = cir.cast(int_to_float, %i : !u32i), !cir.float
127+
cir.return %1 : !cir.float
128+
}
129+
// MLIR-LABEL: func.func @cast_float_to_int_signed(%arg0: f32) -> i32
130+
// LLVM-LABEL: define i32 @cast_float_to_int_signed(float %0)
131+
cir.func @cast_float_to_int_signed(%f : !cir.float) -> !s32i {
132+
// MLIR-NEXT: arith.fptosi %arg0 : f32 to i32
133+
// LLVM-NEXT: fptosi float %0 to i32
134+
135+
%1 = cir.cast(float_to_int, %f : !cir.float), !s32i
136+
cir.return %1 : !s32i
137+
}
138+
// MLIR-LABEL: func.func @cast_float_to_int_unsigned(%arg0: f32) -> i32
139+
// LLVM-LABEL: define i32 @cast_float_to_int_unsigned(float %0)
140+
cir.func @cast_float_to_int_unsigned(%f : !cir.float) -> !u32i {
141+
// MLIR-NEXT: arith.fptoui %arg0 : f32 to i32
142+
// LLVM-NEXT: fptoui float %0 to i32
143+
144+
%1 = cir.cast(float_to_int, %f : !cir.float), !u32i
145+
cir.return %1 : !u32i
146+
}
147+
}

0 commit comments

Comments
 (0)