Skip to content

Commit 8a0677d

Browse files
authored
Shared events for synchronization + async eval (#998)
* more async eval * fix rebase * try correct async eval * fix async * more tests for async eval * use shared events for synchronization * comment + cleanup * with autorelease pool * fix no metal build * fix compile * fix patch * don't eval if asyn evale'd * don't use is_evaled * comments * more multi stream tests * try and cleanup use of is_evaled * use a status flag
1 parent b18468b commit 8a0677d

28 files changed

+424
-125
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ elseif (MLX_BUILD_METAL)
8282
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
8383

8484
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
85+
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
8586
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
8687
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
88+
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
8789
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
8890
else()
8991
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
@@ -92,6 +94,7 @@ elseif (MLX_BUILD_METAL)
9294
FetchContent_Declare(
9395
metal_cpp
9496
URL ${METAL_CPP_URL}
97+
PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true
9598
)
9699

97100
FetchContent_MakeAvailable(metal_cpp)

cmake/metal.14.0.diff

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
2+
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
3+
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
4+
@@ -62,6 +62,7 @@
5+
6+
uint64_t signaledValue() const;
7+
void setSignaledValue(uint64_t signaledValue);
8+
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
9+
};
10+
11+
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
12+
@@ -138,6 +139,11 @@
13+
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
14+
{
15+
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
16+
+}
17+
+
18+
+// method: waitUntilSignaledValue
19+
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
20+
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
21+
}
22+
23+
// static method: alloc
24+
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
25+
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
26+
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
27+
@@ -1906,6 +1906,9 @@
28+
"setShouldMaximizeConcurrentCompilation:");
29+
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
30+
"setSignaledValue:");
31+
+_MTL_PRIVATE_DEF_SEL(
32+
+ waitUntilSignaledValue_timeoutMS_,
33+
+ "waitUntilSignaledValue:timeoutMS:");
34+
_MTL_PRIVATE_DEF_SEL(setSize_,
35+
"setSize:");
36+
_MTL_PRIVATE_DEF_SEL(setSlice_,

cmake/metal.14.2.diff

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
2+
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
3+
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
4+
@@ -62,6 +62,7 @@
5+
6+
uint64_t signaledValue() const;
7+
void setSignaledValue(uint64_t signaledValue);
8+
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
9+
};
10+
11+
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
12+
@@ -138,6 +139,11 @@
13+
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
14+
{
15+
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
16+
+}
17+
+
18+
+// method: waitUntilSignaledValue
19+
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
20+
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
21+
}
22+
23+
// static method: alloc
24+
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
25+
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
26+
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
27+
@@ -1918,6 +1918,9 @@
28+
"setShouldMaximizeConcurrentCompilation:");
29+
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
30+
"setSignaledValue:");
31+
+_MTL_PRIVATE_DEF_SEL(
32+
+ waitUntilSignaledValue_timeoutMS_,
33+
+ "waitUntilSignaledValue:timeoutMS:");
34+
_MTL_PRIVATE_DEF_SEL(setSize_,
35+
"setSize:");
36+
_MTL_PRIVATE_DEF_SEL(setSlice_,

mlx/array.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,11 @@ void array::detach() {
9393
}
9494

9595
void array::eval() {
96-
if (!is_evaled()) {
96+
// Ensure the array is ready to be read
97+
if (status() == Status::scheduled) {
98+
event().wait();
99+
set_status(Status::available);
100+
} else if (status() == Status::unscheduled) {
97101
mlx::core::eval({*this});
98102
}
99103
}
@@ -176,7 +180,7 @@ void array::ArrayDesc::init() {
176180
}
177181

178182
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
179-
: shape(std::move(shape)), dtype(dtype) {
183+
: shape(std::move(shape)), dtype(dtype), status(Status::available) {
180184
init();
181185
}
182186

@@ -187,6 +191,7 @@ array::ArrayDesc::ArrayDesc(
187191
std::vector<array> inputs)
188192
: shape(std::move(shape)),
189193
dtype(dtype),
194+
status(Status::unscheduled),
190195
primitive(std::move(primitive)),
191196
inputs(std::move(inputs)) {
192197
init();

mlx/array.h

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "mlx/allocator.h"
1111
#include "mlx/dtype.h"
12+
#include "mlx/event.h"
1213

1314
namespace mlx::core {
1415

@@ -315,9 +316,27 @@ class array {
315316
return static_cast<T*>(array_desc_->data_ptr);
316317
};
317318

318-
// Check if the array has been evaluated
319-
bool is_evaled() const {
320-
return array_desc_->data != nullptr;
319+
enum Status { unscheduled, scheduled, available };
320+
321+
bool is_available() const {
322+
return status() == Status::available;
323+
}
324+
const Status status() const {
325+
return array_desc_->status;
326+
}
327+
328+
void set_status(Status s) const {
329+
array_desc_->status = s;
330+
}
331+
332+
// Get the array's shared event
333+
Event& event() const {
334+
return array_desc_->event;
335+
}
336+
337+
// Attach an event to a not yet evaluated array
338+
void attach_event(Event e) const {
339+
array_desc_->event = std::move(e);
321340
}
322341

323342
// Mark the array as a tracer array (true) or not.
@@ -370,6 +389,11 @@ class array {
370389
Dtype dtype;
371390
std::shared_ptr<Primitive> primitive;
372391

392+
Status status;
393+
394+
// An event on the array used for synchronization
395+
Event event;
396+
373397
// Indicates an array is being used in a graph transform
374398
// and should not be detached from the graph
375399
bool is_tracer{false};
@@ -470,10 +494,11 @@ T array::item() const {
470494
if (size() != 1) {
471495
throw std::invalid_argument("item can only be called on arrays of size 1.");
472496
}
473-
if (!is_evaled()) {
497+
if (status() == Status::unscheduled) {
474498
throw std::invalid_argument(
475499
"item() const can only be called on evaled arrays");
476500
}
501+
const_cast<array*>(this)->eval();
477502
return *data<T>();
478503
}
479504

mlx/backend/metal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ target_sources(
2626
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
2727
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
2828
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
29+
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
2930
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
3031
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
3132
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp

mlx/backend/metal/device.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,11 +544,12 @@ Device& device(mlx::core::Device) {
544544
return metal_device;
545545
}
546546

547-
std::shared_ptr<void> new_scoped_memory_pool() {
547+
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
548548
auto dtor = [](void* ptr) {
549549
static_cast<NS::AutoreleasePool*>(ptr)->release();
550550
};
551-
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
551+
return std::unique_ptr<void, std::function<void(void*)>>(
552+
NS::AutoreleasePool::alloc()->init(), dtor);
552553
}
553554

554555
void new_stream(Stream stream) {

mlx/backend/metal/event.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
#include "mlx/event.h"
4+
#include "mlx/backend/metal/device.h"
5+
#include "mlx/backend/metal/metal_impl.h"
6+
7+
namespace mlx::core {
8+
9+
Event::Event(const Stream& stream) : stream_(stream) {
10+
auto dtor = [](void* ptr) {
11+
auto p = metal::new_scoped_memory_pool();
12+
static_cast<MTL::SharedEvent*>(ptr)->release();
13+
};
14+
auto p = metal::new_scoped_memory_pool();
15+
event_ = std::shared_ptr<void>(
16+
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
17+
}
18+
19+
void Event::wait() {
20+
if (!static_cast<MTL::SharedEvent*>(raw_event().get())
21+
->waitUntilSignaledValue(value(), -1)) {
22+
throw std::runtime_error("[Event::wait] Timed out");
23+
}
24+
}
25+
26+
void Event::signal() {
27+
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
28+
}
29+
30+
} // namespace mlx::core

mlx/backend/metal/metal.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,20 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
5555
}
5656
}
5757

58-
std::function<void()> make_task(
59-
array& arr,
60-
std::vector<std::shared_future<void>> deps,
61-
std::shared_ptr<std::promise<void>> p) {
62-
auto task = [arr, deps = std::move(deps), p = std::move(p)]() mutable {
58+
std::function<void()> make_task(array arr, bool signal) {
59+
auto task = [arr = std::move(arr), signal]() mutable {
6360
auto pool = new_scoped_memory_pool();
64-
for (auto& d : deps) {
65-
d.wait();
66-
}
6761
auto s = arr.primitive().stream();
6862
auto command_buffer = increment_command_buffer(s);
63+
for (auto& input : arr.inputs()) {
64+
if (input.event().valid() &&
65+
input.event().stream() != arr.primitive().stream()) {
66+
// TODO, consider committing the buffer and encoding a wait in the new
67+
// buffer rather than on the task thread
68+
input.event().wait();
69+
}
70+
}
71+
6972
auto outputs = arr.outputs();
7073
{
7174
// If the array is a tracer hold a reference
@@ -88,13 +91,16 @@ std::function<void()> make_task(
8891
if (!arr.is_tracer()) {
8992
arr.detach();
9093
}
91-
if (p) {
94+
95+
if (signal) {
9296
metal::device(s.device).end_encoding(s.index);
97+
command_buffer->encodeSignalEvent(
98+
static_cast<MTL::Event*>(arr.event().raw_event().get()),
99+
arr.event().value());
93100
scheduler::notify_new_task(s);
94101
command_buffer->addCompletedHandler(
95-
[s, buffers = std::move(buffers), p = std::move(p)](
102+
[s, buffers = std::move(buffers), event = arr.event()](
96103
MTL::CommandBuffer* cbuf) {
97-
p->set_value();
98104
scheduler::notify_task_completion(s);
99105
check_error(cbuf);
100106
});

mlx/backend/metal/metal_impl.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,17 @@
22

33
#pragma once
44

5-
#include <future>
65
#include <memory>
7-
#include <vector>
86

97
#include "mlx/array.h"
108
#include "mlx/stream.h"
119

1210
namespace mlx::core::metal {
1311

1412
void new_stream(Stream stream);
15-
std::shared_ptr<void> new_scoped_memory_pool();
1613

17-
std::function<void()> make_task(
18-
array& arr,
19-
std::vector<std::shared_future<void>> deps,
20-
std::shared_ptr<std::promise<void>> p);
14+
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
15+
16+
std::function<void()> make_task(array arr, bool signal);
2117

2218
} // namespace mlx::core::metal

0 commit comments

Comments
 (0)