diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 70f46b7e4..3681f277f 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -20,6 +20,11 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!"); TORCH_CHECK( device_.type() == torch::kCPU, "Unsupported device: ", device_.str()); + startWorkerThread(); +} + +CpuDeviceInterface::~CpuDeviceInterface() { + stopWorkerThread(); } void CpuDeviceInterface::initialize( @@ -433,4 +438,112 @@ std::string CpuDeviceInterface::getDetails() { return std::string("CPU Device Interface."); } +// -------------------------------------------------------------------------- +// ASYNC COLOR CONVERSION THREADING +// -------------------------------------------------------------------------- + +void CpuDeviceInterface::startWorkerThread() { + workerShouldExit_ = false; + workerThread_ = std::thread(&CpuDeviceInterface::colorConversionWorker, this); +} + +void CpuDeviceInterface::stopWorkerThread() { + { + std::lock_guard lock(queueMutex_); + workerShouldExit_ = true; + } + workAvailable_.notify_one(); + if (workerThread_.joinable()) { + workerThread_.join(); + } +} + +void CpuDeviceInterface::colorConversionWorker() { + while (true) { + ConversionWorkItem workItem; + + // Wait for work or exit signal + { + std::unique_lock lock(queueMutex_); + workAvailable_.wait(lock, [this] { + return !workQueue_.empty() || workerShouldExit_; + }); + + if (workerShouldExit_ && workQueue_.empty()) { + return; // Exit thread + } + + workItem = std::move(workQueue_.front()); + workQueue_.pop(); + } + + // Do the color conversion (outside the lock) + FrameOutput frameOutput; + if (avMediaType_ == AVMEDIA_TYPE_AUDIO) { + convertAudioAVFrameToFrameOutput(workItem.avFrame, frameOutput); + } else { + convertVideoAVFrameToFrameOutput( + workItem.avFrame, + frameOutput, + workItem.preAllocatedOutputTensor); + } + + // Push result + { + std::lock_guard lock(queueMutex_); + resultQueue_.push(std::move(frameOutput)); + } + resultAvailable_.notify_one(); + } +} + +void CpuDeviceInterface::enqueueConversion( + UniqueAVFrame avFrame, + std::optional preAllocatedOutputTensor) { + std::unique_lock lock(queueMutex_); + + // Block if queue is full (backpressure) + workAvailable_.wait(lock, [this] { + return workQueue_.size() < kMaxQueueDepth; + }); + + ConversionWorkItem workItem{ + std::move(avFrame), std::move(preAllocatedOutputTensor)}; + workQueue_.push(std::move(workItem)); + + workAvailable_.notify_one(); +} + +FrameOutput CpuDeviceInterface::dequeueConversionResult() { + std::unique_lock lock(queueMutex_); + + resultAvailable_.wait(lock, [this] { + return !resultQueue_.empty(); + }); + + FrameOutput result = std::move(resultQueue_.front()); + resultQueue_.pop(); + + // Notify enqueue that there's space (for backpressure) + workAvailable_.notify_one(); + + return result; +} + +void CpuDeviceInterface::flushConversionQueue() { + std::lock_guard lock(queueMutex_); + + // Clear both queues + while (!workQueue_.empty()) { + workQueue_.pop(); + } + while (!resultQueue_.empty()) { + resultQueue_.pop(); + } + + // Notify waiting threads + workAvailable_.notify_all(); + resultAvailable_.notify_all(); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 801b83826..7f8b501c6 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -6,6 +6,10 @@ #pragma once +#include +#include +#include +#include #include "DeviceInterface.h" #include "FFMPEGCommon.h" #include "FilterGraph.h" @@ -16,7 +20,7 @@ class CpuDeviceInterface : public DeviceInterface { public: CpuDeviceInterface(const torch::Device& device); - virtual ~CpuDeviceInterface() {} + virtual ~CpuDeviceInterface(); std::optional findCodec( [[maybe_unused]] const AVCodecID& codecId) override { @@ -45,7 +49,32 @@ class CpuDeviceInterface : public DeviceInterface { std::string getDetails() override; + // Async color conversion API + void enqueueConversion( + UniqueAVFrame avFrame, + std::optional preAllocatedOutputTensor = std::nullopt); + FrameOutput dequeueConversionResult(); + void flushConversionQueue(); + private: + // Worker thread for async color conversion + struct ConversionWorkItem { + UniqueAVFrame avFrame; + std::optional preAllocatedOutputTensor; + }; + + void colorConversionWorker(); + void startWorkerThread(); + void stopWorkerThread(); + + std::thread workerThread_; + std::mutex queueMutex_; + std::condition_variable workAvailable_; + std::condition_variable resultAvailable_; + std::queue workQueue_; + std::queue resultQueue_; + bool workerShouldExit_ = false; + static constexpr size_t kMaxQueueDepth = 2; void convertAudioAVFrameToFrameOutput( UniqueAVFrame& srcAVFrame, FrameOutput& frameOutput); diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 319fe01a8..c0fa62ab1 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -138,6 +138,28 @@ class DeviceInterface { return ""; } + // Async color conversion API (optional) + // Default implementations fall back to synchronous conversion + virtual void enqueueConversion( + UniqueAVFrame avFrame, + std::optional preAllocatedOutputTensor = std::nullopt) { + // Default: synchronous conversion (no pipelining) + FrameOutput output; + convertAVFrameToFrameOutput(avFrame, output, preAllocatedOutputTensor); + // Store the result for dequeue + throw std::runtime_error( + "Async conversion not implemented for this device interface"); + } + + virtual FrameOutput dequeueConversionResult() { + throw std::runtime_error( + "Async conversion not implemented for this device interface"); + } + + virtual void flushConversionQueue() { + // Default: no-op + } + protected: torch::Device device_; SharedAVCodecContext codecContext_; diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 8375753ae..895edfc07 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -700,14 +700,47 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( frameBatchOutput.durationSeconds[indexInOutput] = frameBatchOutput.durationSeconds[previousIndexInOutput]; } else { - FrameOutput frameOutput = getFrameAtIndexInternal( - indexInVideo, frameBatchOutput.data[indexInOutput]); - frameBatchOutput.ptsSeconds[indexInOutput] = frameOutput.ptsSeconds; - frameBatchOutput.durationSeconds[indexInOutput] = - frameOutput.durationSeconds; + // Async conversion path: decode + enqueue conversion + // Then dequeue previous result + if (indexInVideo != lastDecodedFrameIndex_ + 1) { + int64_t pts = getPts(indexInVideo); + setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); + } + + UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) { + return getPtsOrDts(avFrame) >= cursor_; + }); + + // Enqueue conversion for current frame + deviceInterface_->enqueueConversion( + std::move(avFrame), frameBatchOutput.data[indexInOutput]); + + // Dequeue result from previous frame (if exists) + if (f > 0) { + auto prevIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; + FrameOutput result = deviceInterface_->dequeueConversionResult(); + // Data is already in frameBatchOutput.data[prevIndexInOutput] + frameBatchOutput.ptsSeconds[prevIndexInOutput] = result.ptsSeconds; + frameBatchOutput.durationSeconds[prevIndexInOutput] = + result.durationSeconds; + } + + lastDecodedFrameIndex_ = indexInVideo; } previousIndexInVideo = indexInVideo; } + + // Dequeue the last frame's result + if (frameIndices.numel() > 0) { + auto lastIndexInOutput = + indicesAreSorted ? frameIndices.numel() - 1 : argsort[frameIndices.numel() - 1]; + FrameOutput lastResult = deviceInterface_->dequeueConversionResult(); + // Data is already in the right place + frameBatchOutput.ptsSeconds[lastIndexInOutput] = lastResult.ptsSeconds; + frameBatchOutput.durationSeconds[lastIndexInOutput] = + lastResult.durationSeconds; + } + frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); return frameBatchOutput; } @@ -1194,6 +1227,7 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { decodeStats_.numFlushes++; deviceInterface_->flush(); + deviceInterface_->flushConversionQueue(); } // --------------------------------------------------------------------------