Skip to content

Commit 3d5e17e

Browse files
authored
MLX_SWITCH macros to templates (#2320)
1 parent 33bf1a2 commit 3d5e17e

27 files changed

+702
-701
lines changed

mlx/backend/cuda/arg_reduce.cu

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -152,35 +152,29 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
152152
encoder.set_input_array(in);
153153
encoder.set_output_array(out);
154154
encoder.launch_kernel([&](cudaStream_t stream) {
155-
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
156-
using InType = cuda_type_t<CTYPE>;
155+
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
156+
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
157157
constexpr uint32_t N_READS = 4;
158-
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
159-
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
160-
dim3 block_dims{BLOCK_DIM, 1, 1};
161-
auto kernel = &cu::arg_reduce_general<
162-
InType,
163-
cu::ArgMax<InType>,
164-
BLOCK_DIM,
165-
N_READS>;
166-
if (reduce_type_ == ArgReduce::ArgMin) {
167-
kernel = &cu::arg_reduce_general<
168-
InType,
169-
cu::ArgMin<InType>,
170-
BLOCK_DIM,
171-
N_READS>;
172-
}
173-
kernel<<<num_blocks, block_dims, 0, stream>>>(
174-
in.data<InType>(),
175-
out.data<uint32_t>(),
176-
out.size(),
177-
const_param(shape),
178-
const_param(in_strides),
179-
const_param(out_strides),
180-
ndim,
181-
axis_stride,
182-
axis_size);
183-
});
158+
dispatch_block_dim(
159+
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
160+
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
161+
auto kernel =
162+
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
163+
if (reduce_type_ == ArgReduce::ArgMin) {
164+
kernel = cu::
165+
arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
166+
}
167+
kernel<<<num_blocks, block_dim(), 0, stream>>>(
168+
in.data<T>(),
169+
out.data<uint32_t>(),
170+
out.size(),
171+
const_param(shape),
172+
const_param(in_strides),
173+
const_param(out_strides),
174+
ndim,
175+
axis_stride,
176+
axis_size);
177+
});
184178
});
185179
});
186180
}

