diff --git a/.gitmodules b/.gitmodules index f0b325ee..5cf827e9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,6 +4,3 @@ [submodule "external/metal-cpp"] path = external/metal-cpp url = https://github.com/bkaradzic/metal-cpp.git -[submodule "external/nanothread"] - path = external/nanothread - url = https://github.com/skallweitNV/nanothread.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 983dd4e3..6cdbb5e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -374,12 +374,6 @@ if(APPLE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-assume -Wno-switch") endif() -# nanothread -# TODO: disabled for now as it introduces a dependency to libatomic which leads to issues with cross compilation -# set(NANOTHREAD_STATIC ON) -# add_subdirectory(external/nanothread) -# set_target_properties(nanothread PROPERTIES POSITION_INDEPENDENT_CODE ON) - # Setup compiler warnings target_compile_options(slang-rhi PRIVATE $<$: @@ -611,7 +605,6 @@ endif() target_include_directories(slang-rhi PUBLIC include) target_include_directories(slang-rhi PRIVATE src) -# target_link_libraries(slang-rhi PRIVATE nanothread) target_compile_definitions(slang-rhi PRIVATE SLANG_RHI_ENABLE_CPU=$ diff --git a/README.md b/README.md index 3a18a26b..c43b7552 100644 --- a/README.md +++ b/README.md @@ -20,4 +20,3 @@ This library is under active refactoring and development, and is not yet ready f - [stb](https://github.com/nothings/stb) (Public Domain) - [Vulkan-Headers](https://github.com/KhronosGroup/Vulkan-Headers) (MIT) - [OffsetAllocator](https://github.com/sebbbi/OffsetAllocator) (MIT) -- [nanothread](https://github.com/mitsuba-renderer/nanothread) (BSD 3-Clause) diff --git a/external/nanothread b/external/nanothread deleted file mode 160000 index 94d38922..00000000 --- a/external/nanothread +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 94d389228f34db4953f53309df4b64331a8ed77e diff --git a/include/slang-rhi.h b/include/slang-rhi.h index 390e1455..b6efee55 100644 --- a/include/slang-rhi.h +++ b/include/slang-rhi.h @@ -2552,26 +2552,50 @@ class IDevice : public ISlangUnknown convertCooperativeVectorMatrix(const ConvertCooperativeVectorMatrixDesc* descs, uint32_t descCount) = 0; }; -class ITaskScheduler : public ISlangUnknown +class ITaskPool : public ISlangUnknown { SLANG_COM_INTERFACE(0xab272cee, 0xa546, 0x4ae6, {0xbd, 0x0d, 0xcd, 0xab, 0xa9, 0x3f, 0x6d, 0xa6}); public: typedef void* TaskHandle; - /// Submit a task. - /// The scheduler needs to call the `run` function with the `payload` argument. - /// The `parentTasks` contains a list of tasks that need to be completed before the submitted task can run. - /// Every submitted task is released using `releaseTask` once the task handle is no longer used. - virtual SLANG_NO_THROW TaskHandle SLANG_MCALL - submitTask(TaskHandle* parentTasks, uint32_t parentTaskCount, void (*run)(void* /*payload*/), void* payload) = 0; + /// \brief Submit a new task. + /// The returned task must be released with `releaseTask()` when no longer needed + /// for specifying dependencies or issuing `waitTask()`. + /// \param func Function to execute. + /// \param payload Payload to pass to the function. + /// \param payloadDeleter Optional payload deleter (called when task is destroyed). + /// \param deps Parent tasks to wait for. + /// \param depsCount Number of parent tasks. + /// \return The new task. + virtual SLANG_NO_THROW TaskHandle SLANG_MCALL submitTask( + void (*func)(void*), + void* payload, + void (*payloadDeleter)(void*), + TaskHandle* deps, + size_t depsCount + ) = 0; + + /// \brief Get the task payload data. + /// \param task Task to get the payload for. + /// \return The payload. + virtual SLANG_NO_THROW void* SLANG_MCALL getTaskPayload(TaskHandle task) = 0; - /// Release a task. - /// This is called when the task handle is no longer used. + /// \brief Release a task. + /// \param task Task to release. virtual SLANG_NO_THROW void SLANG_MCALL releaseTask(TaskHandle task) = 0; - // Wait for a task to complete. - virtual SLANG_NO_THROW void SLANG_MCALL waitForCompletion(TaskHandle task) = 0; + /// \brief Wait for a task to finish. + /// \param task Task to wait for. + virtual SLANG_NO_THROW void SLANG_MCALL waitTask(TaskHandle task) = 0; + + /// \brief Check if a task is done. + /// \param task Task to check. + /// \return True if the task is done. + virtual SLANG_NO_THROW bool SLANG_MCALL isTaskDone(TaskHandle task) = 0; + + /// \brief Wait for all tasks in the pool to finish. + virtual SLANG_NO_THROW void SLANG_MCALL waitAll() = 0; }; class IPersistentShaderCache : public ISlangUnknown @@ -2651,12 +2675,16 @@ class IRHI /// Set the global task pool worker count. /// Must be called before any devices are created. - /// This is ignored if the task scheduler is set. + /// This only affects the default task pool implementation and has no effect + /// if a custom task pool is used with `setTaskPool`. + /// If count is set to 0, a blocking task pool implementation is used. + /// If count is -1, a threaded task pool is used with a worker count equal to the number of logical cores. + /// If count is 1 or larger, a threaded task pool is used with the specified worker count. virtual SLANG_NO_THROW Result SLANG_MCALL setTaskPoolWorkerCount(uint32_t count) = 0; - /// Set the global task scheduler for the RHI. + /// Set the global task pool for the RHI. /// Must be called before any devices are created. - virtual SLANG_NO_THROW Result SLANG_MCALL setTaskScheduler(ITaskScheduler* scheduler) = 0; + virtual SLANG_NO_THROW Result SLANG_MCALL setTaskPool(ITaskPool* taskPool) = 0; }; // Global public functions diff --git a/src/core/task-pool.cpp b/src/core/task-pool.cpp index d2a585c7..0cd7219c 100644 --- a/src/core/task-pool.cpp +++ b/src/core/task-pool.cpp @@ -1,202 +1,432 @@ #include "task-pool.h" -#if 0 -#include -#endif - +#include #include +#include +#include +#include namespace rhi { -class BlockingTaskScheduler : public ITaskScheduler, public ComObject +// BlockingTaskPool + +struct BlockingTaskPool::Task { -public: - SLANG_COM_OBJECT_IUNKNOWN_ALL + void* payload; + void (*payloadDeleter)(void*); +}; - ITaskScheduler* getInterface(const Guid& guid) - { - if (guid == ISlangUnknown::getTypeGuid() || guid == ITaskScheduler::getTypeGuid()) - return static_cast(this); - return nullptr; - } +ITaskPool* BlockingTaskPool::getInterface(const Guid& guid) +{ + if (guid == ISlangUnknown::getTypeGuid() || guid == ITaskPool::getTypeGuid()) + return static_cast(this); + return nullptr; +} + +ITaskPool::TaskHandle BlockingTaskPool::submitTask( + void (*func)(void*), + void* payload, + void (*payloadDeleter)(void*), + TaskHandle* deps, + size_t depsCount +) +{ + SLANG_RHI_ASSERT(func); + SLANG_RHI_ASSERT(depsCount == 0 || deps); + + // Dependent tasks are guaranteed to be done. + SLANG_UNUSED(deps); + SLANG_UNUSED(depsCount); + + // Create task just to defer the payload deletion. + Task* task = new Task(); + task->payload = payload; + task->payloadDeleter = payloadDeleter; + + // Execute the task function. + func(payload); + + return task; +} - virtual SLANG_NO_THROW TaskHandle SLANG_MCALL - submitTask(TaskHandle* parentTasks, uint32_t parentTaskCount, void (*run)(void*), void* payload) override +void* BlockingTaskPool::getTaskPayload(TaskHandle task) +{ + SLANG_RHI_ASSERT(task); + + Task* taskImpl = static_cast(task); + return taskImpl->payload; +} + +void BlockingTaskPool::releaseTask(TaskHandle task) +{ + SLANG_RHI_ASSERT(task); + + Task* taskImpl = static_cast(task); + if (taskImpl->payloadDeleter) { - SLANG_UNUSED(parentTasks); - SLANG_UNUSED(parentTaskCount); - run(payload); - return payload; + taskImpl->payloadDeleter(taskImpl->payload); } +} + +void BlockingTaskPool::waitTask(TaskHandle task) +{ + SLANG_UNUSED(task); +} + +bool BlockingTaskPool::isTaskDone(TaskHandle task) +{ + return true; +} + +void BlockingTaskPool::waitAll() {} - virtual SLANG_NO_THROW void SLANG_MCALL releaseTask(TaskHandle task) override { SLANG_UNUSED(task); } +// ThreadedTaskPool - virtual SLANG_NO_THROW void SLANG_MCALL waitForCompletion(TaskHandle task) override { SLANG_UNUSED(task); } +struct ThreadedTaskPool::Task +{ + // Function to execute. + void (*func)(void*) = nullptr; + // Pointer to payload data. + void* payload = nullptr; + // Optional deleter for the payload. + void (*payloadDeleter)(void*) = nullptr; + + // Pool that owns the task. + Pool* pool = nullptr; + + // Reference counter. + std::atomic refCount{0}; + + // Number of dependencies that are not yet finished. + std::atomic depsRemaining{0}; + + // Flag indicating the task has finished. + std::atomic done{false}; + + // Mutex and condition variable for waitTask(). + std::mutex waitMutex; + std::condition_variable waitCV; + + // List of tasks that depend on this task. + std::vector children; + std::mutex childrenMutex; }; -#if 0 -class NanoThreadTaskScheduler : public ITaskScheduler, public ComObject +struct ThreadedTaskPool::Pool { -public: - SLANG_COM_OBJECT_IUNKNOWN_ALL + // Queue of tasks ready for execution. + std::queue m_queue; + std::mutex m_queueMutex; + std::condition_variable m_queueCV; + + // Flag to signal worker threads to stop. + std::atomic m_stop{false}; + + // Worker threads. + std::vector m_workerThreads; + + // Total number of tasks not yet completed. + std::atomic m_tasksRemaining{0}; + + // Mutex and condition variable for waitAll(). + std::mutex m_waitMutex; + std::condition_variable m_waitCV; + + void workerThread(); + + Pool(int workerCount) + { + if (workerCount <= 0) + { + workerCount = static_cast(std::thread::hardware_concurrency()); + if (workerCount <= 0) + workerCount = 1; + } + for (int i = 0; i < workerCount; i++) + { + m_workerThreads.emplace_back([this]() { workerThread(); }); + } + } - ITaskScheduler* getInterface(const Guid& guid) + ~Pool() { - if (guid == ISlangUnknown::getTypeGuid() || guid == ITaskScheduler::getTypeGuid()) - return static_cast(this); - return nullptr; + { + std::lock_guard lock(m_queueMutex); + m_stop.store(true); + } + m_queueCV.notify_all(); + for (std::thread& worker : m_workerThreads) + { + if (worker.joinable()) + worker.join(); + } + while (!m_queue.empty()) + { + Task* task = m_queue.front(); + m_queue.pop(); + releaseTask(task); + } } - NanoThreadTaskScheduler(uint32_t size) { m_pool = ::pool_create(size); } - ~NanoThreadTaskScheduler() { ::pool_destroy(m_pool); } + void retainTask(Task* task, size_t count = 1) + { + SLANG_RHI_ASSERT(task); + + task->refCount.fetch_add(count, std::memory_order_relaxed); + } - virtual SLANG_NO_THROW TaskHandle SLANG_MCALL - submitTask(TaskHandle* parentTasks, uint32_t parentTaskCount, void (*run)(void*), void* payload) override + void releaseTask(Task* task) { - TaskInfo taskInfo{run, payload}; - if (parentTasks && parentTaskCount > 0) + SLANG_RHI_ASSERT(task); + + if (task->refCount.fetch_sub(1, std::memory_order_acq_rel) == 1) { - return ::task_submit_dep( - m_pool, - (::Task**)parentTasks, - parentTaskCount, - 1, - runTask, - &taskInfo, - sizeof(TaskInfo), - nullptr, - 1 - ); + if (task->payloadDeleter) + { + task->payloadDeleter(task->payload); + } + delete task; } - else + } + + void enqueue(Task* task) + { + SLANG_RHI_ASSERT(task); + { - return ::task_submit(m_pool, 1, runTask, &taskInfo, sizeof(TaskInfo), nullptr, 1); + std::lock_guard lock(m_queueMutex); + m_queue.push(task); } + + m_queueCV.notify_one(); } - virtual SLANG_NO_THROW void SLANG_MCALL releaseTask(TaskHandle task) override { ::task_release((::Task*)task); } + Task* submitTask(void (*func)(void*), void* payload, void (*payloadDeleter)(void*), Task** deps, size_t depsCount) + { + SLANG_RHI_ASSERT(func); + SLANG_RHI_ASSERT(depsCount == 0 || deps); + + Task* task = new Task(); + + // Increment the reference count by 2. + // One reference is for the pool, the other is for the caller. + retainTask(task, 2); + + task->func = func; + task->payload = payload; + task->payloadDeleter = payloadDeleter; + task->pool = this; + task->depsRemaining = depsCount; - virtual SLANG_NO_THROW void SLANG_MCALL waitForCompletion(TaskHandle task) override { ::task_wait((::Task*)task); } + m_tasksRemaining.fetch_add(1, std::memory_order_relaxed); + + if (depsCount == 0) + { + // If there are no dependencies, enqueue the task immediately. + enqueue(task); + } + else + { + // Process dependencies. + for (size_t i = 0; i < depsCount; i++) + { + Task* dep = deps[i]; + SLANG_RHI_ASSERT(dep); + SLANG_RHI_ASSERT(dep->refCount.load(std::memory_order_acquire) > 0); + { + std::lock_guard lock(dep->childrenMutex); + if (!dep->done.load(std::memory_order_acquire)) + { + // Add an extra reference that will be released when the dependency finishes. + retainTask(task); + dep->children.push_back(task); + } + else + { + // Dependency is already done, decrement the counter and enqueue if necessary. + if (task->depsRemaining.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + enqueue(task); + } + } + } + } + } + return task; + } -private: - struct TaskInfo + bool isTaskDone(Task* task) { - void (*run)(void*); - void* payload; - }; + SLANG_RHI_ASSERT(task); - static void runTask(uint32_t index, void* payload) + return task->done.load(std::memory_order_acquire); + } + + void waitTask(Task* task) { - TaskInfo* taskInfo = (TaskInfo*)payload; - taskInfo->run(taskInfo->payload); + SLANG_RHI_ASSERT(task); + + std::unique_lock lock(task->waitMutex); + task->waitCV.wait(lock, [task] { return task->done.load(std::memory_order_acquire); }); } - ::Pool* m_pool; + void waitAll() + { + std::unique_lock lock(m_waitMutex); + m_waitCV.wait(lock, [this] { return m_tasksRemaining.load(std::memory_order_acquire) == 0; }); + } }; -#endif -class WaitTask : public Task +void ThreadedTaskPool::Pool::workerThread() { -public: - virtual void run() override {} -}; + while (true) + { + Task* task = nullptr; + // Fetch next task from queue. + { + std::unique_lock lock(m_queueMutex); + m_queueCV.wait(lock, [this] { return m_stop.load() || !m_queue.empty(); }); + if (m_stop.load() && m_queue.empty()) + return; + task = m_queue.front(); + m_queue.pop(); + } + // Execute the task function. + task->func(task->payload); + // Mark the task as done. + task->done.store(true, std::memory_order_release); + // Notify waiters. + { + std::lock_guard lock(task->waitMutex); + task->waitCV.notify_all(); + } + // Notify child tasks waiting on this dependency. + { + std::lock_guard lock(task->childrenMutex); + for (Task* child : task->children) + { + // Decrement the child's dependency counter. + if (child->depsRemaining.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + // All dependencies satisfied; enqueue the child. + enqueue(child); + } + // Release the extra reference taken when adding as a dependency. + releaseTask(child); + } + task->children.clear(); + } + // Decrement the remaining task counter and notify waiters. + if (m_tasksRemaining.fetch_sub(1, std::memory_order_acq_rel) == 1) + { + std::lock_guard lock(m_waitMutex); + m_waitCV.notify_all(); + } + // Release the pool's reference. + releaseTask(task); + } +} -static void runTask(void* task) +ITaskPool* ThreadedTaskPool::getInterface(const Guid& guid) { - ((Task*)task)->run(); - ((Task*)task)->releaseReference(); + if (guid == ISlangUnknown::getTypeGuid() || guid == ITaskPool::getTypeGuid()) + return static_cast(this); + return nullptr; } -TaskPool::TaskPool(uint32_t workerCount) +ThreadedTaskPool::ThreadedTaskPool(int workerCount) { - m_scheduler = new BlockingTaskScheduler(); -#if 0 - if (workerCount == 0) - { - m_scheduler = new BlockingTaskScheduler(); - } - else - { - m_scheduler = new NanoThreadTaskScheduler(workerCount == kAutoWorkerCount ? NANOTHREAD_AUTO : workerCount); - } -#endif - SLANG_RHI_ASSERT(m_scheduler); + m_pool = new Pool(workerCount); } -TaskPool::TaskPool(ITaskScheduler* scheduler) - : m_scheduler(scheduler) +ThreadedTaskPool::~ThreadedTaskPool() { - SLANG_RHI_ASSERT(m_scheduler); + delete m_pool; } -TaskPool::~TaskPool() {} +ITaskPool::TaskHandle ThreadedTaskPool::submitTask( + void (*func)(void*), + void* payload, + void (*payloadDeleter)(void*), + TaskHandle* deps, + size_t depsCount +) +{ + return m_pool->submitTask(func, payload, payloadDeleter, (Task**)deps, depsCount); +} -TaskHandle TaskPool::submitTask(Task* task, TaskHandle* parentTaskHandles, uint32_t parentTaskHandleCount) +void* ThreadedTaskPool::getTaskPayload(TaskHandle task) { - task->addReference(); - return m_scheduler->submitTask(parentTaskHandles, parentTaskHandleCount, runTask, task); + return static_cast(task)->payload; } -void TaskPool::releaseTask(TaskHandle taskHandle) +void ThreadedTaskPool::releaseTask(TaskHandle task) { - m_scheduler->releaseTask(taskHandle); + m_pool->releaseTask(static_cast(task)); } -void TaskPool::waitForCompletion(TaskHandle taskHandle) +void ThreadedTaskPool::waitTask(TaskHandle task) { - m_scheduler->waitForCompletion(taskHandle); + m_pool->waitTask(static_cast(task)); } -void TaskPool::waitForCompletion(TaskHandle* taskHandles, uint32_t taskHandleCount) +bool ThreadedTaskPool::isTaskDone(TaskHandle task) { - if (taskHandles && taskHandleCount > 0) - { - RefPtr waitTask = new WaitTask(); - TaskHandle waitTaskHandle = submitTask(waitTask, taskHandles, taskHandleCount); - waitForCompletion(waitTaskHandle); - releaseTask(waitTaskHandle); - } + return m_pool->isTaskDone(static_cast(task)); +} + +void ThreadedTaskPool::waitAll() +{ + m_pool->waitAll(); } static std::mutex s_globalTaskPoolMutex; -static std::unique_ptr s_globalTaskPool; -static uint32_t s_globalTaskPoolWorkerCount = TaskPool::kAutoWorkerCount; -static ComPtr s_globalTaskScheduler; +static ComPtr s_globalTaskPool; +static uint32_t s_globalTaskPoolWorkerCount = -1; Result setGlobalTaskPoolWorkerCount(uint32_t count) { std::lock_guard lock(s_globalTaskPoolMutex); if (s_globalTaskPool) + { return SLANG_FAIL; + } s_globalTaskPoolWorkerCount = count; return SLANG_OK; } -Result setGlobalTaskScheduler(ITaskScheduler* scheduler) +Result setGlobalTaskPool(ITaskPool* taskPool) { std::lock_guard lock(s_globalTaskPoolMutex); if (s_globalTaskPool) + { return SLANG_FAIL; - s_globalTaskScheduler = scheduler; + } + s_globalTaskPool = taskPool; return SLANG_OK; } -TaskPool& globalTaskPool() +ITaskPool* globalTaskPool() { - static std::atomic taskPoolPtr; - if (taskPoolPtr) + static std::atomic taskPool; + if (taskPool) { - return *taskPoolPtr; + return taskPool; } std::lock_guard lock(s_globalTaskPoolMutex); if (!s_globalTaskPool) { - s_globalTaskPool.reset( - s_globalTaskScheduler ? new TaskPool(s_globalTaskScheduler) : new TaskPool(s_globalTaskPoolWorkerCount) - ); - taskPoolPtr = s_globalTaskPool.get(); + if (s_globalTaskPoolWorkerCount == 0) + { + s_globalTaskPool = new BlockingTaskPool(); + } + else + { + s_globalTaskPool = new ThreadedTaskPool(s_globalTaskPoolWorkerCount); + } } - return *taskPoolPtr; + taskPool = s_globalTaskPool.get(); + return taskPool; } } // namespace rhi diff --git a/src/core/task-pool.h b/src/core/task-pool.h index ba343682..874cf59b 100644 --- a/src/core/task-pool.h +++ b/src/core/task-pool.h @@ -4,37 +4,73 @@ namespace rhi { -class Task : public RefObject +class BlockingTaskPool : public ITaskPool, public ComObject { public: - virtual ~Task() {} + SLANG_COM_OBJECT_IUNKNOWN_ALL - virtual void run() = 0; + ITaskPool* getInterface(const Guid& guid); + +public: + TaskHandle submitTask( + void (*func)(void*), + void* payload, + void (*payloadDeleter)(void*), + TaskHandle* deps, + size_t depsCount + ) override; + + void* getTaskPayload(TaskHandle task) override; + + void releaseTask(TaskHandle task) override; + + void waitTask(TaskHandle task) override; + + bool isTaskDone(TaskHandle task) override; + + void waitAll() override; private: - std::string m_name; + struct Task; }; -using TaskHandle = ITaskScheduler::TaskHandle; - -class TaskPool +class ThreadedTaskPool : public ITaskPool, public ComObject { public: - static constexpr uint32_t kAutoWorkerCount = uint32_t(-1); + SLANG_COM_OBJECT_IUNKNOWN_ALL - TaskPool(uint32_t workerCount = kAutoWorkerCount); - TaskPool(ITaskScheduler* scheduler); - ~TaskPool(); + ITaskPool* getInterface(const Guid& guid); - TaskHandle submitTask(Task* task, TaskHandle* parentTaskHandles = nullptr, uint32_t parentTaskHandleCount = 0); - void releaseTask(TaskHandle taskHandle); - void waitForCompletion(TaskHandle taskHandle); - void waitForCompletion(TaskHandle* taskHandles, uint32_t taskHandleCount); +public: + ThreadedTaskPool(int workerCount = -1); + ~ThreadedTaskPool() override; + + TaskHandle submitTask( + void (*func)(void*), + void* payload, + void (*payloadDeleter)(void*), + TaskHandle* deps, + size_t depsCount + ) override; + + void* getTaskPayload(TaskHandle task) override; + + void releaseTask(TaskHandle task) override; + + void waitTask(TaskHandle task) override; + + bool isTaskDone(TaskHandle task) override; + + void waitAll() override; private: - ComPtr m_scheduler; + struct Task; + struct Pool; + + Pool* m_pool; }; + /// Set the global task pool worker count. /// Must be called before first accessing the global task pool. /// This is ignored if the task scheduler is set. @@ -42,9 +78,9 @@ Result setGlobalTaskPoolWorkerCount(uint32_t count); /// Set the global task scheduler. /// Must be called before first accessing the global task pool. -Result setGlobalTaskScheduler(ITaskScheduler* scheduler); +Result setGlobalTaskPool(ITaskPool* taskPool); /// Returns the global task pool. -TaskPool& globalTaskPool(); +ITaskPool* globalTaskPool(); } // namespace rhi diff --git a/src/rhi.cpp b/src/rhi.cpp index 0a96982d..70b90852 100644 --- a/src/rhi.cpp +++ b/src/rhi.cpp @@ -240,7 +240,7 @@ class RHI : public IRHI void enableDebugLayers() override; Result reportLiveObjects() override; Result setTaskPoolWorkerCount(uint32_t count) override; - Result setTaskScheduler(ITaskScheduler* scheduler) override; + Result setTaskPool(ITaskPool* scheduler) override; static RHI* getInstance() { @@ -454,9 +454,9 @@ Result RHI::setTaskPoolWorkerCount(uint32_t count) return setGlobalTaskPoolWorkerCount(count); } -Result RHI::setTaskScheduler(ITaskScheduler* scheduler) +Result RHI::setTaskPool(ITaskPool* taskPool) { - return setGlobalTaskScheduler(scheduler); + return setGlobalTaskPool(taskPool); } bool isDebugLayersEnabled() diff --git a/tests/test-task-pool.cpp b/tests/test-task-pool.cpp index 791a8d22..851fcb8d 100644 --- a/tests/test-task-pool.cpp +++ b/tests/test-task-pool.cpp @@ -8,164 +8,297 @@ using namespace rhi; -class SimpleTask : public Task +// Create a number of tasks and wait for each of them individually. +void testSimple(ITaskPool* pool) { -public: - SimpleTask( - std::function onCreate = nullptr, - std::function onDestroy = nullptr, - std::function onRun = nullptr - ) - : m_onCreate(onCreate) - , m_onDestroy(onDestroy) - , m_onRun(onRun) + REQUIRE(pool != nullptr); + + static constexpr size_t N = 1000; + static size_t result[N]; + static bool deleted[N]; + ITaskPool::TaskHandle tasks[N]; + + ::memset(result, 0, sizeof(result)); + ::memset(deleted, 0, sizeof(deleted)); + + for (size_t i = 0; i < N; ++i) { - if (onCreate) - onCreate(); + size_t* payload = new size_t{i}; + tasks[i] = pool->submitTask( + [](void* payload) + { + size_t i = *static_cast(payload); + result[i] = i; + }, + payload, + [](void* payload) + { + size_t i = *static_cast(payload); + deleted[i] = true; + delete static_cast(payload); + }, + nullptr, + 0 + ); } - ~SimpleTask() + for (size_t i = 0; i < N; ++i) { - if (m_onDestroy) - m_onDestroy(); + CAPTURE(i); + CHECK(!deleted[i]); + pool->waitTask(tasks[i]); + pool->releaseTask(tasks[i]); + CHECK(result[i] == (size_t)i); } - void run() override + pool->waitAll(); + + for (size_t i = 0; i < N; ++i) { - if (m_onRun) - m_onRun(); + CAPTURE(i); + CHECK(deleted[i]); } +} + +// Create a number of tasks and wait for all of them at once. +void testWaitAll(ITaskPool* pool) +{ + REQUIRE(pool != nullptr); + + static constexpr size_t N = 1000; + static size_t result[N]; + static bool deleted[N]; + + ::memset(result, 0, sizeof(result)); + ::memset(deleted, 0, sizeof(deleted)); + + for (size_t i = 0; i < N; ++i) + { + size_t* payload = new size_t{i}; + ITaskPool::TaskHandle task = pool->submitTask( + [](void* payload) + { + size_t i = *static_cast(payload); + result[i] = i; + }, + payload, + [](void* payload) + { + size_t i = *static_cast(payload); + deleted[i] = true; + delete static_cast(payload); + }, + nullptr, + 0 + ); + CHECK(!deleted[i]); + pool->releaseTask(task); + } + + pool->waitAll(); + + for (size_t i = 0; i < N; ++i) + { + CAPTURE(i); + CHECK(result[i] == (size_t)i); + CHECK(deleted[i]); + } +} + +// Create a number of tasks and wait for all of them at once. +void testSimpleDependency(ITaskPool* pool) +{ + REQUIRE(pool != nullptr); + + static constexpr size_t N = 1000; + static size_t result[N]; + static ITaskPool::TaskHandle tasks[N]; + static std::atomic finished; + + finished = 0; + + for (size_t i = 0; i < N; ++i) + { + tasks[i] = pool->submitTask( + [](void* payload) + { + size_t i = (size_t)(uintptr_t)payload; + result[i] = i; + finished++; + }, + (void*)i, + nullptr, + nullptr, + 0 + ); + } + + ITaskPool::TaskHandle waitTask = pool->submitTask([](void*) { CHECK(finished == N); }, nullptr, nullptr, tasks, N); + + for (size_t i = 0; i < N; ++i) + { + pool->releaseTask(tasks[i]); + } + + pool->waitTask(waitTask); + pool->releaseTask(waitTask); + + for (size_t i = 0; i < N; ++i) + { + CAPTURE(i); + CHECK(result[i] == (size_t)i); + } +} + +inline ITaskPool::TaskHandle spawn(ITaskPool* pool, int depth) +{ + if (depth > 0) + { + ITaskPool::TaskHandle a = spawn(pool, depth - 1); + ITaskPool::TaskHandle b = spawn(pool, depth - 1); + ITaskPool::TaskHandle tasks[] = {a, b}; + ITaskPool::TaskHandle c = pool->submitTask([](void*) {}, nullptr, nullptr, tasks, 2); + pool->releaseTask(a); + pool->releaseTask(b); + return c; + } + else + { + return pool->submitTask([](void*) {}, nullptr, nullptr, nullptr, 0); + } +} + +void testRecursiveDependency(ITaskPool* pool) +{ + REQUIRE(pool != nullptr); + + ITaskPool::TaskHandle task = spawn(pool, 10); + pool->waitTask(task); + pool->releaseTask(task); +} -private: - std::function m_onCreate; - std::function m_onDestroy; - std::function m_onRun; +struct FibonacciPayload +{ + int result; + ITaskPool::TaskHandle a; + ITaskPool::TaskHandle b; }; -TEST_CASE("task-pool") +static ITaskPool* fibonacciPool; + +inline ITaskPool::TaskHandle fibonacciTask(int n) { - TaskPool pool; - - SUBCASE("wait single") - { - std::atomic alive(false); - std::atomic done(false); - RefPtr task = - new SimpleTask([&]() { alive = true; }, [&]() { alive = false; }, [&]() { done = true; }); - REQUIRE(task); - REQUIRE(alive); - REQUIRE(!done); - TaskHandle taskHandle = pool.submitTask(task); - REQUIRE(taskHandle); - task.setNull(); - pool.waitForCompletion(taskHandle); - CHECK(done); - pool.releaseTask(taskHandle); - CHECK(!alive); - } - - SUBCASE("wait multiple") - { - static constexpr int N = 100; - std::atomic alive[N]; - std::atomic done[N]; - TaskHandle taskHandles[N]; - for (int i = 0; i < N; ++i) + FibonacciPayload* payload = new FibonacciPayload{}; + + if (n <= 1) + { + payload->result = n; + payload->a = nullptr; + payload->b = nullptr; + return fibonacciPool->submitTask([](void* payload) {}, payload, ::free, nullptr, 0); + } + else + { + payload->a = fibonacciTask(n - 1); + payload->b = fibonacciTask(n - 2); + ITaskPool::TaskHandle tasks[] = {payload->a, payload->b}; + return fibonacciPool->submitTask( + [](void* payload) + { + FibonacciPayload* p = static_cast(payload); + FibonacciPayload* pa = static_cast(fibonacciPool->getTaskPayload(p->a)); + FibonacciPayload* pb = static_cast(fibonacciPool->getTaskPayload(p->b)); + p->result = pa->result + pb->result; + fibonacciPool->releaseTask(p->a); + fibonacciPool->releaseTask(p->b); + }, + payload, + ::free, + tasks, + 2 + ); + } +} + +inline int fibonacci(int n) +{ + return n <= 1 ? n : fibonacci(n - 1) + fibonacci(n - 2); +} + +void testFibonacci(ITaskPool* pool) +{ + REQUIRE(pool != nullptr); + + fibonacciPool = pool; + int N = 25; + int expected = fibonacci(N); + ITaskPool::TaskHandle task = fibonacciTask(N); + pool->waitTask(task); + int result = static_cast(pool->getTaskPayload(task))->result; + CHECK(result == expected); + pool->releaseTask(task); +} + +TEST_CASE("task-pool-blocking") +{ + ComPtr pool(new BlockingTaskPool()); + + SUBCASE("simple") + { + testSimple(pool); + } + SUBCASE("wait-all") + { + testWaitAll(pool); + } + SUBCASE("simple-dependency") + { + testSimpleDependency(pool); + } + SUBCASE("recursive-dependency") + { + testRecursiveDependency(pool); + } + SUBCASE("fibonacci") + { + testFibonacci(pool); + } +} + +TEST_CASE("task-pool-threaded") +{ + ComPtr pool(new ThreadedTaskPool()); + + SUBCASE("simple") + { + for (int i = 0; i < 100; ++i) { - alive[i] = false; - done[i] = false; - RefPtr task = new SimpleTask( - [&, i]() { alive[i] = true; }, - [&, i]() { alive[i] = false; }, - [&, i]() { done[i] = true; } - ); - taskHandles[i] = pool.submitTask(task); + testSimple(pool); } - pool.waitForCompletion(taskHandles, N); - for (int i = 0; i < N; ++i) + } + SUBCASE("wait-all") + { + for (int i = 0; i < 100; ++i) { - pool.releaseTask(taskHandles[i]); - CHECK(!alive[i]); - CHECK(done[i]); + testWaitAll(pool); } } - - SUBCASE("simple dependency") + SUBCASE("simple-dependency") { - std::atomic aliveA(false); - std::atomic doneA(false); - std::atomic aliveB(false); - std::atomic doneB(false); - RefPtr taskA = - new SimpleTask([&]() { aliveA = true; }, [&]() { aliveA = false; }, [&]() { doneA = true; }); - RefPtr taskB = new SimpleTask( - [&]() { aliveB = true; }, - [&]() { aliveB = false; }, - [&]() - { - CHECK(doneA); - doneB = true; - } - ); - TaskHandle taskHandleA = pool.submitTask(taskA); - taskA.setNull(); - TaskHandle taskHandleB = pool.submitTask(taskB, &taskHandleA, 1); - taskB.setNull(); - pool.releaseTask(taskHandleA); - pool.waitForCompletion(taskHandleB); - CHECK(doneB); - pool.releaseTask(taskHandleB); - CHECK(!aliveA); - CHECK(!aliveB); - } - - SUBCASE("complex dependency") - { - static constexpr int N = 100; - static constexpr int M = 10; - std::atomic aliveInner[N][M]; - std::atomic doneInner[N][M]; - std::atomic aliveOuter[N]; - std::atomic doneOuter[N]; - TaskHandle taskHandlesOuter[N]; - for (int i = 0; i < N; ++i) + for (int i = 0; i < 100; ++i) { - TaskHandle taskHandlesInner[M]; - for (int j = 0; j < M; j++) - { - aliveInner[i][j] = false; - doneInner[i][j] = false; - RefPtr taskInner = new SimpleTask( - [&, i, j]() { aliveInner[i][j] = true; }, - [&, i, j]() { aliveInner[i][j] = false; }, - [&, i, j]() { doneInner[i][j] = true; } - ); - taskHandlesInner[j] = pool.submitTask(taskInner); - } - aliveOuter[i] = false; - doneOuter[i] = false; - RefPtr taskOuter = new SimpleTask( - [&, i]() { aliveOuter[i] = true; }, - [&, i]() { aliveOuter[i] = false; }, - [&, i]() - { - for (int k = 0; k < M; ++k) - CHECK(doneInner[i][k]); - doneOuter[i] = true; - } - ); - taskHandlesOuter[i] = pool.submitTask(taskOuter, taskHandlesInner, M); - for (int j = 0; j < M; j++) - { - pool.releaseTask(taskHandlesInner[j]); - } + testSimpleDependency(pool); } - pool.waitForCompletion(taskHandlesOuter, N); - for (int i = 0; i < N; ++i) + } + SUBCASE("recursive-dependency") + { + for (int i = 0; i < 100; ++i) { - pool.releaseTask(taskHandlesOuter[i]); - CHECK(doneOuter[i]); - CHECK(!aliveOuter[i]); + testRecursiveDependency(pool); } } + SUBCASE("fibonacci") + { + testFibonacci(pool); + } }