Skip to content

Commit c9934fe

Browse files
authored
Metal validation (#432)
* tests clear metal validation * add cpp test with metal validation to circleci * nit
1 parent 975e265 commit c9934fe

File tree

10 files changed

+142
-35
lines changed

10 files changed

+142
-35
lines changed

.circleci/config.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ jobs:
8080
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
8181
- store_test_results:
8282
path: test-results
83+
- run:
84+
name: Build CPP only
85+
command: |
86+
mkdir -p build && cd build && cmake .. && make -j
87+
- run:
88+
name: Run CPP tests
89+
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
8390

8491
build_release:
8592
machine: true

mlx/backend/metal/allocator.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ MetalAllocator::MetalAllocator()
153153
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
154154

155155
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
156+
// Metal doesn't like empty buffers
157+
if (size == 0) {
158+
return Buffer{nullptr};
159+
}
160+
156161
// Align up memory
157162
if (size > vm_page_size) {
158163
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);

mlx/backend/metal/copy.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
2020
} else {
2121
out.set_data(allocator::malloc_or_wait(out.nbytes()));
2222
}
23+
if (out.size() == 0) {
24+
return;
25+
}
2326
if (ctype == CopyType::GeneralGeneral) {
2427
ctype = CopyType::General;
2528
}

mlx/backend/metal/indexing.cpp

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// Copyright © 2023 Apple Inc.
2-
32
#include <algorithm>
43
#include <cassert>
54
#include <numeric>
@@ -33,6 +32,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
3332
}
3433

3534
out.set_data(allocator::malloc_or_wait(out.nbytes()));
35+
if (out.size() == 0) {
36+
return;
37+
}
3638

