Skip to content

Commit 0eb56d5

Browse files
authored
Wired (#1510)
* expose residency sets as wire/unwire * returns wired size * fix * runtime support check * fix os check * fix test * fix no metal build * docs * nit * nits in docs * nits
1 parent f70764a commit 0eb56d5

File tree

13 files changed

+246
-14
lines changed

13 files changed

+246
-14
lines changed

docs/src/python/metal.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Metal
1414
get_cache_memory
1515
set_memory_limit
1616
set_cache_limit
17+
set_wired_limit
1718
clear_cache
1819
start_capture
1920
stop_capture

mlx/backend/metal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ target_sources(
9999
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
100100
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
101101
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
102+
${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
102103
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
103104

104105
if(NOT MLX_METAL_PATH)

mlx/backend/metal/allocator.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "mlx/backend/metal/allocator.h"
33
#include "mlx/backend/metal/metal.h"
44
#include "mlx/backend/metal/metal_impl.h"
5+
#include "mlx/backend/metal/resident.h"
56

67
#include <mach/vm_page_size.h>
78
#include <unistd.h>
@@ -140,6 +141,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
140141

141142
MetalAllocator::MetalAllocator()
142143
: device_(device(mlx::core::Device::gpu).mtl_device()),
144+
residency_set_(device_),
143145
buffer_cache_(device_) {
144146
auto memsize = std::get<size_t>(device_info()["memory_size"]);
145147
block_limit_ =
@@ -148,6 +150,8 @@ MetalAllocator::MetalAllocator()
148150
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
149151
block_limit_);
150152
max_pool_size_ = block_limit_;
153+
device(mlx::core::Device::gpu)
154+
.set_residency_set(residency_set_.mtl_residency_set());
151155
}
152156

153157
size_t MetalAllocator::set_cache_limit(size_t limit) {
@@ -164,6 +168,12 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
164168
return limit;
165169
};
166170

171+
size_t MetalAllocator::set_wired_limit(size_t limit) {
172+
std::swap(limit, wired_limit_);
173+
residency_set_.resize(wired_limit_);
174+
return limit;
175+
};
176+
167177
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
168178
// Metal doesn't like empty buffers
169179
if (size == 0) {
@@ -220,6 +230,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
220230
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
221231
}
222232

233+
residency_set_.insert(buf);
234+
223235
return Buffer{static_cast<void*>(buf)};
224236
}
225237

@@ -231,6 +243,7 @@ void MetalAllocator::clear_cache() {
231243
void MetalAllocator::free(Buffer buffer) {
232244
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
233245
std::unique_lock lk(mutex_);
246+
residency_set_.erase(buf);
234247
active_memory_ -= buf->length();
235248
if (get_cache_memory() < max_pool_size_) {
236249
buffer_cache_.recycle_to_cache(buf);
@@ -246,15 +259,9 @@ size_t MetalAllocator::size(Buffer buffer) const {
246259
}
247260

248261
MetalAllocator& allocator() {
249-
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
250-
// not be called on exit and all the buffers will be leaked. This is necessary
251-
// because releasing buffers can take more than 30sec when the program holds a
252-
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
253-
// users when exiting.
254-
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
255-
// when applying this pattern to more places, or when introducing sanitizers
256-
// to MLX.
257-
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
262+
// By creating the |allocator_| on heap, the destructor of MetalAllocator
263+
// will not be called on exit and buffers in the cache will be leaked. This
264+
// can save some time at program exit.
258265
static MetalAllocator* allocator_ = new MetalAllocator;
259266
return *allocator_;
260267
}
@@ -265,6 +272,15 @@ size_t set_cache_limit(size_t limit) {
265272
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
266273
return allocator().set_memory_limit(limit, relaxed);
267274
}
275+
size_t set_wired_limit(size_t limit) {
276+
if (limit >
277+
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
278+
throw std::invalid_argument(
279+
"[metal::set_wired_limit] Setting a wired limit larger than "
280+
"the maximum working set size is not allowed.");
281+
}
282+
return allocator().set_wired_limit(limit);
283+
}
268284
size_t get_active_memory() {
269285
return allocator().get_active_memory();
270286
}

mlx/backend/metal/allocator.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlx/allocator.h"
1010
#include "mlx/backend/metal/device.h"
11+
#include "mlx/backend/metal/resident.h"
1112

1213
namespace mlx::core::metal {
1314

@@ -72,6 +73,7 @@ class MetalAllocator : public allocator::Allocator {
7273
};
7374
size_t set_cache_limit(size_t limit);
7475
size_t set_memory_limit(size_t limit, bool relaxed);
76+
size_t set_wired_limit(size_t limit);
7577
void clear_cache();
7678

7779
private:
@@ -82,12 +84,15 @@ class MetalAllocator : public allocator::Allocator {
8284
// Caching allocator
8385
BufferCache buffer_cache_;
8486

87+
ResidencySet residency_set_;
88+
8589
// Allocation stats
8690
size_t block_limit_;
8791
size_t gc_limit_;
8892
size_t active_memory_{0};
8993
size_t peak_memory_{0};
9094
size_t max_pool_size_;
95+
size_t wired_limit_{0};
9196
bool relaxed_{true};
9297

9398
std::mutex mutex_;

mlx/backend/metal/device.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ void Device::new_queue(int index) {
206206
"[metal::Device] Failed to make new command queue.");
207207
}
208208
stream_map_.emplace(index, q);
209+
if (residency_set_ != nullptr) {
210+
q->addResidencySet(residency_set_);
211+
}
209212
}
210213

211214
int Device::get_command_buffer_ops(int index) {
@@ -351,7 +354,7 @@ MTL::Library* Device::build_library_(const std::string& source_string) {
351354
// Throw error if unable to compile library
352355
if (!mtl_lib) {
353356
std::ostringstream msg;
354-
msg << "[metal::Device] Unable to build metal library from source" << "\n";
357+
msg << "[metal::Device] Unable to build metal library from source\n";
355358
if (error) {
356359
msg << error->localizedDescription()->utf8String() << "\n";
357360
}
@@ -593,6 +596,21 @@ MTL::ComputePipelineState* Device::get_kernel(
593596
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
594597
}
595598

599+
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
600+
if (residency_set_ != nullptr) {
601+
throw std::runtime_error(
602+
"[Device::set_residency_set] Can only be set once.");
603+
}
604+
if (residency_set == nullptr) {
605+
return;
606+
}
607+
residency_set_ = residency_set;
608+
// Attach residency set to existing command queues
609+
for (auto& [_, stream] : stream_map_) {
610+
stream.queue->addResidencySet(residency_set_);
611+
}
612+
}
613+
596614
Device& device(mlx::core::Device) {
597615
static Device metal_device;
598616
return metal_device;

mlx/backend/metal/device.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ class Device {
181181
void add_temporary(array arr, int index);
182182
void add_temporaries(std::vector<array> arrays, int index);
183183

184+
void set_residency_set(const MTL::ResidencySet* residency_set);
185+
184186
private:
185187
DeviceStream& get_stream_(int index) {
186188
return stream_map_.find(index)->second;
@@ -225,6 +227,7 @@ class Device {
225227

226228
std::shared_mutex library_mtx_;
227229
std::unordered_map<std::string, MTL::Library*> library_map_;
230+
const MTL::ResidencySet* residency_set_{nullptr};
228231
};
229232

230233
Device& device(mlx::core::Device);

mlx/backend/metal/matmul.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// Copyright © 2023 Apple Inc.
22

3+
#pragma once
4+
35
#include "mlx/backend/metal/device.h"
46

57
namespace mlx::core {

mlx/backend/metal/metal.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,19 @@ size_t set_cache_limit(size_t limit);
6363
/* Clear the memory cache. */
6464
void clear_cache();
6565

66+
/* Set the wired size limit.
67+
*
68+
* Note, this function is only useful for macOS 15.0 or higher.
69+
*
70+
* The wired limit is the total size in bytes of memory that will be kept
71+
* resident. The default value is ``0``.
72+
*
73+
* Setting a wired limit larger than system wired limit is an error.
74+
*
75+
* Returns the previous wired limit.
76+
* */
77+
size_t set_wired_limit(size_t limit);
78+
6679
/** Capture a GPU trace, saving it to an absolute file `path` */
6780
void start_capture(std::string path = "");
6881
void stop_capture();

mlx/backend/metal/resident.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
#include "mlx/backend/metal/resident.h"
4+
#include "mlx/backend/metal/metal_impl.h"
5+
6+
namespace mlx::core::metal {
7+
8+
// TODO maybe worth including tvos / visionos
9+
#define supported __builtin_available(macOS 15, iOS 18, *)
10+
11+
ResidencySet::ResidencySet(MTL::Device* d) {
12+
if (supported) {
13+
auto pool = new_scoped_memory_pool();
14+
auto desc = MTL::ResidencySetDescriptor::alloc()->init();
15+
NS::Error* error;
16+
wired_set_ = d->newResidencySet(desc, &error);
17+
desc->release();
18+
if (!wired_set_) {
19+
std::ostringstream msg;
20+
msg << "[metal::Device] Unable to construct residency set.\n";
21+
if (error) {
22+
msg << error->localizedDescription()->utf8String() << "\n";
23+
}
24+
throw std::runtime_error(msg.str());
25+
}
26+
}
27+
}
28+
29+
void ResidencySet::insert(MTL::Allocation* buf) {
30+
if (supported) {
31+
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
32+
wired_set_->addAllocation(buf);
33+
wired_set_->commit();
34+
wired_set_->requestResidency();
35+
} else {
36+
unwired_set_.insert(buf);
37+
}
38+
}
39+
}
40+
41+
void ResidencySet::erase(MTL::Allocation* buf) {
42+
if (supported) {
43+
if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) {
44+
unwired_set_.erase(it);
45+
} else {
46+
wired_set_->removeAllocation(buf);
47+
wired_set_->commit();
48+
}
49+
}
50+
}
51+
52+
void ResidencySet::resize(size_t size) {
53+
if (supported) {
54+
if (capacity_ == size) {
55+
return;
56+
}
57+
capacity_ = size;
58+
59+
size_t current_size = wired_set_->allocatedSize();
60+
61+
if (current_size < size) {
62+
// Add unwired allocations to the set
63+
for (auto it = unwired_set_.begin(); it != unwired_set_.end();) {
64+
auto buf_size = (*it)->allocatedSize();
65+
if (current_size + buf_size > size) {
66+
it++;
67+
} else {
68+
current_size += buf_size;
69+
wired_set_->addAllocation(*it);
70+
unwired_set_.erase(it++);
71+
}
72+
}
73+
wired_set_->commit();
74+
wired_set_->requestResidency();
75+
} else if (current_size > size) {
76+
// Remove wired allocations until under capacity
77+
auto allocations = wired_set_->allAllocations();
78+
auto num_allocations = wired_set_->allocationCount();
79+
for (int i = 0; i < num_allocations && current_size > size; ++i) {
80+
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
81+
wired_set_->removeAllocation(buf);
82+
current_size -= buf->allocatedSize();
83+
unwired_set_.insert(buf);
84+
}
85+
wired_set_->commit();
86+
}
87+
}
88+
}
89+
90+
ResidencySet::~ResidencySet() {
91+
if (supported) {
92+
wired_set_->release();
93+
}
94+
}
95+
96+
} // namespace mlx::core::metal

mlx/backend/metal/resident.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
#pragma once
4+
5+
#include "mlx/backend/metal/device.h"
6+
7+
namespace mlx::core::metal {
8+
9+
class ResidencySet {
10+
public:
11+
ResidencySet(MTL::Device* d);
12+
~ResidencySet();
13+
14+
ResidencySet(const ResidencySet&) = delete;
15+
ResidencySet& operator=(const ResidencySet&) = delete;
16+
17+
const MTL::ResidencySet* mtl_residency_set() {
18+
return wired_set_;
19+
}
20+
21+
void insert(MTL::Allocation* buf);
22+
void erase(MTL::Allocation* buf);
23+
24+
void resize(size_t size);
25+
26+
private:
27+
MTL::ResidencySet* wired_set_{nullptr};
28+
std::unordered_set<const MTL::Allocation*> unwired_set_;
29+
size_t capacity_{0};
30+
};
31+
32+
} // namespace mlx::core::metal

0 commit comments

Comments
 (0)