Skip to content

Commit 2ca533b

Browse files
authored
Fix compilation with CUDA 11 (#2331)
1 parent 4a9b29a commit 2ca533b

File tree

11 files changed

+116
-57
lines changed

11 files changed

+116
-57
lines changed

mlx/backend/cuda/arg_reduce.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright © 2025 Apple Inc.
22
#include "mlx/backend/common/utils.h"
33
#include "mlx/backend/cuda/device.h"
4+
#include "mlx/backend/cuda/device/fp16_math.cuh"
45
#include "mlx/backend/cuda/iterators/strided_iterator.cuh"
56
#include "mlx/backend/cuda/kernel_utils.cuh"
67
#include "mlx/dtype_utils.h"

mlx/backend/cuda/device.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,19 +264,26 @@ void CommandEncoder::commit() {
264264
graph_key_ += std::to_string(graph_node_count_);
265265
graph_key_ += ".";
266266
graph_key_ += std::to_string(empty_node_count_);
267-
auto [it, _] = graph_cache_.emplace(graph_key_, nullptr);
268-
auto& graph_exec = it->second;
269-
270-
if (graph_exec != NULL) {
271-
cudaGraphExecUpdateResultInfo update_result;
272-
cudaGraphExecUpdate(graph_exec, graph_, &update_result);
273-
if (update_result.result != cudaGraphExecUpdateSuccess) {
274-
cudaGetLastError();
267+
268+
cudaGraphExec_t& graph_exec = graph_cache_[graph_key_];
269+
270+
if (graph_exec != nullptr) {
271+
cudaGraphExecUpdateResult update_result;
272+
#if CUDART_VERSION >= 12000
273+
cudaGraphExecUpdateResultInfo info;
274+
cudaGraphExecUpdate(graph_exec, graph_, &info);
275+
update_result = info.result;
276+
#else
277+
cudaGraphNode_t error_node;
278+
cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result);
279+
#endif // CUDART_VERSION >= 12000
280+
if (update_result != cudaGraphExecUpdateSuccess) {
281+
cudaGetLastError(); // reset error
275282
CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec));
276-
graph_exec = NULL;
283+
graph_exec = nullptr;
277284
}
278285
}
279-
if (graph_exec == NULL) {
286+
if (graph_exec == nullptr) {
280287
CHECK_CUDA_ERROR(
281288
cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0));
282289
}

mlx/backend/cuda/device/cast_op.cuh

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#pragma once
44

55
#include <cuComplex.h>
6+
#include <cuda_bf16.h>
7+
#include <cuda_fp16.h>
68
#include <thrust/iterator/transform_iterator.h>
79

810
namespace mlx::core::cu {
@@ -17,6 +19,26 @@ struct CastOp {
1719
}
1820
};
1921

22+
// Castings between complex and boolean.
23+
// TODO: Should make a custom complex type.
24+
template <>
25+
struct CastOp<cuComplex, bool> {
26+
static constexpr bool is_castable = true;
27+
28+
__device__ bool operator()(cuComplex x) {
29+
return x.x != 0 && x.y != 0;
30+
}
31+
};
32+
33+
template <>
34+
struct CastOp<bool, cuComplex> {
35+
static constexpr bool is_castable = true;
36+
37+
__device__ cuComplex operator()(bool x) {
38+
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
39+
}
40+
};
41+
2042
// Converting a complex number to real number discards the imaginary part.
2143
template <typename DstT>
2244
struct CastOp<
@@ -45,6 +67,7 @@ struct CastOp<
4567
}
4668
};
4769

70+
// Do nothing when no casting is needed.
4871
template <typename SrcT, typename DstT>
4972
struct CastOp<
5073
SrcT,
@@ -57,9 +80,53 @@ struct CastOp<
5780
}
5881
};
5982

