@@ -11,8 +11,6 @@ template <typename U>
1111struct IndexValPair {
1212 uint32_t index;
1313 U val;
14-
15- IndexValPair (uint32_t _index, U _val) : index(_index), val(_val) {}
1614};
1715
1816template <typename U>
@@ -65,10 +63,10 @@ struct ArgMax {
6563
6664template <typename U>
6765IndexValPair<U> simd_shuffle_down (IndexValPair<U> data, uint16_t delta) {
68- return IndexValPair<U>(
66+ return IndexValPair<U>{
6967 simd_shuffle_down (data.index , delta),
7068 simd_shuffle_down (data.val , delta)
71- ) ;
69+ } ;
7270}
7371
7472
@@ -82,7 +80,6 @@ template <typename T, typename Op, int N_READS>
8280 const device size_t& ndim [[buffer(5 )]],
8381 const device size_t& axis_stride [[buffer(6 )]],
8482 const device size_t& axis_size [[buffer(7 )]],
85- threadgroup IndexValPair<T> *local_data [[threadgroup(0 )]],
8683 uint gid [[thread_position_in_grid]],
8784 uint lid [[thread_position_in_threadgroup]],
8885 uint lsize [[threads_per_threadgroup]],
@@ -111,7 +108,9 @@ template <typename T, typename Op, int N_READS>
111108 auto in_idx = elem_to_loc (gid / lsize, shape, in_strides, ndim);
112109 auto out_idx = elem_to_loc (gid / lsize, shape, out_strides, ndim);
113110
114- IndexValPair<T> best (0 , Op::init);
111+ IndexValPair<T> best{0 , Op::init};
112+
113+ threadgroup IndexValPair<T> local_data[32 ];
115114
116115 // Loop over the reduction axis in lsize*N_READS buckets
117116 for (uint r=0 ; r < ceildiv (axis_size, N_READS*lsize); r++) {
@@ -172,7 +171,6 @@ template <typename T, typename Op, int N_READS>
172171 const device size_t & ndim [[buffer(5 )]], \
173172 const device size_t & axis_stride [[buffer(6 )]], \
174173 const device size_t & axis_size [[buffer(7 )]], \
175- threadgroup IndexValPair<itype> *local_data [[threadgroup(0 )]], \
176174 uint gid [[thread_position_in_grid]], \
177175 uint lid [[thread_position_in_threadgroup]], \
178176 uint lsize [[threads_per_threadgroup]], \
0 commit comments