diff --git a/benchmarks/decoders/benchmark_transforms.py b/benchmarks/decoders/benchmark_transforms.py index 8342ef7f8..01222f403 100644 --- a/benchmarks/decoders/benchmark_transforms.py +++ b/benchmarks/decoders/benchmark_transforms.py @@ -133,6 +133,7 @@ def main(): input_height = metadata.height input_width = metadata.width + for num_fraction in args.total_frame_fractions: num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction) print( diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 70f46b7e4..042591c71 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -39,15 +39,18 @@ void CpuDeviceInterface::initializeVideo( videoStreamOptions_ = videoStreamOptions; resizedOutputDims_ = resizedOutputDims; - // We can only use swscale when we have a single resize transform. Note that - // this means swscale will not support the case of having several, - // back-to-base resizes. There's no strong reason to even do that, but if - // someone does, it's more correct to implement that with filtergraph. + // We can use swscale when we have a single resize transform. + // With a single resize, we use swscale twice: + // first for color conversion (YUV->RGB24), then for resize in RGB24 space. + // + // Note that this means swscale will not support the case of having several, + // back-to-back resizes or other transforms. // - // We calculate this value during initilization but we don't refer to it until - // getColorConversionLibrary() is called. Calculating this value during + // We calculate this value during initialization but we don't refer to it + // until getColorConversionLibrary() is called. Calculating this value during // initialization saves us from having to save all of the transforms. - areTransformsSwScaleCompatible_ = transforms.empty(); + areTransformsSwScaleCompatible_ = transforms.empty() || + (transforms.size() == 1 && transforms[0]->isResize()); // Note that we do not expose this capability in the public API, only through // the core API. @@ -57,6 +60,16 @@ void CpuDeviceInterface::initializeVideo( userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary == ColorConversionLibrary::SWSCALE; + // We can only use swscale when we have a single resize transform. Note that + // we actually decide on whether or not to actually use swscale at the last + // possible moment, when we actually convert the frame. This is because we + // need to know the actual frame dimensions. + if (transforms.size() == 1 && transforms[0]->isResize()) { + auto resize = dynamic_cast(transforms[0].get()); + TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!"); + swsFlags_ = resize->getSwsFlags(); + } + // If we have any transforms, replace filters_ with the filter strings from // the transforms. As noted above, we decide between swscale and filtergraph // when we actually decode a frame. @@ -238,10 +251,9 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( enum AVPixelFormat frameFormat = static_cast(avFrame->format); - TORCH_CHECK( - avFrame->height == outputDims.height && - avFrame->width == outputDims.width, - "Input dimensions are not equal to output dimensions; resize for sws_scale() is not yet supported."); + bool needsResize = + (avFrame->height != outputDims.height || + avFrame->width != outputDims.width); // We need to compare the current frame context with our previous frame // context. If they are different, then we need to re-create our colorspace @@ -254,8 +266,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( avFrame->width, avFrame->height, frameFormat, - outputDims.width, - outputDims.height); + needsResize ? avFrame->width : outputDims.width, + needsResize ? avFrame->height : outputDims.height); if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { swsContext_ = createSwsContext( @@ -266,25 +278,86 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( // pixel format. /*outputFormat=*/AV_PIX_FMT_RGB24, - // We don't set any flags because we don't yet use sw_scale() for - // resizing. + // No flags for color conversion. When resizing is needed, we use a + // separate swscale context with the appropriate resize flags. /*swsFlags=*/0); prevSwsFrameContext_ = swsFrameContext; } - uint8_t* pointers[4] = { - outputTensor.data_ptr(), nullptr, nullptr, nullptr}; - int expectedOutputWidth = outputTensor.sizes()[1]; - int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; - int resultHeight = sws_scale( + // When resizing is needed, we do sws_scale twice: first convert to RGB24 at + // original resolution, then resize in RGB24 space. This ensures transforms + // happen in the output color space (RGB24) rather than the input color space + // (YUV). + // + // When no resize is needed, we do color conversion directly into the output + // tensor. + + torch::Tensor colorConvertedTensor = needsResize + ? allocateEmptyHWCTensor( + FrameDims(avFrame->height, avFrame->width), torch::kCPU) + : outputTensor; + + uint8_t* colorConvertedPointers[4] = { + colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; + int colorConvertedWidth = colorConvertedTensor.sizes()[1]; + int colorConvertedLinesizes[4] = {colorConvertedWidth * 3, 0, 0, 0}; + + int colorConvertedHeight = sws_scale( swsContext_.get(), avFrame->data, avFrame->linesize, 0, avFrame->height, - pointers, - linesizes); - return resultHeight; + colorConvertedPointers, + colorConvertedLinesizes); + + TORCH_CHECK( + colorConvertedHeight == avFrame->height, + "Color conversion swscale pass failed: colorConvertedHeight != avFrame->height: ", + colorConvertedHeight, + " != ", + avFrame->height); + + if (needsResize) { + // Use cached swscale context for resizing, similar to the color conversion + // context caching above. + SwsFrameContext resizeSwsFrameContext( + avFrame->width, + avFrame->height, + AV_PIX_FMT_RGB24, + outputDims.width, + outputDims.height); + + if (!resizeSwsContext_ || + prevResizeSwsFrameContext_ != resizeSwsFrameContext) { + resizeSwsContext_ = createSwsContext( + resizeSwsFrameContext, + AVCOL_SPC_RGB, + /*outputFormat=*/AV_PIX_FMT_RGB24, + /*swsFlags=*/swsFlags_); + prevResizeSwsFrameContext_ = resizeSwsFrameContext; + } + + uint8_t* srcPointers[4] = { + colorConvertedTensor.data_ptr(), nullptr, nullptr, nullptr}; + int srcLinesizes[4] = {avFrame->width * 3, 0, 0, 0}; + + uint8_t* dstPointers[4] = { + outputTensor.data_ptr(), nullptr, nullptr, nullptr}; + int expectedOutputWidth = outputTensor.sizes()[1]; + int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; + + colorConvertedHeight = sws_scale( + resizeSwsContext_.get(), + srcPointers, + srcLinesizes, + 0, + avFrame->height, + dstPointers, + dstLinesizes); + } + + return colorConvertedHeight; } torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 2a6bceac3..ac853947a 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -101,6 +101,12 @@ class CpuDeviceInterface : public DeviceInterface { UniqueSwsContext swsContext_; SwsFrameContext prevSwsFrameContext_; + // Cached swscale context for resizing in RGB24 space (used in double swscale + // path). Like the color conversion context above, we cache this to avoid + // recreating it for every frame. + UniqueSwsContext resizeSwsContext_; + SwsFrameContext prevResizeSwsFrameContext_; + // We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline // of what FFmpeg calls "filters" to apply to decoded frames before returning // them. In the PyTorch ecosystem, we call these "transforms". During @@ -120,6 +126,11 @@ class CpuDeviceInterface : public DeviceInterface { bool areTransformsSwScaleCompatible_; bool userRequestedSwScale_; + // The flags we supply to the resize swscale context. The flags control the + // resizing algorithm. We default to bilinear. Users can override this with a + // ResizeTransform that specifies a different interpolation mode. + int swsFlags_ = SWS_BILINEAR; + bool initialized_ = false; // Audio-specific members diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp index e89f64cc0..a375ef427 100644 --- a/src/torchcodec/_core/Transform.cpp +++ b/src/torchcodec/_core/Transform.cpp @@ -25,6 +25,18 @@ std::string toFilterGraphInterpolation( } } +int toSwsInterpolation(ResizeTransform::InterpolationMode mode) { + switch (mode) { + case ResizeTransform::InterpolationMode::BILINEAR: + return SWS_BILINEAR; + default: + TORCH_CHECK( + false, + "Unknown interpolation mode: " + + std::to_string(static_cast(mode))); + } +} + } // namespace std::string ResizeTransform::getFilterGraphCpu() const { @@ -37,6 +49,14 @@ std::optional ResizeTransform::getOutputFrameDims() const { return outputDims_; } +bool ResizeTransform::isResize() const { + return true; +} + +int ResizeTransform::getSwsFlags() const { + return toSwsInterpolation(interpolationMode_); +} + CropTransform::CropTransform(const FrameDims& dims) : outputDims_(dims) {} CropTransform::CropTransform(const FrameDims& dims, int x, int y) diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h index 71dad31be..84ebfe17e 100644 --- a/src/torchcodec/_core/Transform.h +++ b/src/torchcodec/_core/Transform.h @@ -29,6 +29,12 @@ class Transform { return std::nullopt; } + // The ResizeTransform is special because it is the only transform + // that swscale can handle. + virtual bool isResize() const { + return false; + } + // The validity of some transforms depends on the characteristics of the // AVStream they're being applied to. For example, some transforms will // specify coordinates inside a frame, we need to validate that those are @@ -51,6 +57,9 @@ class ResizeTransform : public Transform { std::string getFilterGraphCpu() const override; std::optional getOutputFrameDims() const override; + bool isResize() const override; + + int getSwsFlags() const; private: FrameDims outputDims_;