@@ -19,7 +19,7 @@ namespace dp {
1919 return std::forward<T>(v);
2020 }
2121
22- // Bind F and args... into a nullary one-shot lambda . Lambda captures by value.
22+ // bind F and parameter pack into a nullary one shot . Lambda captures by value.
2323 template <typename ... Args, typename F>
2424 auto bind (F &&f, Args &&...args) {
2525 return [f = decay_copy (std::forward<F>(f)),
@@ -28,63 +28,58 @@ namespace dp {
2828 };
2929 }
3030
31- template <class Queue , class U = typename Queue::value_type>
32- concept is_valid_queue = requires (Queue q) {
33- { q.empty () } -> std::convertible_to<bool >;
34- { q.front () } -> std::convertible_to<U &>;
35- { q.back () } -> std::convertible_to<U &>;
36- q.pop ();
37- };
38-
39- static_assert (detail::is_valid_queue<std::queue<int >>);
40- static_assert (detail::is_valid_queue<dp::thread_safe_queue<int >>);
4131 } // namespace detail
4232
43- template <template < class T > class Queue , typename FunctionType = std::function<void ()>>
33+ template <typename FunctionType = std::function<void ()>>
4434 requires std::invocable<FunctionType> &&
45- std::is_same_v<void , std::invoke_result_t <FunctionType>> &&
46- detail::is_valid_queue<Queue<FunctionType>>
47- class thread_pool_impl {
35+ std::is_same_v<void , std::invoke_result_t <FunctionType>>
36+ class thread_pool {
4837 public:
49- thread_pool_impl (
50- const unsigned int &number_of_threads = std::thread::hardware_concurrency() ) {
38+ thread_pool ( const unsigned int &number_of_threads = std::thread::hardware_concurrency())
39+ : queues_(number_of_threads ) {
5140 for (std::size_t i = 0 ; i < number_of_threads; ++i) {
52- queues_.push_back (std::make_unique<task_pair>());
5341 threads_.emplace_back ([&, id = i](std::stop_token stop_tok) {
5442 do {
5543 // check if we have task
56- if (queues_[id]-> tasks .empty ()) {
44+ if (queues_[id]. tasks .empty ()) {
5745 // no tasks, so we wait instead of spinning
58- queues_[id]-> semaphore .acquire ();
46+ queues_[id]. semaphore .acquire ();
5947 }
6048
6149 // ensure we have a task before getting task
6250 // since the dtor releases the semaphore as well
63- if (!queues_[id]-> tasks .empty ()) {
51+ if (!queues_[id]. tasks .empty ()) {
6452 // get the task
65- auto &task = queues_[id]-> tasks .front ();
53+ auto &task = queues_[id]. tasks .front ();
6654 // invoke the task
6755 std::invoke (std::move (task));
56+ // decrement in-flight counter
57+ --in_flight_;
6858 // remove task from the queue
69- queues_[id]-> tasks .pop ();
59+ queues_[id]. tasks .pop ();
7060 }
7161 } while (!stop_tok.stop_requested ());
7262 });
7363 }
7464 }
7565
76- ~thread_pool_impl () {
66+ ~thread_pool () {
67+ // wait for tasks to complete first
68+ do {
69+ std::this_thread::yield ();
70+ } while (in_flight_ > 0 );
71+
7772 // stop all threads
7873 for (std::size_t i = 0 ; i < threads_.size (); ++i) {
7974 threads_[i].request_stop ();
80- queues_[i]-> semaphore .release ();
75+ queues_[i]. semaphore .release ();
8176 threads_[i].join ();
8277 }
8378 }
8479
8580 // / thread pool is non-copyable
86- thread_pool_impl (const thread_pool_impl &) = delete ;
87- thread_pool_impl &operator =(const thread_pool_impl &) = delete ;
81+ thread_pool (const thread_pool &) = delete ;
82+ thread_pool &operator =(const thread_pool &) = delete ;
8883
8984 /* *
9085 * @brief Enqueue a task into the thread pool that returns a result.
@@ -98,11 +93,21 @@ namespace dp {
9893 template <typename Function, typename ... Args,
9994 typename ReturnType = std::invoke_result_t <Function &&, Args &&...>>
10095 requires std::invocable<Function, Args...>
101- [[nodiscard]] std::future<ReturnType> enqueue (Function &&f, Args &&...args) {
102- // use shared promise here so that we don't break the promise later
96+ [[nodiscard]] std::future<ReturnType> enqueue (Function f, Args... args) {
97+ /*
98+ * use shared promise here so that we don't break the promise later (until C++23)
99+ *
100+ * with C++23 we can do the following:
101+ *
102+ * std::promise<ReturnType> promise;
103+ * auto future = promise.get_future();
104+ * auto task = [func = std::move(f), ... largs = std::move(args),
105+ promise = std::move(promise)]() mutable {...};
106+ */
103107 auto shared_promise = std::make_shared<std::promise<ReturnType>>();
104108 auto task = [func = std::move (f), ... largs = std::move (args),
105109 promise = shared_promise]() { promise->set_value (func (largs...)); };
110+
106111 // get the future before enqueuing the task
107112 auto future = shared_promise->get_future ();
108113 // enqueue the task
@@ -125,33 +130,25 @@ namespace dp {
125130 }
126131
127132 private:
128- using semaphore_type = std::binary_semaphore;
129- using task_type = FunctionType;
130- struct task_pair {
131- semaphore_type semaphore{0 };
132- Queue<task_type> tasks{};
133+ struct task_queue {
134+ std::binary_semaphore semaphore{0 };
135+ dp::thread_safe_queue<FunctionType> tasks{};
133136 };
134137
135138 template <typename Function>
136139 void enqueue_task (Function &&f) {
137140 const std::size_t i = count_++ % queues_.size ();
138- queues_[i]->tasks .push (std::forward<Function>(f));
139- queues_[i]->semaphore .release ();
141+ ++in_flight_;
142+ queues_[i].tasks .push (std::forward<Function>(f));
143+ queues_[i].semaphore .release ();
140144 }
141145
142146 std::vector<std::jthread> threads_;
143- // have to use unique_ptr here because std::binary_semaphore is not move/copy
144- // assignable/constructible
145- std::vector<std::unique_ptr<task_pair>> queues_;
147+ std::deque<task_queue> queues_;
146148 std::size_t count_ = 0 ;
149+ std::atomic<int64_t > in_flight_{0 };
147150 };
148151
149- /* *
150- * @brief Thread pool class capable of queuing detached tasks and value returning tasks.
151- * @details This is a default alias for the dp::thread_pool_impl
152- */
153- using thread_pool = thread_pool_impl<dp::thread_safe_queue>;
154-
155152 /* *
156153 * @example mandelbrot/source/main.cpp
157154 * Example showing how to use thread pool with tasks that return a value. Outputs a PPM image of
0 commit comments