Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions benchmarks/decoders/benchmark_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from torch import Tensor
from torchcodec._core import _add_video_stream, create_from_file, get_frames_by_pts
from torchcodec.decoders import VideoDecoder
from torchvision.transforms import v2

Expand Down Expand Up @@ -93,6 +94,22 @@ def decoder_crop(
return transformed_frames


def decoder_resize_swscale(
path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int
) -> Tensor:
height, width = dims
decoder = create_from_file(str(path), seek_mode="approximate")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using core api calls here to make it easier to benchmark the two code paths against each other. Will modify if we decide to switch to sw_scale

_add_video_stream(
decoder,
stream_index=None,
num_threads=num_threads,
transform_specs=f"resize, {height}, {width}",
color_conversion_library="swscale",
)
frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
return frames


def main():
parser = ArgumentParser()
parser.add_argument("--path", type=str, help="path to file", required=True)
Expand Down Expand Up @@ -120,6 +137,11 @@ def main():
type=float,
default=[0.5, 0.25, 0.125],
)
parser.add_argument(
"--benchmark-swscale-vs-filtergraph",
action="store_true",
help="Run swscale vs filtergraph benchmarks",
)

args = parser.parse_args()
path = Path(args.path)
Expand All @@ -133,6 +155,53 @@ def main():

input_height = metadata.height
input_width = metadata.width

if args.benchmark_swscale_vs_filtergraph:
print("\n" + "=" * 70)
print("SWSCALE VS FILTERGRAPH BENCHMARKS")
print("=" * 70)

for num_fraction in args.total_frame_fractions:
num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction)
print(
f"\nSampling {num_fraction * 100}%, {num_frames_to_sample}, of {metadata.num_frames} frames"
)

# Generate timestamps for decoder API
uniform_timestamps = [
i * duration / num_frames_to_sample for i in range(num_frames_to_sample)
]

# Resize at different dimensions
for dims_fraction in args.input_dimension_fractions:
dims = (
int(input_height * dims_fraction),
int(input_width * dims_fraction),
)

print(f"\n--- Resize to {dims} ---")
times = bench(
decoder_resize_swscale,
path,
uniform_timestamps,
dims,
args.num_threads,
num_exp=args.num_exp,
)
report_stats(times, prefix=f"swscale resize{dims}")

times = bench(
decoder_resize,
path,
uniform_timestamps,
dims,
args.num_threads,
num_exp=args.num_exp,
)
report_stats(times, prefix=f"filtergraph resize{dims}")

return

for num_fraction in args.total_frame_fractions:
num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction)
print(
Expand Down
140 changes: 112 additions & 28 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ void CpuDeviceInterface::initializeVideo(
videoStreamOptions_ = videoStreamOptions;
resizedOutputDims_ = resizedOutputDims;

// We can only use swscale when we have a single resize transform. Note that
// this means swscale will not support the case of having several,
// back-to-base resizes. There's no strong reason to even do that, but if
// someone does, it's more correct to implement that with filtergraph.
// We can use swscale when we have a single resize transform.
// With a single resize, we use swscale twice:
// first for color conversion (YUV->RGB24), then for resize in RGB24 space.
//
// Note that this means swscale will not support the case of having several,
// back-to-back resizes or other transforms.
//
// We calculate this value during initilization but we don't refer to it until
// getColorConversionLibrary() is called. Calculating this value during
// We calculate this value during initialization but we don't refer to it
// until getColorConversionLibrary() is called. Calculating this value during
// initialization saves us from having to save all of the transforms.
areTransformsSwScaleCompatible_ = transforms.empty();
areTransformsSwScaleCompatible_ = transforms.empty() ||
(transforms.size() == 1 && transforms[0]->isResize());

// Note that we do not expose this capability in the public API, only through
// the core API.
Expand All @@ -57,6 +60,16 @@ void CpuDeviceInterface::initializeVideo(
userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
ColorConversionLibrary::SWSCALE;

// We can only use swscale when we have a single resize transform. Note that
// we actually decide on whether or not to actually use swscale at the last
// possible moment, when we actually convert the frame. This is because we
// need to know the actual frame dimensions.
if (transforms.size() == 1 && transforms[0]->isResize()) {
auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!");
swsFlags_ = resize->getSwsFlags();
}

// If we have any transforms, replace filters_ with the filter strings from
// the transforms. As noted above, we decide between swscale and filtergraph
// when we actually decode a frame.
Expand Down Expand Up @@ -238,10 +251,9 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);

TORCH_CHECK(
avFrame->height == outputDims.height &&
avFrame->width == outputDims.width,
"Input dimensions are not equal to output dimensions; resize for sws_scale() is not yet supported.");
bool needsResize =
(avFrame->height != outputDims.height ||
avFrame->width != outputDims.width);

// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
Expand All @@ -254,8 +266,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
avFrame->width,
avFrame->height,
frameFormat,
outputDims.width,
outputDims.height);
needsResize ? avFrame->width : outputDims.width,
needsResize ? avFrame->height : outputDims.height);

if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
swsContext_ = createSwsContext(
Expand All @@ -266,25 +278,97 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
// pixel format.
/*outputFormat=*/AV_PIX_FMT_RGB24,

// We don't set any flags because we don't yet use sw_scale() for
// resizing.
// No flags for color conversion. When resizing is needed, we use a
// separate swscale context with the appropriate resize flags.
/*swsFlags=*/0);
prevSwsFrameContext_ = swsFrameContext;
}