83+
// In CUDA 11 the half types do not define conversions between some types,
84+
// provide fallbacks here.
85+
#if CUDART_VERSION < 12000
86+
template <typename SrcT, typename DstT>
87+
struct CastOp<
88+
SrcT,
89+
DstT,
90+
cuda::std::enable_if_t<
91+
!cuda::std::is_convertible_v<SrcT, DstT> &&
92+
!cuda::std::is_same_v<SrcT, cuComplex> &&
93+
(cuda::std::is_same_v<DstT, __half> ||
94+
cuda::std::is_same_v<DstT, __nv_bfloat16>)>> {
95+
static constexpr bool is_castable = true;
96+
97+
__device__ DstT operator()(SrcT x) {
98+
return DstT(static_cast<float>(x));
99+
}
100+
};
101+
102+
template <typename SrcT, typename DstT>
103+
struct CastOp<
104+
SrcT,
105+
DstT,
106+
cuda::std::enable_if_t<
107+
!cuda::std::is_convertible_v<SrcT, DstT> &&
108+
!cuda::std::is_same_v<DstT, cuComplex> &&
109+
!cuda::std::is_same_v<DstT, __half> &&
110+
!cuda::std::is_same_v<DstT, __nv_bfloat16> &&
111+
(cuda::std::is_same_v<SrcT, __half> ||
112+
cuda::std::is_same_v<SrcT, __nv_bfloat16>)>> {
113+
static constexpr bool is_castable = true;
114+
115+
__device__ DstT operator()(SrcT x) {
116+
return DstT(static_cast<float>(x));
117+
}
118+
};
119+
#endif // CUDART_VERSION < 12000
120+
121+
// Helper to deduce the SrcT.
122+
template <typename DstT, typename SrcT>
123+
inline __host__ __device__ auto cast_to(SrcT x) {
124+
return CastOp<SrcT, DstT>{}(x);
125+
}
126+
60127
// Return an iterator that cast the value to DstT using CastOp.
61128
template <typename DstT, typename Iterator>
62-
__host__ __device__ auto make_cast_iterator(Iterator it) {
129+
inline __host__ __device__ auto make_cast_iterator(Iterator it) {
63130
using SrcT = typename cuda::std::iterator_traits<Iterator>::value_type;
64131
if constexpr (std::is_same_v<SrcT, DstT>) {
65132
return it;

mlx/backend/cuda/device/utils.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,20 @@ struct Limits<
9999
return cuda::std::numeric_limits<T>::infinity();
100100
}
101101
static constexpr __host__ __device__ T min() {
102-
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
103-
return -cuda::std::numeric_limits<T>::infinity();
104-
#else
102+
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
105103
return -cuda::std::numeric_limits<float>::infinity();
104+
#else
105+
return -cuda::std::numeric_limits<T>::infinity();
106106
#endif
107107
}
108108
static constexpr __host__ __device__ T finite_max() {
109109
return cuda::std::numeric_limits<T>::max();
110110
}
111111
static constexpr __host__ __device__ T finite_min() {
112-
#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000
113-
return cuda::std::numeric_limits<T>::lowest();
114-
#else
112+
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
115113
return cuda::std::numeric_limits<float>::lowest();
114+
#else
115+
return cuda::std::numeric_limits<T>::lowest();
116116
#endif
117117
}
118118
};

mlx/backend/cuda/reduce/all_reduce.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
3737
for (; i + block.size() * N <= check; i += block.size() * N) {
3838
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
3939
for (int j = 0; j < N; j++) {
40-
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
40+
accs[0] = op(accs[0], cast_to<U>(vals[j]));
4141
}
4242
}
4343

4444
if (i < check) {
4545
cub::LoadDirectBlocked(
46-
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
46+
block.thread_rank(), in + i, vals, check - i, cast_to<T>(init));
4747
for (int i = 0; i < N; i++) {
48-
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
48+
accs[0] = op(accs[0], cast_to<U>(vals[i]));
4949
}
5050
}
5151

mlx/backend/cuda/reduce/col_reduce.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <numeric>
44

55
#include "mlx/backend/cuda/device.h"
6-
#include "mlx/backend/cuda/device/cast_op.cuh"
76
#include "mlx/backend/cuda/reduce/reduce.cuh"
87

98
#include <cooperative_groups.h>
@@ -128,7 +127,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
128127
T vals[N_READS];
129128
cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals);
130129
for (int i = 0; i < N_READS; i++) {
131-
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
130+
totals[i] = op(totals[i], cast_to<U>(vals[i]));
132131
}
133132
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
134133
}
@@ -137,7 +136,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
137136
T vals[N_READS];
138137
cub::LoadDirectBlocked(thread_x, in + loop.location(), vals);
139138
for (int i = 0; i < N_READS; i++) {
140-
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
139+
totals[i] = op(totals[i], cast_to<U>(vals[i]));
141140
}
142141
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
143142
}
@@ -150,9 +149,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) {
150149
in + loop.location(),
151150
vals,
152151
args.reduction_stride - tile_x * BN,
153-
__cast<T, U>(ReduceInit<Op, T>::value()));
152+
cast_to<T>(ReduceInit<Op, T>::value()));
154153
for (int i = 0; i < N_READS; i++) {
155-
totals[i] = op(totals[i], __cast<U, T>(vals[i]));
154+
totals[i] = op(totals[i], cast_to<U>(vals[i]));
156155
}
157156
loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data());
158157
}

mlx/backend/cuda/reduce/reduce_ops.cuh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#pragma once
44

5+
#include "mlx/backend/cuda/device/atomic_ops.cuh"
6+
#include "mlx/backend/cuda/device/cast_op.cuh"
57
#include "mlx/backend/cuda/device/utils.cuh"
68
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
79

@@ -40,15 +42,15 @@ struct Sum {
4042
}
4143

