Skip to content

Commit 4f72c66

Browse files
authored
improvements to scatter / gather (#1541)
1 parent 960e3f0 commit 4f72c66

File tree

9 files changed

+192
-245
lines changed

9 files changed

+192
-245
lines changed

benchmarks/python/scatter_bench.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
1111
def scatter(dst, x, idx):
12-
dst[*idx] = x
12+
dst[tuple(idx)] = x
1313
mx.eval(dst)
1414

1515
idx = []
@@ -23,8 +23,8 @@ def scatter(dst, x, idx):
2323

2424

2525
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
26-
def gather(dst, x, idx, device):
27-
dst[*idx] = x
26+
def scatter(dst, x, idx, device):
27+
dst[tuple(idx)] = x
2828
if device == torch.device("mps"):
2929
torch.mps.synchronize()
3030

@@ -34,7 +34,7 @@ def gather(dst, x, idx, device):
3434
x = torch.randn(x_shape, dtype=torch.float32).to(device)
3535
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
3636

37-
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
37+
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
3838
print(f"PyTorch: {runtime:.3f}ms")
3939

4040

@@ -54,7 +54,7 @@ def gather(dst, x, idx, device):
5454
(100_000, 64),
5555
(1_000_000, 64),
5656
(100_000,),
57-
(2_000_00,),
57+
(200_000,),
5858
(20_000_000,),
5959
(10000, 64),
6060
(100, 64),
@@ -91,6 +91,6 @@ def gather(dst, x, idx, device):
9191

9292
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
9393
print("=" * 20)
94-
print(f"X {x_shape}, Indices {idx_shape}")
94+
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
9595
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
9696
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

mlx/backend/metal/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
2626
make_jit_source(binary_ops)
2727
make_jit_source(ternary_ops)
2828
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
29-
make_jit_source(scatter)
30-
make_jit_source(gather)
29+
make_jit_source(scatter kernels/indexing.h)
30+
make_jit_source(gather kernels/indexing.h)
3131
make_jit_source(hadamard)
3232

3333
if(MLX_METAL_JIT)

mlx/backend/metal/indexing.cpp

Lines changed: 113 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -113,39 +113,38 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
113113
// Collect all idx shapes and strides into one place
114114
std::vector<int> idx_shapes;
115115
std::vector<size_t> idx_strides;
116-
116+
std::vector<char> idx_contigs;
117117
for (int i = 0; i < nidx; ++i) {
118118
idx_shapes.insert(
119119
idx_shapes.end(),
120120
inputs[i + 1].shape().begin(),
121121
inputs[i + 1].shape().end());
122-
123122
idx_strides.insert(
124123
idx_strides.end(),
125124
inputs[i + 1].strides().begin(),
126125
inputs[i + 1].strides().end());
126+
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
127127
}
128128

129129
// Set all the buffers
130130
compute_encoder.set_input_array(src, 0);
131131
compute_encoder.set_output_array(out, 1);
132132

133133
// Set source info
134-
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
135-
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
134+
set_vector_bytes(compute_encoder, src.shape(), 2);
135+
set_vector_bytes(compute_encoder, src.strides(), 3);
136136
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
137-
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
138-
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
137+
set_vector_bytes(compute_encoder, slice_sizes_, 5);
138+
set_vector_bytes(compute_encoder, axes_, 6);
139139

140140
// Set index info
141141
//
142142
// We don't need to check for empty idx_shapes because gather has a
143143
// idx_ndim == 0 specialization
144-
compute_encoder->setBytes(
145-
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
146-
compute_encoder->setBytes(
147-
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
148-
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
144+
set_vector_bytes(compute_encoder, idx_shapes, 7);
145+
set_vector_bytes(compute_encoder, idx_strides, 8);
146+
set_vector_bytes(compute_encoder, idx_contigs, 9);
147+
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
149148

150149
// Set index buffers
151150
for (int i = 0; i < nidx; ++i) {
@@ -172,12 +171,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
172171
}
173172

174173
// Copy src into out
175-
auto copy_type =
176-
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
174+
CopyType copy_type;
175+
if (inputs[0].data_size() == 1) {
176+
copy_type = CopyType::Scalar;
177+
} else if (inputs[0].flags().row_contiguous) {
178+
copy_type = CopyType::Vector;
179+
} else {
180+
copy_type = CopyType::General;
181+
}
177182
copy_gpu(inputs[0], out, copy_type);
178183

184+
auto& upd = inputs.back();
185+
179186
// Empty update
180-
if (inputs.back().size() == 0) {
187+
if (upd.size() == 0) {
181188
return;
182189
}
183190

@@ -186,19 +193,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
186193
auto& d = metal::device(s.device);
187194

188195
int idx_ndim = nidx ? inputs[1].ndim() : 0;
189-
bool index_nd1_specialization = (idx_ndim == 1);
190-
191-
// Bail from fast path (1d index specialization) if scatter dims aren't
192-
// the outermost dims and contiguous since update access won't be raster
193-
// order.
194-
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
195-
index_nd1_specialization &= (axes_[i] == i);
196-
}
197-
198-
// Bail from fast path (1d index specialization) if any of the dims are
199-
// broadcasted, since we can't rely on linear indexing in that case.
200-
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
201-
index_nd1_specialization &= inputs[i].flags().row_contiguous;
196+
size_t idx_size = nidx ? inputs[1].size() : 1;
197+
198+
auto idx_to_out = idx_size / out.size();
199+
int nwork;
200+
if (idx_ndim <= 1 || idx_to_out < 1) {
201+
nwork = 1;
202+
} else if (idx_to_out <= 4) {
203+
nwork = 4;
204+
} else if (idx_to_out < 16) {
205+
nwork = 8;
206+
} else if (idx_to_out < 32) {
207+
nwork = 16;
208+
} else {
209+
nwork = 32;
202210
}
203211

204212
std::string lib_name;
@@ -222,19 +230,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
222230
op_name = "min";
223231
break;
224232
}
225-
233+
auto upd_contig = upd.flags().row_contiguous;
226234
{
227235
std::ostringstream kname;
228-
if (index_nd1_specialization) {
229-
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
230-
} else {
231-
kname << "scatter" << type_to_name(out) << idx_type_name;
232-
}
233-
kname << "_" << op_name << "_" << nidx;
236+
kname << "scatter" << type_to_name(out) << idx_type_name;
237+
kname << "_" << op_name << "_" << nidx << "_"
238+
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
234239
lib_name = kname.str();
235240
kernel_name = kname.str();
236241
}
237-
238242
auto lib = d.get_library(lib_name, [&]() {
239243
std::ostringstream kernel_source;
240244
kernel_source << metal::utils() << metal::reduce_utils()
@@ -274,14 +278,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
274278
op_type,
275279
nidx,
276280
idx_args,
277-
idx_arr);
281+
idx_arr,
282+
upd_contig,
283+
nwork);
278284
return kernel_source.str();
279285
});
280286

