diff --git a/include/oneapi/tbb/detail/_task.h b/include/oneapi/tbb/detail/_task.h index e1bb70c5be..94e20f06dd 100644 --- a/include/oneapi/tbb/detail/_task.h +++ b/include/oneapi/tbb/detail/_task.h @@ -62,6 +62,7 @@ TBB_EXPORT d1::slot_id __TBB_EXPORTED_FUNC execution_slot(const d1::execution_da TBB_EXPORT d1::slot_id __TBB_EXPORTED_FUNC execution_slot(const d1::task_arena_base&); TBB_EXPORT d1::task_group_context* __TBB_EXPORTED_FUNC current_context(); TBB_EXPORT d1::wait_tree_vertex_interface* get_thread_reference_vertex(d1::wait_tree_vertex_interface* wc); +TBB_EXPORT d1::task* __TBB_EXPORTED_FUNC current_task(); // Do not place under __TBB_RESUMABLE_TASKS. It is a stub for unsupported platforms. struct suspend_point_type; @@ -213,6 +214,7 @@ class reference_vertex : public wait_tree_vertex_interface { } private: wait_tree_vertex_interface* my_parent; +protected: std::atomic m_ref_count; }; @@ -268,6 +270,7 @@ inline void wait(wait_context& wait_ctx, task_group_context& ctx) { call_itt_task_notify(destroy, &wait_ctx); } +using r1::current_task; using r1::current_context; class task_traits { diff --git a/include/oneapi/tbb/detail/_task_handle.h b/include/oneapi/tbb/detail/_task_handle.h index 26212b462c..7c20a189ec 100644 --- a/include/oneapi/tbb/detail/_task_handle.h +++ b/include/oneapi/tbb/detail/_task_handle.h @@ -22,7 +22,11 @@ #include "_task.h" #include "_small_object_pool.h" #include "_utils.h" +#include "oneapi/tbb/mutex.h" + #include +#include +#include namespace tbb { namespace detail { @@ -31,12 +35,89 @@ namespace d1 { class task_group_context; class wait_context; struct execution_da namespace d2 { class task_handle; +class task_handle_task; +class task_state_handler; -class task_handle_task : public d1::task { - std::uint64_t m_version_and_traits{}; - d1::wait_tree_vertex_interface* m_wait_tree_vertex; +class continuation_vertex : public d1::reference_vertex { +public: + continuation_vertex(task_handle_task* continuation_task, d1::task_group_context& ctx, d1::small_object_allocator& alloc) + : d1::reference_vertex(nullptr, 1) + , m_continuation_task(continuation_task) + , m_ctx(ctx) + , m_allocator(alloc) + {} + + void release(std::uint32_t delta = 1) override; + +private: + task_handle_task* m_continuation_task; d1::task_group_context& m_ctx; d1::small_object_allocator m_allocator; +}; + +class transfer_vertex : public d1::reference_vertex { +public: + transfer_vertex(task_state_handler* handler, d1::small_object_allocator& alloc) + : d1::reference_vertex(nullptr, 1) + , m_handler(handler) + , m_allocator(alloc) + {} + + void release(std::uint32_t) override; + + void add_successor(d1::wait_tree_vertex_interface* successor) { + m_wait_tree_vertex_successors.push_front(successor); + } + +private: + task_state_handler* m_handler{nullptr}; + std::forward_list m_wait_tree_vertex_successors; + d1::small_object_allocator m_allocator; +}; + +class task_state_handler { +public: + task_state_handler(task_handle_task* task, d1::small_object_allocator& alloc) : m_task(task), m_alloc(alloc) {} + void release() { + d1::mutex::scoped_lock lock(m_mutex); + release_impl(lock); + } + + void complete_task(bool is_from_transfer = false) { + d1::mutex::scoped_lock lock(m_mutex); + if (m_transfer == nullptr || is_from_transfer) { + m_is_finished = true; + } + release_impl(lock); + } + + void add_successor(task_handle_task& successor); + void transfer_successors_to(task_handle_task* target); + + transfer_vertex* create_transfer_vertex() { + d1::small_object_allocator alloc; + ++m_num_references; + m_transfer = alloc.new_object(this, alloc); + return m_transfer; + } + +private: + void release_impl(d1::mutex::scoped_lock& lock) { + if (--m_num_references == 0) { + lock.release(); + m_alloc.delete_object(this); + } + } + + task_handle_task* m_task; + bool m_is_finished{false}; + transfer_vertex* m_transfer{nullptr}; + int m_num_references{2}; + d1::mutex m_mutex; + d1::small_object_allocator m_alloc; +}; + +class task_handle_task : public d1::task { public: void finalize(const d1::execution_data* ed = nullptr) { if (ed) { @@ -47,18 +128,56 @@ class task_handle_task : public d1::task { } task_handle_task(d1::wait_tree_vertex_interface* vertex, d1::task_group_context& ctx, d1::small_object_allocator& alloc) - : m_wait_tree_vertex(vertex) + : m_wait_tree_vertex_successors{vertex} , m_ctx(ctx) - , m_allocator(alloc) { + , m_allocator(alloc) + { suppress_unused_warning(m_version_and_traits); - m_wait_tree_vertex->reserve(); + vertex->reserve(); } - ~task_handle_task() override { - m_wait_tree_vertex->release(); + ~task_handle_task() { + if (m_state_holder) { + m_state_holder->complete_task(); + } + release_successors(); } d1::task_group_context& ctx() const { return m_ctx; } + + bool has_dependency() const { return m_continuation != nullptr; } + + void release_continuation() { m_continuation->release(); } + + void unset_continuation() { m_continuation = nullptr; } + + void transfer_successors_to(task_handle_task* target) { + // TODO: What if we set current task as a dependency later? + if (m_state_holder) { + m_state_holder->transfer_successors_to(target); + } + } + + task_state_handler* get_state_holder() { + d1::small_object_allocator alloc; + m_state_holder = alloc.new_object(this, alloc); + return m_state_holder; + } + +private: + void release_successors() { + for (auto successor : m_wait_tree_vertex_successors) { + successor->release(); + } + } + + friend task_state_handler; + std::uint64_t m_version_and_traits{}; + task_state_handler* m_state_holder{nullptr}; + std::forward_list m_wait_tree_vertex_successors; + continuation_vertex* m_continuation{nullptr}; + d1::task_group_context& m_ctx; + d1::small_object_allocator m_allocator; }; @@ -69,10 +188,39 @@ class task_handle { using handle_impl_t = std::unique_ptr; handle_impl_t m_handle = {nullptr}; + task_state_handler* m_state_holder = {nullptr}; public: task_handle() = default; - task_handle(task_handle&&) = default; - task_handle& operator=(task_handle&&) = default; + task_handle(task_handle&& th) : m_handle(std::move(th.m_handle)), m_state_holder(th.m_state_holder) { + th.m_state_holder = nullptr; + } + + task_handle& operator=(task_handle&& th) { + if (this != &th) { + m_handle = std::move(th.m_handle); + m_state_holder = th.m_state_holder; + th.m_state_holder = nullptr; + } + return *this; + } + + ~task_handle() { + if (m_state_holder) { + m_state_holder->release(); + } + } + + void add_predecessor(task_handle& th) { + if (m_state_holder) { + th.m_state_holder->add_successor(*m_handle); + } + } + + void add_successor(task_handle& th) { + if (m_state_holder) { + m_state_holder->add_successor(*th.m_handle); + } + } explicit operator bool() const noexcept { return static_cast(m_handle); } @@ -85,7 +233,7 @@ class task_handle { private: friend struct task_handle_accessor; - task_handle(task_handle_task* t) : m_handle {t}{}; + task_handle(task_handle_task* t) : m_handle{t}, m_state_holder(t->get_state_holder()) {}; d1::task* release() { return m_handle.release(); @@ -99,6 +247,14 @@ static d1::task_group_context& ctx_of(task_handle& th) { __TBB_ASSERT(th.m_handle, "ctx_of does not expect empty task_handle."); return th.m_handle->ctx(); } +static bool has_dependency(task_handle& th) { return th.m_handle->has_dependency(); } +static void release_continuation(task_handle& th) { + th.m_handle->release_continuation(); + th.release(); +} +static void transfer_successors_to(task_handle& th, task_handle_task* task) { + task->transfer_successors_to(th.m_handle.get()); +} }; inline bool operator==(task_handle const& th, std::nullptr_t) noexcept { diff --git a/include/oneapi/tbb/task_group.h b/include/oneapi/tbb/task_group.h index c0811c8502..a89cff6f07 100644 --- a/include/oneapi/tbb/task_group.h +++ b/include/oneapi/tbb/task_group.h @@ -414,6 +414,7 @@ class task_group_context : no_copy { friend struct r1::task_arena_impl; friend struct r1::task_group_context_impl; friend class d2::task_group_base; + friend class d2::continuation_vertex; }; // class task_group_context static_assert(sizeof(task_group_context) == 128, "Wrong size of task_group_context"); @@ -445,7 +446,9 @@ class function_stack_task : public d1::task { d1::wait_tree_vertex_interface* m_wait_tree_vertex; void finalize() { - m_wait_tree_vertex->release(); + if (m_wait_tree_vertex) { + m_wait_tree_vertex->release(); + } } task* execute(d1::execution_data&) override { task* res = d2::task_ptr_or_nullptr(m_func); @@ -581,13 +584,16 @@ class task_group : public task_group_base { using acs = d2::task_handle_accessor; __TBB_ASSERT(&acs::ctx_of(h) == &context(), "Attempt to schedule task_handle into different task_group"); - d1::spawn(*acs::release(h), context()); + if (!acs::has_dependency(h)) { + d1::spawn(*acs::release(h), context()); + } else { + acs::release_continuation(h); + } } template d2::task_handle defer(F&& f) { return prepare_task_handle(std::forward(f)); - } template @@ -600,6 +606,57 @@ class task_group : public task_group_base { } }; // class task_group +inline void continuation_vertex::release(std::uint32_t delta) { + std::uint64_t ref = m_ref_count.fetch_sub(static_cast(delta)) - static_cast(delta); + if (ref == 0) { + m_continuation_task->unset_continuation(); + d1::spawn(*m_continuation_task, m_ctx.actual_context()); + m_allocator.delete_object(this); + } +} + +inline void task_state_handler::add_successor(task_handle_task& successor) { + if (successor.m_continuation == nullptr) { + d1::small_object_allocator alloc; + successor.m_continuation = alloc.new_object(&successor, successor.m_ctx, alloc); + } + + d1::mutex::scoped_lock lock(m_mutex); + if (!m_is_finished && m_transfer) { + successor.m_continuation->reserve(); + m_transfer->add_successor(successor.m_continuation); + } else if (!m_is_finished) { + successor.m_continuation->reserve(); + m_task->m_wait_tree_vertex_successors.push_front(successor.m_continuation); + } +} + +inline void task_state_handler::transfer_successors_to(task_handle_task* target) { + d1::mutex::scoped_lock lock(m_mutex); + + auto task_finalizer = create_transfer_vertex(); + target->m_wait_tree_vertex_successors.push_front(task_finalizer); + target->m_wait_tree_vertex_successors.splice_after(target->m_wait_tree_vertex_successors.begin(), m_task->m_wait_tree_vertex_successors); + m_task->m_wait_tree_vertex_successors.clear(); +} + +inline void transfer_vertex::release(std::uint32_t) { + m_handler->complete_task(true); + for (auto successor : m_wait_tree_vertex_successors) { + successor->release(); + } + m_allocator.delete_object(this); +} + +inline void transfer_successors_to(d2::task_handle& h) { + task_handle_task* task = dynamic_cast(d1::current_task()); + __TBB_ASSERT_RELEASE(task, "Attempt to transfer successors from non-task_handle_task"); + using acs = d2::task_handle_accessor; + __TBB_ASSERT(&acs::ctx_of(h) == &task->ctx(), "Attempt to transfer successors to task_handle into different task_group"); + + acs::transfer_successors_to(h, task); +} + #if TBB_PREVIEW_ISOLATED_TASK_GROUP class spawn_delegate : public d1::delegate_base { d1::task* task_to_spawn; @@ -701,6 +758,13 @@ using detail::d1::is_current_task_group_canceling; using detail::r1::missing_wait; using detail::d2::task_handle; + +namespace this_task_group { +namespace current_task { + using detail::d2::transfer_successors_to; +} +} + } } // namespace tbb diff --git a/src/tbb/def/lin32-tbb.def b/src/tbb/def/lin32-tbb.def index 737e8ec2af..294982b1c8 100644 --- a/src/tbb/def/lin32-tbb.def +++ b/src/tbb/def/lin32-tbb.def @@ -78,6 +78,7 @@ _ZN3tbb6detail2r16resumeEPNS1_18suspend_point_typeE; _ZN3tbb6detail2r121current_suspend_pointEv; _ZN3tbb6detail2r114notify_waitersEj; _ZN3tbb6detail2r127get_thread_reference_vertexEPNS0_2d126wait_tree_vertex_interfaceE; +_ZN3tbb6detail2r112current_taskEv; /* Task dispatcher (task_dispatcher.cpp) */ _ZN3tbb6detail2r114execution_slotEPKNS0_2d114execution_dataE; diff --git a/src/tbb/def/lin64-tbb.def b/src/tbb/def/lin64-tbb.def index 41aca2e932..d5b4e9dba4 100644 --- a/src/tbb/def/lin64-tbb.def +++ b/src/tbb/def/lin64-tbb.def @@ -78,6 +78,7 @@ _ZN3tbb6detail2r16resumeEPNS1_18suspend_point_typeE; _ZN3tbb6detail2r121current_suspend_pointEv; _ZN3tbb6detail2r114notify_waitersEm; _ZN3tbb6detail2r127get_thread_reference_vertexEPNS0_2d126wait_tree_vertex_interfaceE; +_ZN3tbb6detail2r112current_taskEv; /* Task dispatcher (task_dispatcher.cpp) */ _ZN3tbb6detail2r114execution_slotEPKNS0_2d114execution_dataE; diff --git a/src/tbb/def/mac64-tbb.def b/src/tbb/def/mac64-tbb.def index 38bc48d30e..84f060685c 100644 --- a/src/tbb/def/mac64-tbb.def +++ b/src/tbb/def/mac64-tbb.def @@ -80,6 +80,7 @@ __ZN3tbb6detail2r16resumeEPNS1_18suspend_point_typeE __ZN3tbb6detail2r121current_suspend_pointEv __ZN3tbb6detail2r114notify_waitersEm __ZN3tbb6detail2r127get_thread_reference_vertexEPNS0_2d126wait_tree_vertex_interfaceE +__ZN3tbb6detail2r112current_taskEv # Task dispatcher (task_dispatcher.cpp) __ZN3tbb6detail2r114execution_slotEPKNS0_2d114execution_dataE diff --git a/src/tbb/def/win32-tbb.def b/src/tbb/def/win32-tbb.def index 94b5441701..05797adff2 100644 --- a/src/tbb/def/win32-tbb.def +++ b/src/tbb/def/win32-tbb.def @@ -72,6 +72,7 @@ EXPORTS ?suspend@r1@detail@tbb@@YAXP6AXPAXPAUsuspend_point_type@123@@Z0@Z ?notify_waiters@r1@detail@tbb@@YAXI@Z ?get_thread_reference_vertex@r1@detail@tbb@@YAPAVwait_tree_vertex_interface@d1@23@PAV4523@@Z +?current_task@r1@detail@tbb@@YAPAVtask@d1@23@XZ ; Task dispatcher (task_dispatcher.cpp) ?spawn@r1@detail@tbb@@YAXAAVtask@d1@23@AAVtask_group_context@523@G@Z diff --git a/src/tbb/def/win64-tbb.def b/src/tbb/def/win64-tbb.def index 96bafc0163..333a8f0a01 100644 --- a/src/tbb/def/win64-tbb.def +++ b/src/tbb/def/win64-tbb.def @@ -72,6 +72,7 @@ EXPORTS ?current_suspend_point@r1@detail@tbb@@YAPEAUsuspend_point_type@123@XZ ?notify_waiters@r1@detail@tbb@@YAX_K@Z ?get_thread_reference_vertex@r1@detail@tbb@@YAPEAVwait_tree_vertex_interface@d1@23@PEAV4523@@Z +?current_task@r1@detail@tbb@@YAPEAVtask@d1@23@XZ ; Task dispatcher (task_dispatcher.cpp) ?spawn@r1@detail@tbb@@YAXAEAVtask@d1@23@AEAVtask_group_context@523@@Z diff --git a/src/tbb/scheduler_common.h b/src/tbb/scheduler_common.h index e4686e1673..348402e8cc 100644 --- a/src/tbb/scheduler_common.h +++ b/src/tbb/scheduler_common.h @@ -483,6 +483,9 @@ class alignas (max_nfs_size) task_dispatcher { > m_reference_vertex_map; + //! Innermost task whose task::execute() is running. A nullptr on the outermost level. + d1::task* m_innermost_running_task{ nullptr }; + //! Attempt to get a task from the mailbox. /** Gets a task only if it has not been executed by its sender or a thief that has stolen it from the sender's task pool. Otherwise returns nullptr. diff --git a/src/tbb/task.cpp b/src/tbb/task.cpp index 84b4278f0a..73866ece6d 100644 --- a/src/tbb/task.cpp +++ b/src/tbb/task.cpp @@ -252,6 +252,10 @@ d1::wait_tree_vertex_interface* get_thread_reference_vertex(d1::wait_tree_vertex return ref_counter; } +d1::task* current_task() { + return governor::get_thread_data()->get_current_task(); +} + } // namespace r1 } // namespace detail } // namespace tbb diff --git a/src/tbb/task_dispatcher.h b/src/tbb/task_dispatcher.h index c818934e5a..58c5aaeef6 100644 --- a/src/tbb/task_dispatcher.h +++ b/src/tbb/task_dispatcher.h @@ -249,11 +249,13 @@ d1::task* task_dispatcher::local_wait_for_all(d1::task* t, Waiter& waiter ) { task_dispatcher& task_disp; execution_data_ext old_execute_data_ext; properties old_properties; + d1::task* old_innermost_running_task; bool is_initially_registered; ~dispatch_loop_guard() { task_disp.m_execute_data_ext = old_execute_data_ext; task_disp.m_properties = old_properties; + task_disp.m_innermost_running_task = old_innermost_running_task; if (!is_initially_registered) { task_disp.m_thread_data->my_arena->my_tc_client.get_pm_client()->unregister_thread(); @@ -263,7 +265,7 @@ d1::task* task_dispatcher::local_wait_for_all(d1::task* t, Waiter& waiter ) { __TBB_ASSERT(task_disp.m_thread_data && governor::is_thread_data_set(task_disp.m_thread_data), nullptr); __TBB_ASSERT(task_disp.m_thread_data->my_task_dispatcher == &task_disp, nullptr); } - } dl_guard{ *this, m_execute_data_ext, m_properties, m_thread_data->my_is_registered }; + } dl_guard{ *this, m_execute_data_ext, m_properties, m_innermost_running_task, m_thread_data->my_is_registered }; // The context guard to track fp setting and itt tasks. context_guard_helper context_guard; @@ -328,6 +330,7 @@ d1::task* task_dispatcher::local_wait_for_all(d1::task* t, Waiter& waiter ) { ITT_CALLEE_ENTER(ITTPossible, t, itt_caller); + m_innermost_running_task = t; if (ed.context->is_group_execution_cancelled()) { t = t->cancel(ed); } else { diff --git a/src/tbb/thread_data.h b/src/tbb/thread_data.h index 422ec694ec..f6baf8abe4 100644 --- a/src/tbb/thread_data.h +++ b/src/tbb/thread_data.h @@ -139,6 +139,7 @@ class thread_data : public ::rml::job void enter_task_dispatcher(task_dispatcher& task_disp, std::uintptr_t stealing_threshold); void leave_task_dispatcher(); void propagate_task_group_state(std::atomic d1::task_group_context::* mptr_state, d1::task_group_context& src, uint32_t new_state); + d1::task* get_current_task(); //! Index of the arena slot the scheduler occupies now, or occupied last time unsigned short my_arena_index; @@ -254,6 +255,10 @@ inline void thread_data::propagate_task_group_state(std::atomic d my_context_list->epoch.store(the_context_state_propagation_epoch.load(std::memory_order_relaxed), std::memory_order_release); } +inline d1::task* thread_data::get_current_task() { + return my_task_dispatcher->m_innermost_running_task; +} + } // namespace r1 } // namespace detail } // namespace tbb diff --git a/test/common/memory_usage.h b/test/common/memory_usage.h index cf8b4180d4..06490f03db 100644 --- a/test/common/memory_usage.h +++ b/test/common/memory_usage.h @@ -1,5 +1,5 @@ /* - Copyright (c) 2005-2022 Intel Corporation + Copyright (c) 2005-2024 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,6 +38,8 @@ #elif __APPLE__ && !__ARM_ARCH #include #include +// Undef due to conflict with library API +#undef current_task #include #if MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_6 || __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_8_0 #include diff --git a/test/conformance/conformance_task_group.cpp b/test/conformance/conformance_task_group.cpp index ef2ac39de5..9bf01ae340 100644 --- a/test/conformance/conformance_task_group.cpp +++ b/test/conformance/conformance_task_group.cpp @@ -1,5 +1,5 @@ /* - Copyright (c) 2021-2022 Intel Corporation + Copyright (c) 2021-2024 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -281,3 +281,431 @@ TEST_CASE("Respect task_group_context passed from outside") { accept_task_group_context::test(); } +//! \brief \ref interface \ref requirement +TEST_CASE("Test continuation") { + tbb::task_group tg; + + int x{}, y{}; + int sum{}; + auto sum_task = tg.defer([&] { + sum = x + y; + }); + + auto y_task = tg.defer([&] { + y = 2; + }); + + auto x_task = tg.defer([&] { + x = 40; + }); + + sum_task.add_predecessor(x_task); + sum_task.add_predecessor(y_task); + + tg.run(std::move(sum_task)); + tg.run(std::move(x_task)); + tg.run(std::move(y_task)); + + tg.wait(); + + REQUIRE(sum == 42); +} + +//! \brief \ref interface \ref requirement +TEST_CASE("Test continuation - small tree") { + tbb::task_group tg; + + int x{}, y{}; + int sum{}; + auto sum_task = tg.defer([&] { + sum = x + y; + }); + + auto y_task = tg.defer([&] { + y = 2; + }); + + auto x_task = tg.defer([&] { + x = 40; + }); + + sum_task.add_predecessor(x_task); + sum_task.add_predecessor(y_task); + + int multiplier{}; + int product{}; + auto mult_task = tg.defer([&] { + multiplier = 42; + }); + + auto product_task = tg.defer([&] { + product = sum * multiplier; + }); + + product_task.add_predecessor(sum_task); + product_task.add_predecessor(mult_task); + + tg.run(std::move(sum_task)); + tg.run(std::move(x_task)); + tg.run(std::move(y_task)); + tg.run(std::move(mult_task)); + tg.run(std::move(product_task)); + + tg.wait(); + + REQUIRE(product == 42 * 42); +} + +long serial_fib(int n) { + return n < 2 ? n : serial_fib(n - 1) + serial_fib(n - 2); +} + +struct fib_continuation { + void operator()() const { + m_data->m_sum = m_data->m_x + m_data->m_y; + delete m_data; + } + + struct data { + data(int& sum) : m_sum(sum) {} + int m_x{ 0 }, m_y{ 0 }; + int& m_sum; + }* m_data; + + fib_continuation(fib_continuation::data* d) : m_data(d) {} +}; + +struct fib_computation { + fib_computation(int n, int* x, tbb::task_group& tg) : m_n(n), m_x(x), m_tg(tg) {} + + void operator()() const { + if (m_n < 16) { + *m_x = serial_fib(m_n); + } else { + // Continuation passing + fib_continuation::data* data = new fib_continuation::data(*m_x); + auto continuation = m_tg.defer(fib_continuation{data}); + tbb::this_task_group::current_task::transfer_successors_to(continuation); + // auto& c = *this->allocate_continuation(/* children_counter = */ 2, *x); + auto right = m_tg.defer(fib_computation(m_n - 1, &data->m_x, m_tg)); + auto left = m_tg.defer(fib_computation(m_n - 2, &data->m_y, m_tg)); + + continuation.add_predecessor(left); + continuation.add_predecessor(right); + m_tg.run(std::move(continuation)); + m_tg.run(std::move(left)); + m_tg.run(std::move(right)); + } + } + + int m_n; + int* m_x; + tbb::task_group& m_tg; +}; + +//! \brief \ref interface \ref requirement +TEST_CASE("Test continuation - fibonacci") { + tbb::task_group tg; + int N = 0; + tg.run(fib_computation(30, &N, tg)); + tg.wait(); + REQUIRE_MESSAGE(N == 832040, "Fibonacci(30) should be 832040"); +} + +//! \brief \ref interface \ref requirement +TEST_CASE("Test continuation - multiple successors") { + tbb::task_group tg; + + int x{}, y{}; + int sum{}; + auto sum_task = tg.defer([&] { + sum = x + y; + }); + + auto y_task = tg.defer([&] { + y = 2; + }); + + auto x_task = tg.defer([&] { + x = 40; + }); + + sum_task.add_predecessor(x_task); + sum_task.add_predecessor(y_task); + + int multiplier{}; + int product{}; + auto mult_task = tg.defer([&] { + multiplier = 42; + }); + + auto product_task = tg.defer([&] { + product = sum * multiplier; + }); + + product_task.add_predecessor(sum_task); + product_task.add_predecessor(mult_task); + + int product_plus_sum{}; + auto total_results = tg.defer([&] { + product_plus_sum = product + sum; + }); + + product_task.add_successor(total_results); + sum_task.add_successor(total_results); + // total_results.add_predecessor(product_task); + // total_results.add_predecessor(sum_task); + + tg.run(std::move(total_results)); + tg.run(std::move(sum_task)); + tg.run(std::move(x_task)); + tg.run(std::move(y_task)); + tg.run(std::move(mult_task)); + tg.run(std::move(product_task)); + + tg.wait(); + + REQUIRE(product_plus_sum == 42 * 42 + 42); +} + +//! \brief \ref interface \ref requirement +TEST_CASE("Test add_predecessor after task submission") { + tbb::task_group tg; + + int x{}, y{}; + int sum{}; + auto sum_task = tg.defer([&] { + sum = x + y; + }); + + std::atomic is_ready{false}; + auto y_task = tg.defer([&] { + y = 2; + while (!is_ready) ; + }); + + auto x_task = tg.defer([&] { + x = 40; + while (!is_ready) ; + }); + + tg.run(std::move(x_task)); + tg.run(std::move(y_task)); + + sum_task.add_predecessor(x_task); + sum_task.add_predecessor(y_task); + + is_ready = true; + tg.run(std::move(sum_task)); + + tg.wait(); + + REQUIRE(sum == 42); +} + +//! \brief \ref interface \ref requirement +TEST_CASE("Test add_predecessor after task finish") { + tbb::task_group tg; + + int x{}, y{}; + int sum{}; + auto sum_task = tg.defer([&] { + sum = x + y; + }); + + std::atomic is_ready{}; + auto y_task = tg.defer([&] { + y = 2; + }); + + auto x_task = tg.defer([&] { + x = 40; + is_ready = true; + }); + + tg.run(std::move(x_task)); + + while (!is_ready) ; + sum_task.add_predecessor(x_task); + sum_task.add_predecessor(y_task); + + tg.run(std::move(y_task)); + tg.run(std::move(sum_task)); + + tg.wait(); + + REQUIRE(sum == 42); +} + +namespace users { + template + bool range_is_too_small(T b, T e) { + return std::distance(b, e) <= 4; + } + + template + void do_serial_sort(T b, T e) { + std::sort(b, e); + } + + template + void create_left_range(T& lb, T& le, T b, T e) { + lb = b; + le = b + std::distance(b, e) / 2; + } + + template + void create_right_range(T& rb, T& re, T b, T e) { + rb = b + std::distance(b, e) / 2; + re = e; + } + + template + void do_merge(T lb, T le, T rb, T re) { + std::vector merged(std::distance(lb, le) + std::distance(rb, re)); + std::merge(lb, le, rb, re, merged.begin()); + std::copy(merged.begin(), merged.end(), lb); + } +} + +template +void merge_sort(tbb::task_group& tg, T b, T e) { + if (users::range_is_too_small(b, e)) { + // base-case when range is small + users::do_serial_sort(b, e); + } else { + // calculate left and right ranges + T lb, le, rb, re; + users::create_left_range(lb, le, b, e); + users::create_right_range(rb, re, b, e); + + // create the three tasks + tbb::task_handle sortleft = tg.defer([lb, le, &tg] { merge_sort(tg, lb, le); }); + + tbb::task_handle sortright = tg.defer([rb, re, &tg] { merge_sort(tg, rb, re); }); + tbb::task_handle merge = tg.defer([rb, lb, le, re] { users::do_merge(lb, le, rb, re); }); + + // add predecessors for new merge task + merge.add_predecessor(sortleft); + merge.add_predecessor(sortright); + + // insert new subgraph between currently executing + // task and its successors + tbb::this_task_group::current_task::transfer_successors_to(merge); + + tg.run(std::move(sortleft)); + tg.run(std::move(sortright)); + tg.run(std::move(merge)); + } +} + +//! \brief \ref interface \ref requirement +TEST_CASE("Test continuation - recursive decomposition") { + tbb::task_group tg; + int size = 10000; + std::vector v(size); + for (int i = 0; i < size; ++i) { + v[i] = size - i; + } + + tg.run_and_wait(tg.defer([&] { merge_sort(tg, v.begin(), v.end()); })); + + for (int i = 0; i < size - 1; ++i) { + REQUIRE(v[i] <= v[i + 1]); + } +} + +//! \brief \ref interface \ref requirement +TEST_CASE("Test continuation - add dependency on pre-submitted task tree") { + tbb::task_group tg; + int sum{}; + int x{}, y{}; + + auto init_sum_task = tg.defer([&] { + + auto final_sum = tg.defer([&] { + sum = x + y; + }); + + tbb::this_task_group::current_task::transfer_successors_to(final_sum); + auto x_task = tg.defer([&] { + x = 40; + }); + + auto y_task = tg.defer([&] { + y = 2; + }); + + final_sum.add_predecessor(x_task); + final_sum.add_predecessor(y_task); + + tg.run(std::move(final_sum)); + tg.run(std::move(x_task)); + tg.run(std::move(y_task)); + }); + + int product{}; + auto product_task = tg.defer([&] { + product = sum * sum; + }); + + product_task.add_predecessor(init_sum_task); + + tg.run(std::move(init_sum_task)); + tg.run(std::move(product_task)); + + tg.wait(); + + REQUIRE(product == 42 * 42); +} + +//! \brief \ref interface \ref requirement +TEST_CASE("Test continuation - add dependency after transfer_successors_to") { + tbb::task_group tg; + int sum{}; + int x{}, y{}; + + std::atomic is_ready{false}; + auto init_sum_task = tg.defer([&] { + + auto final_sum = tg.defer([&] { + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + sum = x + y; + }); + + tbb::this_task_group::current_task::transfer_successors_to(final_sum); + is_ready = true; + + auto x_task = tg.defer([&] { + x = 40; + }); + + auto y_task = tg.defer([&] { + y = 2; + }); + + final_sum.add_predecessor(x_task); + final_sum.add_predecessor(y_task); + + tg.run(std::move(final_sum)); + tg.run(std::move(x_task)); + tg.run(std::move(y_task)); + }); + + int product{}; + auto product_task = tg.defer([&] { + product = sum * sum; + }); + tg.run(std::move(init_sum_task)); + + while (!is_ready) ; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + product_task.add_predecessor(init_sum_task); + + tg.run(std::move(product_task)); + + tg.wait(); + + REQUIRE(product == 42 * 42); +} +