@@ -208,8 +208,7 @@ std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
208208 using FunType = T (*)(U...);
209209 const FunType* fun_ptr = fun.template target <FunType>();
210210 if (fun_ptr == nullptr ) {
211- throw std::invalid_argument (
212- " [compile] Cannot compile a non-addressable function." );
211+ return 0 ;
213212 }
214213 return reinterpret_cast <std::uintptr_t >(*fun_ptr);
215214}
@@ -817,17 +816,28 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
817816 }
818817}
819818
819+ bool skip_compile () {
820+ return compile_mode () == CompileMode::disabled ||
821+ !(compile_available_for_device (default_device ()));
822+ }
823+
820824std::function<std::vector<array>(const std::vector<array>&)> compile (
821- const std::function<std::vector<array>(const std::vector<array>&)>& fun,
825+ std::function<std::vector<array>(const std::vector<array>&)> fun,
822826 std::uintptr_t fun_id,
823827 bool shapeless /* = false */ ,
824828 std::vector<uint64_t> constants /* = {} */ ) {
825- if (compile_mode () == CompileMode::disabled ||
826- !(compile_available_for_device (default_device ()))) {
829+ if (skip_compile ()) {
827830 return fun;
828831 }
829- return [fun, fun_id, shapeless, constants = std::move (constants)](
830- const std::vector<array>& inputs) {
832+ if (!fun) {
833+ throw std::invalid_argument (
834+ " [compile] Cannot compile a function without a target." );
835+ }
836+
837+ return [fun = std::move (fun),
838+ fun_id,
839+ shapeless,
840+ constants = std::move (constants)](const std::vector<array>& inputs) {
831841 // If the inputs are tracers, trace the original graph
832842 if (std::any_of (inputs.begin (), inputs.end (), [](auto & in) {
833843 return in.is_tracer ();
@@ -889,13 +899,41 @@ void compile_clear_cache() {
889899} // namespace detail
890900
891901std::function<std::vector<array>(const std::vector<array>&)> compile (
892- const std::function<std::vector<array>(const std::vector<array>&)>& fun,
902+ std::function<std::vector<array>(const std::vector<array>&)> fun,
893903 bool shapeless /* false */ ) {
894- if (detail::compile_mode () == CompileMode::disabled ) {
904+ if (detail::skip_compile () ) {
895905 return fun;
896906 }
897907 auto fun_id = detail::get_function_address (fun);
898- return detail::compile (fun, fun_id, shapeless);
908+ if (fun_id) {
909+ // If the function has an addressable target then no need to manage it's
910+ // lifetime
911+ return detail::compile (std::move (fun), fun_id, shapeless);
912+ } else {
913+ auto pfun = std::shared_ptr<
914+ std::function<std::vector<array>(const std::vector<array>&)>>(
915+ new std::function<std::vector<array>(const std::vector<array>&)>{fun},
916+ [](auto p) {
917+ detail::compile_erase (reinterpret_cast <std::uintptr_t >(p));
918+ delete p;
919+ });
920+ fun_id = reinterpret_cast <std::uintptr_t >(pfun.get ());
921+ return detail::compile (
922+ [pfun = std::move (pfun)](const auto & inputs) {
923+ return (*pfun)(inputs);
924+ },
925+ fun_id,
926+ shapeless);
927+ }
928+ }
929+
930+ std::function<std::vector<array>(const std::vector<array>&)> compile (
931+ std::vector<array>(fun)(const std::vector<array>&),
932+ bool shapeless /* = false */ ) {
933+ if (detail::skip_compile ()) {
934+ return fun;
935+ }
936+ return detail::compile (fun, reinterpret_cast <std::uintptr_t >(fun), shapeless);
899937}
900938
901939void disable_compile () {
0 commit comments