Skip to content

Commit 98b6ce3

Browse files
awniangeloskath
andauthored
Refactor reductions and fix scatter atomics for large sizes (#1300)
Co-authored-by: Angelos Katharopoulos <[email protected]>
1 parent f9e00ef commit 98b6ce3

File tree

18 files changed

+1567
-1218
lines changed

18 files changed

+1567
-1218
lines changed

mlx/backend/common/reduce.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ struct ReductionPlan {
4949
ReductionPlan(ReductionOpType type_) : type(type_) {}
5050
};
5151

52-
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
52+
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
5353

5454
// Helper for the ndimensional strided loop
5555
// Should this be in utils?

mlx/backend/common/reduce_utils.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
1919
return std::make_pair(shape, strides);
2020
}
2121

22-
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
22+
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
2323
// The data is all there and we are reducing over everything
2424
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
2525
x.flags().contiguous) {
@@ -41,6 +41,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
4141
}
4242
}
4343

44+
// Remove singleton axes from the plan
45+
for (int i = shape.size() - 1; i >= 0; i--) {
46+
if (shape[i] == 1) {
47+
shape.erase(shape.begin() + i);
48+
strides.erase(strides.begin() + i);
49+
}
50+
}
51+
4452
if (strides.back() == 1) {
4553
return ReductionPlan(ContiguousReduce, shape, strides);
4654
} else if (strides.back() > 1) {
@@ -63,10 +71,14 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
6371
// have a contiguous reduction.
6472
std::vector<std::pair<int, size_t>> reductions;
6573
for (auto a : axes) {
66-
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
74+
if (x.shape(a) > 1) {
75+
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
76+
}
6777
}
6878
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
69-
return a.second > b.second;
79+
bool a_is_zero = a.second == 0;
80+
bool b_is_zero = b.second == 0;
81+
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
7082
});
7183
// Extract the two smallest and try to merge them in case the contiguous
7284
// reduction can be bigger than just the last axis.
@@ -98,16 +110,33 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
98110
// strides.back() are contiguous.
99111
if (strides.back() > 1) {
100112
int size = 1;
113+
bool have_expand = false;
101114
for (int i = x.ndim() - 1; i >= 0; i--) {
102115
if (axes.back() == i) {
103116
continue;
104117
}
105-
if (x.strides()[i] != size) {
118+
119+
size_t stride_i = x.strides()[i];
120+
int shape_i = x.shape(i);
121+
if (stride_i == 0) {
122+
if (shape_i == 1) {
123+
continue;
124+
}
125+
126+
have_expand = true;
127+
break;
128+
}
129+
130+
if (stride_i != size && shape_i != 1) {
106131
break;
107132
}
108-
size *= x.shape(i);
133+
size *= shape_i;
109134
}
110-
if (size >= strides.back()) {
135+
// In the case of an expanded dimension we are being conservative and
136+
// require the smallest reduction stride to be smaller than the maximum row
137+
// contiguous size. The reason is that we can't easily know if the reduced
138+
// axis is before or after an expanded dimension.
139+
if (size > strides.back() || (size == strides.back() && !have_expand)) {
111140
return ReductionPlan(GeneralStridedReduce, shape, strides);
112141
}
113142
}

mlx/backend/common/utils.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,33 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
104104
std::vector<array>{std::forward<Arrays>(xs)...});
105105
}
106106

107+
// The single array version of the above.
108+
inline std::tuple<std::vector<int>, std::vector<size_t>>
109+
collapse_contiguous_dims(
110+
const std::vector<int>& shape,
111+
const std::vector<size_t>& strides) {
112+
std::vector<int> collapsed_shape;
113+
std::vector<size_t> collapsed_strides;
114+
115+
if (shape.size() > 0) {
116+
collapsed_shape.push_back(shape[0]);
117+
collapsed_strides.push_back(strides[0]);
118+
for (int i = 1; i < shape.size(); i++) {
119+
if (strides[i] * shape[i] != collapsed_strides.back() ||
120+
collapsed_shape.back() * static_cast<size_t>(shape[i]) >
121+
std::numeric_limits<int>::max()) {
122+
collapsed_shape.push_back(shape[i]);
123+
collapsed_strides.push_back(strides[i]);
124+
} else {
125+
collapsed_shape.back() *= shape[i];
126+
collapsed_strides.back() = strides[i];
127+
}
128+
}
129+
}
130+
131+
return std::make_tuple(collapsed_shape, collapsed_strides);
132+
}
133+
107134
template <typename stride_t>
108135
inline auto check_contiguity(
109136
const std::vector<int>& shape,

mlx/backend/metal/kernels/atomic.h

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -37,59 +37,61 @@ struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
3737

3838
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
3939
METAL_FUNC T
40-
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
40+
mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
4141
return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
4242
}
4343

