Skip to content

Commit a749a91

Browse files
authored
Support disable metal buffer cache to prevent performance degradation caused by large memory caching (#390)
* support disable metal buffer cache, due to large unused memory buffered when llm generated long context tokens * Run format and add "cache_enabled" feature tests
1 parent 49a5261 commit a749a91

File tree

4 files changed

+67
-1
lines changed

4 files changed

+67
-1
lines changed

mlx/backend/metal/allocator.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ void* Buffer::raw_ptr() {
2323

2424
namespace metal {
2525

26+
static bool cache_enabled_ = true;
27+
28+
bool cache_enabled() {
29+
return cache_enabled_;
30+
}
31+
32+
void set_cache_enabled(bool enabled) {
33+
cache_enabled_ = enabled;
34+
}
35+
2636
namespace {
2737

2838
BufferCache::BufferCache(MTL::Device* device)
@@ -196,7 +206,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
196206

197207
void MetalAllocator::free(Buffer buffer) {
198208
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
199-
buffer_cache_.recycle_to_cache(buf);
209+
if (cache_enabled()) {
210+
buffer_cache_.recycle_to_cache(buf);
211+
} else {
212+
buf->release();
213+
}
200214
}
201215

202216
MetalAllocator& allocator() {

mlx/backend/metal/metal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ constexpr bool is_available() {
1919
#endif
2020
}
2121

22+
bool cache_enabled(void);
23+
void set_cache_enabled(bool enabled);
24+
2225
void new_stream(Stream stream);
2326
std::shared_ptr<void> new_scoped_memory_pool();
2427

python/src/metal.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,12 @@ using namespace mlx::core;
1111
void init_metal(py::module_& m) {
1212
py::module_ metal = m.def_submodule("metal", "mlx.metal");
1313
metal.def("is_available", &metal::is_available);
14+
metal.def(
15+
"cache_enabled",
16+
&metal::cache_enabled,
17+
"check if metal buffer cache is enabled, default is true");
18+
metal.def(
19+
"set_cache_enabled",
20+
&metal::set_cache_enabled,
21+
"enable or disable metal buffer cache");
1422
}

tests/metal_tests.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "mlx/backend/metal/device.h"
77
#include "mlx/backend/metal/metal.h"
8+
#include "mlx/backend/metal/allocator.h"
89
#include "mlx/mlx.h"
910

1011
using namespace mlx::core;
@@ -471,3 +472,43 @@ TEST_CASE("test metal validation") {
471472

472473
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
473474
}
475+
476+
TEST_CASE("test metal enable/disable cache") {
477+
// Test enable metal cache
478+
{
479+
metal::set_cache_enabled(true);
480+
CHECK(metal::cache_enabled());
481+
482+
auto &a = metal::allocator();
483+
auto size = 100;
484+
auto buf = a.malloc(size, false);
485+
486+
// Release a
487+
a.free(buf);
488+
489+
// Check size should equals to size
490+
CHECK_EQ(static_cast<MTL::Buffer*>(buf.ptr())->length(), size);
491+
}
492+
493+
// Test disable metal cache
494+
{
495+
metal::set_cache_enabled(false);
496+
CHECK(!metal::cache_enabled());
497+
498+
auto &a = metal::allocator();
499+
auto size = 100;
500+
auto buf = a.malloc(size, false);
501+
auto buf_ptr = static_cast<MTL::Buffer*>(buf.ptr());
502+
unsigned char first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
503+
printf("first byte: %d\n", first_byte);
504+
505+
// Release a
506+
a.free(buf);
507+
508+
// If release successfully, the first byte should be different from the first byte before release
509+
unsigned char new_first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
510+
printf("new first byte: %d\n", new_first_byte);
511+
512+
CHECK_NE(new_first_byte, first_byte);
513+
}
514+
}

0 commit comments

Comments
 (0)