Skip to content

Commit 884b4ed

Browse files
authored
Fix threadgroup memory in arg reduce (#723)
1 parent 972d9a3 commit 884b4ed

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

mlx/backend/metal/kernels/arg_reduce.metal

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ template <typename U>
1111
struct IndexValPair {
1212
uint32_t index;
1313
U val;
14-
15-
IndexValPair(uint32_t _index, U _val) : index(_index), val(_val) {}
1614
};
1715

1816
template <typename U>
@@ -65,10 +63,10 @@ struct ArgMax {
6563

6664
template <typename U>
6765
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
68-
return IndexValPair<U>(
66+
return IndexValPair<U>{
6967
simd_shuffle_down(data.index, delta),
7068
simd_shuffle_down(data.val, delta)
71-
);
69+
};
7270
}
7371

7472

@@ -82,7 +80,6 @@ template <typename T, typename Op, int N_READS>
8280
const device size_t& ndim [[buffer(5)]],
8381
const device size_t& axis_stride [[buffer(6)]],
8482
const device size_t& axis_size [[buffer(7)]],
85-
threadgroup IndexValPair<T> *local_data [[threadgroup(0)]],
8683
uint gid [[thread_position_in_grid]],
8784
uint lid [[thread_position_in_threadgroup]],
8885
uint lsize [[threads_per_threadgroup]],
@@ -111,7 +108,9 @@ template <typename T, typename Op, int N_READS>
111108
auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim);
112109
auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim);
113110

114-
IndexValPair<T> best(0, Op::init);
111+
IndexValPair<T> best{0, Op::init};
112+
113+
threadgroup IndexValPair<T> local_data[32];
115114

116115
// Loop over the reduction axis in lsize*N_READS buckets
117116
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
@@ -172,7 +171,6 @@ template <typename T, typename Op, int N_READS>
172171
const device size_t& ndim [[buffer(5)]], \
173172
const device size_t& axis_stride [[buffer(6)]], \
174173
const device size_t& axis_size [[buffer(7)]], \
175-
threadgroup IndexValPair<itype> *local_data [[threadgroup(0)]], \
176174
uint gid [[thread_position_in_grid]], \
177175
uint lid [[thread_position_in_threadgroup]], \
178176
uint lsize [[threads_per_threadgroup]], \

mlx/backend/metal/primitives.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,6 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
430430
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
431431
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
432432
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
433-
compute_encoder->setThreadgroupMemoryLength(
434-
simd_size * (sizeof(uint32_t) + in.itemsize()), 0);
435433
compute_encoder->dispatchThreads(grid_dims, group_dims);
436434
}
437435
}

0 commit comments

Comments
 (0)