Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/decoders/benchmark_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def main():

input_height = metadata.height
input_width = metadata.width

for num_fraction in args.total_frame_fractions:
num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction)
print(
Expand Down
119 changes: 96 additions & 23 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ void CpuDeviceInterface::initializeVideo(
videoStreamOptions_ = videoStreamOptions;
resizedOutputDims_ = resizedOutputDims;

// We can only use swscale when we have a single resize transform. Note that
// this means swscale will not support the case of having several,
// back-to-base resizes. There's no strong reason to even do that, but if
// someone does, it's more correct to implement that with filtergraph.
// We can use swscale when we have a single resize transform.
// With a single resize, we use swscale twice:
// first for color conversion (YUV->RGB24), then for resize in RGB24 space.
//
// Note that this means swscale will not support the case of having several,
// back-to-back resizes or other transforms.
//
// We calculate this value during initilization but we don't refer to it until
// getColorConversionLibrary() is called. Calculating this value during
// We calculate this value during initialization but we don't refer to it
// until getColorConversionLibrary() is called. Calculating this value during
// initialization saves us from having to save all of the transforms.
areTransformsSwScaleCompatible_ = transforms.empty();
areTransformsSwScaleCompatible_ = transforms.empty() ||
(transforms.size() == 1 && transforms[0]->isResize());

// Note that we do not expose this capability in the public API, only through
// the core API.
Expand All @@ -57,6 +60,16 @@ void CpuDeviceInterface::initializeVideo(
userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
ColorConversionLibrary::SWSCALE;

// We can only use swscale when we have a single resize transform. Note that
// we actually decide on whether or not to actually use swscale at the last
// possible moment, when we actually convert the frame. This is because we
// need to know the actual frame dimensions.
if (transforms.size() == 1 && transforms[0]->isResize()) {
auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!");
swsFlags_ = resize->getSwsFlags();
}

// If we have any transforms, replace filters_ with the filter strings from
// the transforms. As noted above, we decide between swscale and filtergraph
// when we actually decode a frame.
Expand Down Expand Up @@ -238,10 +251,9 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

TORCH_CHECK(
avFrame->height == outputDims.height &&
avFrame->width == outputDims.width,
"Input dimensions are not equal to output dimensions; resize for sws_scale() is not yet supported.");
bool needsResize =
(avFrame->height != outputDims.height ||
avFrame->width != outputDims.width);

// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
Expand All @@ -254,8 +266,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
avFrame->width,
avFrame->height,
frameFormat,
outputDims.width,
outputDims.height);
needsResize ? avFrame->width : outputDims.width,
needsResize ? avFrame->height : outputDims.height);

if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
swsContext_ = createSwsContext(
Expand All @@ -266,25 +278,86 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
// pixel format.
/*outputFormat=*/AV_PIX_FMT_RGB24,

// We don't set any flags because we don't yet use sw_scale() for
// resizing.
// No flags for color conversion. When resizing is needed, we use a
// separate swscale context with the appropriate resize flags.
/*swsFlags=*/0);
prevSwsFrameContext_ = swsFrameContext;
}