4244
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
43-
atomicAdd(x, y);
45+
atomic_add(x, y);
4446
}
4547

4648
__device__ void atomic_update(int* x, int y) {
47-
atomicAdd(x, y);
49+
atomic_add(x, y);
4850
}
4951

5052
__device__ void atomic_update(float* x, float y) {
51-
atomicAdd(x, y);
53+
atomic_add(x, y);
5254
}
5355
};
5456

@@ -152,7 +154,7 @@ struct ReduceInit<Sum, T> {
152154
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
153155
return T{0, 0};
154156
} else {
155-
return typename ReduceResult<Sum, T>::type{0};
157+
return cast_to<typename ReduceResult<Sum, T>::type>(0);
156158
}
157159
}
158160
};
@@ -163,7 +165,7 @@ struct ReduceInit<Prod, T> {
163165
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
164166
return T{1, 0};
165167
} else {
166-
return typename ReduceResult<Prod, T>::type{1};
168+
return cast_to<typename ReduceResult<Prod, T>::type>(1);
167169
}
168170
}
169171
};

mlx/backend/cuda/reduce/reduce_utils.cuh

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,6 @@ __device__ void atomic_reduce(T* x, T y) {
5555
}
5656
}
5757

58-
// TODO: Should make a custom complex type
59-
template <typename U, typename T>
60-
inline __device__ U __cast(T x) {
61-
return static_cast<U>(x);
62-
}
63-
64-
template <>
65-
inline __device__ bool __cast<bool, cuComplex>(cuComplex x) {
66-
return x.x != 0 && x.y != 0;
67-
}
68-
69-
template <>
70-
inline __device__ cuComplex __cast<cuComplex, bool>(bool x) {
71-
return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0);
72-
}
73-
7458
template <typename T, int N, typename Block, typename Warp, typename Op>
7559
inline __device__ void
7660
block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {

mlx/backend/cuda/reduce/row_reduce.cu

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <numeric>
44

55
#include "mlx/backend/cuda/device.h"
6-
#include "mlx/backend/cuda/device/cast_op.cuh"
76
#include "mlx/backend/cuda/reduce/reduce.cuh"
87

98
#include <cooperative_groups.h>
@@ -113,7 +112,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
113112
in + k * size + r * (block.size() * N),
114113
vals[k]);
115114
for (int j = 0; j < N; j++) {
116-
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
115+
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
117116
}
118117
}
119118
}
@@ -125,7 +124,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
125124
in + k * size + r * (block.size() * N),
126125
vals[k]);
127126
for (int j = 0; j < N; j++) {
128-
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
127+
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
129128
}
130129
}
131130
}
@@ -138,9 +137,9 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) {
138137
in + k * size + final_offset,
139138
vals[k],
140139
size,
141-
__cast<T, U>(init));
140+
cast_to<T>(init));
142141
for (int j = 0; j < N; j++) {
143-
accs[k] = op(accs[k], __cast<U, T>(vals[k][j]));
142+
accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
144143
}
145144
}
146145
}
@@ -199,7 +198,7 @@ __global__ void row_reduce_looped(
199198
in + loop.location() + r * BLOCK_DIM * N_READS,
200199
vals);
201200
for (int i = 0; i < N_READS; i++) {
202-
total[0] = op(total[0], __cast<U, T>(vals[i]));
201+
total[0] = op(total[0], cast_to<U>(vals[i]));
203202
}
204203
}
205204
if (final_offset < args.row_size) {
@@ -209,9 +208,9 @@ __global__ void row_reduce_looped(
209208
in + loop.location() + final_offset,
210209
vals,
211210
args.row_size - final_offset,
212-
__cast<T, U>(init));
211+
cast_to<T>(init));
213212
for (int i = 0; i < N_READS; i++) {
214-
total[0] = op(total[0], __cast<U, T>(vals[i]));
213+
total[0] = op(total[0], cast_to<U>(vals[i]));
215214
}
216215
}
217216
// TODO: Maybe block.sync() here?

mlx/backend/cuda/rms_norm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ __global__ void rms_norm(
7474
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
7575
auto index = r * BLOCK_DIM + block.thread_rank();
7676
T xn[N_READS];
77-
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
77+
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
7878
for (int i = 0; i < N_READS; ++i) {
7979
float t = static_cast<float>(xn[i]);
8080
normalizer += t * t;
@@ -130,7 +130,7 @@ __global__ void rms_norm_vjp(
130130
T wn[N_READS] = {};
131131
T gn[N_READS] = {};
132132
auto index = r * BLOCK_DIM + block.thread_rank();
133-
cub::LoadDirectBlocked(index, x, xn, axis_size, 0);
133+
cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to<T>(0));
134134
cub::LoadDirectBlocked(index, g, gn, axis_size);
135135
cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size);
136136
for (int i = 0; i < N_READS; i++) {

0 commit comments

Comments
 (0)