Skip to content

Commit 972d9a3

Browse files
Vijay Krishawni
andauthored
Up to 10x faster scatter. (#709)
* Faster scatter. Add specialization for 1-d index tensors. * Address review comments. - Check for row contiguity of index, update tensors instead of checking strides. - Add support for 1d specialization with col contiguous update tensor, along with a test. * Nit1 Co-authored-by: Awni Hannun <[email protected]> * Nit2 Co-authored-by: Awni Hannun <[email protected]> --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 7dcdd88 commit 972d9a3

File tree

4 files changed

+244
-83
lines changed

4 files changed

+244
-83
lines changed

benchmarks/python/scatter_bench.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,30 @@
77
from time_utils import measure_runtime
88

99

10-
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
10+
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
1111
def scatter(dst, x, idx):
12-
dst[idx] = x
12+
dst[*idx] = x
1313
mx.eval(dst)
1414

15-
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
15+
idx = []
16+
for idx_shape in idx_shapes:
17+
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
1618
x = mx.random.normal(x_shape).astype(mx.float32)
1719
dst = mx.random.normal(dst_shape).astype(mx.float32)
1820

1921
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
2022
print(f"MLX: {runtime:.3f}ms")
2123

2224

23-
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
25+
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
2426
def gather(dst, x, idx, device):
25-
dst[idx] = x
27+
dst[*idx] = x
2628
if device == torch.device("mps"):
2729
torch.mps.synchronize()
2830

29-
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
31+
idx = []
32+
for idx_shape in idx_shapes:
33+
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
3034
x = torch.randn(x_shape, dtype=torch.float32).to(device)
3135
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
3236

@@ -45,9 +49,45 @@ def gather(dst, x, idx, device):
4549
else:
4650
device = torch.device("mps")
4751

48-
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
49-
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
50-
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
52+
dst_shapes = [
53+
(10, 64),
54+
(100_000, 64),
55+
(1_000_000, 64),
56+
(100_000,),
57+
(2_000_00,),
58+
(20_000_000,),
59+
(10000, 64),
60+
(100, 64),
61+
(100, 10_000, 64),
62+
(10, 100, 100, 21),
63+
(1_000, 1_000, 10),
64+
]
65+
idx_shapes = [
66+
[(1_000_000,)],
67+
[(1_000_000,)],
68+
[(100_000,)],
69+
[(1_000_000,)],
70+
[(20_000_000,)],
71+
[(20_000_000,)],
72+
[(1000000,)],
73+
[(10000000,)],
74+
[(1_000,)],
75+
[(10_000,)],
76+
[(1_000,), (1_000,)],
77+
]
78+
x_shapes = [
79+
(1_000_000, 64),
80+
(1_000_000, 64),
81+
(100_000, 64),
82+
(1_000_000,),
83+
(20_000_000,),
84+
(20_000_000,),
85+
(1000000, 64),
86+
(10000000, 64),
87+
(1_000, 10_000, 64),
88+
(10_000, 100, 100, 21),
89+
(1_000, 10),
90+
]
5191

5292
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
5393
print("=" * 20)

mlx/backend/metal/indexing.cpp

Lines changed: 108 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
142142
// Get kernel name
143143
std::ostringstream kname;
144144
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
145-
kname << "scatter" << type_to_name(out) << idx_type_name;
145+
146+
int idx_ndim = nidx ? inputs[1].ndim() : 0;
147+
bool index_nd1_specialization = (idx_ndim == 1);
148+
149+
// Bail from fast path (1d index specialization) if scatter dims aren't
150+
// the outermost dims and contiguous since update access won't be raster
151+
// order.
152+
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
153+
index_nd1_specialization &= (axes_[i] == i);
154+
}
155+
156+
// Bail from fast path (1d index specialization) if any of the dims are
157+
// broadcasted, since we can't rely on linear indexing in that case.
158+
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
159+
index_nd1_specialization &= inputs[i].flags().row_contiguous;
160+
}
161+
162+
if (index_nd1_specialization) {
163+
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
164+
} else {
165+
kname << "scatter" << type_to_name(out) << idx_type_name;
166+
}
146167
switch (reduce_type_) {
147168
case Scatter::None:
148169
kname << "_none";
@@ -170,85 +191,106 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
170191

171192
compute_encoder->setComputePipelineState(kernel);
172193

173-
// Collect all idx shapes and strides into one place
174-
int idx_ndim = nidx ? inputs[1].ndim() : 0;
175-
std::vector<int> idx_shapes;
176-
std::vector<size_t> idx_strides;
177-
178-
for (int i = 0; i < nidx; ++i) {
179-
idx_shapes.insert(
180-
idx_shapes.end(),
181-
inputs[i + 1].shape().begin(),
182-
inputs[i + 1].shape().end());
183-
184-
idx_strides.insert(
185-
idx_strides.end(),
186-
inputs[i + 1].strides().begin(),
187-
inputs[i + 1].strides().end());
188-
}
189-
190194
// Set all the buffers
191195
set_array_buffer(compute_encoder, upd, 1);
192196
set_array_buffer(compute_encoder, out, 2);
193197

194198
// Set update info
195-
size_t upd_ndim = upd.ndim();
199+
uint upd_ndim = upd.ndim();
196200
size_t upd_size = 1;
197201
for (int i = idx_ndim; i < upd.ndim(); ++i) {
198202
upd_size *= upd.shape(i);
199203
}
200-
if (upd_ndim == 0) {
201-
// Need placeholders so Metal doesn't compalain
202-
int shape_ = 0;
203-
size_t stride_ = 0;
204-
compute_encoder->setBytes(&shape_, sizeof(int), 3);
205-
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
206-
} else {
207-
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
204+
205+
if (index_nd1_specialization) {
206+
bool upd_col_contiguous = upd.flags().col_contiguous;
208207
compute_encoder->setBytes(
209-
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
210-
}
211-
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
212-
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
213-
214-
// Set output info
215-
size_t out_ndim = out.ndim();
216-
if (out_ndim == 0) {
217-
// Need placeholders so Metal doesn't compalain
218-
int shape_ = 0;
219-
size_t stride_ = 0;
220-
compute_encoder->setBytes(&shape_, sizeof(int), 7);
221-
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
222-
} else {
223-
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
208+
out.shape().data(), out.shape().size() * sizeof(int), 3);
224209
compute_encoder->setBytes(
225-
out.strides().data(), out_ndim * sizeof(size_t), 8);
226-
}
227-
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
228-
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
210+
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
211+
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
212+
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
229213

230-
// Set index info
231-
if (idx_ndim == 0) {
232-
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
233-
// error in the metal API.
234-
idx_shapes.push_back(0);
235-
idx_strides.push_back(0);
236-
}
237-
compute_encoder->setBytes(
238-
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
239-
compute_encoder->setBytes(
240-
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
241-
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
214+
// Set index buffers
215+
for (int i = 1; i < nidx + 1; ++i) {
216+
set_array_buffer(compute_encoder, inputs[i], 20 + i);
217+
}
242218

243-
// Set index buffers
244-
for (int i = 1; i < nidx + 1; ++i) {
245-
set_array_buffer(compute_encoder, inputs[i], 20 + i);
246-
}
219+
// Launch grid
220+
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
221+
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
222+
compute_encoder->dispatchThreads(grid_dims, group_dims);
247223

248-
// Launch grid
249-
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
250-
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
251-
compute_encoder->dispatchThreads(grid_dims, group_dims);
224+
} else {
225+
// Collect all idx shapes and strides into one place
226+
std::vector<int> idx_shapes;
227+
std::vector<size_t> idx_strides;
228+
229+
for (int i = 0; i < nidx; ++i) {
230+
idx_shapes.insert(
231+
idx_shapes.end(),
232+
inputs[i + 1].shape().begin(),
233+
inputs[i + 1].shape().end());
234+
235+
idx_strides.insert(
236+
idx_strides.end(),
237+
inputs[i + 1].strides().begin(),
238+
inputs[i + 1].strides().end());
239+
}
240+
241+
if (upd_ndim == 0) {
242+
// Need placeholders so Metal doesn't compalain
243+
int shape_ = 0;
244+
size_t stride_ = 0;
245+
compute_encoder->setBytes(&shape_, sizeof(int), 3);
246+
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
247+
} else {
248+
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
249+
compute_encoder->setBytes(
250+
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
251+
}
252+
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
253+
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
254+
255+
// Set output info
256+
size_t out_ndim = out.ndim();
257+
if (out_ndim == 0) {
258+
// Need placeholders so Metal doesn't compalain
259+
int shape_ = 0;
260+
size_t stride_ = 0;
261+
compute_encoder->setBytes(&shape_, sizeof(int), 7);
262+
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
263+
} else {
264+
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
265+
compute_encoder->setBytes(
266+
out.strides().data(), out_ndim * sizeof(size_t), 8);
267+
}
268+
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
269+
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
270+
271+
// Set index info
272+
if (idx_ndim == 0) {
273+
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
274+
// error in the metal API.
275+
idx_shapes.push_back(0);
276+
idx_strides.push_back(0);
277+
}
278+
compute_encoder->setBytes(
279+
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
280+
compute_encoder->setBytes(
281+
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
282+
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
283+
284+
// Set index buffers
285+
for (int i = 1; i < nidx + 1; ++i) {
286+
set_array_buffer(compute_encoder, inputs[i], 20 + i);
287+
}
288+
289+
// Launch grid
290+
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
291+
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
292+
compute_encoder->dispatchThreads(grid_dims, group_dims);
293+
}
252294
}
253295

254296
} // namespace mlx::core

0 commit comments

Comments
 (0)