3739
auto& s = stream();
3840
auto& d = metal::device(s.device);
@@ -110,14 +112,18 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
110112
for (int i = 0; i < nidx; ++i) {
111113
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
112114
}
113-
arg_enc->setBuffer(
114-
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
115-
compute_encoder->useResource(
116-
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
117-
arg_enc->setBuffer(
118-
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
119-
compute_encoder->useResource(
120-
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
115+
if (idx_ndim > 0) {
116+
arg_enc->setBuffer(
117+
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
118+
compute_encoder->useResource(
119+
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
120+
MTL::ResourceUsageRead);
121+
arg_enc->setBuffer(
122+
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
123+
compute_encoder->useResource(
124+
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
125+
MTL::ResourceUsageRead);
126+
}
121127
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
122128

123129
// Set all the buffers
@@ -163,6 +169,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
163169
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
164170
copy_gpu(inputs[0], out, copy_type);
165171

172+
// Empty update
173+
if (inputs.back().size() == 0) {
174+
return;
175+
}
176+
166177
// Get stream
167178
auto& s = stream();
168179
auto& d = metal::device(s.device);
@@ -254,14 +265,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
254265
for (int i = 0; i < nidx; ++i) {
255266
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
256267
}
257-
arg_enc->setBuffer(
258-
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
259-
compute_encoder->useResource(
260-
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), MTL::ResourceUsageRead);
261-
arg_enc->setBuffer(
262-
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
263-
compute_encoder->useResource(
264-
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), MTL::ResourceUsageRead);
268+
if (idx_ndim > 0) {
269+
arg_enc->setBuffer(
270+
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
271+
compute_encoder->useResource(
272+
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
273+
MTL::ResourceUsageRead);
274+
arg_enc->setBuffer(
275+
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
276+
compute_encoder->useResource(
277+
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
278+
MTL::ResourceUsageRead);
279+
}
265280
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
266281

267282
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
@@ -272,14 +287,32 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
272287
}
273288
set_array_buffer(compute_encoder, upd, 1);
274289
set_array_buffer(compute_encoder, out, 2);
275-
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
276-
compute_encoder->setBytes(upd.strides().data(), upd_ndim * sizeof(size_t), 4);
290+
if (upd_ndim == 0) {
291+
// Need placeholders so Metal doesn't compalain
292+
int shape_ = 0;
293+
size_t stride_ = 0;
294+
compute_encoder->setBytes(&shape_, sizeof(int), 3);
295+
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
296+
} else {
297+
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
298+
compute_encoder->setBytes(
299+
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
300+
}
277301
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
278302
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
279303

280304
size_t out_ndim = out.ndim();
281-
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
282-
compute_encoder->setBytes(out.strides().data(), out_ndim * sizeof(size_t), 8);
305+
if (out_ndim == 0) {
306+
// Need placeholders so Metal doesn't compalain
307+
int shape_ = 0;
308+
size_t stride_ = 0;
309+
compute_encoder->setBytes(&shape_, sizeof(int), 7);
310+
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
311+
} else {
312+
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
313+
compute_encoder->setBytes(
314+
out.strides().data(), out_ndim * sizeof(size_t), 8);
315+
}
283316
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
284317
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
285318

mlx/backend/metal/primitives.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ void binary_op(
3131
set_binary_op_output_data(a, b, outputs[1], bopt);
3232

3333
auto& out = outputs[0];
34+
if (out.size() == 0) {
35+
return;
36+
}
3437

3538
// Try to collapse contiguous dims
3639
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
@@ -120,6 +123,9 @@ void binary_op(
120123
auto& b = inputs[1];
121124
auto bopt = get_binary_op_type(a, b);
122125
set_binary_op_output_data(a, b, out, bopt);
126+
if (out.size() == 0) {
127+
return;
128+
}
123129

124130
// Try to collapse contiguous dims
125131
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
@@ -214,6 +220,9 @@ void unary_op(
214220
} else {
215221
out.set_data(allocator::malloc_or_wait(out.nbytes()));
216222
}
223+
if (in.size() == 0) {
224+
return;
225+
}
217226

218227
auto& s = out.primitive().stream();
219228
auto& d = metal::device(s.device);
@@ -263,6 +272,9 @@ void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) {
263272
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
264273
assert(inputs.size() == 0);
265274
out.set_data(allocator::malloc_or_wait(out.nbytes()));
275+
if (out.size() == 0) {
276+
return;
277+
}
266278
auto& s = stream();
267279
auto& d = metal::device(s.device);
268280
auto kernel = d.get_kernel("arange" + type_to_name(out));
@@ -390,9 +402,18 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
390402
compute_encoder->setComputePipelineState(kernel);
391403
set_array_buffer(compute_encoder, in, 0);
392404
set_array_buffer(compute_encoder, out, 1);
393-
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
394-
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
395-
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
405+
if (ndim == 0) {
406+
// Pass place holders so metal doesn't complain
407+
int shape_ = 0;
408+
size_t stride_ = 0;
409+
compute_encoder->setBytes(&shape_, sizeof(int), 2);
410+
compute_encoder->setBytes(&stride_, sizeof(size_t), 3);
411+
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
412+
} else {
413+
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
414+
compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3);
415+
compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4);
416+
}
396417
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
397418
compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6);
398419
compute_encoder->setBytes(&axis_size, sizeof(size_t), 7);
@@ -629,6 +650,9 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
629650
size_t elems_per_key = out.size() / num_keys;
630651
size_t bytes_per_key = out.itemsize() * elems_per_key;
631652
out.set_data(allocator::malloc_or_wait(out.nbytes()));
653+
if (out.size() == 0) {
654+
return;
655+
}
632656

633657
size_t out_per_key = (bytes_per_key + 4 - 1) / 4;
634658
size_t half_size = out_per_key / 2;

mlx/backend/metal/reduce.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include <algorithm>
44
#include <cassert>
5-
#include <iostream>
65
#include <sstream>
76

87
#include "mlx/backend/common/reduce.h"
@@ -21,10 +20,14 @@ namespace mlx::core {
2120

2221
namespace {
2322

24-
inline auto safe_divup(size_t n, size_t m) {
23+
inline auto safe_div(size_t n, size_t m) {
2524
return m == 0 ? 0 : (n + m - 1) / m;
2625
}
2726

27+
inline auto safe_divup(size_t n, size_t m) {
28+
return safe_div(n, m) * m;
29+
}
30+
2831
// All Reduce
2932
void all_reduce_dispatch(
3033
const array& in,
@@ -56,7 +59,7 @@ void all_reduce_dispatch(
5659
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
5760

5861
// If the number of thread groups needed exceeds 1024, we reuse threads groups
59-
uint n_thread_groups = safe_divup(mod_in_size, thread_group_size);
62+
uint n_thread_groups = safe_div(mod_in_size, thread_group_size);
6063
n_thread_groups = std::min(n_thread_groups, 1024u);
6164
uint nthreads = n_thread_groups * thread_group_size;
6265

@@ -204,7 +207,8 @@ void strided_reduce_general_dispatch(
204207
// if we ever come to doubles. In that case, we should also cut
205208
// down the number of threads we launch in a threadgroup
206209
compute_encoder->setThreadgroupMemoryLength(
207-
threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 0);
210+
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
211+
0);
208212

209213
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
210214
}
@@ -231,7 +235,10 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
231235
assert(!axes_.empty());
232236

233237
// Continue with reduction operation
234-
out.set_data(allocator::malloc_or_wait(out.nbytes()));
238+
// Minimum of 4 bytes since we use size 4 structs for all reduce
239+
// and metal will complain o/w
240+
size_t min_bytes = std::max(out.nbytes(), 4ul);
241+
out.set_data(allocator::malloc_or_wait(min_bytes));
235242
std::string op_name;
236243
switch (reduce_type_) {
237244
case Reduce::And:
@@ -273,7 +280,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
273280
}
274281

275282
// Reduce
276-
{
283+
if (in.size() > 0) {
277284
std::vector<array> copies;
278285
ReductionPlan plan = get_reduction_plan(in, axes_);
279286

mlx/ops.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// Copyright © 2023 Apple Inc.
2-
32
#include <cmath>
43
#include <numeric>
54
#include <set>

tests/allocator_tests.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,4 @@ TEST_CASE("test large allocations") {
3838
auto buffer = allocator::malloc(size);
3939
allocator::free(buffer);
4040
}
41-
// Shouldn't be able to allocate an exabyte anytime soon.
42-
CHECK_THROWS_AS(allocator::malloc(1ull << 60), std::runtime_error);
4341
}

tests/arg_reduce_tests.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
// Copyright © 2023 Apple Inc.
22

3-
#include <iostream>
4-
53
#include "doctest/doctest.h"
64

75
#include "mlx/mlx.h"

tests/metal_tests.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,36 @@ TEST_CASE("test metal matmul") {
438438
CHECK(array_equal(out, full({3, 3, 2, 2}, 2.0f), Device::cpu).item<bool>());
439439
}
440440
}
441+
442+
TEST_CASE("test metal validation") {
443+
// Run this test with Metal validation enabled
444+
// METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./tests/tests \
445+
// -tc="test metal validation" \
446+
447+
auto x = array({});
448+
eval(exp(x));
449+
450+
auto y = array({});
451+
eval(add(x, y));
452+
453+
eval(sum(x));
454+
455+
x = array({1, 2, 3});
456+
y = array(0);
457+
eval(gather(x, y, 0, {0}));
458+
eval(gather(x, y, 0, {2}));
459+
460+
eval(gather(x, y, 0, {0}));
461+
eval(gather(x, y, 0, {2}));
462+
463+
eval(scatter(x, y, array({2}), 0));
464+
465+
x = arange(0, -3, 1);
466+
eval(x);
467+
array_equal(x, array({})).item<bool>();
468+
469+
x = array({1.0, 0.0});
470+
eval(argmax(x));
471+
472+
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
473+
}

0 commit comments

Comments
 (0)