@@ -709,6 +709,135 @@ class CIRGetGlobalOpLowering
709
709
}
710
710
};
711
711
712
+ static mlir::Value createIntCast (mlir::ConversionPatternRewriter &rewriter,
713
+ mlir::Value src, mlir::Type dstTy,
714
+ bool isSigned = false ) {
715
+ auto srcTy = src.getType ();
716
+ assert (isa<mlir::IntegerType>(srcTy));
717
+ assert (isa<mlir::IntegerType>(dstTy));
718
+
719
+ auto srcWidth = srcTy.cast <mlir::IntegerType>().getWidth ();
720
+ auto dstWidth = dstTy.cast <mlir::IntegerType>().getWidth ();
721
+ auto loc = src.getLoc ();
722
+
723
+ if (dstWidth > srcWidth && isSigned)
724
+ return rewriter.create <mlir::arith::ExtSIOp>(loc, dstTy, src);
725
+ else if (dstWidth > srcWidth)
726
+ return rewriter.create <mlir::arith::ExtUIOp>(loc, dstTy, src);
727
+ else if (dstWidth < srcWidth)
728
+ return rewriter.create <mlir::arith::TruncIOp>(loc, dstTy, src);
729
+ else
730
+ return rewriter.create <mlir::arith::BitcastOp>(loc, dstTy, src);
731
+ }
732
+
733
+ class CIRCastOpLowering : public mlir ::OpConversionPattern<mlir::cir::CastOp> {
734
+ public:
735
+ using OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern;
736
+
737
+ inline mlir::Type convertTy (mlir::Type ty) const {
738
+ return getTypeConverter ()->convertType (ty);
739
+ }
740
+
741
+ mlir::LogicalResult
742
+ matchAndRewrite (mlir::cir::CastOp op, OpAdaptor adaptor,
743
+ mlir::ConversionPatternRewriter &rewriter) const override {
744
+ if (isa<mlir::cir::VectorType>(op.getSrc ().getType ()))
745
+ llvm_unreachable (" CastOp lowering for vector type is not supported yet" );
746
+ auto src = adaptor.getSrc ();
747
+ auto dstType = op.getResult ().getType ();
748
+ using CIR = mlir::cir::CastKind;
749
+ switch (op.getKind ()) {
750
+ case CIR::int_to_bool: {
751
+ auto zero = rewriter.create <mlir::cir::ConstantOp>(
752
+ src.getLoc (), op.getSrc ().getType (),
753
+ mlir::cir::IntAttr::get (op.getSrc ().getType (), 0 ));
754
+ rewriter.replaceOpWithNewOp <mlir::cir::CmpOp>(
755
+ op, mlir::cir::BoolType::get (getContext ()), mlir::cir::CmpOpKind::ne,
756
+ op.getSrc (), zero);
757
+ return mlir::success ();
758
+ }
759
+ case CIR::integral: {
760
+ auto newDstType = convertTy (dstType);
761
+ auto srcType = op.getSrc ().getType ();
762
+ mlir::cir::IntType srcIntType = srcType.cast <mlir::cir::IntType>();
763
+ auto newOp =
764
+ createIntCast (rewriter, src, newDstType, srcIntType.isSigned ());
765
+ rewriter.replaceOp (op, newOp);
766
+ return mlir::success ();
767
+ }
768
+ case CIR::floating: {
769
+ auto newDstType = convertTy (dstType);
770
+ auto srcTy = op.getSrc ().getType ();
771
+ auto dstTy = op.getResult ().getType ();
772
+
773
+ if (!dstTy.isa <mlir::cir::CIRFPTypeInterface>() ||
774
+ !srcTy.isa <mlir::cir::CIRFPTypeInterface>())
775
+ return op.emitError () << " NYI cast from " << srcTy << " to " << dstTy;
776
+
777
+ auto getFloatWidth = [](mlir::Type ty) -> unsigned {
778
+ return ty.cast <mlir::cir::CIRFPTypeInterface>().getWidth ();
779
+ };
780
+
781
+ if (getFloatWidth (srcTy) > getFloatWidth (dstTy))
782
+ rewriter.replaceOpWithNewOp <mlir::arith::TruncFOp>(op, newDstType, src);
783
+ else
784
+ rewriter.replaceOpWithNewOp <mlir::arith::ExtFOp>(op, newDstType, src);
785
+ return mlir::success ();
786
+ }
787
+ case CIR::float_to_bool: {
788
+ auto dstTy = op.getType ().cast <mlir::cir::BoolType>();
789
+ auto newDstType = convertTy (dstTy);
790
+ auto kind = mlir::arith::CmpFPredicate::UNE;
791
+
792
+ // Check if float is not equal to zero.
793
+ auto zeroFloat = rewriter.create <mlir::arith::ConstantOp>(
794
+ op.getLoc (), src.getType (), mlir::FloatAttr::get (src.getType (), 0.0 ));
795
+
796
+ // Extend comparison result to either bool (C++) or int (C).
797
+ mlir::Value cmpResult = rewriter.create <mlir::arith::CmpFOp>(
798
+ op.getLoc (), kind, src, zeroFloat);
799
+ rewriter.replaceOpWithNewOp <mlir::arith::ExtUIOp>(op, newDstType,
800
+ cmpResult);
801
+ return mlir::success ();
802
+ }
803
+ case CIR::bool_to_int: {
804
+ auto dstTy = op.getType ().cast <mlir::cir::IntType>();
805
+ auto newDstType = convertTy (dstTy).cast <mlir::IntegerType>();
806
+ auto newOp = createIntCast (rewriter, src, newDstType);
807
+ rewriter.replaceOp (op, newOp);
808
+ return mlir::success ();
809
+ }
810
+ case CIR::bool_to_float: {
811
+ auto dstTy = op.getType ();
812
+ auto newDstType = convertTy (dstTy);
813
+ rewriter.replaceOpWithNewOp <mlir::arith::UIToFPOp>(op, newDstType, src);
814
+ return mlir::success ();
815
+ }
816
+ case CIR::int_to_float: {
817
+ auto dstTy = op.getType ();
818
+ auto newDstType = convertTy (dstTy);
819
+ if (op.getSrc ().getType ().cast <mlir::cir::IntType>().isSigned ())
820
+ rewriter.replaceOpWithNewOp <mlir::arith::SIToFPOp>(op, newDstType, src);
821
+ else
822
+ rewriter.replaceOpWithNewOp <mlir::arith::UIToFPOp>(op, newDstType, src);
823
+ return mlir::success ();
824
+ }
825
+ case CIR::float_to_int: {
826
+ auto dstTy = op.getType ();
827
+ auto newDstType = convertTy (dstTy);
828
+ if (op.getResult ().getType ().cast <mlir::cir::IntType>().isSigned ())
829
+ rewriter.replaceOpWithNewOp <mlir::arith::FPToSIOp>(op, newDstType, src);
830
+ else
831
+ rewriter.replaceOpWithNewOp <mlir::arith::FPToUIOp>(op, newDstType, src);
832
+ return mlir::success ();
833
+ }
834
+ default :
835
+ break ;
836
+ }
837
+ return mlir::failure ();
838
+ }
839
+ };
840
+
712
841
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
713
842
mlir::TypeConverter &converter) {
714
843
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
@@ -718,7 +847,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
718
847
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
719
848
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
720
849
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
721
- CIRGetGlobalOpLowering>(converter, patterns.getContext ());
850
+ CIRGetGlobalOpLowering, CIRCastOpLowering>(
851
+ converter, patterns.getContext ());
722
852
}
723
853
724
854
static mlir::TypeConverter prepareTypeConverter () {
0 commit comments