Skip to content

Commit b05bcfd

Browse files
authored
Fixes segfault when compiling checkpointed functions (#1235)
1 parent 2615660 commit b05bcfd

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

mlx/compile.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ class CompilerCache {
266266
cache_.erase(fun_id);
267267
}
268268

269+
void clear() {
270+
cache_.clear();
271+
}
272+
269273
private:
270274
CompilerCache() {
271275
// Make sure the allocator is fully
@@ -859,6 +863,10 @@ void compile_erase(std::uintptr_t fun_id) {
859863
detail::compiler_cache().erase(fun_id);
860864
}
861865

866+
void compile_clear_cache() {
867+
detail::compiler_cache().clear();
868+
}
869+
862870
} // namespace detail
863871

864872
std::function<std::vector<array>(const std::vector<array>&)> compile(

mlx/transforms_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
2727
// Erase cached compile functions
2828
void compile_erase(std::uintptr_t fun_id);
2929

30+
// Clear the compiler cache causing a recompilation of all compiled functions
31+
// when called again.
32+
void compile_clear_cache();
33+
3034
// Create an InTracing object during tracing operations to signify to the rest
3135
// of the codebase that we are during tracing so evals should not throw away
3236
// the graph.

python/src/transforms.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,6 @@ struct PyCompiledFun {
536536
class PyCheckpointedFun {
537537
public:
538538
PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}
539-
540539
~PyCheckpointedFun() {
541540
nb::gil_scoped_acquire gil;
542541

@@ -968,5 +967,8 @@ void init_transforms(nb::module_& m) {
968967

969968
// Register static Python object cleanup before the interpreter exits
970969
auto atexit = nb::module_::import_("atexit");
971-
atexit.attr("register")(nb::cpp_function([]() { tree_cache().clear(); }));
970+
atexit.attr("register")(nb::cpp_function([]() {
971+
tree_cache().clear();
972+
detail::compile_clear_cache();
973+
}));
972974
}

0 commit comments

Comments
 (0)