33#pragma once
44
55#include < cuComplex.h>
6+ #include < cuda_bf16.h>
7+ #include < cuda_fp16.h>
68#include < thrust/iterator/transform_iterator.h>
79
810namespace mlx ::core::cu {
@@ -17,6 +19,26 @@ struct CastOp {
1719 }
1820};
1921
22+ // Castings between complex and boolean.
23+ // TODO: Should make a custom complex type.
24+ template <>
25+ struct CastOp <cuComplex, bool > {
26+ static constexpr bool is_castable = true ;
27+
28+ __device__ bool operator ()(cuComplex x) {
29+ return x.x != 0 && x.y != 0 ;
30+ }
31+ };
32+
33+ template <>
34+ struct CastOp <bool , cuComplex> {
35+ static constexpr bool is_castable = true ;
36+
37+ __device__ cuComplex operator ()(bool x) {
38+ return x ? make_cuFloatComplex (1 , 1 ) : make_cuFloatComplex (0 , 0 );
39+ }
40+ };
41+
2042// Converting a complex number to real number discards the imaginary part.
2143template <typename DstT>
2244struct CastOp <
@@ -45,6 +67,7 @@ struct CastOp<
4567 }
4668};
4769
70+ // Do nothing when no casting is needed.
4871template <typename SrcT, typename DstT>
4972struct CastOp <
5073 SrcT,
@@ -57,9 +80,53 @@ struct CastOp<
5780 }
5881};
5982
83+ // In CUDA 11 the half types do not define conversions between some types,
84+ // provide fallbacks here.
85+ #if CUDART_VERSION < 12000
86+ template <typename SrcT, typename DstT>
87+ struct CastOp <
88+ SrcT,
89+ DstT,
90+ cuda::std::enable_if_t <
91+ !cuda::std::is_convertible_v<SrcT, DstT> &&
92+ !cuda::std::is_same_v<SrcT, cuComplex> &&
93+ (cuda::std::is_same_v<DstT, __half> ||
94+ cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
95+ static constexpr bool is_castable = true ;
96+
97+ __device__ DstT operator ()(SrcT x) {
98+ return DstT (static_cast <float >(x));
99+ }
100+ };
101+
102+ template <typename SrcT, typename DstT>
103+ struct CastOp <
104+ SrcT,
105+ DstT,
106+ cuda::std::enable_if_t <
107+ !cuda::std::is_convertible_v<SrcT, DstT> &&
108+ !cuda::std::is_same_v<DstT, cuComplex> &&
109+ !cuda::std::is_same_v<DstT, __half> &&
110+ !cuda::std::is_same_v<DstT, __nv_bfloat16> &&
111+ (cuda::std::is_same_v<SrcT, __half> ||
112+ cuda::std::is_same_v<SrcT, __nv_bfloat16>)>> {
113+ static constexpr bool is_castable = true ;
114+
115+ __device__ DstT operator ()(SrcT x) {
116+ return DstT (static_cast <float >(x));
117+ }
118+ };
119+ #endif // CUDART_VERSION < 12000
120+
121+ // Helper to deduce the SrcT.
122+ template <typename DstT, typename SrcT>
123+ inline __host__ __device__ auto cast_to (SrcT x) {
124+ return CastOp<SrcT, DstT>{}(x);
125+ }
126+
60127// Return an iterator that cast the value to DstT using CastOp.
61128template <typename DstT, typename Iterator>
62- __host__ __device__ auto make_cast_iterator (Iterator it) {
129+ inline __host__ __device__ auto make_cast_iterator (Iterator it) {
63130 using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
64131 if constexpr (std::is_same_v<SrcT, DstT>) {
65132 return it;
0 commit comments