Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
TORCH_CHECK(
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
startWorkerThread();
}

CpuDeviceInterface::~CpuDeviceInterface() {
stopWorkerThread();
}

void CpuDeviceInterface::initialize(
Expand Down Expand Up @@ -433,4 +438,112 @@ std::string CpuDeviceInterface::getDetails() {
return std::string("CPU Device Interface.");
}

// --------------------------------------------------------------------------
// ASYNC COLOR CONVERSION THREADING
// --------------------------------------------------------------------------

void CpuDeviceInterface::startWorkerThread() {
workerShouldExit_ = false;
workerThread_ = std::thread(&CpuDeviceInterface::colorConversionWorker, this);
}

void CpuDeviceInterface::stopWorkerThread() {
{
std::lock_guard<std::mutex> lock(queueMutex_);
workerShouldExit_ = true;
}
workAvailable_.notify_one();
if (workerThread_.joinable()) {
workerThread_.join();
}
}

void CpuDeviceInterface::colorConversionWorker() {
while (true) {
ConversionWorkItem workItem;

// Wait for work or exit signal
{
std::unique_lock<std::mutex> lock(queueMutex_);
workAvailable_.wait(lock, [this] {
return !workQueue_.empty() || workerShouldExit_;
});

if (workerShouldExit_ && workQueue_.empty()) {
return; // Exit thread
}

workItem = std::move(workQueue_.front());
workQueue_.pop();
}

// Do the color conversion (outside the lock)
FrameOutput frameOutput;
if (avMediaType_ == AVMEDIA_TYPE_AUDIO) {
convertAudioAVFrameToFrameOutput(workItem.avFrame, frameOutput);
} else {
convertVideoAVFrameToFrameOutput(
workItem.avFrame,
frameOutput,
workItem.preAllocatedOutputTensor);
}

// Push result
{
std::lock_guard<std::mutex> lock(queueMutex_);
resultQueue_.push(std::move(frameOutput));
}
resultAvailable_.notify_one();
}
}

void CpuDeviceInterface::enqueueConversion(
UniqueAVFrame avFrame,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
std::unique_lock<std::mutex> lock(queueMutex_);

// Block if queue is full (backpressure)
workAvailable_.wait(lock, [this] {
return workQueue_.size() < kMaxQueueDepth;
});

ConversionWorkItem workItem{
std::move(avFrame), std::move(preAllocatedOutputTensor)};
workQueue_.push(std::move(workItem));

workAvailable_.notify_one();
}

FrameOutput CpuDeviceInterface::dequeueConversionResult() {
std::unique_lock<std::mutex> lock(queueMutex_);

resultAvailable_.wait(lock, [this] {
return !resultQueue_.empty();
});

FrameOutput result = std::move(resultQueue_.front());
resultQueue_.pop();

// Notify enqueue that there's space (for backpressure)
workAvailable_.notify_one();

return result;
}

void CpuDeviceInterface::flushConversionQueue() {
std::lock_guard<std::mutex> lock(queueMutex_);

// Clear both queues
while (!workQueue_.empty()) {
workQueue_.pop();
}
while (!resultQueue_.empty()) {
resultQueue_.pop();
}

// Notify waiting threads
workAvailable_.notify_all();
resultAvailable_.notify_all();
}

} // namespace facebook::torchcodec
31 changes: 30 additions & 1 deletion src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

#pragma once