mlx/backend/cuda/binary.cu

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -140,54 +140,64 @@ void binary_op_gpu_inplace(
140140
encoder.set_input_array(b);
141141
encoder.set_output_array(out);
142142
encoder.launch_kernel([&](cudaStream_t stream) {
143-
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
144-
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
143+
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
144+
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
145+
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
146+
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
145147
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
146148
using InType = cuda_type_t<CTYPE_IN>;
147149
using OutType = cuda_type_t<CTYPE_OUT>;
148150
auto bopt = get_binary_op_type(a, b);
149151
if (bopt == BinaryOpType::General) {
150-
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
151-
auto& a_strides = strides[0];
152-
auto& b_strides = strides[1];
153-
bool large = a.data_size() > INT32_MAX ||
154-
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
155-
MLX_SWITCH_BOOL(large, LARGE, {
156-
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
157-
int ndim = shape.size();
158-
if (ndim <= 3) {
159-
MLX_SWITCH_1_2_3(ndim, NDIM, {
160-
auto kernel =
161-
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
162-
auto [num_blocks, block_dims] =
163-
get_launch_args(kernel, out, large);
164-
kernel<<<num_blocks, block_dims, 0, stream>>>(
165-
a.data<InType>(),
166-
b.data<InType>(),
167-
out.data<OutType>(),
168-
out.size(),
169-
const_param<NDIM>(shape),
170-
const_param<NDIM>(a_strides),
171-
const_param<NDIM>(b_strides));
152+
dispatch_bool(
153+
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
154+
out.data_size() > INT32_MAX,
155+
[&](auto large) {
156+
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
157+
Shape shape;
158+
std::vector<Strides> strides;
159+
std::tie(shape, strides) =
160+
collapse_contiguous_dims(a, b, out);
161+
auto& a_strides = strides[0];
162+
auto& b_strides = strides[1];
163+
int ndim = shape.size();
164+
if (ndim <= 3) {
165+
dispatch_1_2_3(ndim, [&](auto dims_constant) {
166+
auto kernel = cu::binary_g_nd<
167+
Op,
168+
InType,
169+
OutType,
170+
IdxT,
171+
dims_constant()>;
172+
auto [num_blocks, block_dims] =
173+
get_launch_args(kernel, out, large());
174+
kernel<<<num_blocks, block_dims, 0, stream>>>(
175+
a.data<InType>(),
176+
b.data<InType>(),
177+
out.data<OutType>(),
178+
out.size(),
179+
const_param<dims_constant()>(shape),
180+
const_param<dims_constant()>(a_strides),
181+
const_param<dims_constant()>(b_strides));
182+
});
183+
} else {
184+
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
185+
auto [num_blocks, block_dims] =
186+
get_launch_args(kernel, out, large());
187+
kernel<<<num_blocks, block_dims, 0, stream>>>(
188+
a.data<InType>(),
189+
b.data<InType>(),
190+
out.data<OutType>(),
191+
out.size(),
192+
const_param(shape),
193+
const_param(a_strides),
194+
const_param(b_strides),
195+
ndim);
196+
}
172197
});
173-
} else {
174-
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
175-
auto [num_blocks, block_dims] =
176-
get_launch_args(kernel, out, large);
177-
kernel<<<num_blocks, block_dims, 0, stream>>>(
178-
a.data<InType>(),
179-
b.data<InType>(),
180-
out.data<OutType>(),
181-
out.size(),
182-
const_param(shape),
183-
const_param(a_strides),
184-
const_param(b_strides),
185-
ndim);
186-
}
187-
});
188198
} else {
189-
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
190-
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
199+
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
200+
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
191201
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
192202
if (bopt == BinaryOpType::ScalarVector) {
193203
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
@@ -197,7 +207,7 @@ void binary_op_gpu_inplace(
197207
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
198208
}
199209
auto [num_blocks, block_dims] = get_launch_args(
200-
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
210+
kernel, out.data_size(), out.shape(), out.strides(), large());
201211
kernel<<<num_blocks, block_dims, 0, stream>>>(
202212
a.data<InType>(),
203213
b.data<InType>(),

mlx/backend/cuda/binary_two.cu

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -138,57 +138,67 @@ void binary_op_gpu_inplace(
138138
encoder.set_output_array(out_a);
139139
encoder.set_output_array(out_b);
140140
encoder.launch_kernel([&](cudaStream_t stream) {
141-
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
142-
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
141+
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
142+
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
143+
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
144+
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
143145
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
144146
using InType = cuda_type_t<CTYPE_IN>;
145147
using OutType = cuda_type_t<CTYPE_OUT>;
146148

147149
auto bopt = get_binary_op_type(a, b);
148150
if (bopt == BinaryOpType::General) {
149-
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
150-
auto& a_strides = strides[0];
151-
auto& b_strides = strides[1];
152-
bool large = a.data_size() > INT32_MAX ||
153-
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
154-
MLX_SWITCH_BOOL(large, LARGE, {
155-
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
156-
int ndim = shape.size();
157-
if (ndim <= 3) {
158-
MLX_SWITCH_1_2_3(ndim, NDIM, {
159-
auto kernel =
160-
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
161-
auto [num_blocks, block_dims] =
162-
get_launch_args(kernel, out_a, large);
163-
kernel<<<num_blocks, block_dims, 0, stream>>>(
164-
a.data<InType>(),
165-
b.data<InType>(),
166-
out_a.data<OutType>(),
167-
out_b.data<OutType>(),
168-
out_a.size(),
169-
const_param<NDIM>(shape),
170-
const_param<NDIM>(a_strides),
171-
const_param<NDIM>(b_strides));
151+
dispatch_bool(
152+
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
153+
out_a.data_size() > INT32_MAX,
154+
[&](auto large) {
155+
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
156+
Shape shape;
157+
std::vector<Strides> strides;
158+
std::tie(shape, strides) =
159+
collapse_contiguous_dims(a, b, out_a);
160+
auto& a_strides = strides[0];
161+
auto& b_strides = strides[1];
162+
int ndim = shape.size();
163+
if (ndim <= 3) {
164+
dispatch_1_2_3(ndim, [&](auto dims_constant) {
165+
auto kernel = cu::binary_g_nd<
166+
Op,
167+
InType,
168+
OutType,
169+
IdxT,
170+
dims_constant()>;
171+
auto [num_blocks, block_dims] =
172+
get_launch_args(kernel, out_a, large());
173+
kernel<<<num_blocks, block_dims, 0, stream>>>(
174+
a.data<InType>(),
175+
b.data<InType>(),
176+
out_a.data<OutType>(),
177+
out_b.data<OutType>(),
178+
out_a.size(),
179+
const_param<dims_constant()>(shape),
180+
const_param<dims_constant()>(a_strides),
181+
const_param<dims_constant()>(b_strides));
182+
});
183+
} else {
184+
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
185+
auto [num_blocks, block_dims] =
186+
get_launch_args(kernel, out_a, large());
187+
kernel<<<num_blocks, block_dims, 0, stream>>>(
188+
a.data<InType>(),
189+
b.data<InType>(),
190+
out_a.data<OutType>(),
191+
out_b.data<OutType>(),
192+
out_a.size(),
193+
const_param(shape),
194+
const_param(a_strides),
195+
const_param(b_strides),
196+
ndim);
197+
}
172198
});
173-
} else {
174-
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
175-
auto [num_blocks, block_dims] =
176-
get_launch_args(kernel, out_a, large);
177-
kernel<<<num_blocks, block_dims, 0, stream>>>(
178-
a.data<InType>(),
179-
b.data<InType>(),
180-
out_a.data<OutType>(),
181-
out_b.data<OutType>(),
182-
out_a.size(),
183-
const_param(shape),
184-
const_param(a_strides),
185-
const_param(b_strides),
186-
ndim);
187-
}
188-
});
189199
} else {
190-
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
191-
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
200+
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
201+
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
192202
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
193203
if (bopt == BinaryOpType::ScalarVector) {
194204
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
@@ -202,7 +212,7 @@ void binary_op_gpu_inplace(
202212
out_a.data_size(),
203213
out_a.shape(),
204214
out_a.strides(),
205-
LARGE);
215+
large());
206216
kernel<<<num_blocks, block_dims, 0, stream>>>(
207217
a.data<InType>(),
208218
b.data<InType>(),

mlx/backend/cuda/copy/copy.cuh

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,6 @@
1010

1111
namespace mlx::core {
1212

13-
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
14-
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
15-
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
16-
using InType = cuda_type_t<CTYPE_IN>; \
17-
using OutType = cuda_type_t<CTYPE_OUT>; \
18-
__VA_ARGS__; \
19-
}); \
20-
})
21-
2213
void copy_contiguous(
2314
cu::CommandEncoder& encoder,
2415
CopyType ctype,

mlx/backend/cuda/copy/copy_contiguous.cu

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,23 @@ void copy_contiguous(
3636
int64_t in_offset,
3737
int64_t out_offset) {
3838
encoder.launch_kernel([&](cudaStream_t stream) {
39-
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
40-
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
41-
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
42-
auto kernel = cu::copy_s<InType, OutType, IdxT>;
43-
if (ctype == CopyType::Vector) {
44-
kernel = cu::copy_v<InType, OutType, IdxT>;
45-
}
46-
auto [num_blocks, block_dims] = get_launch_args(
47-
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
48-
kernel<<<num_blocks, block_dims, 0, stream>>>(
49-
in.data<InType>() + in_offset,
50-
out.data<OutType>() + out_offset,
51-
out.data_size());
39+
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
40+
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
41+
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
42+
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
43+
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
44+
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
45+
auto kernel = cu::copy_s<InType, OutType, IdxT>;
46+
if (ctype == CopyType::Vector) {
47+
kernel = cu::copy_v<InType, OutType, IdxT>;
48+
}
49+
auto [num_blocks, block_dims] = get_launch_args(
50+
kernel, out.data_size(), out.shape(), out.strides(), large());
51+
kernel<<<num_blocks, block_dims, 0, stream>>>(
52+
in.data<InType>() + in_offset,
53+
out.data<OutType>() + out_offset,
54+
out.data_size());
55+
});
5256
});
5357
});
5458
});

0 commit comments

Comments
 (0)