@@ -37,59 +37,61 @@ struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
3737
3838template <typename T, enable_if_t <is_metal_atomic<T>, bool > = true >
3939METAL_FUNC T
40- mlx_atomic_load_explicit (device mlx_atomic<T>* object, uint offset) {
40+ mlx_atomic_load_explicit (device mlx_atomic<T>* object, size_t offset) {
4141 return atomic_load_explicit (&(object[offset].val ), memory_order_relaxed);
4242}
4343
4444template <typename T, enable_if_t <is_metal_atomic<T>, bool > = true >
4545METAL_FUNC void
46- mlx_atomic_store_explicit (device mlx_atomic<T>* object, T val, uint offset) {
46+ mlx_atomic_store_explicit (device mlx_atomic<T>* object, T val, size_t offset) {
4747 atomic_store_explicit (&(object[offset].val ), val, memory_order_relaxed);
4848}
4949
5050template <typename T, enable_if_t <is_metal_atomic<T>, bool > = true >
5151METAL_FUNC void mlx_atomic_fetch_and_explicit (
5252 device mlx_atomic<T>* object,
5353 T val,
54- uint offset) {
54+ size_t offset) {
5555 atomic_fetch_and_explicit (&(object[offset].val ), val, memory_order_relaxed);
5656}
5757
5858template <typename T, enable_if_t <is_metal_atomic<T>, bool > = true >
59- METAL_FUNC void
60- mlx_atomic_fetch_or_explicit (device mlx_atomic<T>* object, T val, uint offset) {
59+ METAL_FUNC void mlx_atomic_fetch_or_explicit (
60+ device mlx_atomic<T>* object,
61+ T val,
62+ size_t offset) {
6163 atomic_fetch_or_explicit (&(object[offset].val ), val, memory_order_relaxed);
6264}
6365
6466template <typename T, enable_if_t <is_metal_atomic<T>, bool > = true >
6567METAL_FUNC void mlx_atomic_fetch_min_explicit (
6668 device mlx_atomic<T>* object,
6769 T val,
68- uint offset) {
70+ size_t offset) {
6971 atomic_fetch_min_explicit (&(object[offset].val ), val, memory_order_relaxed);
7072}
7173
7274template <typename T, enable_if_t <is_metal_atomic<T>, bool > = true >
7375METAL_FUNC void mlx_atomic_fetch_max_explicit (
7476 device mlx_atomic<T>* object,
7577 T val,
76- uint offset) {
78+ size_t offset) {
7779 atomic_fetch_max_explicit (&(object[offset].val ), val, memory_order_relaxed);
7880}
7981
8082template <typename T, enable_if_t <is_metal_atomic<T>, bool > = true >
8183METAL_FUNC void mlx_atomic_fetch_add_explicit (
8284 device mlx_atomic<T>* object,
8385 T val,
84- uint offset) {
86+ size_t offset) {
8587 atomic_fetch_add_explicit (&(object[offset].val ), val, memory_order_relaxed);
8688}
8789
8890template <typename T, enable_if_t <is_metal_atomic<T>, bool > = true >
8991METAL_FUNC void mlx_atomic_fetch_mul_explicit (
9092 device mlx_atomic<T>* object,
9193 T val,
92- uint offset) {
94+ size_t offset) {
9395 T expected = mlx_atomic_load_explicit (object, offset);
9496 while (!mlx_atomic_compare_exchange_weak_explicit (
9597 object, &expected, val * expected, offset)) {
@@ -101,7 +103,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
101103 device mlx_atomic<T>* object,
102104 thread T* expected,
103105 T val,
104- uint offset) {
106+ size_t offset) {
105107 return atomic_compare_exchange_weak_explicit (
106108 &(object[offset].val ),
107109 expected,
@@ -115,7 +117,7 @@ template <>
115117METAL_FUNC void mlx_atomic_fetch_min_explicit<float >(
116118 device mlx_atomic<float >* object,
117119 float val,
118- uint offset) {
120+ size_t offset) {
119121 float expected = mlx_atomic_load_explicit (object, offset);
120122 while (val < expected) {
121123 if (mlx_atomic_compare_exchange_weak_explicit (
@@ -130,7 +132,7 @@ template <>
130132METAL_FUNC void mlx_atomic_fetch_max_explicit<float >(
131133 device mlx_atomic<float >* object,
132134 float val,
133- uint offset) {
135+ size_t offset) {
134136 float expected = mlx_atomic_load_explicit (object, offset);
135137 while (val > expected) {
136138 if (mlx_atomic_compare_exchange_weak_explicit (
@@ -157,7 +159,7 @@ union uint_or_packed {
157159
158160template <typename T, typename Op>
159161struct mlx_atomic_update_helper {
160- uint operator ()(uint_or_packed<T> init, T update, uint elem_offset) {
162+ uint operator ()(uint_or_packed<T> init, T update, size_t elem_offset) {
161163 Op op;
162164 init.val [elem_offset] = op (update, init.val [elem_offset]);
163165 return init.bits ;
@@ -168,9 +170,9 @@ template <typename T, typename Op>
168170METAL_FUNC void mlx_atomic_update_and_store (
169171 device mlx_atomic<T>* object,
170172 T update,
171- uint offset) {
172- uint pack_offset = offset / packing_size<T>;
173- uint elem_offset = offset % packing_size<T>;
173+ size_t offset) {
174+ size_t pack_offset = offset / packing_size<T>;
175+ size_t elem_offset = offset % packing_size<T>;
174176
175177 mlx_atomic_update_helper<T, Op> helper;
176178 uint_or_packed<T> expected;
@@ -251,9 +253,9 @@ struct __Min {
251253
252254template <typename T, enable_if_t <!is_metal_atomic<T>, bool > = true >
253255METAL_FUNC T
254- mlx_atomic_load_explicit (device mlx_atomic<T>* object, uint offset) {
255- uint pack_offset = offset / sizeof (T);
256- uint elem_offset = offset % sizeof (T);
256+ mlx_atomic_load_explicit (device mlx_atomic<T>* object, size_t offset) {
257+ size_t pack_offset = offset / sizeof (T);
258+ size_t elem_offset = offset % sizeof (T);
257259 uint_or_packed<T> packed_val;
258260 packed_val.bits =
259261 atomic_load_explicit (&(object[pack_offset].val ), memory_order_relaxed);
@@ -262,17 +264,17 @@ mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
262264
263265template <typename T, enable_if_t <!is_metal_atomic<T>, bool > = true >
264266METAL_FUNC void
265- mlx_atomic_store_explicit (device mlx_atomic<T>* object, T val, uint offset) {
267+ mlx_atomic_store_explicit (device mlx_atomic<T>* object, T val, size_t offset) {
266268 mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
267269}
268270
269271template <typename T, enable_if_t <!is_metal_atomic<T>, bool > = true >
270272METAL_FUNC void mlx_atomic_fetch_and_explicit (
271273 device mlx_atomic<T>* object,
272274 T val,
273- uint offset) {
274- uint pack_offset = offset / packing_size<T>;
275- uint elem_offset = offset % packing_size<T>;
275+ size_t offset) {
276+ size_t pack_offset = offset / packing_size<T>;
277+ size_t elem_offset = offset % packing_size<T>;
276278 uint_or_packed<T> identity;
277279 identity.bits = __UINT32_MAX__;
278280 identity.val [elem_offset] = val;
@@ -282,10 +284,12 @@ METAL_FUNC void mlx_atomic_fetch_and_explicit(
282284}
283285
284286template <typename T, enable_if_t <!is_metal_atomic<T>, bool > = true >
285- METAL_FUNC void
286- mlx_atomic_fetch_or_explicit (device mlx_atomic<T>* object, T val, uint offset) {
287- uint pack_offset = offset / packing_size<T>;
288- uint elem_offset = offset % packing_size<T>;
287+ METAL_FUNC void mlx_atomic_fetch_or_explicit (
288+ device mlx_atomic<T>* object,
289+ T val,
290+ size_t offset) {
291+ size_t pack_offset = offset / packing_size<T>;
292+ size_t elem_offset = offset % packing_size<T>;
289293 uint_or_packed<T> identity;
290294 identity.bits = 0 ;
291295 identity.val [elem_offset] = val;
@@ -298,31 +302,31 @@ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
298302METAL_FUNC void mlx_atomic_fetch_min_explicit (
299303 device mlx_atomic<T>* object,
300304 T val,
301- uint offset) {
305+ size_t offset) {
302306 mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
303307}
304308
305309template <typename T, enable_if_t <!is_metal_atomic<T>, bool > = true >
306310METAL_FUNC void mlx_atomic_fetch_max_explicit (
307311 device mlx_atomic<T>* object,
308312 T val,
309- uint offset) {
313+ size_t offset) {
310314 mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
311315}
312316
313317template <typename T, enable_if_t <!is_metal_atomic<T>, bool > = true >
314318METAL_FUNC void mlx_atomic_fetch_add_explicit (
315319 device mlx_atomic<T>* object,
316320 T val,
317- uint offset) {
321+ size_t offset) {
318322 mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
319323}
320324
321325template <typename T, enable_if_t <!is_metal_atomic<T>, bool > = true >
322326METAL_FUNC void mlx_atomic_fetch_mul_explicit (
323327 device mlx_atomic<T>* object,
324328 T val,
325- uint offset) {
329+ size_t offset) {
326330 mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
327331}
328332
@@ -331,7 +335,7 @@ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
331335 device mlx_atomic<T>* object,
332336 thread uint* expected,
333337 uint val,
334- uint offset) {
338+ size_t offset) {
335339 return atomic_compare_exchange_weak_explicit (
336340 &(object[offset].val ),
337341 expected,
0 commit comments