@@ -98,7 +98,7 @@ struct ReadWriter {
9898 }
9999
100100 METAL_FUNC void load () const {
101- int batch_idx = elem.x * grid.y * n;
101+ size_t batch_idx = size_t ( elem.x * grid.y ) * n;
102102 short tg_idx = elem.y * grid.z + elem.z ;
103103 short max_index = grid.y * n - 2 ;
104104
@@ -121,7 +121,7 @@ struct ReadWriter {
121121 }
122122
123123 METAL_FUNC void write () const {
124- int batch_idx = elem.x * grid.y * n;
124+ size_t batch_idx = size_t ( elem.x * grid.y ) * n;
125125 short tg_idx = elem.y * grid.z + elem.z ;
126126 short max_index = grid.y * n - 2 ;
127127
@@ -144,7 +144,7 @@ struct ReadWriter {
144144
145145 // Padded IO for Bluestein's algorithm
146146 METAL_FUNC void load_padded (int length, const device float2* w_k) const {
147- int batch_idx = elem.x * grid.y * length + elem.y * length;
147+ size_t batch_idx = size_t ( elem.x * grid.y ) * length + elem.y * length;
148148 int fft_idx = elem.z ;
149149 int m = grid.z ;
150150
@@ -161,7 +161,7 @@ struct ReadWriter {
161161 }
162162
163163 METAL_FUNC void write_padded (int length, const device float2* w_k) const {
164- int batch_idx = elem.x * grid.y * length + elem.y * length;
164+ size_t batch_idx = size_t ( elem.x * grid.y ) * length + elem.y * length;
165165 int fft_idx = elem.z ;
166166 int m = grid.z ;
167167 float2 inv_factor = {1 .0f / n, -1 .0f / n};
@@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
261261
262262template <>
263263METAL_FUNC void ReadWriter<float , float2>::load() const {
264- int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2 ;
264+ size_t batch_idx = size_t ( elem.x * grid.y ) * n * 2 + elem.y * n * 2 ;
265265 threadgroup float2* seq_buf = buf + elem.y * n;
266266
267267 // No out of bounds accesses on odd batch sizes
@@ -283,7 +283,8 @@ template <>
283283METAL_FUNC void ReadWriter<float , float2>::write() const {
284284 short n_over_2 = (n / 2 ) + 1 ;
285285
286- int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2 ;
286+ size_t batch_idx =
287+ size_t (elem.x * grid.y ) * n_over_2 * 2 + elem.y * n_over_2 * 2 ;
287288 threadgroup float2* seq_buf = buf + elem.y * n;
288289
289290 int grid_index = elem.x * grid.y + elem.y ;
@@ -317,7 +318,7 @@ template <>
317318METAL_FUNC void ReadWriter<float , float2>::load_padded(
318319 int length,
319320 const device float2* w_k) const {
320- int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2 ;
321+ size_t batch_idx = size_t ( elem.x * grid.y ) * length * 2 + elem.y * length * 2 ;
321322 threadgroup float2* seq_buf = buf + elem.y * n;
322323
323324 // No out of bounds accesses on odd batch sizes
@@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter<float, float2>::write_padded(
345346 int length,
346347 const device float2* w_k) const {
347348 int length_over_2 = (length / 2 ) + 1 ;
348- int batch_idx =
349- elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2 ;
349+ size_t batch_idx =
350+ size_t ( elem.x * grid.y ) * length_over_2 * 2 + elem.y * length_over_2 * 2 ;
350351 threadgroup float2* seq_buf = buf + elem.y * n + length - 1 ;
351352
352353 int grid_index = elem.x * grid.y + elem.y ;
@@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
397398template <>
398399METAL_FUNC void ReadWriter<float2, float >::load() const {
399400 short n_over_2 = (n / 2 ) + 1 ;
400- int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2 ;
401+ size_t batch_idx =
402+ size_t (elem.x * grid.y ) * n_over_2 * 2 + elem.y * n_over_2 * 2 ;
401403 threadgroup float2* seq_buf = buf + elem.y * n;
402404
403405 // No out of bounds accesses on odd batch sizes
@@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter<float2, float>::load_padded(
458460 int n_over_2 = (n / 2 ) + 1 ;
459461 int length_over_2 = (length / 2 ) + 1 ;
460462
461- int batch_idx =
462- elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2 ;
463+ size_t batch_idx =
464+ size_t ( elem.x * grid.y ) * length_over_2 * 2 + elem.y * length_over_2 * 2 ;
463465 threadgroup float2* seq_buf = buf + elem.y * n;
464466
465467 // No out of bounds accesses on odd batch sizes
@@ -503,7 +505,7 @@ template <>
503505METAL_FUNC void ReadWriter<float2, float >::write_padded(
504506 int length,
505507 const device float2* w_k) const {
506- int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2 ;
508+ size_t batch_idx = size_t ( elem.x * grid.y ) * length * 2 + elem.y * length * 2 ;
507509 threadgroup float2* seq_buf = buf + elem.y * n + length - 1 ;
508510
509511 int grid_index = elem.x * grid.y + elem.y ;
0 commit comments