Skip to content

Commit 69a2991

Browse files
authored
allow compiling lambdas in C++ (#1650)
* allow compiling lambdas in C++ * fix test * more tests * auto detect capture-less lambda
1 parent fd3377d commit 69a2991

File tree

4 files changed

+115
-12
lines changed

4 files changed

+115
-12
lines changed

mlx/compile.cpp

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,7 @@ std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
208208
using FunType = T (*)(U...);
209209
const FunType* fun_ptr = fun.template target<FunType>();
210210
if (fun_ptr == nullptr) {
211-
throw std::invalid_argument(
212-
"[compile] Cannot compile a non-addressable function.");
211+
return 0;
213212
}
214213
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
215214
}
@@ -817,17 +816,28 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
817816
}
818817
}
819818

819+
bool skip_compile() {
820+
return compile_mode() == CompileMode::disabled ||
821+
!(compile_available_for_device(default_device()));
822+
}
823+
820824
std::function<std::vector<array>(const std::vector<array>&)> compile(
821-
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
825+
std::function<std::vector<array>(const std::vector<array>&)> fun,
822826
std::uintptr_t fun_id,
823827
bool shapeless /* = false */,
824828
std::vector<uint64_t> constants /* = {} */) {
825-
if (compile_mode() == CompileMode::disabled ||
826-
!(compile_available_for_device(default_device()))) {
829+
if (skip_compile()) {
827830
return fun;
828831
}
829-
return [fun, fun_id, shapeless, constants = std::move(constants)](
830-
const std::vector<array>& inputs) {
832+
if (!fun) {
833+
throw std::invalid_argument(
834+
"[compile] Cannot compile a function without a target.");
835+
}
836+
837+
return [fun = std::move(fun),
838+
fun_id,
839+
shapeless,
840+
constants = std::move(constants)](const std::vector<array>& inputs) {
831841
// If the inputs are tracers, trace the original graph
832842
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
833843
return in.is_tracer();
@@ -889,13 +899,41 @@ void compile_clear_cache() {
889899
} // namespace detail
890900

891901
std::function<std::vector<array>(const std::vector<array>&)> compile(
892-
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
902+
std::function<std::vector<array>(const std::vector<array>&)> fun,
893903
bool shapeless /* false */) {
894-
if (detail::compile_mode() == CompileMode::disabled) {
904+
if (detail::skip_compile()) {
895905
return fun;
896906
}
897907
auto fun_id = detail::get_function_address(fun);
898-
return detail::compile(fun, fun_id, shapeless);
908+
if (fun_id) {
909+
// If the function has an addressable target then no need to manage it's
910+
// lifetime
911+
return detail::compile(std::move(fun), fun_id, shapeless);
912+
} else {
913+
auto pfun = std::shared_ptr<
914+
std::function<std::vector<array>(const std::vector<array>&)>>(
915+
new std::function<std::vector<array>(const std::vector<array>&)>{fun},
916+
[](auto p) {
917+
detail::compile_erase(reinterpret_cast<std::uintptr_t>(p));
918+
delete p;
919+
});
920+
fun_id = reinterpret_cast<std::uintptr_t>(pfun.get());
921+
return detail::compile(
922+
[pfun = std::move(pfun)](const auto& inputs) {
923+
return (*pfun)(inputs);
924+
},
925+
fun_id,
926+
shapeless);
927+
}
928+
}
929+
930+
std::function<std::vector<array>(const std::vector<array>&)> compile(
931+
std::vector<array>(fun)(const std::vector<array>&),
932+
bool shapeless /* = false */) {
933+
if (detail::skip_compile()) {
934+
return fun;
935+
}
936+
return detail::compile(fun, reinterpret_cast<std::uintptr_t>(fun), shapeless);
899937
}
900938

901939
void disable_compile() {

mlx/compile.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,24 @@ enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
1010

1111
/** Compile takes a function and returns a compiled function. */
1212
std::function<std::vector<array>(const std::vector<array>&)> compile(
13-
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
13+
std::function<std::vector<array>(const std::vector<array>&)> fun,
1414
bool shapeless = false);
1515

16+
std::function<std::vector<array>(const std::vector<array>&)> compile(
17+
std::vector<array>(fun)(const std::vector<array>&),
18+
bool shapeless = false);
19+
20+
// Convert capture-less lambdas to function pointers.
21+
template <
22+
typename F,
23+
typename = std::enable_if_t<
24+
std::is_convertible_v<F, decltype(+std::declval<F>())>>>
25+
std::function<std::vector<array>(const std::vector<array>&)> compile(
26+
F&& f,
27+
bool shapeless = false) {
28+
return compile(+f, shapeless);
29+
}
30+
1631
/** Globally disable compilation.
1732
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
1833
* be used to disable compilation.

mlx/compile_impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace mlx::core::detail {
99
// This is not part of the general C++ API as calling with a bad id is a bad
1010
// idea.
1111
std::function<std::vector<array>(const std::vector<array>&)> compile(
12-
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
12+
std::function<std::vector<array>(const std::vector<array>&)> fun,
1313
std::uintptr_t fun_id,
1414
bool shapeless = false,
1515
std::vector<uint64_t> constants = {});

tests/compile_tests.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,3 +730,53 @@ TEST_CASE("test compile change streams") {
730730
out = cfun({array(1.0f), array(2.0f)})[0];
731731
CHECK_EQ(out.primitive().stream(), s);
732732
}
733+
734+
TEST_CASE("test compile lambda") {
735+
auto fun = [](const std::vector<array>& inputs) {
736+
return std::vector<array>{abs(inputs[0])};
737+
};
738+
739+
auto out = compile(fun)({array(-1)});
740+
CHECK_EQ(out[0].item<int>(), 1);
741+
742+
decltype(compile(nullptr)) c_local_fun;
743+
{
744+
auto local_fun = [](const std::vector<array>& inputs) {
745+
return std::vector<array>{abs(inputs[0])};
746+
};
747+
c_local_fun = compile(local_fun);
748+
}
749+
750+
// This is ok even though local_fun is out of scope
751+
out = c_local_fun({array(-1)});
752+
CHECK_EQ(out[0].item<int>(), 1);
753+
754+
{
755+
int x = 2;
756+
auto local_fun = [x](const std::vector<array>& inputs) {
757+
return std::vector<array>{inputs[0] + x};
758+
};
759+
c_local_fun = compile(local_fun);
760+
}
761+
// Also ok even though local_fun is out of scope.
762+
out = c_local_fun({array(0)});
763+
CHECK_EQ(out[0].item<int>(), 2);
764+
765+
int x = 2;
766+
auto fun_with_capture = [&x](const std::vector<array>& inputs) {
767+
return std::vector<array>{inputs[0] + x};
768+
};
769+
auto cfun = compile(fun_with_capture);
770+
out = cfun({array(0)});
771+
CHECK_EQ(out[0].item<int>(), 2);
772+
773+
// Doesn't recompile
774+
x = 3;
775+
out = cfun({array(0)});
776+
CHECK_EQ(out[0].item<int>(), 2);
777+
778+
// Recompiles
779+
auto cfun2 = compile(fun_with_capture);
780+
out = cfun2({array(0)});
781+
CHECK_EQ(out[0].item<int>(), 3);
782+
}

0 commit comments

Comments
 (0)