Skip to content

Commit b57a528

Browse files
authored
Further reduction tuning (#1349)
* More reduction tuning * Forgotten pdb * Small column long row specialization
1 parent da8deb2 commit b57a528

File tree

4 files changed

+77
-29
lines changed

4 files changed

+77
-29
lines changed

mlx/backend/metal/kernels/reduction/reduce_col.h

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ template <
2828
looped_elem_to_loc<NDIMS> loop;
2929
const device T* row;
3030

31-
// Case 1:
32-
// reduction_stride is small, reduction_size is small and non_col_reductions
33-
// is small. Each thread computes reduction_stride outputs.
34-
if (reduction_size * non_col_reductions < 64) {
31+
// Case 1: Small row small column
32+
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
3533
U totals[31];
3634
for (int i = 0; i < 31; i++) {
3735
totals[i] = Op::init;
@@ -71,10 +69,55 @@ template <
7169
}
7270
}
7371

74-
// Case 2:
75-
// Reduction stride is small but everything else can be big. We loop both
76-
// across reduction size and non_col_reductions. Each simdgroup produces
77-
// N_READS outputs.
72+
// Case 2: Long row small column
73+
else if (reduction_size * non_col_reductions < 32) {
74+
U totals[N_READS];
75+
for (int i = 0; i < N_READS; i++) {
76+
totals[i] = Op::init;
77+
}
78+
79+
short size = reduction_size;
80+
size_t offset = size_t(tid.x) * N_READS;
81+
bool safe = offset + N_READS <= reduction_stride;
82+
short extra = reduction_stride - offset;
83+
84+
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
85+
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
86+
87+
for (uint r = 0; r < non_col_reductions; r++) {
88+
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
89+
90+
if (safe) {
91+
for (short i = 0; i < size; i++) {
92+
for (short j = 0; j < N_READS; j++) {
93+
totals[j] =
94+
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
95+
}
96+
}
97+
} else {
98+
for (short i = 0; i < size; i++) {
99+
for (short j = 0; j < extra; j++) {
100+
totals[j] =
101+
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
102+
}
103+
}
104+
}
105+
106+
loop.next(reduce_shape, reduce_strides);
107+
}
108+
out += out_idx * reduction_stride + offset;
109+
if (safe) {
110+
for (short i = 0; i < N_READS; i++) {
111+
out[i] = totals[i];
112+
}
113+
} else {
114+
for (short i = 0; i < extra; i++) {
115+
out[i] = totals[i];
116+
}
117+
}
118+
}
119+
120+
// Case 3: Long row medium column
78121
else {
79122
threadgroup U shared_vals[1024];
80123
U totals[N_READS];
@@ -147,17 +190,13 @@ template <
147190
/**
148191
* Our approach is the following simple looped approach:
149192
* 1. Each thread keeps running totals for BN / n_simdgroups outputs.
150-
* 2. Load a tile BM, BN in shared memory.
151-
* 3. Add the values from shared memory to the current running totals.
152-
* Neighboring threads access different rows (transposed acces).
153-
* 4. Move ahead to the next tile until the M axis is exhausted.
154-
* 5. Move ahead to the next non column reduction
155-
* 6. Simd reduce the running totals
193+
* 2. Load a tile BM, BN in registers and accumulate in the running totals
194+
* 3. Move ahead by BM steps until the column axis and the non column
195+
* reductions are exhausted.
196+
* 6. If BM == 32 then transpose in SM and simd reduce the running totals.
197+
* Otherwise write in shared memory and BN threads accumulate the running
198+
* totals with a loop.
156199
* 7. Write them to the output
157-
*
158-
* The kernel becomes verbose because we support all kinds of OOB checks. For
159-
* instance if we choose that reduction_stride must be larger than BN then we
160-
* can get rid of half the kernel.
161200
*/
162201
template <
163202
typename T,

mlx/backend/metal/reduce.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ inline int threadgroup_size_from_row_size(int row_size) {
202202

203203
// 2 simdgroups per row for medium rows
204204
if (row_size <= 1024) {
205-
return 64;
205+
return 128;
206206
}
207207

208208
// up to 32 simdgroups after that
@@ -458,14 +458,25 @@ void strided_reduce_small(
458458
// Figure out the grid dims
459459
MTL::Size grid_dims, group_dims;
460460

461-
// Case 1: everything is small so launch one thread per col reduce
462-
if (args.reduction_size * args.non_col_reductions < 64) {
461+
// Case 1: Small row small column
462+
if (args.reduction_size * args.non_col_reductions < 64 &&
463+
args.reduction_stride < 32) {
463464
grid_dims = output_grid_for_col_reduce(out, args);
464465
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
465466
group_dims = MTL::Size(threadgroup_size, 1, 1);
466467
}
467468

468-
// Case 2: Reduction in the simdgroup
469+
// Case 2: Long row small column
470+
else if (args.reduction_size * args.non_col_reductions < 32) {
471+
auto out_grid_dims = output_grid_for_col_reduce(out, args);
472+
int threads_x =
473+
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
474+
int threadgroup_x = std::min(threads_x, 128);
475+
grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height);
476+
group_dims = MTL::Size(threadgroup_x, 1, 1);
477+
}
478+
479+
// Case 3: Long row medium column
469480
else {
470481
args.reduce_shape.push_back(args.reduction_size);
471482
args.reduce_strides.push_back(args.reduction_stride);
@@ -508,7 +519,7 @@ void strided_reduce_looped(
508519

509520
// Figure out the grid dims
510521
auto out_grid_size = output_grid_for_col_reduce(out, args);
511-
int BN = (args.reduction_stride <= 256) ? 32 : 128;
522+
int BN = (args.reduction_stride <= 1024) ? 32 : 128;
512523
int BM = 1024 / BN;
513524
int threadgroup_size = 4 * 32;
514525
MTL::Size grid_dims(
@@ -544,7 +555,8 @@ void strided_reduce_general_dispatch(
544555
// Prepare the arguments for the kernel
545556
ColReduceArgs args(in, plan, axes);
546557

547-
if (args.reduction_stride < 32) {
558+
if (args.reduction_stride < 32 ||
559+
args.reduction_size * args.non_col_reductions < 32) {
548560
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
549561
}
550562

mlx/backend/metal/reduce.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ void all_reduce_dispatch(
1616
const std::string& op_name,
1717
CommandEncoder& compute_encoder,
1818
metal::Device& d,
19-
const Stream& s);
19+
const Stream& s,
20+
std::vector<array>& copies);
2021

2122
void row_reduce_general_dispatch(
2223
const array& in,

python/tests/test_reduce.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ def test_expand_sums(self):
4343
z_npy = np.sum(y_npy, axis=a) / 1000
4444
z_mlx = mx.sum(y_mlx, axis=a) / 1000
4545
mx.eval(z_mlx)
46-
if not np.allclose(z_npy, np.array(z_mlx), atol=1e-4):
47-
import pdb
48-
49-
pdb.set_trace()
5046
self.assertTrue(
5147
np.allclose(z_npy, np.array(z_mlx), atol=1e-4)
5248
)

0 commit comments

Comments
 (0)