Skip to content

Commit 8dfc376

Browse files
authored
Strided reduce specialization for small reductions (#826)
* Add small column / general reduction specialization
1 parent 1efee9d commit 8dfc376

File tree

2 files changed

+123
-8
lines changed

2 files changed

+123
-8
lines changed

mlx/backend/metal/kernels/reduction/kernels/reduce_col.metal

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,69 @@
66

77
using namespace metal;
88

9+
///////////////////////////////////////////////////////////////////////////////
10+
// Small column reduce kernel
11+
///////////////////////////////////////////////////////////////////////////////
12+
13+
template <typename T, typename U, typename Op>
14+
[[kernel]] void col_reduce_small(
15+
const device T *in [[buffer(0)]],
16+
device U *out [[buffer(1)]],
17+
const constant size_t& reduction_size [[buffer(2)]],
18+
const constant size_t& reduction_stride [[buffer(3)]],
19+
const constant size_t& out_size [[buffer(4)]],
20+
const constant int* shape [[buffer(5)]],
21+
const constant size_t* strides [[buffer(6)]],
22+
const constant int& ndim [[buffer(7)]],
23+
const constant size_t& non_col_reductions [[buffer(8)]],
24+
const constant int* non_col_shapes [[buffer(9)]],
25+
const constant size_t* non_col_strides [[buffer(10)]],
26+
const constant int& non_col_ndim [[buffer(11)]],
27+
uint tid [[thread_position_in_grid]]) {
28+
29+
// Appease the compiler
30+
(void)out_size;
31+
32+
Op op;
33+
U total_val = Op::init;
34+
35+
auto out_idx = tid;
36+
37+
in += elem_to_loc(
38+
out_idx,
39+
shape + non_col_ndim,
40+
strides + non_col_ndim,
41+
ndim - non_col_ndim);
42+
43+
for(uint i = 0; i < non_col_reductions; i++) {
44+
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
45+
46+
for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
47+
U val = static_cast<U>(in[in_idx]);
48+
total_val = op(total_val, val);
49+
}
50+
}
51+
52+
out[out_idx] = total_val;
53+
}
54+
55+
#define instantiate_col_reduce_small(name, itype, otype, op) \
56+
template [[host_name("col_reduce_small_" #name)]] \
57+
[[kernel]] void col_reduce_small<itype, otype, op>( \
58+
const device itype *in [[buffer(0)]], \
59+
device otype *out [[buffer(1)]], \
60+
const constant size_t& reduction_size [[buffer(2)]], \
61+
const constant size_t& reduction_stride [[buffer(3)]], \
62+
const constant size_t& out_size [[buffer(4)]], \
63+
const constant int* shape [[buffer(5)]], \
64+
const constant size_t* strides [[buffer(6)]], \
65+
const constant int& ndim [[buffer(7)]], \
66+
const constant size_t& non_col_reductions [[buffer(8)]], \
67+
const constant int* non_col_shapes [[buffer(9)]], \
68+
const constant size_t* non_col_strides [[buffer(10)]], \
69+
const constant int& non_col_ndim [[buffer(11)]], \
70+
uint tid [[thread_position_in_grid]]);
71+
972
///////////////////////////////////////////////////////////////////////////////
1073
// Column reduce helper
1174
///////////////////////////////////////////////////////////////////////////////
@@ -171,14 +234,20 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
171234
///////////////////////////////////////////////////////////////////////////////
172235

173236
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
237+
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
174238
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
175239

176240
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
241+
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
177242
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
178243

179244
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
180245
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
181246

182247
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
183248
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
184-
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
249+
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
250+
251+
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
252+
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
253+
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or)

mlx/backend/metal/reduce.cpp

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,6 @@ void strided_reduce_general_dispatch(
307307
metal::Device& d,
308308
const Stream& s) {
309309
Dtype out_dtype = out.dtype();
310-
bool is_out_64b_int = is_64b_int(out_dtype);
311-
auto kernel = (is_out_64b_int)
312-
? d.get_kernel(
313-
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
314-
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
315-
316-
compute_encoder->setComputePipelineState(kernel);
317310

318311
// Prepare the arguments for the kernel
319312
size_t reduction_size = plan.shape.back();
@@ -327,6 +320,11 @@ void strided_reduce_general_dispatch(
327320
for (auto s : shape) {
328321
non_col_reductions *= static_cast<size_t>(s);
329322
}
323+
324+
std::vector<int> non_col_shapes = shape;
325+
std::vector<size_t> non_col_strides = strides;
326+
int non_col_ndim = shape.size();
327+
330328
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
331329
for (auto s : rem_shape) {
332330
shape.push_back(s);
@@ -336,6 +334,54 @@ void strided_reduce_general_dispatch(
336334
}
337335
int ndim = shape.size();
338336

337+
// Specialize for small dims
338+
if (reduction_size * non_col_reductions < 16) {
339+
// Select kernel
340+
auto kernel =
341+
d.get_kernel("col_reduce_small_" + op_name + type_to_name(in));
342+
compute_encoder->setComputePipelineState(kernel);
343+
344+
// Select block dims
345+
MTL::Size grid_dims = MTL::Size(out_size, 1, 1);
346+
MTL::Size group_dims = MTL::Size(256ul, 1, 1);
347+
348+
if (non_col_ndim == 0) {
349+
non_col_shapes = {1};
350+
non_col_strides = {1};
351+
}
352+
353+
// Encode arrays
354+
set_array_buffer(compute_encoder, in, 0);
355+
set_array_buffer(compute_encoder, out, 1);
356+
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
357+
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
358+
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
359+
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
360+
compute_encoder->setBytes(
361+
strides.data(), strides.size() * sizeof(size_t), 6);
362+
compute_encoder->setBytes(&ndim, sizeof(int), 7);
363+
compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 8);
364+
compute_encoder->setBytes(
365+
non_col_shapes.data(), non_col_shapes.size() * sizeof(int), 9);
366+
compute_encoder->setBytes(
367+
non_col_strides.data(), non_col_shapes.size() * sizeof(size_t), 10);
368+
compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11);
369+
370+
// Dispatch threads
371+
compute_encoder->dispatchThreads(grid_dims, group_dims);
372+
373+
return;
374+
}
375+
376+
// Select kernel
377+
bool is_out_64b_int = is_64b_int(out_dtype);
378+
auto kernel = (is_out_64b_int)
379+
? d.get_kernel(
380+
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
381+
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
382+
383+
compute_encoder->setComputePipelineState(kernel);
384+
339385
// Select block dimensions
340386
// Each thread reads 16 inputs to give it more work
341387
uint n_inputs_per_thread = REDUCE_N_READS;

0 commit comments

Comments
 (0)