Skip to content

Commit 771575d

Browse files
authored
Expose function to clear memory cache (#1032)
* expose function to clear memory cache * fix linux build * fix metal tests
1 parent 20a01bb commit 771575d

File tree

9 files changed

+31
-2
lines changed

9 files changed

+31
-2
lines changed

docs/src/python/metal.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ Metal
1212
get_cache_memory
1313
set_memory_limit
1414
set_cache_limit
15+
clear_cache
1516
start_capture
1617
stop_capture

docs/src/python/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Operations
9292
moveaxis
9393
multiply
9494
negative
95+
not_equal
9596
ones
9697
ones_like
9798
outer

mlx/backend/metal/allocator.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
209209
return Buffer{static_cast<void*>(buf)};
210210
}
211211

212+
void MetalAllocator::clear_cache() {
213+
std::unique_lock lk(mutex_);
214+
buffer_cache_.clear();
215+
}
216+
212217
void MetalAllocator::free(Buffer buffer) {
213218
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
214219
std::unique_lock lk(mutex_);
@@ -242,6 +247,9 @@ size_t get_peak_memory() {
242247
size_t get_cache_memory() {
243248
return allocator().get_cache_memory();
244249
}
250+
void clear_cache() {
251+
return allocator().clear_cache();
252+
}
245253

246254
} // namespace metal
247255

mlx/backend/metal/allocator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class BufferCache {
2626
size_t cache_size() {
2727
return pool_size_;
2828
}
29+
void clear();
2930

3031
private:
3132
struct BufferHolder {
@@ -37,7 +38,6 @@ class BufferCache {
3738
MTL::Buffer* buf;
3839
};
3940

40-
void clear();
4141
void add_at_head(BufferHolder* to_add);
4242
void remove_from_list(BufferHolder* to_remove);
4343

@@ -67,6 +67,7 @@ class MetalAllocator : public allocator::Allocator {
6767
};
6868
size_t set_cache_limit(size_t limit);
6969
size_t set_memory_limit(size_t limit, bool relaxed);
70+
void clear_cache();
7071

7172
private:
7273
MTL::Device* device_;

mlx/backend/metal/metal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ size_t set_memory_limit(size_t limit, bool relaxed = true);
5454
* */
5555
size_t set_cache_limit(size_t limit);
5656

57+
/* Clear the memory cache. */
58+
void clear_cache();
59+
5760
/** Capture a GPU trace, saving it to an absolute file `path` */
5861
void start_capture(std::string path = "");
5962
void stop_capture();

mlx/backend/no_metal/metal.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
#include "mlx/backend/metal/metal.h"
66
#include "mlx/backend/metal/metal_impl.h"
7-
87
namespace mlx::core::metal {
98

109
bool is_available() {
@@ -48,5 +47,6 @@ size_t set_cache_limit(size_t) {
4847
}
4948
void start_capture(std::string path) {}
5049
void stop_capture() {}
50+
void clear_cache() {}
5151

5252
} // namespace mlx::core::metal

python/src/metal.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,15 @@ void init_metal(nb::module_& m) {
9090
Returns:
9191
int: The previous cache limit in bytes.
9292
)pbdoc");
93+
metal.def(
94+
"clear_cache",
95+
&metal::clear_cache,
96+
R"pbdoc(
97+
Clear the memory cache.
98+
99+
After calling this, :func:`get_cache_memory` should return ``0``.
100+
)pbdoc");
101+
93102
metal.def(
94103
"start_capture",
95104
&metal::start_capture,

python/tests/test_metal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def test_memory_info(self):
4242
cache_mem = mx.metal.get_cache_memory()
4343
self.assertTrue(cache_mem >= 4096 * 4)
4444

45+
mx.metal.clear_cache()
46+
self.assertEqual(mx.metal.get_cache_memory(), 0)
47+
4548

4649
if __name__ == "__main__":
4750
unittest.main()

tests/metal_tests.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,4 +513,7 @@ TEST_CASE("test metal memory info") {
513513
auto cache_mem = metal::get_cache_memory();
514514
CHECK(cache_mem >= 4096 * 4);
515515
}
516+
517+
metal::clear_cache();
518+
CHECK_EQ(metal::get_cache_memory(), 0);
516519
}

0 commit comments

Comments
 (0)