-
Notifications
You must be signed in to change notification settings - Fork 78
Benchmark swscale for resize transforms #1130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,15 +39,18 @@ void CpuDeviceInterface::initializeVideo( | |
| videoStreamOptions_ = videoStreamOptions; | ||
| resizedOutputDims_ = resizedOutputDims; | ||
|
|
||
| // We can only use swscale when we have a single resize transform. Note that | ||
| // this means swscale will not support the case of having several, | ||
| // back-to-base resizes. There's no strong reason to even do that, but if | ||
| // someone does, it's more correct to implement that with filtergraph. | ||
| // We can use swscale when we have a single resize transform. | ||
| // With a single resize, we use swscale twice: | ||
| // first for color conversion (YUV->RGB24), then for resize in RGB24 space. | ||
| // | ||
| // Note that this means swscale will not support the case of having several, | ||
| // back-to-back resizes or other transforms. | ||
| // | ||
| // We calculate this value during initilization but we don't refer to it until | ||
| // getColorConversionLibrary() is called. Calculating this value during | ||
| // We calculate this value during initialization but we don't refer to it | ||
| // until getColorConversionLibrary() is called. Calculating this value during | ||
| // initialization saves us from having to save all of the transforms. | ||
| areTransformsSwScaleCompatible_ = transforms.empty(); | ||
| areTransformsSwScaleCompatible_ = transforms.empty() || | ||
| (transforms.size() == 1 && transforms[0]->isResize()); | ||
|
|
||
| // Note that we do not expose this capability in the public API, only through | ||
| // the core API. | ||
|
|
@@ -57,6 +60,16 @@ void CpuDeviceInterface::initializeVideo( | |
| userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary == | ||
| ColorConversionLibrary::SWSCALE; | ||
|
|
||
| // We can only use swscale when we have a single resize transform. Note that | ||
| // we actually decide on whether or not to actually use swscale at the last | ||
| // possible moment, when we actually convert the frame. This is because we | ||
| // need to know the actual frame dimensions. | ||
| if (transforms.size() == 1 && transforms[0]->isResize()) { | ||
| auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get()); | ||
| TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!"); | ||
| swsFlags_ = resize->getSwsFlags(); | ||
| } | ||
|
|
||
| // If we have any transforms, replace filters_ with the filter strings from | ||
| // the transforms. As noted above, we decide between swscale and filtergraph | ||
| // when we actually decode a frame. | ||
|
|
@@ -238,10 +251,9 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( | |
| enum AVPixelFormat frameFormat = | ||
| static_cast<enum AVPixelFormat>(avFrame->format); | ||
|
|
||
| TORCH_CHECK( | ||
| avFrame->height == outputDims.height && | ||
| avFrame->width == outputDims.width, | ||
| "Input dimensions are not equal to output dimensions; resize for sws_scale() is not yet supported."); | ||
| bool needsResize = | ||
| (avFrame->height != outputDims.height || | ||
| avFrame->width != outputDims.width); | ||
|
|
||
| // We need to compare the current frame context with our previous frame | ||
| // context. If they are different, then we need to re-create our colorspace | ||
|
|
@@ -254,8 +266,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( | |
| avFrame->width, | ||
| avFrame->height, | ||
| frameFormat, | ||
| outputDims.width, | ||
| outputDims.height); | ||
| needsResize ? avFrame->width : outputDims.width, | ||
| needsResize ? avFrame->height : outputDims.height); | ||
|
|
||
| if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { | ||
| swsContext_ = createSwsContext( | ||
|
|
@@ -266,25 +278,97 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( | |
| // pixel format. | ||
| /*outputFormat=*/AV_PIX_FMT_RGB24, | ||
|
|
||
| // We don't set any flags because we don't yet use sw_scale() for | ||
| // resizing. | ||
| // No flags for color conversion. When resizing is needed, we use a | ||
| // separate swscale context with the appropriate resize flags. | ||
| /*swsFlags=*/0); | ||
| prevSwsFrameContext_ = swsFrameContext; | ||
| } | ||
|
|
||
| uint8_t* pointers[4] = { | ||
| outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr}; | ||
| int expectedOutputWidth = outputTensor.sizes()[1]; | ||
| int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; | ||
| int resultHeight = sws_scale( | ||
| swsContext_.get(), | ||
| avFrame->data, | ||
| avFrame->linesize, | ||
| 0, | ||
| avFrame->height, | ||
| pointers, | ||
| linesizes); | ||
| return resultHeight; | ||
| if (needsResize) { | ||
| // Double swscale path: first convert to RGB24 at original resolution, | ||
| // then resize in RGB24 space. This ensures transforms happen in the | ||
| // output color space (RGB24) rather than the input color space (YUV). | ||
|
|
||
| // First pass: color conversion (YUV -> RGB24) at original resolution | ||
| torch::Tensor intermediateTensor = allocateEmptyHWCTensor( | ||
| FrameDims(avFrame->height, avFrame->width), torch::kCPU); | ||
|
|
||
| uint8_t* intermediatePointers[4] = { | ||
| intermediateTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr}; | ||
| int intermediateWidth = avFrame->width; | ||
| int intermediateLinesizes[4] = {intermediateWidth * 3, 0, 0, 0}; | ||
|
|
||
| int intermediateHeight = sws_scale( | ||
| swsContext_.get(), | ||
| avFrame->data, | ||
| avFrame->linesize, | ||
| 0, | ||
| avFrame->height, | ||
| intermediatePointers, | ||
| intermediateLinesizes); | ||
|
|
||
| TORCH_CHECK( | ||
| intermediateHeight == avFrame->height, | ||
| "First swscale pass failed: intermediateHeight != avFrame->height: ", | ||
| intermediateHeight, | ||
| " != ", | ||
| avFrame->height); | ||
|
|
||
| // Second pass: resize in RGB24 space | ||
| // Use cached swscale context for resizing, similar to the color conversion | ||
| // context caching above. | ||
| SwsFrameContext resizeSwsFrameContext( | ||
| avFrame->width, | ||
| avFrame->height, | ||
| AV_PIX_FMT_RGB24, | ||
| outputDims.width, | ||
| outputDims.height); | ||
|
|
||
| if (!resizeSwsContext_ || | ||
| prevResizeSwsFrameContext_ != resizeSwsFrameContext) { | ||
| resizeSwsContext_ = createSwsContext( | ||
| resizeSwsFrameContext, | ||
| AVCOL_SPC_RGB, | ||
| /*outputFormat=*/AV_PIX_FMT_RGB24, | ||
| /*swsFlags=*/swsFlags_); | ||
| prevResizeSwsFrameContext_ = resizeSwsFrameContext; | ||
| } | ||
|
|
||
| uint8_t* srcPointers[4] = { | ||
| intermediateTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr}; | ||
| int srcLinesizes[4] = {avFrame->width * 3, 0, 0, 0}; | ||
|
|
||
| uint8_t* dstPointers[4] = { | ||
| outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr}; | ||
| int expectedOutputWidth = outputTensor.sizes()[1]; | ||
| int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; | ||
|
|
||
| int resultHeight = sws_scale( | ||
| resizeSwsContext_.get(), | ||
| srcPointers, | ||
| srcLinesizes, | ||
| 0, | ||
| avFrame->height, | ||
| dstPointers, | ||
| dstLinesizes); | ||
|
|
||
| return resultHeight; | ||
| } else { | ||
| // No resize needed, just color conversion | ||
| uint8_t* pointers[4] = { | ||
| outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr}; | ||
| int expectedOutputWidth = outputTensor.sizes()[1]; | ||
| int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; | ||
| int resultHeight = sws_scale( | ||
| swsContext_.get(), | ||
| avFrame->data, | ||
| avFrame->linesize, | ||
| 0, | ||
| avFrame->height, | ||
| pointers, | ||
| linesizes); | ||
| return resultHeight; | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can simplify this code by structuring it, roughly, as: // do color conversion
result = sws_scale(/* color conversion params */);
if (needsResize) {
result = sws_scale(/* resize params */);
}Or is there something I'm missing that makes the above structure difficult to do? |
||
| } | ||
|
|
||
| torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,6 +100,12 @@ class CpuDeviceInterface : public DeviceInterface { | |
| UniqueSwsContext swsContext_; | ||
| SwsFrameContext prevSwsFrameContext_; | ||
|
|
||
| // Cached swscale context for resizing in RGB24 space (used in double swscale | ||
| // path). Like the color conversion context above, we cache this to avoid | ||
| // recreating it for every frame. | ||
| UniqueSwsContext resizeSwsContext_; | ||
| SwsFrameContext prevResizeSwsFrameContext_; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably refactor the sws logic into its own class, similar to what we did for filtergraph. But let's tackle that as a follow-up PR to this functionality. |
||
|
|
||
| // We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline | ||
| // of what FFmpeg calls "filters" to apply to decoded frames before returning | ||
| // them. In the PyTorch ecosystem, we call these "transforms". During | ||
|
|
@@ -119,6 +125,11 @@ class CpuDeviceInterface : public DeviceInterface { | |
| bool areTransformsSwScaleCompatible_; | ||
| bool userRequestedSwScale_; | ||
|
|
||
| // The flags we supply to the resize swscale context. The flags control the | ||
| // resizing algorithm. We default to bilinear. Users can override this with a | ||
| // ResizeTransform that specifies a different interpolation mode. | ||
| int swsFlags_ = SWS_BILINEAR; | ||
|
|
||
| bool initialized_ = false; | ||
|
|
||
| // Audio-specific members | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using core api calls here to make it easier to benchmark the two code paths against each other. Will modify if we decide to switch to
sw_scale