Skip to content

Commit 884af42

Browse files
authored
Fix thread group for large arrays (#1543)
* fix thread group for large arrays * comment * one more
1 parent 048fabd commit 884af42

File tree

6 files changed

+24
-21
lines changed

6 files changed

+24
-21
lines changed

mlx/backend/metal/binary.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// Copyright © 2024 Apple Inc.
2-
32
#include "mlx/backend/common/binary.h"
43
#include "mlx/backend/metal/device.h"
54
#include "mlx/backend/metal/kernels.h"
@@ -110,6 +109,7 @@ void binary_op_gpu_inplace(
110109
compute_encoder.set_output_array(outputs[1], arg_idx++);
111110
}
112111

112+
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
113113
if (bopt == BinaryOpType::General) {
114114
// Launch up to 3D grid of threads
115115
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
@@ -132,7 +132,6 @@ void binary_op_gpu_inplace(
132132
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
133133
}
134134

135-
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
136135
if (thread_group_size != 1024) {
137136
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
138137
}
@@ -142,13 +141,12 @@ void binary_op_gpu_inplace(
142141
} else {
143142
// Launch a 1D or 2D grid of threads
144143
size_t nthreads = out.data_size();
145-
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
146-
: MTL::Size(nthreads, 1, 1);
147-
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
148144
if (thread_group_size > nthreads) {
149145
thread_group_size = nthreads;
150146
}
151147
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
148+
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
149+
: MTL::Size(nthreads, 1, 1);
152150
compute_encoder.dispatchThreads(grid_dims, group_dims);
153151
}
154152
}

mlx/backend/metal/compiled.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,11 +421,12 @@ void Compiled::eval_gpu(
421421
// Launch the kernel
422422
if (contiguous) {
423423
size_t nthreads = outputs[0].data_size();
424+
MTL::Size group_dims(
425+
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
426+
424427
MTL::Size grid_dims = use_2d
425428
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
426429
: MTL::Size(nthreads, 1, 1);
427-
MTL::Size group_dims(
428-
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
429430
compute_encoder.dispatchThreads(grid_dims, group_dims);
430431
} else {
431432
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;

mlx/backend/metal/copy.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ void copy_gpu_inplace(
120120
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
121121
compute_encoder.set_output_array(out, 1, out_offset);
122122

123+
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
123124
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
124125
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
125126
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
@@ -145,7 +146,6 @@ void copy_gpu_inplace(
145146
}
146147

147148
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
148-
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
149149
if (thread_group_size != 1024) {
150150
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
151151
}
@@ -155,13 +155,12 @@ void copy_gpu_inplace(
155155
compute_encoder.dispatchThreads(grid_dims, group_dims);
156156
} else {
157157
size_t nthreads = out.data_size();
158-
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
159-
: MTL::Size(nthreads, 1, 1);
160-
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
161158
if (thread_group_size > nthreads) {
162159
thread_group_size = nthreads;
163160
}
164161
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
162+
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
163+
: MTL::Size(nthreads, 1, 1);
165164
compute_encoder.dispatchThreads(grid_dims, group_dims);
166165
}
167166
}
@@ -205,14 +204,14 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
205204
compute_encoder.set_input_array(val, 0);
206205
compute_encoder.set_output_array(out, 1);
207206

207+
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
208208
size_t nthreads = out.data_size();
209-
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
210-
: MTL::Size(nthreads, 1, 1);
211-
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
212209
if (thread_group_size > nthreads) {
213210
thread_group_size = nthreads;
214211
}
215212
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
213+
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
214+
: MTL::Size(nthreads, 1, 1);
216215
compute_encoder.dispatchThreads(grid_dims, group_dims);
217216
}
218217

mlx/backend/metal/ternary.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ void ternary_op_gpu_inplace(
7272
compute_encoder.set_input_array(donate_c ? out : c, 2);
7373
compute_encoder.set_output_array(out, 3);
7474

75+
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
7576
if (topt == TernaryOpType::General) {
7677
// Launch up to 3D grid of threads
7778
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
@@ -93,7 +94,6 @@ void ternary_op_gpu_inplace(
9394
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
9495
}
9596

96-
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
9797
if (thread_group_size != 1024) {
9898
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
9999
}
@@ -103,13 +103,12 @@ void ternary_op_gpu_inplace(
103103
} else {
104104
// Launch a 1D or 2D grid of threads
105105
size_t nthreads = out.data_size();
106-
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
107-
: MTL::Size(nthreads, 1, 1);
108-
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
109106
if (thread_group_size > nthreads) {
110107
thread_group_size = nthreads;
111108
}
112109
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
110+
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
111+
: MTL::Size(nthreads, 1, 1);
113112
compute_encoder.dispatchThreads(grid_dims, group_dims);
114113
}
115114
}

mlx/backend/metal/unary.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ void unary_op_gpu_inplace(
4747
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
4848
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
4949

50-
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
51-
: MTL::Size(nthreads, 1, 1);
52-
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
50+
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
5351
auto& compute_encoder = d.get_command_encoder(s.index);
5452
compute_encoder->setComputePipelineState(kernel);
5553
compute_encoder.set_input_array(
@@ -75,6 +73,8 @@ void unary_op_gpu_inplace(
7573
thread_group_size = nthreads;
7674
}
7775
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
76+
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
77+
: MTL::Size(nthreads, 1, 1);
7878
compute_encoder.dispatchThreads(grid_dims, group_dims);
7979
}
8080
}

mlx/backend/metal/utils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ MTL::Size get_2d_grid_dims(
103103
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
104104
throw std::runtime_error("Unable to safely factor shape.");
105105
}
106+
if (grid_y > grid_x) {
107+
std::swap(grid_x, grid_y);
108+
}
106109
return MTL::Size(
107110
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
108111
}
@@ -145,6 +148,9 @@ MTL::Size get_2d_grid_dims(
145148
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
146149
throw std::runtime_error("Unable to safely factor shape.");
147150
}
151+
if (grid_y > grid_x) {
152+
std::swap(grid_x, grid_y);
153+
}
148154
return MTL::Size(
149155
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
150156
}

0 commit comments

Comments
 (0)