4444
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
4545
METAL_FUNC void
46-
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
46+
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
4747
atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
4848
}
4949

5050
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
5151
METAL_FUNC void mlx_atomic_fetch_and_explicit(
5252
device mlx_atomic<T>* object,
5353
T val,
54-
uint offset) {
54+
size_t offset) {
5555
atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
5656
}
5757

5858
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
59-
METAL_FUNC void
60-
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
59+
METAL_FUNC void mlx_atomic_fetch_or_explicit(
60+
device mlx_atomic<T>* object,
61+
T val,
62+
size_t offset) {
6163
atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
6264
}
6365

6466
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
6567
METAL_FUNC void mlx_atomic_fetch_min_explicit(
6668
device mlx_atomic<T>* object,
6769
T val,
68-
uint offset) {
70+
size_t offset) {
6971
atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
7072
}
7173

7274
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
7375
METAL_FUNC void mlx_atomic_fetch_max_explicit(
7476
device mlx_atomic<T>* object,
7577
T val,
76-
uint offset) {
78+
size_t offset) {
7779
atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
7880
}
7981

8082
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
8183
METAL_FUNC void mlx_atomic_fetch_add_explicit(
8284
device mlx_atomic<T>* object,
8385
T val,
84-
uint offset) {
86+
size_t offset) {
8587
atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
8688
}
8789

