Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7b7577a
Initial work.
mzient Sep 26, 2022
a8c913c
NewThreadPool - working
mzient Sep 26, 2022
2886c97
Add init/exit callbacks.
mzient Sep 28, 2022
2b527fd
Fix exception class.
mzient Sep 29, 2022
ac4beb5
[WIP]
mzient Oct 4, 2022
845b145
Add multi-error as a seperate header.
mzient Oct 4, 2022
4f5f80d
[WIP]
mzient Dec 1, 2022
3b8bf46
[WIP]
mzient Dec 12, 2022
e0d46f1
Experiment.
mzient Feb 2, 2023
6085a3d
Remove the cc file, currently not used.
mzient Feb 3, 2023
92e9f60
Moving files.
mzient Feb 3, 2023
def2aee
Move thread_pool_base to core.
mzient Feb 27, 2023
7dd0557
Add ThreadedExecutionEngine that combines a thread pool reference and…
mzient Feb 27, 2023
89ba1a6
Fix after rebase.
mzient Dec 16, 2025
be74a1a
[WIP]
mzient Dec 19, 2025
5c6c5aa
Add incremental job. Validate in nvimgcodec.
mzient Jan 5, 2026
1147205
Make destructor virtual.
mzient Jan 5, 2026
164aca1
[WIP]
mzient Jan 8, 2026
bec7e50
[WIP]
mzient Jan 8, 2026
3e8dfe1
Use semaphore. Numberous fixes.
mzient Jan 12, 2026
04aa3f2
Fix: move notification and waiting to one translation unit. Refactoring.
mzient Jan 14, 2026
66577da
Revert nvimgcodec.
mzient Jan 14, 2026
364727d
Fix thread pool name handling.
mzient Jan 14, 2026
944f6d6
Tidy up includes.
mzient Jan 14, 2026
d5f32ed
Remove NewThreadPool as untested.
mzient Jan 14, 2026
5ba7aa2
Fix. Add more tests. Refactor tests.
mzient Jan 14, 2026
6420878
Use new thread pool in NvImgCodec
mzient Jan 14, 2026
ffd0c35
[WIP]
mzient Jan 14, 2026
aa23c3e
Bugfix.
mzient Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions dali/core/exec/thread_pool_base.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
// Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dali/core/exec/thread_pool_base.h"
#include <stdexcept>
#include <thread>

