@@ -709,6 +709,145 @@ class CIRGetGlobalOpLowering
709
709
}
710
710
};
711
711
712
+ class CIRCastOpLowering : public mlir ::OpConversionPattern<mlir::cir::CastOp> {
713
+ public:
714
+ using OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern;
715
+
716
+ inline mlir::Type convertTy (mlir::Type ty) const {
717
+ return getTypeConverter ()->convertType (ty);
718
+ }
719
+
720
+ // / If the given type is a vector type, return the vector's element type.
721
+ // / Otherwise return the given type unchanged.
722
+ inline mlir::Type elementTypeIfVector (mlir::Type type) const {
723
+ if (auto VecType = type.dyn_cast <mlir::cir::VectorType>()) {
724
+ return VecType.getEltType ();
725
+ }
726
+ return type;
727
+ }
728
+
729
+ mlir::LogicalResult
730
+ matchAndRewrite (mlir::cir::CastOp op, OpAdaptor adaptor,
731
+ mlir::ConversionPatternRewriter &rewriter) const override {
732
+ auto src = adaptor.getSrc ();
733
+ auto dstType = op.getResult ().getType ();
734
+ using CIR = mlir::cir::CastKind;
735
+ switch (op.getKind ()) {
736
+ case CIR::int_to_bool: {
737
+ auto zero = rewriter.create <mlir::cir::ConstantOp>(
738
+ src.getLoc (), op.getSrc ().getType (),
739
+ mlir::cir::IntAttr::get (op.getSrc ().getType (), 0 ));
740
+ rewriter.replaceOpWithNewOp <mlir::cir::CmpOp>(
741
+ op, mlir::cir::BoolType::get (getContext ()), mlir::cir::CmpOpKind::ne,
742
+ op.getSrc (), zero);
743
+ return mlir::success ();
744
+ }
745
+ case CIR::integral: {
746
+ auto srcType = op.getSrc ().getType ();
747
+ auto newDstType = convertTy (dstType);
748
+ mlir::cir::IntType srcIntType =
749
+ elementTypeIfVector (srcType).cast <mlir::cir::IntType>();
750
+ mlir::cir::IntType dstIntType =
751
+ elementTypeIfVector (dstType).cast <mlir::cir::IntType>();
752
+
753
+ if (dstIntType.getWidth () < srcIntType.getWidth ()) {
754
+ // Bigger to smaller. Truncate.
755
+ rewriter.replaceOpWithNewOp <mlir::arith::TruncIOp>(op, newDstType, src);
756
+ } else if (dstIntType.getWidth () > srcIntType.getWidth ()) {
757
+ // Smaller to bigger. Zero extend or sign extend based on signedness.
758
+ if (srcIntType.isUnsigned ())
759
+ rewriter.replaceOpWithNewOp <mlir::arith::ExtUIOp>(op, newDstType,
760
+ src);
761
+ else
762
+ rewriter.replaceOpWithNewOp <mlir::arith::ExtSIOp>(op, newDstType,
763
+ src);
764
+ } else {
765
+ // Same size. Signedness changes doesn't matter. Do nothing.
766
+ rewriter.replaceOp (op, src);
767
+ }
768
+ return mlir::success ();
769
+ }
770
+ case CIR::floating: {
771
+ auto newDstType = convertTy (dstType);
772
+ auto srcTy = elementTypeIfVector (op.getSrc ().getType ());
773
+ auto dstTy = elementTypeIfVector (op.getResult ().getType ());
774
+
775
+ if (!dstTy.isa <mlir::cir::CIRFPTypeInterface>() ||
776
+ !srcTy.isa <mlir::cir::CIRFPTypeInterface>())
777
+ return op.emitError () << " NYI cast from " << srcTy << " to " << dstTy;
778
+
779
+ auto getFloatWidth = [](mlir::Type ty) -> unsigned {
780
+ return ty.cast <mlir::cir::CIRFPTypeInterface>().getWidth ();
781
+ };
782
+
783
+ if (getFloatWidth (srcTy) > getFloatWidth (dstTy))
784
+ rewriter.replaceOpWithNewOp <mlir::arith::TruncFOp>(op, newDstType, src);
785
+ else
786
+ rewriter.replaceOpWithNewOp <mlir::arith::ExtFOp>(op, newDstType, src);
787
+ return mlir::success ();
788
+ }
789
+ case CIR::float_to_bool: {
790
+ auto dstTy = op.getType ().cast <mlir::cir::BoolType>();
791
+ auto newDstType = convertTy (dstTy);
792
+ auto kind = mlir::arith::CmpFPredicate::UNE;
793
+
794
+ // Check if float is not equal to zero.
795
+ auto zeroFloat = rewriter.create <mlir::arith::ConstantOp>(
796
+ op.getLoc (), src.getType (), mlir::FloatAttr::get (src.getType (), 0.0 ));
797
+
798
+ // Extend comparison result to either bool (C++) or int (C).
799
+ mlir::Value cmpResult = rewriter.create <mlir::arith::CmpFOp>(
800
+ op.getLoc (), kind, src, zeroFloat);
801
+ rewriter.replaceOpWithNewOp <mlir::arith::ExtUIOp>(op, newDstType,
802
+ cmpResult);
803
+ return mlir::success ();
804
+ }
805
+ case CIR::bool_to_int: {
806
+ auto newSrcTy = src.getType ().cast <mlir::IntegerType>();
807
+ auto dstTy = op.getType ().cast <mlir::cir::IntType>();
808
+ auto newDstType = convertTy (dstTy).cast <mlir::IntegerType>();
809
+ if (newSrcTy.getWidth () == newDstType.getWidth ())
810
+ rewriter.replaceOpWithNewOp <mlir::arith::BitcastOp>(op, newDstType,
811
+ src);
812
+ else
813
+ rewriter.replaceOpWithNewOp <mlir::arith::ExtUIOp>(op, newDstType, src);
814
+ return mlir::success ();
815
+ }
816
+ case CIR::bool_to_float: {
817
+ auto dstTy = op.getType ();
818
+ auto newDstType = convertTy (dstTy);
819
+ rewriter.replaceOpWithNewOp <mlir::arith::UIToFPOp>(op, newDstType, src);
820
+ return mlir::success ();
821
+ }
822
+ case CIR::int_to_float: {
823
+ auto dstTy = op.getType ();
824
+ auto newDstType = convertTy (dstTy);
825
+ if (elementTypeIfVector (op.getSrc ().getType ())
826
+ .cast <mlir::cir::IntType>()
827
+ .isSigned ())
828
+ rewriter.replaceOpWithNewOp <mlir::arith::SIToFPOp>(op, newDstType, src);
829
+ else
830
+ rewriter.replaceOpWithNewOp <mlir::arith::UIToFPOp>(op, newDstType, src);
831
+ return mlir::success ();
832
+ }
833
+ case CIR::float_to_int: {
834
+ auto dstTy = op.getType ();
835
+ auto newDstType = convertTy (dstTy);
836
+ if (elementTypeIfVector (op.getResult ().getType ())
837
+ .cast <mlir::cir::IntType>()
838
+ .isSigned ())
839
+ rewriter.replaceOpWithNewOp <mlir::arith::FPToSIOp>(op, newDstType, src);
840
+ else
841
+ rewriter.replaceOpWithNewOp <mlir::arith::FPToUIOp>(op, newDstType, src);
842
+ return mlir::success ();
843
+ }
844
+ default :
845
+ break ;
846
+ }
847
+ return mlir::failure ();
848
+ }
849
+ };
850
+
712
851
void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
713
852
mlir::TypeConverter &converter) {
714
853
patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
@@ -718,7 +857,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
718
857
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
719
858
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
720
859
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
721
- CIRGetGlobalOpLowering>(converter, patterns.getContext ());
860
+ CIRGetGlobalOpLowering, CIRCastOpLowering>(
861
+ converter, patterns.getContext ());
722
862
}
723
863
724
864
static mlir::TypeConverter prepareTypeConverter () {
0 commit comments