Skip to content

Commit 6661387

Browse files
authored
Fix fft for integer overflow (#2161)
1 parent a7fae8a commit 6661387

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

mlx/backend/metal/fft.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ void fft_op(
632632
func_consts.push_back(make_int(&rader_m, 3));
633633

634634
// The overall number of FFTs we're going to compute for this input
635-
int size = out.dtype() == float32 ? out.size() : in.size();
635+
size_t size = out.dtype() == float32 ? out.size() : in.size();
636636
if (real && inverse && four_step_params.required) {
637637
size = out.size();
638638
}
@@ -659,8 +659,6 @@ void fft_op(
659659
// We can perform 2 RFFTs at once so the batch size is halved.
660660
batch_size = (batch_size + 2 - 1) / 2;
661661
}
662-
int out_buffer_size = out.size();
663-
664662
auto& compute_encoder = d.get_command_encoder(s.index);
665663
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
666664
auto out_type_str = out.dtype() == float32 ? "float" : "float2";

mlx/backend/metal/kernels/fft/readwrite.h

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ struct ReadWriter {
9898
}
9999

100100
METAL_FUNC void load() const {
101-
int batch_idx = elem.x * grid.y * n;
101+
size_t batch_idx = size_t(elem.x * grid.y) * n;
102102
short tg_idx = elem.y * grid.z + elem.z;
103103
short max_index = grid.y * n - 2;
104104

@@ -121,7 +121,7 @@ struct ReadWriter {
121121
}
122122

123123
METAL_FUNC void write() const {
124-
int batch_idx = elem.x * grid.y * n;
124+
size_t batch_idx = size_t(elem.x * grid.y) * n;
125125
short tg_idx = elem.y * grid.z + elem.z;
126126
short max_index = grid.y * n - 2;
127127

@@ -144,7 +144,7 @@ struct ReadWriter {
144144

145145
// Padded IO for Bluestein's algorithm
146146
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
147-
int batch_idx = elem.x * grid.y * length + elem.y * length;
147+
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
148148
int fft_idx = elem.z;
149149
int m = grid.z;
150150

@@ -161,7 +161,7 @@ struct ReadWriter {
161161
}
162162

163163
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
164-
int batch_idx = elem.x * grid.y * length + elem.y * length;
164+
size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length;
165165
int fft_idx = elem.z;
166166
int m = grid.z;
167167
float2 inv_factor = {1.0f / n, -1.0f / n};
@@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
261261

262262
template <>
263263
METAL_FUNC void ReadWriter<float, float2>::load() const {
264-
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
264+
size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2;
265265
threadgroup float2* seq_buf = buf + elem.y * n;
266266

267267
// No out of bounds accesses on odd batch sizes
@@ -283,7 +283,8 @@ template <>
283283
METAL_FUNC void ReadWriter<float, float2>::write() const {
284284
short n_over_2 = (n / 2) + 1;
285285

286-
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
286+
size_t batch_idx =
287+
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
287288
threadgroup float2* seq_buf = buf + elem.y * n;
288289

289290
int grid_index = elem.x * grid.y + elem.y;
@@ -317,7 +318,7 @@ template <>
317318
METAL_FUNC void ReadWriter<float, float2>::load_padded(
318319
int length,
319320
const device float2* w_k) const {
320-
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
321+
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
321322
threadgroup float2* seq_buf = buf + elem.y * n;
322323

323324
// No out of bounds accesses on odd batch sizes
@@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter<float, float2>::write_padded(
345346
int length,
346347
const device float2* w_k) const {
347348
int length_over_2 = (length / 2) + 1;
348-
int batch_idx =
349-
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
349+
size_t batch_idx =
350+
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
350351
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
351352

352353
int grid_index = elem.x * grid.y + elem.y;
@@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
397398
template <>
398399
METAL_FUNC void ReadWriter<float2, float>::load() const {
399400
short n_over_2 = (n / 2) + 1;
400-
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
401+
size_t batch_idx =
402+
size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2;
401403
threadgroup float2* seq_buf = buf + elem.y * n;
402404

403405
// No out of bounds accesses on odd batch sizes
@@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter<float2, float>::load_padded(
458460
int n_over_2 = (n / 2) + 1;
459461
int length_over_2 = (length / 2) + 1;
460462

461-
int batch_idx =
462-
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
463+
size_t batch_idx =
464+
size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2;
463465
threadgroup float2* seq_buf = buf + elem.y * n;
464466

465467
// No out of bounds accesses on odd batch sizes
@@ -503,7 +505,7 @@ template <>
503505
METAL_FUNC void ReadWriter<float2, float>::write_padded(
504506
int length,
505507
const device float2* w_k) const {
506-
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
508+
size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2;
507509
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
508510

509511
int grid_index = elem.x * grid.y + elem.y;

0 commit comments

Comments
 (0)