Skip to content

Commit c1b91fd

Browse files
author
Paul T
authored
Merge pull request #3 from DeveloperPaul123/feature/minor-improvements
2 parents 6d01b06 + b7490c3 commit c1b91fd

3 files changed

Lines changed: 71 additions & 53 deletions

File tree

include/thread_pool/thread_pool.h

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace dp {
1919
return std::forward<T>(v);
2020
}
2121

22-
// Bind F and args... into a nullary one-shot lambda. Lambda captures by value.
22+
// bind F and parameter pack into a nullary one shot. Lambda captures by value.
2323
template <typename... Args, typename F>
2424
auto bind(F &&f, Args &&...args) {
2525
return [f = decay_copy(std::forward<F>(f)),
@@ -28,63 +28,58 @@ namespace dp {
2828
};
2929
}
3030

31-
template <class Queue, class U = typename Queue::value_type>
32-
concept is_valid_queue = requires(Queue q) {
33-
{ q.empty() } -> std::convertible_to<bool>;
34-
{ q.front() } -> std::convertible_to<U &>;
35-
{ q.back() } -> std::convertible_to<U &>;
36-
q.pop();
37-
};
38-
39-
static_assert(detail::is_valid_queue<std::queue<int>>);
40-
static_assert(detail::is_valid_queue<dp::thread_safe_queue<int>>);
4131
} // namespace detail
4232

