From e2d9d6f15b78cbb99fe09e0fe6877bf3ce77bd2d Mon Sep 17 00:00:00 2001 From: Simon Kallweit Date: Sun, 23 Feb 2025 22:05:58 +0100 Subject: [PATCH 1/6] remove nanothread --- .gitmodules | 3 --- CMakeLists.txt | 7 ------- README.md | 1 - external/nanothread | 1 - 4 files changed, 12 deletions(-) delete mode 160000 external/nanothread diff --git a/.gitmodules b/.gitmodules index f0b325ee..5cf827e9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,6 +4,3 @@ [submodule "external/metal-cpp"] path = external/metal-cpp url = https://github.com/bkaradzic/metal-cpp.git -[submodule "external/nanothread"] - path = external/nanothread - url = https://github.com/skallweitNV/nanothread.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 983dd4e3..6cdbb5e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -374,12 +374,6 @@ if(APPLE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-assume -Wno-switch") endif() -# nanothread -# TODO: disabled for now as it introduces a dependency to libatomic which leads to issues with cross compilation -# set(NANOTHREAD_STATIC ON) -# add_subdirectory(external/nanothread) -# set_target_properties(nanothread PROPERTIES POSITION_INDEPENDENT_CODE ON) - # Setup compiler warnings target_compile_options(slang-rhi PRIVATE $<$: @@ -611,7 +605,6 @@ endif() target_include_directories(slang-rhi PUBLIC include) target_include_directories(slang-rhi PRIVATE src) -# target_link_libraries(slang-rhi PRIVATE nanothread) target_compile_definitions(slang-rhi PRIVATE SLANG_RHI_ENABLE_CPU=$ diff --git a/README.md b/README.md index 3a18a26b..c43b7552 100644 --- a/README.md +++ b/README.md @@ -20,4 +20,3 @@ This library is under active refactoring and development, and is not yet ready f - [stb](https://github.com/nothings/stb) (Public Domain) - [Vulkan-Headers](https://github.com/KhronosGroup/Vulkan-Headers) (MIT) - [OffsetAllocator](https://github.com/sebbbi/OffsetAllocator) (MIT) -- [nanothread](https://github.com/mitsuba-renderer/nanothread) (BSD 3-Clause) diff --git a/external/nanothread b/external/nanothread deleted file mode 160000 index 94d38922..00000000 --- a/external/nanothread +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 94d389228f34db4953f53309df4b64331a8ed77e From 8c145a7c9e6cea9991082a0f8f26a96ead1539e5 Mon Sep 17 00:00:00 2001 From: Simon Kallweit Date: Wed, 26 Feb 2025 10:52:11 +0100 Subject: [PATCH 2/6] implement new task pool --- include/slang-rhi.h | 56 +++-- src/core/task-pool.cpp | 452 +++++++++++++++++++++++++++++---------- src/core/task-pool.h | 72 +++++-- src/rhi.cpp | 6 +- tests/test-task-pool.cpp | 395 ++++++++++++++++++++++------------ 5 files changed, 697 insertions(+), 284 deletions(-) diff --git a/include/slang-rhi.h b/include/slang-rhi.h index 390e1455..b6efee55 100644 --- a/include/slang-rhi.h +++ b/include/slang-rhi.h @@ -2552,26 +2552,50 @@ class IDevice : public ISlangUnknown convertCooperativeVectorMatrix(const ConvertCooperativeVectorMatrixDesc* descs, uint32_t descCount) = 0; }; -class ITaskScheduler : public ISlangUnknown +class ITaskPool : public ISlangUnknown { SLANG_COM_INTERFACE(0xab272cee, 0xa546, 0x4ae6, {0xbd, 0x0d, 0xcd, 0xab, 0xa9, 0x3f, 0x6d, 0xa6}); public: typedef void* TaskHandle; - /// Submit a task. - /// The scheduler needs to call the `run` function with the `payload` argument. - /// The `parentTasks` contains a list of tasks that need to be completed before the submitted task can run. - /// Every submitted task is released using `releaseTask` once the task handle is no longer used. - virtual SLANG_NO_THROW TaskHandle SLANG_MCALL - submitTask(TaskHandle* parentTasks, uint32_t parentTaskCount, void (*run)(void* /*payload*/), void* payload) = 0; + /// \brief Submit a new task. + /// The returned task must be released with `releaseTask()` when no longer needed + /// for specifying dependencies or issuing `waitTask()`. + /// \param func Function to execute. + /// \param payload Payload to pass to the function. + /// \param payloadDeleter Optional payload deleter (called when task is destroyed). + /// \param deps Parent tasks to wait for. + /// \param depsCount Number of parent tasks. + /// \return The new task. + virtual SLANG_NO_THROW TaskHandle SLANG_MCALL submitTask( + void (*func)(void*), + void* payload, + void (*payloadDeleter)(void*), + TaskHandle* deps, + size_t depsCount + ) = 0; + + /// \brief Get the task payload data. + /// \param task Task to get the payload for. + /// \return The payload. + virtual SLANG_NO_THROW void* SLANG_MCALL getTaskPayload(TaskHandle task) = 0; - /// Release a task. - /// This is called when the task handle is no longer used. + /// \brief Release a task. + /// \param task Task to release. virtual SLANG_NO_THROW void SLANG_MCALL releaseTask(TaskHandle task) = 0; - // Wait for a task to complete. - virtual SLANG_NO_THROW void SLANG_MCALL waitForCompletion(TaskHandle task) = 0; + /// \brief Wait for a task to finish. + /// \param task Task to wait for. + virtual SLANG_NO_THROW void SLANG_MCALL waitTask(TaskHandle task) = 0; + + /// \brief Check if a task is done. + /// \param task Task to check. + /// \return True if the task is done. + virtual SLANG_NO_THROW bool SLANG_MCALL isTaskDone(TaskHandle task) = 0; + + /// \brief Wait for all tasks in the pool to finish. + virtual SLANG_NO_THROW void SLANG_MCALL waitAll() = 0; }; class IPersistentShaderCache : public ISlangUnknown @@ -2651,12 +2675,16 @@ class IRHI /// Set the global task pool worker count. /// Must be called before any devices are created. - /// This is ignored if the task scheduler is set. + /// This only affects the default task pool implementation and has no effect + /// if a custom task pool is used with `setTaskPool`. + /// If count is set to 0, a blocking task pool implementation is used. + /// If count is -1, a threaded task pool is used with a worker count equal to the number of logical cores. + /// If count is 1 or larger, a threaded task pool is used with the specified worker count. virtual SLANG_NO_THROW Result SLANG_MCALL setTaskPoolWorkerCount(uint32_t count) = 0; - /// Set the global task scheduler for the RHI. + /// Set the global task pool for the RHI. /// Must be called before any devices are created. - virtual SLANG_NO_THROW Result SLANG_MCALL setTaskScheduler(ITaskScheduler* scheduler) = 0; + virtual SLANG_NO_THROW Result SLANG_MCALL setTaskPool(ITaskPool* taskPool) = 0; }; // Global public functions diff --git a/src/core/task-pool.cpp b/src/core/task-pool.cpp index d2a585c7..259070cc 100644 --- a/src/core/task-pool.cpp +++ b/src/core/task-pool.cpp @@ -1,202 +1,428 @@ #include "task-pool.h" -#if 0 -#include -#endif - #include +#include +#include +#include namespace rhi { -class BlockingTaskScheduler : public ITaskScheduler, public ComObject +// BlockingTaskPool + +struct BlockingTaskPool::Task { -public: - SLANG_COM_OBJECT_IUNKNOWN_ALL + void* payload; + void (*payloadDeleter)(void*); +}; - ITaskScheduler* getInterface(const Guid& guid) - { - if (guid == ISlangUnknown::getTypeGuid() || guid == ITaskScheduler::getTypeGuid()) - return static_cast(this); - return nullptr; - } +ITaskPool* BlockingTaskPool::getInterface(const Guid& guid) +{ + if (guid == ISlangUnknown::getTypeGuid() || guid == ITaskPool::getTypeGuid()) + return static_cast(this); + return nullptr; +} + +ITaskPool::TaskHandle BlockingTaskPool::submitTask( + void (*func)(void*), + void* payload, + void (*payloadDeleter)(void*), + TaskHandle* deps, + size_t depsCount +) +{ + SLANG_RHI_ASSERT(func); + SLANG_RHI_ASSERT(depsCount == 0 || deps); + + // Dependent tasks are guaranteed to be done. + SLANG_UNUSED(deps); + SLANG_UNUSED(depsCount); + + // Create task just to defer the payload deletion. + Task* task = new Task(); + task->payload = payload; + task->payloadDeleter = payloadDeleter; + + // Execute the task function. + func(payload); + + return task; +} - virtual SLANG_NO_THROW TaskHandle SLANG_MCALL - submitTask(TaskHandle* parentTasks, uint32_t parentTaskCount, void (*run)(void*), void* payload) override +void* BlockingTaskPool::getTaskPayload(TaskHandle task) +{ + SLANG_RHI_ASSERT(task); + + Task* taskImpl = checked_cast(task); + return taskImpl->payload; +} + +void BlockingTaskPool::releaseTask(TaskHandle task) +{ + SLANG_RHI_ASSERT(task); + + Task* taskImpl = checked_cast(task); + if (taskImpl->payloadDeleter) { - SLANG_UNUSED(parentTasks); - SLANG_UNUSED(parentTaskCount); - run(payload); - return payload; + taskImpl->payloadDeleter(taskImpl->payload); } +} + +void BlockingTaskPool::waitTask(TaskHandle task) +{ + SLANG_UNUSED(task); +} + +bool BlockingTaskPool::isTaskDone(TaskHandle task) +{ + return true; +} + +void BlockingTaskPool::waitAll() {} - virtual SLANG_NO_THROW void SLANG_MCALL releaseTask(TaskHandle task) override { SLANG_UNUSED(task); } +// ThreadedTaskPool - virtual SLANG_NO_THROW void SLANG_MCALL waitForCompletion(TaskHandle task) override { SLANG_UNUSED(task); } +struct ThreadedTaskPool::Task +{ + // Function to execute. + void (*func)(void*) = nullptr; + // Pointer to payload data. + void* payload = nullptr; + // Optional deleter for the payload. + void (*payloadDeleter)(void*) = nullptr; + + // Pool that owns the task. + Pool* pool = nullptr; + + // Reference counter. + std::atomic refCount{0}; + + // Number of dependencies that are not yet finished. + std::atomic depsRemaining{0}; + + // Flag indicating the task has finished. + std::atomic done{false}; + + // Mutex and condition variable for waitTask(). + std::mutex waitMutex; + std::condition_variable waitCV; + + // List of tasks that depend on this task. + std::vector children; + std::mutex childrenMutex; }; -#if 0 -class NanoThreadTaskScheduler : public ITaskScheduler, public ComObject +struct ThreadedTaskPool::Pool { -public: - SLANG_COM_OBJECT_IUNKNOWN_ALL + // Queue of tasks ready for execution. + std::queue m_queue; + std::mutex m_queueMutex; + std::condition_variable m_queueCV; - ITaskScheduler* getInterface(const Guid& guid) + // Flag to signal worker threads to stop. + std::atomic m_stop{false}; + + // Worker threads. + std::vector m_workerThreads; + + // Total number of tasks not yet completed. + std::atomic m_tasksRemaining{0}; + + // Mutex and condition variable for waitAll(). + std::mutex m_waitMutex; + std::condition_variable m_waitCV; + + void workerThread(); + + Pool(int workerCount) { - if (guid == ISlangUnknown::getTypeGuid() || guid == ITaskScheduler::getTypeGuid()) - return static_cast(this); - return nullptr; + if (workerCount <= 0) + { + workerCount = static_cast(std::thread::hardware_concurrency()); + if (workerCount <= 0) + workerCount = 1; + } + for (int i = 0; i < workerCount; i++) + { + m_workerThreads.emplace_back([this]() { workerThread(); }); + } } - NanoThreadTaskScheduler(uint32_t size) { m_pool = ::pool_create(size); } - ~NanoThreadTaskScheduler() { ::pool_destroy(m_pool); } + ~Pool() + { + { + std::lock_guard lock(m_queueMutex); + m_stop.store(true); + } + m_queueCV.notify_all(); + for (std::thread& worker : m_workerThreads) + { + if (worker.joinable()) + worker.join(); + } + while (!m_queue.empty()) + { + Task* task = m_queue.front(); + m_queue.pop(); + releaseTask(task); + } + } - virtual SLANG_NO_THROW TaskHandle SLANG_MCALL - submitTask(TaskHandle* parentTasks, uint32_t parentTaskCount, void (*run)(void*), void* payload) override + void retainTask(Task* task, size_t count = 1) { - TaskInfo taskInfo{run, payload}; - if (parentTasks && parentTaskCount > 0) + SLANG_RHI_ASSERT(task); + + task->refCount.fetch_add(count, std::memory_order_relaxed); + } + + void releaseTask(Task* task) + { + SLANG_RHI_ASSERT(task); + + if (task->refCount.fetch_sub(1, std::memory_order_acq_rel) == 1) { - return ::task_submit_dep( - m_pool, - (::Task**)parentTasks, - parentTaskCount, - 1, - runTask, - &taskInfo, - sizeof(TaskInfo), - nullptr, - 1 - ); + if (task->payloadDeleter) + { + task->payloadDeleter(task->payload); + } + delete task; } - else + } + + void enqueue(Task* task) + { + SLANG_RHI_ASSERT(task); + { - return ::task_submit(m_pool, 1, runTask, &taskInfo, sizeof(TaskInfo), nullptr, 1); + std::lock_guard lock(m_queueMutex); + m_queue.push(task); } + + m_queueCV.notify_one(); } - virtual SLANG_NO_THROW void SLANG_MCALL releaseTask(TaskHandle task) override { ::task_release((::Task*)task); } + Task* submitTask(void (*func)(void*), void* payload, void (*payloadDeleter)(void*), Task** deps, size_t depsCount) + { + SLANG_RHI_ASSERT(func); + SLANG_RHI_ASSERT(depsCount == 0 || deps); + + Task* task = new Task(); + + // Increment the reference count by 2. + // One reference is for the pool, the other is for the caller. + retainTask(task, 2); - virtual SLANG_NO_THROW void SLANG_MCALL waitForCompletion(TaskHandle task) override { ::task_wait((::Task*)task); } + task->func = func; + task->payload = payload; + task->payloadDeleter = payloadDeleter; + task->pool = this; + task->depsRemaining = depsCount; -private: - struct TaskInfo + m_tasksRemaining.fetch_add(1, std::memory_order_relaxed); + + if (depsCount == 0) + { + // If there are no dependencies, enqueue the task immediately. + enqueue(task); + } + else + { + // Process dependencies. + for (size_t i = 0; i < depsCount; i++) + { + Task* dep = deps[i]; + SLANG_RHI_ASSERT(dep); + SLANG_RHI_ASSERT(dep->refCount.load(std::memory_order_acquire) > 0); + { + std::lock_guard lock(dep->childrenMutex); + if (!dep->done.load(std::memory_order_acquire)) + { + // Add an extra reference that will be released when the dependency finishes. + retainTask(task); + dep->children.push_back(task); + } + else + { + // Dependency is already done, decrement the counter and enqueue if necessary. + if (task->depsRemaining.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + enqueue(task); + } + } + } + } + } + return task; + } + + bool isTaskDone(Task* task) { - void (*run)(void*); - void* payload; - }; + SLANG_RHI_ASSERT(task); + + return task->done.load(std::memory_order_acquire); + } - static void runTask(uint32_t index, void* payload) + void waitTask(Task* task) { - TaskInfo* taskInfo = (TaskInfo*)payload; - taskInfo->run(taskInfo->payload); + SLANG_RHI_ASSERT(task); + + std::unique_lock lock(task->waitMutex); + task->waitCV.wait(lock, [task] { return task->done.load(std::memory_order_acquire); }); } - ::Pool* m_pool; + void waitAll() + { + std::unique_lock lock(m_waitMutex); + m_waitCV.wait(lock, [this] { return m_tasksRemaining.load(std::memory_order_acquire) == 0; }); + } }; -#endif -class WaitTask : public Task +void ThreadedTaskPool::Pool::workerThread() { -public: - virtual void run() override {} -}; + while (true) + { + Task* task = nullptr; + // Fetch next task from queue. + { + std::unique_lock lock(m_queueMutex); + m_queueCV.wait(lock, [this] { return m_stop.load() || !m_queue.empty(); }); + if (m_stop.load() && m_queue.empty()) + return; + task = m_queue.front(); + m_queue.pop(); + } + // Execute the task function. + task->func(task->payload); + // Mark the task as done. + task->done.store(true, std::memory_order_release); + // Notify waiters. + task->waitCV.notify_all(); + // Notify child tasks waiting on this dependency. + { + std::lock_guard lock(task->childrenMutex); + for (Task* child : task->children) + { + // Decrement the child's dependency counter. + if (child->depsRemaining.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + // All dependencies satisfied; enqueue the child. + enqueue(child); + } + // Release the extra reference taken when adding as a dependency. + releaseTask(child); + } + task->children.clear(); + } + // Decrement the remaining task counter and notify waiters. + if (m_tasksRemaining.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + std::lock_guard lock(m_waitMutex); + m_waitCV.notify_all(); + } + // Release the pool's reference. + releaseTask(task); + } +} -static void runTask(void* task) +ITaskPool* ThreadedTaskPool::getInterface(const Guid& guid) { - ((Task*)task)->run(); - ((Task*)task)->releaseReference(); + if (guid == ISlangUnknown::getTypeGuid() || guid == ITaskPool::getTypeGuid()) + return static_cast(this); + return nullptr; } -TaskPool::TaskPool(uint32_t workerCount) +ThreadedTaskPool::ThreadedTaskPool(int workerCount) { - m_scheduler = new BlockingTaskScheduler(); -#if 0 - if (workerCount == 0) - { - m_scheduler = new BlockingTaskScheduler(); - } - else - { - m_scheduler = new NanoThreadTaskScheduler(workerCount == kAutoWorkerCount ? NANOTHREAD_AUTO : workerCount); - } -#endif - SLANG_RHI_ASSERT(m_scheduler); + m_pool = new Pool(workerCount); } -TaskPool::TaskPool(ITaskScheduler* scheduler) - : m_scheduler(scheduler) +ThreadedTaskPool::~ThreadedTaskPool() { - SLANG_RHI_ASSERT(m_scheduler); + delete m_pool; } -TaskPool::~TaskPool() {} +ITaskPool::TaskHandle ThreadedTaskPool::submitTask( + void (*func)(void*), + void* payload, + void (*payloadDeleter)(void*), + TaskHandle* deps, + size_t depsCount +) +{ + return m_pool->submitTask(func, payload, payloadDeleter, (Task**)deps, depsCount); +} -TaskHandle TaskPool::submitTask(Task* task, TaskHandle* parentTaskHandles, uint32_t parentTaskHandleCount) +void* ThreadedTaskPool::getTaskPayload(TaskHandle task) { - task->addReference(); - return m_scheduler->submitTask(parentTaskHandles, parentTaskHandleCount, runTask, task); + return checked_cast(task)->payload; } -void TaskPool::releaseTask(TaskHandle taskHandle) +void ThreadedTaskPool::releaseTask(TaskHandle task) { - m_scheduler->releaseTask(taskHandle); + m_pool->releaseTask(checked_cast(task)); } -void TaskPool::waitForCompletion(TaskHandle taskHandle) +void ThreadedTaskPool::waitTask(TaskHandle task) { - m_scheduler->waitForCompletion(taskHandle); + m_pool->waitTask(checked_cast(task)); } -void TaskPool::waitForCompletion(TaskHandle* taskHandles, uint32_t taskHandleCount) +bool ThreadedTaskPool::isTaskDone(TaskHandle task) { - if (taskHandles && taskHandleCount > 0) - { - RefPtr waitTask = new WaitTask(); - TaskHandle waitTaskHandle = submitTask(waitTask, taskHandles, taskHandleCount); - waitForCompletion(waitTaskHandle); - releaseTask(waitTaskHandle); - } + return m_pool->isTaskDone(checked_cast(task)); +} + +void ThreadedTaskPool::waitAll() +{ + m_pool->waitAll(); } static std::mutex s_globalTaskPoolMutex; -static std::unique_ptr s_globalTaskPool; -static uint32_t s_globalTaskPoolWorkerCount = TaskPool::kAutoWorkerCount; -static ComPtr s_globalTaskScheduler; +static ComPtr s_globalTaskPool; +static uint32_t s_globalTaskPoolWorkerCount = -1; Result setGlobalTaskPoolWorkerCount(uint32_t count) { std::lock_guard lock(s_globalTaskPoolMutex); if (s_globalTaskPool) + { return SLANG_FAIL; + } s_globalTaskPoolWorkerCount = count; return SLANG_OK; } -Result setGlobalTaskScheduler(ITaskScheduler* scheduler) +Result setGlobalTaskPool(ITaskPool* taskPool) { std::lock_guard lock(s_globalTaskPoolMutex); if (s_globalTaskPool) + { return SLANG_FAIL; - s_globalTaskScheduler = scheduler; + } + s_globalTaskPool = taskPool; return SLANG_OK; } -TaskPool& globalTaskPool() +ITaskPool* globalTaskPool() { - static std::atomic taskPoolPtr; - if (taskPoolPtr) + static std::atomic taskPool; + if (taskPool) { - return *taskPoolPtr; + return taskPool; } std::lock_guard lock(s_globalTaskPoolMutex); if (!s_globalTaskPool) { - s_globalTaskPool.reset( - s_globalTaskScheduler ? new TaskPool(s_globalTaskScheduler) : new TaskPool(s_globalTaskPoolWorkerCount) - ); - taskPoolPtr = s_globalTaskPool.get(); + if (s_globalTaskPoolWorkerCount == 0) + { + s_globalTaskPool = new BlockingTaskPool(); + } + else + { + s_globalTaskPool = new ThreadedTaskPool(s_globalTaskPoolWorkerCount); + } } - return *taskPoolPtr; + taskPool = s_globalTaskPool.get(); + return taskPool; } } // namespace rhi diff --git a/src/core/task-pool.h b/src/core/task-pool.h index ba343682..c2578c81 100644 --- a/src/core/task-pool.h +++ b/src/core/task-pool.h @@ -4,37 +4,73 @@ namespace rhi { -class Task : public RefObject +class BlockingTaskPool : public ITaskPool, public ComObject { public: - virtual ~Task() {} + SLANG_COM_OBJECT_IUNKNOWN_ALL - virtual void run() = 0; + ITaskPool* getInterface(const Guid& guid); + +public: + TaskHandle submitTask( + void (*func)(void*), + void* payload, + void (*payloadDeleter)(void*), + TaskHandle* deps, + size_t depsCount + ) override; + + void* getTaskPayload(TaskHandle task) override; + + void releaseTask(TaskHandle task) override; + + void waitTask(TaskHandle task) override; + + bool isTaskDone(TaskHandle task) override; + + void waitAll() override; private: - std::string m_name; + struct Task; }; -using TaskHandle = ITaskScheduler::TaskHandle; - -class TaskPool +class ThreadedTaskPool : public ITaskPool, public ComObject { public: - static constexpr uint32_t kAutoWorkerCount = uint32_t(-1); + SLANG_COM_OBJECT_IUNKNOWN_ALL - TaskPool(uint32_t workerCount = kAutoWorkerCount); - TaskPool(ITaskScheduler* scheduler); - ~TaskPool(); + ITaskPool* getInterface(const Guid& guid); - TaskHandle submitTask(Task* task, TaskHandle* parentTaskHandles = nullptr, uint32_t parentTaskHandleCount = 0); - void releaseTask(TaskHandle taskHandle); - void waitForCompletion(TaskHandle taskHandle); - void waitForCompletion(TaskHandle* taskHandles, uint32_t taskHandleCount); +public: + ThreadedTaskPool(int workerCount = -1); + ~ThreadedTaskPool(); + + TaskHandle submitTask( + void (*func)(void*), + void* payload, + void (*payloadDeleter)(void*), + TaskHandle* deps, + size_t depsCount + ) override; + + void* getTaskPayload(TaskHandle task) override; + + void releaseTask(TaskHandle task) override; + + void waitTask(TaskHandle task) override; + + bool isTaskDone(TaskHandle task) override; + + void waitAll() override; private: - ComPtr m_scheduler; + struct Task; + struct Pool; + + Pool* m_pool; }; + /// Set the global task pool worker count. /// Must be called before first accessing the global task pool. /// This is ignored if the task scheduler is set. @@ -42,9 +78,9 @@ Result setGlobalTaskPoolWorkerCount(uint32_t count); /// Set the global task scheduler. /// Must be called before first accessing the global task pool. -Result setGlobalTaskScheduler(ITaskScheduler* scheduler); +Result setGlobalTaskPool(ITaskPool* taskPool); /// Returns the global task pool. -TaskPool& globalTaskPool(); +ITaskPool* globalTaskPool(); } // namespace rhi diff --git a/src/rhi.cpp b/src/rhi.cpp index 0a96982d..70b90852 100644 --- a/src/rhi.cpp +++ b/src/rhi.cpp @@ -240,7 +240,7 @@ class RHI : public IRHI void enableDebugLayers() override; Result reportLiveObjects() override; Result setTaskPoolWorkerCount(uint32_t count) override; - Result setTaskScheduler(ITaskScheduler* scheduler) override; + Result setTaskPool(ITaskPool* scheduler) override; static RHI* getInstance() { @@ -454,9 +454,9 @@ Result RHI::setTaskPoolWorkerCount(uint32_t count) return setGlobalTaskPoolWorkerCount(count); } -Result RHI::setTaskScheduler(ITaskScheduler* scheduler) +Result RHI::setTaskPool(ITaskPool* taskPool) { - return setGlobalTaskScheduler(scheduler); + return setGlobalTaskPool(taskPool); } bool isDebugLayersEnabled() diff --git a/tests/test-task-pool.cpp b/tests/test-task-pool.cpp index 791a8d22..234f6433 100644 --- a/tests/test-task-pool.cpp +++ b/tests/test-task-pool.cpp @@ -8,164 +8,287 @@ using namespace rhi; -class SimpleTask : public Task +// Create a number of tasks and wait for each of them individually. +void testSimple(ITaskPool* pool) { -public: - SimpleTask( - std::function onCreate = nullptr, - std::function onDestroy = nullptr, - std::function onRun = nullptr - ) - : m_onCreate(onCreate) - , m_onDestroy(onDestroy) - , m_onRun(onRun) + REQUIRE(pool != nullptr); + + static constexpr size_t N = 1000; + static size_t result[N]; + static bool deleted[N]; + ITaskPool::TaskHandle tasks[N]; + + ::memset(result, 0, sizeof(result)); + ::memset(deleted, 0, sizeof(deleted)); + + for (size_t i = 0; i < N; ++i) + { + size_t* payload = new size_t{i}; + tasks[i] = pool->submitTask( + [](void* payload) + { + size_t i = *static_cast(payload); + result[i] = i; + }, + payload, + [](void* payload) + { + size_t i = *static_cast(payload); + deleted[i] = true; + delete static_cast(payload); + }, + nullptr, + 0 + ); + } + + for (size_t i = 0; i < N; ++i) { - if (onCreate) - onCreate(); + CHECK(!deleted[i]); + pool->waitTask(tasks[i]); + pool->releaseTask(tasks[i]); + CHECK(result[i] == (size_t)i); + CHECK(deleted[i]); } +} + +// Create a number of tasks and wait for all of them at once. +void testWaitAll(ITaskPool* pool) +{ + REQUIRE(pool != nullptr); - ~SimpleTask() + static constexpr size_t N = 1000; + static size_t result[N]; + static bool deleted[N]; + + ::memset(result, 0, sizeof(result)); + ::memset(deleted, 0, sizeof(deleted)); + + for (size_t i = 0; i < N; ++i) { - if (m_onDestroy) - m_onDestroy(); + size_t* payload = new size_t{i}; + ITaskPool::TaskHandle task = pool->submitTask( + [](void* payload) + { + size_t i = *static_cast(payload); + result[i] = i; + }, + payload, + [](void* payload) + { + size_t i = *static_cast(payload); + deleted[i] = true; + delete static_cast(payload); + }, + nullptr, + 0 + ); + CHECK(!deleted[i]); + pool->releaseTask(task); } - void run() override + pool->waitAll(); + + for (size_t i = 0; i < N; ++i) { - if (m_onRun) - m_onRun(); + CHECK(result[i] == (size_t)i); + CHECK(deleted[i]); } +} + +// Create a number of tasks and wait for all of them at once. +void testSimpleDependency(ITaskPool* pool) +{ + REQUIRE(pool != nullptr); + + static constexpr size_t N = 1000; + static size_t result[N]; + static ITaskPool::TaskHandle tasks[N]; + static std::atomic finished; + + finished = 0; + + for (size_t i = 0; i < N; ++i) + { + tasks[i] = pool->submitTask( + [](void* payload) + { + size_t i = (size_t)(uintptr_t)payload; + result[i] = i; + finished++; + }, + (void*)i, + nullptr, + nullptr, + 0 + ); + } + + ITaskPool::TaskHandle waitTask = pool->submitTask([](void*) { CHECK(finished == N); }, nullptr, nullptr, tasks, N); + + for (size_t i = 0; i < N; ++i) + { + pool->releaseTask(tasks[i]); + } + + pool->waitTask(waitTask); + pool->releaseTask(waitTask); + + for (size_t i = 0; i < N; ++i) + { + CHECK(result[i] == (size_t)i); + } +} + +inline ITaskPool::TaskHandle spawn(ITaskPool* pool, int depth) +{ + if (depth > 0) + { + ITaskPool::TaskHandle a = spawn(pool, depth - 1); + ITaskPool::TaskHandle b = spawn(pool, depth - 1); + ITaskPool::TaskHandle tasks[] = {a, b}; + ITaskPool::TaskHandle c = pool->submitTask([](void*) {}, nullptr, nullptr, tasks, 2); + pool->releaseTask(a); + pool->releaseTask(b); + return c; + } + else + { + return pool->submitTask([](void*) {}, nullptr, nullptr, nullptr, 0); + } +} + +void testRecursiveDependency(ITaskPool* pool) +{ + REQUIRE(pool != nullptr); + + ITaskPool::TaskHandle task = spawn(pool, 10); + pool->waitTask(task); + pool->releaseTask(task); +} -private: - std::function m_onCreate; - std::function m_onDestroy; - std::function m_onRun; +struct FibonacciPayload +{ + int result; + ITaskPool::TaskHandle a; + ITaskPool::TaskHandle b; }; -TEST_CASE("task-pool") +static ITaskPool* fibonacciPool; + +inline ITaskPool::TaskHandle fibonacciTask(int n) +{ + FibonacciPayload* payload = new FibonacciPayload{}; + + if (n <= 1) + { + payload->result = n; + payload->a = nullptr; + payload->b = nullptr; + return fibonacciPool->submitTask([](void* payload) {}, payload, ::free, nullptr, 0); + } + else + { + payload->a = fibonacciTask(n - 1); + payload->b = fibonacciTask(n - 2); + ITaskPool::TaskHandle tasks[] = {payload->a, payload->b}; + return fibonacciPool->submitTask( + [](void* payload) + { + FibonacciPayload* p = static_cast(payload); + FibonacciPayload* pa = static_cast(fibonacciPool->getTaskPayload(p->a)); + FibonacciPayload* pb = static_cast(fibonacciPool->getTaskPayload(p->b)); + p->result = pa->result + pb->result; + fibonacciPool->releaseTask(p->a); + fibonacciPool->releaseTask(p->b); + }, + payload, + ::free, + tasks, + 2 + ); + } +} + +inline int fibonacci(int n) +{ + return n <= 1 ? n : fibonacci(n - 1) + fibonacci(n - 2); +} + +void testFibonacci(ITaskPool* pool) +{ + REQUIRE(pool != nullptr); + + fibonacciPool = pool; + int N = 25; + ITaskPool::TaskHandle task = fibonacciTask(N); + int expected = fibonacci(N); + pool->waitTask(task); + int result = static_cast(pool->getTaskPayload(task))->result; + CHECK(result == expected); + pool->releaseTask(task); +} + +TEST_CASE("task-pool-blocking") { - TaskPool pool; - - SUBCASE("wait single") - { - std::atomic alive(false); - std::atomic done(false); - RefPtr task = - new SimpleTask([&]() { alive = true; }, [&]() { alive = false; }, [&]() { done = true; }); - REQUIRE(task); - REQUIRE(alive); - REQUIRE(!done); - TaskHandle taskHandle = pool.submitTask(task); - REQUIRE(taskHandle); - task.setNull(); - pool.waitForCompletion(taskHandle); - CHECK(done); - pool.releaseTask(taskHandle); - CHECK(!alive); - } - - SUBCASE("wait multiple") - { - static constexpr int N = 100; - std::atomic alive[N]; - std::atomic done[N]; - TaskHandle taskHandles[N]; - for (int i = 0; i < N; ++i) + ITaskPool* pool = new BlockingTaskPool(); + + SUBCASE("simple") + { + testSimple(pool); + } + SUBCASE("wait-all") + { + testWaitAll(pool); + } + SUBCASE("simple-dependency") + { + testSimpleDependency(pool); + } + SUBCASE("recursive-dependency") + { + testRecursiveDependency(pool); + } + SUBCASE("fibonacci") + { + testFibonacci(pool); + } +} + +TEST_CASE("task-pool-threaded") +{ + ITaskPool* pool = new ThreadedTaskPool(); + + SUBCASE("simple") + { + for (int i = 0; i < 100; ++i) { - alive[i] = false; - done[i] = false; - RefPtr task = new SimpleTask( - [&, i]() { alive[i] = true; }, - [&, i]() { alive[i] = false; }, - [&, i]() { done[i] = true; } - ); - taskHandles[i] = pool.submitTask(task); + testSimple(pool); } - pool.waitForCompletion(taskHandles, N); - for (int i = 0; i < N; ++i) + } + SUBCASE("wait-all") + { + for (int i = 0; i < 100; ++i) { - pool.releaseTask(taskHandles[i]); - CHECK(!alive[i]); - CHECK(done[i]); + testWaitAll(pool); } } - - SUBCASE("simple dependency") + SUBCASE("simple-dependency") { - std::atomic aliveA(false); - std::atomic doneA(false); - std::atomic aliveB(false); - std::atomic doneB(false); - RefPtr taskA = - new SimpleTask([&]() { aliveA = true; }, [&]() { aliveA = false; }, [&]() { doneA = true; }); - RefPtr taskB = new SimpleTask( - [&]() { aliveB = true; }, - [&]() { aliveB = false; }, - [&]() - { - CHECK(doneA); - doneB = true; - } - ); - TaskHandle taskHandleA = pool.submitTask(taskA); - taskA.setNull(); - TaskHandle taskHandleB = pool.submitTask(taskB, &taskHandleA, 1); - taskB.setNull(); - pool.releaseTask(taskHandleA); - pool.waitForCompletion(taskHandleB); - CHECK(doneB); - pool.releaseTask(taskHandleB); - CHECK(!aliveA); - CHECK(!aliveB); - } - - SUBCASE("complex dependency") - { - static constexpr int N = 100; - static constexpr int M = 10; - std::atomic aliveInner[N][M]; - std::atomic doneInner[N][M]; - std::atomic aliveOuter[N]; - std::atomic doneOuter[N]; - TaskHandle taskHandlesOuter[N]; - for (int i = 0; i < N; ++i) + for (int i = 0; i < 100; ++i) { - TaskHandle taskHandlesInner[M]; - for (int j = 0; j < M; j++) - { - aliveInner[i][j] = false; - doneInner[i][j] = false; - RefPtr taskInner = new SimpleTask( - [&, i, j]() { aliveInner[i][j] = true; }, - [&, i, j]() { aliveInner[i][j] = false; }, - [&, i, j]() { doneInner[i][j] = true; } - ); - taskHandlesInner[j] = pool.submitTask(taskInner); - } - aliveOuter[i] = false; - doneOuter[i] = false; - RefPtr taskOuter = new SimpleTask( - [&, i]() { aliveOuter[i] = true; }, - [&, i]() { aliveOuter[i] = false; }, - [&, i]() - { - for (int k = 0; k < M; ++k) - CHECK(doneInner[i][k]); - doneOuter[i] = true; - } - ); - taskHandlesOuter[i] = pool.submitTask(taskOuter, taskHandlesInner, M); - for (int j = 0; j < M; j++) - { - pool.releaseTask(taskHandlesInner[j]); - } + testSimpleDependency(pool); } - pool.waitForCompletion(taskHandlesOuter, N); - for (int i = 0; i < N; ++i) + } + SUBCASE("recursive-dependency") + { + for (int i = 0; i < 100; ++i) { - pool.releaseTask(taskHandlesOuter[i]); - CHECK(doneOuter[i]); - CHECK(!aliveOuter[i]); + testRecursiveDependency(pool); } } + SUBCASE("fibonacci") + { + testFibonacci(pool); + } } From c5b034be55057ba956da1843a8298de26b028c9b Mon Sep 17 00:00:00 2001 From: Simon Kallweit Date: Thu, 27 Feb 2025 21:21:01 +0100 Subject: [PATCH 3/6] add missing include --- src/core/task-pool.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/core/task-pool.cpp b/src/core/task-pool.cpp index 259070cc..bb588ec8 100644 --- a/src/core/task-pool.cpp +++ b/src/core/task-pool.cpp @@ -1,5 +1,6 @@ #include "task-pool.h" +#include #include #include #include From 56a4fa801d8183d60672aad8c004bbfc69a9c0da Mon Sep 17 00:00:00 2001 From: Simon Kallweit Date: Thu, 27 Feb 2025 21:23:41 +0100 Subject: [PATCH 4/6] use static_cast --- src/core/task-pool.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/core/task-pool.cpp b/src/core/task-pool.cpp index bb588ec8..e2d60129 100644 --- a/src/core/task-pool.cpp +++ b/src/core/task-pool.cpp @@ -53,7 +53,7 @@ void* BlockingTaskPool::getTaskPayload(TaskHandle task) { SLANG_RHI_ASSERT(task); - Task* taskImpl = checked_cast(task); + Task* taskImpl = static_cast(task); return taskImpl->payload; } @@ -61,7 +61,7 @@ void BlockingTaskPool::releaseTask(TaskHandle task) { SLANG_RHI_ASSERT(task); - Task* taskImpl = checked_cast(task); + Task* taskImpl = static_cast(task); if (taskImpl->payloadDeleter) { taskImpl->payloadDeleter(taskImpl->payload); @@ -354,22 +354,22 @@ ITaskPool::TaskHandle ThreadedTaskPool::submitTask( void* ThreadedTaskPool::getTaskPayload(TaskHandle task) { - return checked_cast(task)->payload; + return static_cast(task)->payload; } void ThreadedTaskPool::releaseTask(TaskHandle task) { - m_pool->releaseTask(checked_cast(task)); + m_pool->releaseTask(static_cast(task)); } void ThreadedTaskPool::waitTask(TaskHandle task) { - m_pool->waitTask(checked_cast(task)); + m_pool->waitTask(static_cast(task)); } bool ThreadedTaskPool::isTaskDone(TaskHandle task) { - return m_pool->isTaskDone(checked_cast(task)); + return m_pool->isTaskDone(static_cast(task)); } void ThreadedTaskPool::waitAll() From a659b4be84b813da090a4b35c39158f3d61f1c91 Mon Sep 17 00:00:00 2001 From: Simon Kallweit Date: Thu, 27 Feb 2025 22:57:39 +0100 Subject: [PATCH 5/6] test capture --- tests/test-task-pool.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test-task-pool.cpp b/tests/test-task-pool.cpp index 234f6433..2a9f4fef 100644 --- a/tests/test-task-pool.cpp +++ b/tests/test-task-pool.cpp @@ -44,6 +44,7 @@ void testSimple(ITaskPool* pool) for (size_t i = 0; i < N; ++i) { + CAPTURE(i); CHECK(!deleted[i]); pool->waitTask(tasks[i]); pool->releaseTask(tasks[i]); @@ -91,6 +92,7 @@ void testWaitAll(ITaskPool* pool) for (size_t i = 0; i < N; ++i) { + CAPTURE(i); CHECK(result[i] == (size_t)i); CHECK(deleted[i]); } @@ -136,6 +138,7 @@ void testSimpleDependency(ITaskPool* pool) for (size_t i = 0; i < N; ++i) { + CAPTURE(i); CHECK(result[i] == (size_t)i); } } From eb4fa7ff26117a7eccbde4e22af7c272e19d020a Mon Sep 17 00:00:00 2001 From: Simon Kallweit Date: Thu, 27 Feb 2025 23:11:36 +0100 Subject: [PATCH 6/6] fixes --- src/core/task-pool.cpp | 5 ++++- src/core/task-pool.h | 2 +- tests/test-task-pool.cpp | 13 ++++++++++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/core/task-pool.cpp b/src/core/task-pool.cpp index e2d60129..0cd7219c 100644 --- a/src/core/task-pool.cpp +++ b/src/core/task-pool.cpp @@ -296,7 +296,10 @@ void ThreadedTaskPool::Pool::workerThread() // Mark the task as done. task->done.store(true, std::memory_order_release); // Notify waiters. - task->waitCV.notify_all(); + { + std::lock_guard lock(task->waitMutex); + task->waitCV.notify_all(); + } // Notify child tasks waiting on this dependency. { std::lock_guard lock(task->childrenMutex); diff --git a/src/core/task-pool.h b/src/core/task-pool.h index c2578c81..874cf59b 100644 --- a/src/core/task-pool.h +++ b/src/core/task-pool.h @@ -43,7 +43,7 @@ class ThreadedTaskPool : public ITaskPool, public ComObject public: ThreadedTaskPool(int workerCount = -1); - ~ThreadedTaskPool(); + ~ThreadedTaskPool() override; TaskHandle submitTask( void (*func)(void*), diff --git a/tests/test-task-pool.cpp b/tests/test-task-pool.cpp index 2a9f4fef..851fcb8d 100644 --- a/tests/test-task-pool.cpp +++ b/tests/test-task-pool.cpp @@ -49,6 +49,13 @@ void testSimple(ITaskPool* pool) pool->waitTask(tasks[i]); pool->releaseTask(tasks[i]); CHECK(result[i] == (size_t)i); + } + + pool->waitAll(); + + for (size_t i = 0; i < N; ++i) + { + CAPTURE(i); CHECK(deleted[i]); } } @@ -224,8 +231,8 @@ void testFibonacci(ITaskPool* pool) fibonacciPool = pool; int N = 25; - ITaskPool::TaskHandle task = fibonacciTask(N); int expected = fibonacci(N); + ITaskPool::TaskHandle task = fibonacciTask(N); pool->waitTask(task); int result = static_cast(pool->getTaskPayload(task))->result; CHECK(result == expected); @@ -234,7 +241,7 @@ void testFibonacci(ITaskPool* pool) TEST_CASE("task-pool-blocking") { - ITaskPool* pool = new BlockingTaskPool(); + ComPtr pool(new BlockingTaskPool()); SUBCASE("simple") { @@ -260,7 +267,7 @@ TEST_CASE("task-pool-blocking") TEST_CASE("task-pool-threaded") { - ITaskPool* pool = new ThreadedTaskPool(); + ComPtr pool(new ThreadedTaskPool()); SUBCASE("simple") {