8890
template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
8991
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
9092
device mlx_atomic<T>* object,
9193
T val,
92-
uint offset) {
94+
size_t offset) {
9395
T expected = mlx_atomic_load_explicit(object, offset);
9496
while (!mlx_atomic_compare_exchange_weak_explicit(
9597
object, &expected, val * expected, offset)) {
@@ -101,7 +103,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
101103
device mlx_atomic<T>* object,
102104
thread T* expected,
103105
T val,
104-
uint offset) {
106+
size_t offset) {
105107
return atomic_compare_exchange_weak_explicit(
106108
&(object[offset].val),
107109
expected,
@@ -115,7 +117,7 @@ template <>
115117
METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
116118
device mlx_atomic<float>* object,
117119
float val,
118-
uint offset) {
120+
size_t offset) {
119121
float expected = mlx_atomic_load_explicit(object, offset);
120122
while (val < expected) {
121123
if (mlx_atomic_compare_exchange_weak_explicit(
@@ -130,7 +132,7 @@ template <>
130132
METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
131133
device mlx_atomic<float>* object,
132134
float val,
133-
uint offset) {
135+
size_t offset) {
134136
float expected = mlx_atomic_load_explicit(object, offset);
135137
while (val > expected) {
136138
if (mlx_atomic_compare_exchange_weak_explicit(
@@ -157,7 +159,7 @@ union uint_or_packed {
157159

158160
template <typename T, typename Op>
159161
struct mlx_atomic_update_helper {
160-
uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
162+
uint operator()(uint_or_packed<T> init, T update, size_t elem_offset) {
161163
Op op;
162164
init.val[elem_offset] = op(update, init.val[elem_offset]);
163165
return init.bits;
@@ -168,9 +170,9 @@ template <typename T, typename Op>
168170
METAL_FUNC void mlx_atomic_update_and_store(
169171
device mlx_atomic<T>* object,
170172
T update,
171-
uint offset) {
172-
uint pack_offset = offset / packing_size<T>;
173-
uint elem_offset = offset % packing_size<T>;
173+
size_t offset) {
174+
size_t pack_offset = offset / packing_size<T>;
175+
size_t elem_offset = offset % packing_size<T>;
174176

175177
mlx_atomic_update_helper<T, Op> helper;
176178
uint_or_packed<T> expected;
@@ -251,9 +253,9 @@ struct __Min {
251253

252254
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
253255
METAL_FUNC T
254-
mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
255-
uint pack_offset = offset / sizeof(T);
256-
uint elem_offset = offset % sizeof(T);
256+
mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
257+
size_t pack_offset = offset / sizeof(T);
258+
size_t elem_offset = offset % sizeof(T);
257259
uint_or_packed<T> packed_val;
258260
packed_val.bits =
259261
atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
@@ -262,17 +264,17 @@ mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
262264

263265
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
264266
METAL_FUNC void
265-
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
267+
mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
266268
mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
267269
}
268270

269271
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
270272
METAL_FUNC void mlx_atomic_fetch_and_explicit(
271273
device mlx_atomic<T>* object,
272274
T val,
273-
uint offset) {
274-
uint pack_offset = offset / packing_size<T>;
275-
uint elem_offset = offset % packing_size<T>;
275+
size_t offset) {
276+
size_t pack_offset = offset / packing_size<T>;
277+
size_t elem_offset = offset % packing_size<T>;
276278
uint_or_packed<T> identity;
277279
identity.bits = __UINT32_MAX__;
278280
identity.val[elem_offset] = val;
@@ -282,10 +284,12 @@ METAL_FUNC void mlx_atomic_fetch_and_explicit(
282284
}
283285

284286
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
285-
METAL_FUNC void
286-
mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
287-
uint pack_offset = offset / packing_size<T>;
288-
uint elem_offset = offset % packing_size<T>;
287+
METAL_FUNC void mlx_atomic_fetch_or_explicit(
288+
device mlx_atomic<T>* object,
289+
T val,
290+
size_t offset) {
291+
size_t pack_offset = offset / packing_size<T>;
292+
size_t elem_offset = offset % packing_size<T>;
289293
uint_or_packed<T> identity;
290294
identity.bits = 0;
291295
identity.val[elem_offset] = val;
@@ -298,31 +302,31 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
298302
METAL_FUNC void mlx_atomic_fetch_min_explicit(
299303
device mlx_atomic<T>* object,
300304
T val,
301-
uint offset) {
305+
size_t offset) {
302306
mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
303307
}
304308

305309
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
306310
METAL_FUNC void mlx_atomic_fetch_max_explicit(
307311
device mlx_atomic<T>* object,
308312
T val,
309-
uint offset) {
313+
size_t offset) {
310314
mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
311315
}
312316

313317
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
314318
METAL_FUNC void mlx_atomic_fetch_add_explicit(
315319
device mlx_atomic<T>* object,
316320
T val,
317-
uint offset) {
321+
size_t offset) {
318322
mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
319323
}
320324

321325
template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
322326
METAL_FUNC void mlx_atomic_fetch_mul_explicit(
323327
device mlx_atomic<T>* object,
324328
T val,
325-
uint offset) {
329+
size_t offset) {
326330
mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
327331
}
328332

@@ -331,7 +335,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
331335
device mlx_atomic<T>* object,
332336
thread uint* expected,
333337
uint val,
334-
uint offset) {
338+
size_t offset) {
335339
return atomic_compare_exchange_weak_explicit(
336340
&(object[offset].val),
337341
expected,

mlx/backend/metal/kernels/complex.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ struct complex64_t {
2323

2424
// Constructors
2525
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
26+
constexpr complex64_t() : real(0), imag(0) {};
27+
constexpr complex64_t() threadgroup : real(0), imag(0) {};
2628

2729
// Conversions to complex64_t
2830
template <

mlx/backend/metal/kernels/defines.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
#endif
1010

1111
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
12-
static MTL_CONST constexpr int REDUCE_N_READS = 16;
12+
static MTL_CONST constexpr int REDUCE_N_READS = 4;
13+
static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
1314
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
1415
static MTL_CONST constexpr int RMS_N_READS = 4;
1516
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;

0 commit comments

Comments
 (0)