43-
template <template <class T> class Queue, typename FunctionType = std::function<void()>>
33+
template <typename FunctionType = std::function<void()>>
4434
requires std::invocable<FunctionType> &&
45-
std::is_same_v<void, std::invoke_result_t<FunctionType>> &&
46-
detail::is_valid_queue<Queue<FunctionType>>
47-
class thread_pool_impl {
35+
std::is_same_v<void, std::invoke_result_t<FunctionType>>
36+
class thread_pool {
4837
public:
49-
thread_pool_impl(
50-
const unsigned int &number_of_threads = std::thread::hardware_concurrency()) {
38+
thread_pool(const unsigned int &number_of_threads = std::thread::hardware_concurrency())
39+
: queues_(number_of_threads) {
5140
for (std::size_t i = 0; i < number_of_threads; ++i) {
52-
queues_.push_back(std::make_unique<task_pair>());
5341
threads_.emplace_back([&, id = i](std::stop_token stop_tok) {
5442
do {
5543
// check if we have task
56-
if (queues_[id]->tasks.empty()) {
44+
if (queues_[id].tasks.empty()) {
5745
// no tasks, so we wait instead of spinning
58-
queues_[id]->semaphore.acquire();
46+
queues_[id].semaphore.acquire();
5947
}
6048

6149
// ensure we have a task before getting task
6250
// since the dtor releases the semaphore as well
63-
if (!queues_[id]->tasks.empty()) {
51+
if (!queues_[id].tasks.empty()) {
6452
// get the task
65-
auto &task = queues_[id]->tasks.front();
53+
auto &task = queues_[id].tasks.front();
6654
// invoke the task
6755
std::invoke(std::move(task));
56+
// decrement in-flight counter
57+
--in_flight_;
6858
// remove task from the queue
69-
queues_[id]->tasks.pop();
59+
queues_[id].tasks.pop();
7060
}
7161
} while (!stop_tok.stop_requested());
7262
});
7363
}
7464
}
7565

76-
~thread_pool_impl() {
66+
~thread_pool() {
67+
// wait for tasks to complete first
68+
do {
69+
std::this_thread::yield();
70+
} while (in_flight_ > 0);
71+
7772
// stop all threads
7873
for (std::size_t i = 0; i < threads_.size(); ++i) {
7974
threads_[i].request_stop();
80-
queues_[i]->semaphore.release();
75+
queues_[i].semaphore.release();
8176
threads_[i].join();
8277
}
8378
}
8479

8580
/// thread pool is non-copyable
86-
thread_pool_impl(const thread_pool_impl &) = delete;
87-
thread_pool_impl &operator=(const thread_pool_impl &) = delete;
81+
thread_pool(const thread_pool &) = delete;
82+
thread_pool &operator=(const thread_pool &) = delete;
8883

8984
/**
9085
* @brief Enqueue a task into the thread pool that returns a result.
@@ -98,11 +93,21 @@ namespace dp {
9893
template <typename Function, typename... Args,
9994
typename ReturnType = std::invoke_result_t<Function &&, Args &&...>>
10095
requires std::invocable<Function, Args...>
101-
[[nodiscard]] std::future<ReturnType> enqueue(Function &&f, Args &&...args) {
102-
// use shared promise here so that we don't break the promise later
96+
[[nodiscard]] std::future<ReturnType> enqueue(Function f, Args... args) {
97+
/*
98+
* use shared promise here so that we don't break the promise later (until C++23)
99+
*
100+
* with C++23 we can do the following:
101+
*
102+
* std::promise<ReturnType> promise;
103+
* auto future = promise.get_future();
104+
* auto task = [func = std::move(f), ... largs = std::move(args),
105+
promise = std::move(promise)]() mutable {...};
106+
*/
103107
auto shared_promise = std::make_shared<std::promise<ReturnType>>();
104108
auto task = [func = std::move(f), ... largs = std::move(args),
105109
promise = shared_promise]() { promise->set_value(func(largs...)); };
110+
106111
// get the future before enqueuing the task
107112
auto future = shared_promise->get_future();
108113
// enqueue the task
@@ -125,33 +130,25 @@ namespace dp {
125130
}
126131

127132
private:
128-
using semaphore_type = std::binary_semaphore;
129-
using task_type = FunctionType;
130-
struct task_pair {
131-
semaphore_type semaphore{0};
132-
Queue<task_type> tasks{};
133+
struct task_queue {
134+
std::binary_semaphore semaphore{0};
135+
dp::thread_safe_queue<FunctionType> tasks{};
133136
};
134137

135138
template <typename Function>
136139
void enqueue_task(Function &&f) {
137140
const std::size_t i = count_++ % queues_.size();
138-
queues_[i]->tasks.push(std::forward<Function>(f));
139-
queues_[i]->semaphore.release();
141+
++in_flight_;
142+
queues_[i].tasks.push(std::forward<Function>(f));
143+
queues_[i].semaphore.release();
140144
}
141145

142146
std::vector<std::jthread> threads_;
143-
// have to use unique_ptr here because std::binary_semaphore is not move/copy
144-
// assignable/constructible
145-
std::vector<std::unique_ptr<task_pair>> queues_;
147+
std::deque<task_queue> queues_;
146148
std::size_t count_ = 0;
149+
std::atomic<int64_t> in_flight_{0};
147150
};
148151

149-
/**
150-
* @brief Thread pool class capable of queuing detached tasks and value returning tasks.
151-
* @details This is a default alias for the dp::thread_pool_impl
152-
*/
153-
using thread_pool = thread_pool_impl<dp::thread_safe_queue>;
154-
155152
/**
156153
* @example mandelbrot/source/main.cpp
157154
* Example showing how to use thread pool with tasks that return a value. Outputs a PPM image of

include/thread_pool/thread_safe_queue.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
#pragma once
22

33
#include <condition_variable>
4+
#include <deque>
45
#include <mutex>
5-
#include <queue>
66

77
namespace dp {
88
template <typename T>
99
class thread_safe_queue {
1010
public:
1111
using value_type = T;
12-
using size_type = typename std::queue<T>::size_type;
12+
using size_type = typename std::deque<T>::size_type;
1313

1414
thread_safe_queue() = default;
15+
1516
void push(T&& value) {
16-
std::lock_guard lock(mutex_);
17-
data_.push(std::forward<T>(value));
17+
{
18+
std::lock_guard lock(mutex_);
19+
data_.push_back(std::forward<T>(value));
20+
}
1821
condition_variable_.notify_all();
1922
}
23+
2024
bool empty() {
2125
std::lock_guard lock(mutex_);
2226
return data_.empty();
@@ -42,12 +46,12 @@ namespace dp {
4246
void pop() {
4347
std::unique_lock lock(mutex_);
4448
condition_variable_.wait(lock, [this] { return !data_.empty(); });
45-
data_.pop();
49+
data_.pop_front();
4650
}
4751

4852
private:
4953
using mutex_type = std::mutex;
50-
std::queue<T> data_;
54+
std::deque<T> data_;
5155
mutable mutex_type mutex_{};
5256
std::condition_variable condition_variable_{};
5357
};

test/source/thread_pool.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
#include <string>
55

6-
TEST_CASE("Basic Return Types") {
6+
TEST_CASE("Basic task return types") {
77
dp::thread_pool pool(2);
8-
// TODO
98
auto future_value = pool.enqueue([](const int& value) { return value; }, 30);
109
auto future_negative = pool.enqueue([](int x) -> int { return x - 20; }, 3);
1110

@@ -30,3 +29,21 @@ TEST_CASE("Ensure input params are properly passed") {
3029
CHECK(j == futures[j].get());
3130
}
3231
}
32+
33+
TEST_CASE("Ensure work completes upon destruction") {
34+
std::atomic<int> counter;
35+
std::vector<std::future<int>> futures;
36+
const auto total_tasks = 20;
37+
{
38+
dp::thread_pool pool(4);
39+
for (auto i = 0; i < total_tasks; i++) {
40+
auto task = [index = i, &counter]() {
41+
counter++;
42+
return index;
43+
};
44+
futures.push_back(pool.enqueue(task));
45+
}
46+
}
47+
48+
CHECK_EQ(counter.load(), total_tasks);
49+
}

0 commit comments

Comments
 (0)