#include <condition_variable>
#include <mutex>
#include <queue>
#include <thread>
#include "DeviceInterface.h"
#include "FFMPEGCommon.h"
#include "FilterGraph.h"
Expand All @@ -16,7 +20,7 @@ class CpuDeviceInterface : public DeviceInterface {
public:
CpuDeviceInterface(const torch::Device& device);

virtual ~CpuDeviceInterface() {}
virtual ~CpuDeviceInterface();

std::optional<const AVCodec*> findCodec(
[[maybe_unused]] const AVCodecID& codecId) override {
Expand Down Expand Up @@ -45,7 +49,32 @@ class CpuDeviceInterface : public DeviceInterface {

std::string getDetails() override;

// Async color conversion API
void enqueueConversion(
UniqueAVFrame avFrame,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
FrameOutput dequeueConversionResult();
void flushConversionQueue();

private:
// Worker thread for async color conversion
struct ConversionWorkItem {
UniqueAVFrame avFrame;
std::optional<torch::Tensor> preAllocatedOutputTensor;
};

void colorConversionWorker();
void startWorkerThread();
void stopWorkerThread();

std::thread workerThread_;
std::mutex queueMutex_;
std::condition_variable workAvailable_;
std::condition_variable resultAvailable_;
std::queue<ConversionWorkItem> workQueue_;
std::queue<FrameOutput> resultQueue_;
bool workerShouldExit_ = false;
static constexpr size_t kMaxQueueDepth = 2;
void convertAudioAVFrameToFrameOutput(
UniqueAVFrame& srcAVFrame,
FrameOutput& frameOutput);
Expand Down
22 changes: 22 additions & 0 deletions src/torchcodec/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,28 @@ class DeviceInterface {
return "";
}

// Async color conversion API (optional)
// Default implementations fall back to synchronous conversion
virtual void enqueueConversion(
UniqueAVFrame avFrame,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) {
// Default: synchronous conversion (no pipelining)
FrameOutput output;
convertAVFrameToFrameOutput(avFrame, output, preAllocatedOutputTensor);
// Store the result for dequeue
throw std::runtime_error(
"Async conversion not implemented for this device interface");
}

virtual FrameOutput dequeueConversionResult() {
throw std::runtime_error(
"Async conversion not implemented for this device interface");
}

virtual void flushConversionQueue() {
// Default: no-op
}

protected:
torch::Device device_;
SharedAVCodecContext codecContext_;
Expand Down
44 changes: 39 additions & 5 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,14 +700,47 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
frameBatchOutput.durationSeconds[indexInOutput] =
frameBatchOutput.durationSeconds[previousIndexInOutput];
} else {
FrameOutput frameOutput = getFrameAtIndexInternal(
indexInVideo, frameBatchOutput.data[indexInOutput]);
frameBatchOutput.ptsSeconds[indexInOutput] = frameOutput.ptsSeconds;
frameBatchOutput.durationSeconds[indexInOutput] =
frameOutput.durationSeconds;
// Async conversion path: decode + enqueue conversion
// Then dequeue previous result
if (indexInVideo != lastDecodedFrameIndex_ + 1) {
int64_t pts = getPts(indexInVideo);
setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase));
}

UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) {
return getPtsOrDts(avFrame) >= cursor_;
});

// Enqueue conversion for current frame
deviceInterface_->enqueueConversion(
std::move(avFrame), frameBatchOutput.data[indexInOutput]);

// Dequeue result from previous frame (if exists)
if (f > 0) {
auto prevIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1];
FrameOutput result = deviceInterface_->dequeueConversionResult();
// Data is already in frameBatchOutput.data[prevIndexInOutput]
frameBatchOutput.ptsSeconds[prevIndexInOutput] = result.ptsSeconds;
frameBatchOutput.durationSeconds[prevIndexInOutput] =
result.durationSeconds;
}

lastDecodedFrameIndex_ = indexInVideo;
}
previousIndexInVideo = indexInVideo;
}

// Dequeue the last frame's result
if (frameIndices.numel() > 0) {
auto lastIndexInOutput =
indicesAreSorted ? frameIndices.numel() - 1 : argsort[frameIndices.numel() - 1];
FrameOutput lastResult = deviceInterface_->dequeueConversionResult();
// Data is already in the right place
frameBatchOutput.ptsSeconds[lastIndexInOutput] = lastResult.ptsSeconds;
frameBatchOutput.durationSeconds[lastIndexInOutput] =
lastResult.durationSeconds;
}

frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data);
return frameBatchOutput;
}
Expand Down Expand Up @@ -1194,6 +1227,7 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {

decodeStats_.numFlushes++;
deviceInterface_->flush();
deviceInterface_->flushConversionQueue();
}

// --------------------------------------------------------------------------
Expand Down