Skip to content

Commit 11354d5

Browse files
authored
Avoid io timeout for large arrays (#1442)
1 parent 718aea3 commit 11354d5

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

mlx/backend/metal/primitives.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,19 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
200200

201201
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
202202
out.set_data(allocator::malloc_or_wait(out.nbytes()));
203-
204203
auto read_task = [out = out,
205204
offset = offset_,
206205
reader = reader_,
207206
swap_endianness = swap_endianness_]() mutable {
208207
load(out, offset, reader, swap_endianness);
209208
};
209+
210+
// Limit the size that the command buffer will wait on to avoid timing out
211+
// on the event (<4 seconds).
212+
if (out.nbytes() > (1 << 28)) {
213+
read_task();
214+
return;
215+
}
210216
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
211217
auto signal_task = [out = out, fut = std::move(fut)]() {
212218
fut.wait();

0 commit comments

Comments
 (0)