diff --git a/src/Init/System/IO.lean b/src/Init/System/IO.lean index 159d446fb234..8cb5641fca90 100644 --- a/src/Init/System/IO.lean +++ b/src/Init/System/IO.lean @@ -557,6 +557,13 @@ Waits for the task to finish, then returns its result. @[extern "lean_io_wait"] opaque wait (t : Task α) : BaseIO α := return t.get +/-- +Waits until any of the tasks in the list has finished, then return its result. +-/ +@[extern "lean_io_wait_any"] opaque waitAny (tasks : @& List (Task α)) + (h : tasks.length > 0 := by exact Nat.zero_lt_succ _) : BaseIO α := + return tasks[0].get + /-- Returns the number of _heartbeats_ that have occurred during the current thread's execution. The heartbeat count is the number of “small” memory allocations performed in a thread. diff --git a/src/Init/Task.lean b/src/Init/Task.lean index a03bd6688860..3240299424fa 100644 --- a/src/Init/Task.lean +++ b/src/Init/Task.lean @@ -12,17 +12,6 @@ public import Init.System.Promise public section -/-- -Waits until any of the tasks in the list has finished, then return its result. --/ -@[noinline] -def IO.waitAny (tasks : @& List (Task α)) (h : tasks.length > 0 := by exact Nat.zero_lt_succ _) : - BaseIO α := do - have : Nonempty α := ⟨tasks[0].get⟩ - let promise : IO.Promise α ← IO.Promise.new - tasks.forM <| fun t => BaseIO.chainTask (sync := true) t promise.resolve - return promise.result!.get - namespace Task /-- diff --git a/src/include/lean/lean.h b/src/include/lean/lean.h index b3ae906c7935..644f684e24a7 100644 --- a/src/include/lean/lean.h +++ b/src/include/lean/lean.h @@ -1225,6 +1225,8 @@ LEAN_EXPORT bool lean_io_check_canceled_core(void); LEAN_EXPORT void lean_io_cancel_core(b_lean_obj_arg t); /* primitive for implementing `IO.getTaskState : Task a -> IO TaskState` */ LEAN_EXPORT uint8_t lean_io_get_task_state_core(b_lean_obj_arg t); +/* primitive for implementing `IO.waitAny : List (Task a) -> IO (Task a)` */ +LEAN_EXPORT b_lean_obj_res lean_io_wait_any_core(b_lean_obj_arg task_list); /* External objects */ diff --git a/src/runtime/io.cpp b/src/runtime/io.cpp index 6834383a09d1..e1e16df4f003 100644 --- a/src/runtime/io.cpp +++ b/src/runtime/io.cpp @@ -1554,6 +1554,13 @@ extern "C" LEAN_EXPORT obj_res lean_io_wait(obj_arg t) { return lean_task_get_own(t); } +extern "C" LEAN_EXPORT obj_res lean_io_wait_any(b_obj_arg task_list) { + object * t = lean_io_wait_any_core(task_list); + object * v = lean_task_get(t); + lean_inc(v); + return v; +} + extern "C" LEAN_EXPORT obj_res lean_io_exit(uint8_t code) { exit(code); } diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index 3a20f242f6ff..f595049f4a8f 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -847,6 +847,17 @@ class task_manager { } } + object * wait_any_check(object * task_list) { + object * it = task_list; + while (!is_scalar(it)) { + object * head = lean_ctor_get(it, 0); + if (lean_to_task(head)->m_value) + return head; + it = cnstr_get(it, 1); + } + return nullptr; + } + public: task_manager(unsigned max_std_workers): m_max_std_workers(max_std_workers) { @@ -929,6 +940,17 @@ class task_manager { } } + object * wait_any(object * task_list) { + if (object * t = wait_any_check(task_list)) + return t; + unique_lock lock(m_mutex); + while (true) { + if (object * t = wait_any_check(task_list)) + return t; + m_task_finished_cv.wait(lock); + } + } + void deactivate_task(lean_task_object * t) { unique_lock lock(m_mutex); if (object * v = t->m_value) { @@ -1166,6 +1188,10 @@ extern "C" LEAN_EXPORT uint8_t lean_io_get_task_state_core(b_obj_arg t) { return g_task_manager->get_task_state(o); } +extern "C" LEAN_EXPORT b_obj_res lean_io_wait_any_core(b_obj_arg task_list) { + return g_task_manager->wait_any(task_list); +} + obj_res lean_promise_new() { lean_always_assert(g_task_manager); diff --git a/src/runtime/object.h b/src/runtime/object.h index a386eb73a37a..e8508f3d02b3 100644 --- a/src/runtime/object.h +++ b/src/runtime/object.h @@ -287,6 +287,7 @@ inline b_obj_res task_get(b_obj_arg t) { return lean_task_get(t); } inline bool io_check_canceled_core() { return lean_io_check_canceled_core(); } inline void io_cancel_core(b_obj_arg t) { return lean_io_cancel_core(t); } inline bool io_get_task_state_core(b_obj_arg t) { return lean_io_get_task_state_core(t); } +inline b_obj_res io_wait_any_core(b_obj_arg task_list) { return lean_io_wait_any_core(task_list); } // ======================================= // External