namespace dali {

JobBase::~JobBase() noexcept(false) {
if (total_tasks_ > 0 && !waited_for_) {
throw std::logic_error("The job is not empty, but hasn't been abandoned or waited for.");
}
while (running_)
std::this_thread::yield();
}

void JobBase::DoWait() {
if (executor_ == nullptr)
throw std::logic_error("This job hasn't been run - cannot wait for it.");

if (waited_for_)
throw std::logic_error("This job has already been waited for.");

auto ready = [&]() { return num_pending_tasks_ == 0; };
if (ThreadPoolBase::this_thread_pool() != nullptr) {
bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(cv_, ready);
waited_for_ = true;
if (!result)
throw std::logic_error("The thread pool was stopped");
} else {
int old = num_pending_tasks_.load();
while (old != 0) {
num_pending_tasks_.wait(old);
old = num_pending_tasks_.load();
assert(old >= 0);
}
waited_for_ = true;
}
}

void JobBase::DoNotify() {
num_pending_tasks_.notify_all();
(void)std::lock_guard(mtx_);
cv_.notify_all();
// We need this second flag to avoid a race condition where the
// desctructor is called between decrementing num_pending_tasks_ and notification_
// without excessive use of mutexes. This must be the very last operation in the task
// function that touches `this`.
running_ = false;
}

///////////////////////////////////////////////////////////////////////////

void Job::Run(ThreadPoolBase &tp, bool wait) {
if (executor_ != nullptr)
throw std::logic_error("This job has already been started.");
executor_ = &tp;
running_ = !tasks_.empty();
{
auto batch = tp.BeginBulkAdd();
for (auto &x : tasks_) {
batch.Add(std::move(x.second.func));
}
int added = batch.Size();
if (added) {
num_pending_tasks_ += added;
running_ = true;
}
batch.Submit();
}
if (wait && !tasks_.empty())
Wait();
}

void Job::Wait() {
DoWait();

// note - this vector is not allocated unless there were exceptions thrown
std::vector<std::exception_ptr> errors;
for (auto &x : tasks_) {
if (x.second.error)
errors.push_back(std::move(x.second.error));
}
if (errors.size() == 1)
std::rethrow_exception(errors[0]);
else if (errors.size() > 1)
throw MultipleErrors(std::move(errors));
}

void Job::Abandon() {
if (executor_ != nullptr)
throw std::logic_error("Cannot abandon a job that has already been started");
tasks_.clear();
total_tasks_ = 0;
}

///////////////////////////////////////////////////////////////////////////

void IncrementalJob::Run(ThreadPoolBase &tp, bool wait) {
if (executor_ && executor_ != &tp)
throw std::logic_error("This job is already running in a different executor.");
executor_ = &tp;
{
auto it = last_task_run_.has_value() ? std::next(*last_task_run_) : tasks_.begin();
auto batch = tp.BeginBulkAdd();
for (; it != tasks_.end(); ++it) {
batch.Add(std::move(it->func));
last_task_run_ = it;
}
int added = batch.Size();
if (added) {
num_pending_tasks_ += added;
running_ = true;
}
batch.Submit();
}
if (wait && !tasks_.empty())
Wait();
}

void IncrementalJob::Abandon() {
if (executor_)
throw std::logic_error("Cannot abandon a job that has already been started");
tasks_.clear();
total_tasks_ = 0;
}

void IncrementalJob::Wait() {
DoWait();
// note - this vector is not allocated unless there were exceptions thrown
std::vector<std::exception_ptr> errors;
for (auto &x : tasks_) {
if (x.error)
errors.push_back(std::move(x.error));
}
if (errors.size() == 1)
std::rethrow_exception(errors[0]);
else if (errors.size() > 1)
throw MultipleErrors(std::move(errors));
}

///////////////////////////////////////////////////////////////////////////

thread_local ThreadPoolBase *ThreadPoolBase::this_thread_pool_ = nullptr;
thread_local int ThreadPoolBase::this_thread_idx_ = -1;

void ThreadPoolBase::Init(int num_threads, const std::function<OnThreadStartFn> &on_thread_start) {
if (shutdown_pending_)
throw std::logic_error("The thread pool is being shut down.");
std::lock_guard<std::mutex> g(mtx_);
if (!threads_.empty())
throw std::logic_error("The thread pool is already started!");
threads_.reserve(num_threads);
for (int i = 0; i < num_threads; i++)
threads_.push_back(std::thread(&ThreadPoolBase::Run, this, i, on_thread_start));
}

void ThreadPoolBase::Shutdown(bool join) {
if ((shutdown_pending_ && !join) || threads_.empty())
return;
{
std::lock_guard<std::mutex> g(mtx_);
if (shutdown_pending_ && !join)
return;
shutdown_pending_ = true;
sem_.release(threads_.size());
}

for (auto &t : threads_)
t.join();
threads_.clear();
}

void ThreadPoolBase::AddTaskNoLock(TaskFunc &&f) {
if (shutdown_pending_)
throw std::logic_error("The thread pool is stopped and no longer accepts new tasks.");
tasks_.push(std::move(f));
}

void ThreadPoolBase::AddTask(TaskFunc &&f) {
{
std::lock_guard<std::mutex> g(mtx_);
AddTaskNoLock(std::move(f));
}
sem_.release(1);
}

void ThreadPoolBase::Run(
int index,
const std::function<OnThreadStartFn> &on_thread_start) noexcept {
this_thread_pool_ = this;
this_thread_idx_ = index;
std::any scope;
if (on_thread_start)
scope = on_thread_start(index);
while (!shutdown_pending_ || !tasks_.empty()) {
sem_.acquire();
std::unique_lock lock(mtx_);
if (shutdown_pending_)
break;
assert(!tasks_.empty() && "Semaphore acquired but no tasks present.");
PopAndRunTask(lock);
}
}

void ThreadPoolBase::PopAndRunTask(std::unique_lock<std::mutex> &lock) {
TaskFunc t = std::move(tasks_.front());
tasks_.pop();
lock.unlock();
t();
lock.lock();
}

template <typename Condition>
bool ThreadPoolBase::WaitOrRunTasks(std::condition_variable &cv, Condition &&condition) {
assert(this_thread_pool() == this);
std::unique_lock lock(mtx_);
while (!shutdown_pending_ || !tasks_.empty()) {
bool ret;
while (!(ret = condition()) && tasks_.empty())
cv.wait_for(lock, std::chrono::microseconds(100));

if (ret || condition()) // re-evaluate the condition, just in case
return true;
if (shutdown_pending_)
return condition();
if (!sem_.try_acquire())
continue;

assert(!tasks_.empty() && "Semaphore acquired but no tasks present.");
PopAndRunTask(lock);
}
return condition();
}

} // namespace dali
Loading