Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions benchmark/source/matrix_multiplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,12 @@ TEST_CASE("matrix_multiplication") {
}
#endif
{
dp::thread_pool<fu2::unique_function<void()>> pool{};
dp::thread_pool<fu2::unique_function<void() &&>> pool{};
run_benchmark<int>(&bench, array_size, iterations,
"dp::thread_pool - fu2::unique_function",
[&](const std::vector<int>& a, const std::vector<int>& b) -> void {
pool.enqueue_detach(thread_task, a, b);
[&pool, task = thread_task](const std::vector<int>& a,
const std::vector<int>& b) -> void {
pool.enqueue_detach(std::move(task), a, b);
});
}

Expand Down
2 changes: 1 addition & 1 deletion cmake/CPM.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set(CPM_DOWNLOAD_VERSION 0.38.1)
set(CPM_DOWNLOAD_VERSION 0.42.0)

if(CPM_SOURCE_CACHE)
# Expand relative path. This is important if the provided path contains a tilde (~)
Expand Down
2 changes: 1 addition & 1 deletion examples/mandelbrot/source/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void mandelbrot_threadpool(int image_width, int image_height, int max_iterations

std::cout << "calculating mandelbrot" << std::endl;

dp::thread_pool pool;
dp::thread_pool pool{};
std::vector<std::future<std::vector<rgb>>> futures;
futures.reserve(source.height());
const auto start = std::chrono::steady_clock::now();
Expand Down
90 changes: 74 additions & 16 deletions include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@
#include <semaphore>
#include <thread>
#include <type_traits>
#include <unordered_map>
#ifdef __has_include
# if __has_include(<version>)
# include <version>
# endif
#endif

#include "thread_pool/thread_safe_queue.h"
#include "thread_pool/work_stealing_deque.h"

namespace dp {
namespace details {

#ifdef __cpp_lib_move_only_function
#if __cpp_lib_move_only_function
using default_function_type = std::move_only_function<void()>;
#else
using default_function_type = std::function<void()>;
Expand All @@ -40,12 +41,16 @@ namespace dp {
const unsigned int &number_of_threads = std::thread::hardware_concurrency(),
InitializationFunction init = [](std::size_t) {})
: tasks_(number_of_threads) {
producer_id_ = std::this_thread::get_id();
std::size_t current_id = 0;
for (std::size_t i = 0; i < number_of_threads; ++i) {
priority_queue_.push_back(size_t(current_id));
try {
threads_.emplace_back([&, id = current_id,
init](const std::stop_token &stop_tok) {
tasks_[id].thread_id = std::this_thread::get_id();
add_thread_id_to_map(tasks_[id].thread_id, id);

// invoke the init function on the thread
try {
std::invoke(init, id);
Expand All @@ -58,8 +63,17 @@ namespace dp {
tasks_[id].signal.acquire();

do {
// invoke the task
while (auto task = tasks_[id].tasks.pop_front()) {
// execute work from the global queue
// all threads can pull from the top, but the producer thread owns
// the bottom
while (auto task = global_tasks_.pop_top()) {
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
std::invoke(std::move(task.value()));
in_flight_tasks_.fetch_sub(1, std::memory_order_release);
}

// invoke any tasks from the queue that this thread owns
while (auto task = tasks_[id].tasks.pop_top()) {
// decrement the unassigned tasks as the task is now going
// to be executed
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
Expand All @@ -71,10 +85,10 @@ namespace dp {
in_flight_tasks_.fetch_sub(1, std::memory_order_release);
}

// try to steal a task
// try to steal a task from other threads
for (std::size_t j = 1; j < tasks_.size(); ++j) {
const std::size_t index = (id + j) % tasks_.size();
if (auto task = tasks_[index].tasks.steal()) {
if (auto task = tasks_[index].tasks.pop_top()) {
// steal a task
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
std::invoke(std::move(task.value()));
Expand All @@ -87,6 +101,8 @@ namespace dp {
// front and waiting for more work
} while (unassigned_tasks_.load(std::memory_order_acquire) > 0);

// the thread finished all its work, so we "notify" by putting this
// thread in front in the priority queue
priority_queue_.rotate_to_front(id);
// check if all tasks are completed and release the barrier (binary
// semaphore)
Expand Down Expand Up @@ -141,7 +157,7 @@ namespace dp {
typename ReturnType = std::invoke_result_t<Function &&, Args &&...>>
requires std::invocable<Function, Args...>
[[nodiscard]] std::future<ReturnType> enqueue(Function f, Args... args) {
#ifdef __cpp_lib_move_only_function
#if __cpp_lib_move_only_function
// we can do this in C++23 because we now have support for move only functions
std::promise<ReturnType> promise;
auto future = promise.get_future();
Expand Down Expand Up @@ -244,13 +260,45 @@ namespace dp {
private:
template <typename Function>
void enqueue_task(Function &&f) {
auto i_opt = priority_queue_.copy_front_and_rotate_to_back();
if (!i_opt.has_value()) {
// would only be a problem if there are zero threads
return;
// are we enquing from the producer thread? Or is a worker thread
// enquing to the pool?
auto current_id = std::this_thread::get_id();
auto is_producer = current_id == producer_id_;
// assign the work
if (is_producer) {
// we push to the global task queue
global_tasks_.emplace(std::forward<Function>(f));
} else {
// This is a violation of the pre-condition.
// We cannot accept work from an arbitrary thread that is not the root producer or a
// worker in the pool
assert(thread_id_to_index_.contains(current_id));
// assign the task
tasks_[thread_id_to_index_.at(current_id)].tasks.emplace(
std::forward<Function>(f));
}
// get the index
auto i = *(i_opt);

/**
* Now we need to wake up the correct thread. If the thread that is enqueuing the task
* is a worker from the pool, then that thread needs to execute the work. Otherwise we
* need to use the priority queue to use the next available thread.
*/

// immediately invoked lambda
auto thread_wakeup_index = [&]() -> std::size_t {
if (is_producer) {
auto i_opt = priority_queue_.copy_front_and_rotate_to_back();
if (!i_opt.has_value()) {
// would only be a problem if there are zero threads
return std::size_t{0};
}
// get the index
return *(i_opt);
} else {
// get the worker thread id index
return thread_id_to_index_.at(current_id);
}
}();

// increment the unassigned tasks and in flight tasks
unassigned_tasks_.fetch_add(1, std::memory_order_release);
Expand All @@ -262,13 +310,18 @@ namespace dp {
}

// assign work
tasks_[i].tasks.push_back(std::forward<Function>(f));
tasks_[i].signal.release();
tasks_[thread_wakeup_index].signal.release();
}

void add_thread_id_to_map(std::thread::id thread_id, std::size_t index) {
std::lock_guard lock(thread_id_map_mutex_);
thread_id_to_index_.insert_or_assign(thread_id, index);
}

struct task_item {
dp::thread_safe_queue<FunctionType> tasks{};
dp::work_stealing_deque<FunctionType> tasks{};
std::binary_semaphore signal{0};
std::thread::id thread_id;
};

std::vector<ThreadType> threads_;
Expand All @@ -277,6 +330,11 @@ namespace dp {
// guarantee these get zero-initialized
std::atomic_int_fast64_t unassigned_tasks_{0}, in_flight_tasks_{0};
std::atomic_bool threads_complete_signal_{false};

std::thread::id producer_id_;
dp::work_stealing_deque<FunctionType> global_tasks_{};
std::mutex thread_id_map_mutex_{};
std::unordered_map<std::thread::id, std::size_t> thread_id_to_index_{};
};

/**
Expand Down
Loading