281287
auto& compute_encoder = d.get_command_encoder(s.index);
282288
auto kernel = d.get_kernel(kernel_name, lib);
283289

284-
auto& upd = inputs.back();
285290
size_t nthreads = upd.size();
286291

287292
compute_encoder->setComputePipelineState(kernel);
@@ -291,109 +296,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
291296
compute_encoder.set_output_array(out, 2);
292297

293298
// Set update info
294-
uint upd_ndim = upd.ndim();
299+
size_t upd_ndim = upd.ndim();
295300
size_t upd_size = 1;
296301
for (int i = idx_ndim; i < upd.ndim(); ++i) {
297302
upd_size *= upd.shape(i);
298303
}
299-
if (index_nd1_specialization) {
300-
compute_encoder->setBytes(
301-
out.shape().data(), out.shape().size() * sizeof(int), 3);
302-
compute_encoder->setBytes(
303-
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
304-
305-
size_t out_ndim = out.ndim();
306-
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
307-
if (upd_ndim <= 1) {
308-
// Placeholder so Metal doesn't compalain
309-
int shape_ = 0;
310-
compute_encoder->setBytes(&shape_, sizeof(int), 6);
311-
} else {
312-
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
313-
}
314-
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
315-
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
316-
317-
// Set index buffers
318-
for (int i = 0; i < nidx; ++i) {
319-
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
320-
}
321-
322-
// Launch grid
323-
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
324-
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
325-
compute_encoder.dispatchThreads(grid_dims, group_dims);
304+
// Collect all idx shapes and strides into one place
305+
std::vector<int> idx_shapes;
306+
std::vector<size_t> idx_strides;
307+
// To access .data() use char instead of bool
308+
// bool is 1 byte in Metal so this is safe
309+
std::vector<char> idx_contigs;
310+
for (int i = 0; i < nidx; ++i) {
311+
idx_shapes.insert(
312+
idx_shapes.end(),
313+
inputs[i + 1].shape().begin(),
314+
inputs[i + 1].shape().end());
315+
idx_strides.insert(
316+
idx_strides.end(),
317+
inputs[i + 1].strides().begin(),
318+
inputs[i + 1].strides().end());
319+
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
320+
}
326321

322+
if (upd_ndim == 0) {
323+
// Need placeholders so Metal doesn't compalain
324+
int shape_ = 0;
325+
size_t stride_ = 0;
326+
compute_encoder->setBytes(&shape_, sizeof(int), 3);
327+
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
327328
} else {
328-
// Collect all idx shapes and strides into one place
329-
std::vector<int> idx_shapes;
330-
std::vector<size_t> idx_strides;
331-
332-
for (int i = 0; i < nidx; ++i) {
333-
idx_shapes.insert(
334-
idx_shapes.end(),
335-
inputs[i + 1].shape().begin(),
336-
inputs[i + 1].shape().end());
337-
338-
idx_strides.insert(
339-
idx_strides.end(),
340-
inputs[i + 1].strides().begin(),
341-
inputs[i + 1].strides().end());
342-
}
329+
set_vector_bytes(compute_encoder, upd.shape(), 3);
330+
set_vector_bytes(compute_encoder, upd.strides(), 4);
331+
}
332+
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
333+
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
334+
335+
// Set output info
336+
size_t out_ndim = out.ndim();
337+
if (out_ndim == 0) {
338+
// Need placeholders so Metal doesn't compalain
339+
int shape_ = 0;
340+
size_t stride_ = 0;
341+
compute_encoder->setBytes(&shape_, sizeof(int), 7);
342+
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
343+
} else {
344+
set_vector_bytes(compute_encoder, out.shape(), 7);
345+
set_vector_bytes(compute_encoder, out.strides(), 8);
346+
}
347+
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
348+
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
343349

344-
if (upd_ndim == 0) {
345-
// Need placeholders so Metal doesn't compalain
346-
int shape_ = 0;
347-
size_t stride_ = 0;
348-
compute_encoder->setBytes(&shape_, sizeof(int), 3);
349-
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
350-
} else {
351-
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
352-
compute_encoder->setBytes(
353-
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
354-
}
355-
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
356-
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
357-
358-
// Set output info
359-
size_t out_ndim = out.ndim();
360-
if (out_ndim == 0) {
361-
// Need placeholders so Metal doesn't compalain
362-
int shape_ = 0;
363-
size_t stride_ = 0;
364-
compute_encoder->setBytes(&shape_, sizeof(int), 7);
365-
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
366-
} else {
367-
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
368-
compute_encoder->setBytes(
369-
out.strides().data(), out_ndim * sizeof(size_t), 8);
370-
}
371-
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
372-
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
373-
374-
// Set index info
375-
if (idx_ndim == 0) {
376-
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
377-
// error in the metal API.
378-
idx_shapes.push_back(0);
379-
idx_strides.push_back(0);
380-
}
381-
compute_encoder->setBytes(
382-
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
383-
compute_encoder->setBytes(
384-
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
385-
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
386-
387-
// Set index buffers
388-
for (int i = 0; i < nidx; ++i) {
389-
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
390-
}
350+
// Set index info
351+
if (idx_ndim == 0) {
352+
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
353+
// error in the metal API.
354+
idx_shapes.push_back(0);
355+
idx_strides.push_back(0);
356+
idx_contigs.push_back(false);
357+
}
358+
set_vector_bytes(compute_encoder, idx_shapes, 11);
359+
set_vector_bytes(compute_encoder, idx_strides, 12);
360+
set_vector_bytes(compute_encoder, idx_contigs, 13);
361+
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
362+
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
363+
364+
// Set index buffers
365+
for (int i = 0; i < nidx; ++i) {
366+
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
367+
}
391368

392-
// Launch grid
393-
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
394-
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
395-
compute_encoder.dispatchThreads(grid_dims, group_dims);
369+
// Launch grid
370+
auto grid_y = (nthreads / upd_size);
371+
grid_y = (grid_y + nwork - 1) / nwork;
372+
MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
373+
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
374+
if (thread_group_size != 1024) {
375+
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
396376
}
377+
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
378+
compute_encoder.dispatchThreads(grid_dims, group_dims);
397379
}
398380

399381
} // namespace mlx::core

0 commit comments

Comments
 (0)