uint8_t* pointers[4] = {
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int expectedOutputWidth = outputTensor.sizes()[1];
int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};
int resultHeight = sws_scale(
// When resizing is needed, we do sws_scale twice: first convert to RGB24 at
// original resolution, then resize in RGB24 space. This ensures transforms
// happen in the output color space (RGB24) rather than the input color space
// (YUV).
//
// When no resize is needed, we do color conversion directly into the output
// tensor.

torch::Tensor colorConvertedTensor = needsResize
? allocateEmptyHWCTensor(
FrameDims(avFrame->height, avFrame->width), torch::kCPU)
: outputTensor;

uint8_t* colorConvertedPointers[4] = {
colorConvertedTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int colorConvertedWidth = colorConvertedTensor.sizes()[1];
int colorConvertedLinesizes[4] = {colorConvertedWidth * 3, 0, 0, 0};

int colorConvertedHeight = sws_scale(
swsContext_.get(),
avFrame->data,
avFrame->linesize,
0,
avFrame->height,
pointers,
linesizes);
return resultHeight;
colorConvertedPointers,
colorConvertedLinesizes);

TORCH_CHECK(
colorConvertedHeight == avFrame->height,
"Color conversion swscale pass failed: colorConvertedHeight != avFrame->height: ",
colorConvertedHeight,
" != ",
avFrame->height);

if (needsResize) {
// Use cached swscale context for resizing, similar to the color conversion
// context caching above.
SwsFrameContext resizeSwsFrameContext(
avFrame->width,
avFrame->height,
AV_PIX_FMT_RGB24,
outputDims.width,
outputDims.height);

if (!resizeSwsContext_ ||
prevResizeSwsFrameContext_ != resizeSwsFrameContext) {
resizeSwsContext_ = createSwsContext(
resizeSwsFrameContext,
AVCOL_SPC_RGB,
/*outputFormat=*/AV_PIX_FMT_RGB24,
/*swsFlags=*/swsFlags_);
prevResizeSwsFrameContext_ = resizeSwsFrameContext;
}

uint8_t* srcPointers[4] = {
colorConvertedTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int srcLinesizes[4] = {avFrame->width * 3, 0, 0, 0};

uint8_t* dstPointers[4] = {
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int expectedOutputWidth = outputTensor.sizes()[1];
int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};

colorConvertedHeight = sws_scale(
resizeSwsContext_.get(),
srcPointers,
srcLinesizes,
0,
avFrame->height,
dstPointers,
dstLinesizes);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simplify this code by structuring it, roughly, as:

// do color conversion
result = sws_scale(/* color conversion params */);
if (needsResize) {
  result = sws_scale(/* resize params */);
}

Or is there something I'm missing that makes the above structure difficult to do?


return colorConvertedHeight;
}

torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
Expand Down
11 changes: 11 additions & 0 deletions src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ class CpuDeviceInterface : public DeviceInterface {
UniqueSwsContext swsContext_;
SwsFrameContext prevSwsFrameContext_;

// Cached swscale context for resizing in RGB24 space (used in double swscale
// path). Like the color conversion context above, we cache this to avoid
// recreating it for every frame.
UniqueSwsContext resizeSwsContext_;
SwsFrameContext prevResizeSwsFrameContext_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably refactor the sws logic into its own class, similar to what we did for filtergraph. But let's tackle that as a follow-up PR to this functionality.


// We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline
// of what FFmpeg calls "filters" to apply to decoded frames before returning
// them. In the PyTorch ecosystem, we call these "transforms". During
Expand All @@ -120,6 +126,11 @@ class CpuDeviceInterface : public DeviceInterface {
bool areTransformsSwScaleCompatible_;
bool userRequestedSwScale_;

// The flags we supply to the resize swscale context. The flags control the
// resizing algorithm. We default to bilinear. Users can override this with a
// ResizeTransform that specifies a different interpolation mode.
int swsFlags_ = SWS_BILINEAR;

bool initialized_ = false;

// Audio-specific members
Expand Down
20 changes: 20 additions & 0 deletions src/torchcodec/_core/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ std::string toFilterGraphInterpolation(
}
}

int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
switch (mode) {
case ResizeTransform::InterpolationMode::BILINEAR:
return SWS_BILINEAR;
default:
TORCH_CHECK(
false,
"Unknown interpolation mode: " +
std::to_string(static_cast<int>(mode)));
}
}

} // namespace

std::string ResizeTransform::getFilterGraphCpu() const {
Expand All @@ -37,6 +49,14 @@ std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
return outputDims_;
}

bool ResizeTransform::isResize() const {
return true;
}

int ResizeTransform::getSwsFlags() const {
return toSwsInterpolation(interpolationMode_);
}

CropTransform::CropTransform(const FrameDims& dims) : outputDims_(dims) {}

CropTransform::CropTransform(const FrameDims& dims, int x, int y)
Expand Down
9 changes: 9 additions & 0 deletions src/torchcodec/_core/Transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class Transform {
return std::nullopt;
}

// The ResizeTransform is special because it is the only transform
// that swscale can handle.
virtual bool isResize() const {
return false;
}

// The validity of some transforms depends on the characteristics of the
// AVStream they're being applied to. For example, some transforms will
// specify coordinates inside a frame, we need to validate that those are
Expand All @@ -51,6 +57,9 @@ class ResizeTransform : public Transform {

std::string getFilterGraphCpu() const override;
std::optional<FrameDims> getOutputFrameDims() const override;
bool isResize() const override;

int getSwsFlags() const;

private:
FrameDims outputDims_;
Expand Down
Loading