uint8_t* pointers[4] = {
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int expectedOutputWidth = outputTensor.sizes()[1];
int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};
int resultHeight = sws_scale(
swsContext_.get(),
avFrame->data,
avFrame->linesize,
0,
avFrame->height,
pointers,
linesizes);
return resultHeight;
if (needsResize) {
// Double swscale path: first convert to RGB24 at original resolution,
// then resize in RGB24 space. This ensures transforms happen in the
// output color space (RGB24) rather than the input color space (YUV).

// First pass: color conversion (YUV -> RGB24) at original resolution
torch::Tensor intermediateTensor = allocateEmptyHWCTensor(
FrameDims(avFrame->height, avFrame->width), torch::kCPU);

uint8_t* intermediatePointers[4] = {
intermediateTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int intermediateWidth = avFrame->width;
int intermediateLinesizes[4] = {intermediateWidth * 3, 0, 0, 0};

int intermediateHeight = sws_scale(
swsContext_.get(),
avFrame->data,
avFrame->linesize,
0,
avFrame->height,
intermediatePointers,
intermediateLinesizes);

TORCH_CHECK(
intermediateHeight == avFrame->height,
"First swscale pass failed: intermediateHeight != avFrame->height: ",
intermediateHeight,
" != ",
avFrame->height);

// Second pass: resize in RGB24 space
// Use cached swscale context for resizing, similar to the color conversion
// context caching above.
SwsFrameContext resizeSwsFrameContext(
avFrame->width,
avFrame->height,
AV_PIX_FMT_RGB24,
outputDims.width,
outputDims.height);

if (!resizeSwsContext_ ||
prevResizeSwsFrameContext_ != resizeSwsFrameContext) {
resizeSwsContext_ = createSwsContext(
resizeSwsFrameContext,
AVCOL_SPC_RGB,
/*outputFormat=*/AV_PIX_FMT_RGB24,
/*swsFlags=*/swsFlags_);
prevResizeSwsFrameContext_ = resizeSwsFrameContext;
}

uint8_t* srcPointers[4] = {
intermediateTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int srcLinesizes[4] = {avFrame->width * 3, 0, 0, 0};

uint8_t* dstPointers[4] = {
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int expectedOutputWidth = outputTensor.sizes()[1];
int dstLinesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};

int resultHeight = sws_scale(
resizeSwsContext_.get(),
srcPointers,
srcLinesizes,
0,
avFrame->height,
dstPointers,
dstLinesizes);

return resultHeight;
} else {
// No resize needed, just color conversion
uint8_t* pointers[4] = {
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
int expectedOutputWidth = outputTensor.sizes()[1];
int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};
int resultHeight = sws_scale(
swsContext_.get(),
avFrame->data,
avFrame->linesize,
0,
avFrame->height,
pointers,
linesizes);
return resultHeight;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simplify this code by structuring it, roughly, as:

// do color conversion
result = sws_scale(/* color conversion params */);
if (needsResize) {
  result = sws_scale(/* resize params */);
}

Or is there something I'm missing that makes the above structure difficult to do?

}

torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
Expand Down
11 changes: 11 additions & 0 deletions src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class CpuDeviceInterface : public DeviceInterface {
UniqueSwsContext swsContext_;
SwsFrameContext prevSwsFrameContext_;

// Cached swscale context for resizing in RGB24 space (used in double swscale
// path). Like the color conversion context above, we cache this to avoid
// recreating it for every frame.
UniqueSwsContext resizeSwsContext_;
SwsFrameContext prevResizeSwsFrameContext_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably refactor the sws logic into its own class, similar to what we did for filtergraph. But let's tackle that as a follow-up PR to this functionality.


// We pass these filters to FFmpeg's filtergraph API. It is a simple pipeline
// of what FFmpeg calls "filters" to apply to decoded frames before returning
// them. In the PyTorch ecosystem, we call these "transforms". During
Expand All @@ -119,6 +125,11 @@ class CpuDeviceInterface : public DeviceInterface {
bool areTransformsSwScaleCompatible_;
bool userRequestedSwScale_;

// The flags we supply to the resize swscale context. The flags control the
// resizing algorithm. We default to bilinear. Users can override this with a
// ResizeTransform that specifies a different interpolation mode.
int swsFlags_ = SWS_BILINEAR;

bool initialized_ = false;

// Audio-specific members
Expand Down
20 changes: 20 additions & 0 deletions src/torchcodec/_core/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ std::string toFilterGraphInterpolation(
}
}

int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
switch (mode) {
case ResizeTransform::InterpolationMode::BILINEAR:
return SWS_BILINEAR;
default:
TORCH_CHECK(
false,
"Unknown interpolation mode: " +
std::to_string(static_cast<int>(mode)));
}
}

} // namespace

std::string ResizeTransform::getFilterGraphCpu() const {
Expand All @@ -37,6 +49,14 @@ std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
return outputDims_;
}

bool ResizeTransform::isResize() const {
return true;
}

int ResizeTransform::getSwsFlags() const {
return toSwsInterpolation(interpolationMode_);
}

CropTransform::CropTransform(const FrameDims& dims) : outputDims_(dims) {}

CropTransform::CropTransform(const FrameDims& dims, int x, int y)
Expand Down
9 changes: 9 additions & 0 deletions src/torchcodec/_core/Transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class Transform {
return std::nullopt;
}

// The ResizeTransform is special because it is the only transform
// that swscale can handle.
virtual bool isResize() const {
return false;
}

// The validity of some transforms depends on the characteristics of the
// AVStream they're being applied to. For example, some transforms will
// specify coordinates inside a frame, we need to validate that those are
Expand All @@ -51,6 +57,9 @@ class ResizeTransform : public Transform {

std::string getFilterGraphCpu() const override;
std::optional<FrameDims> getOutputFrameDims() const override;
bool isResize() const override;

int getSwsFlags() const;

private:
FrameDims outputDims_;
Expand Down
Loading
Loading