Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __call__(self, filename):
"sampling.py",
"parallel_decoding.py",
"custom_frame_mappings.py",
"transforms.py",
]
else:
assert "examples/encoding" in self.src_dir
Expand Down
318 changes: 318 additions & 0 deletions examples/decoding/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
=======================================================
Decoder Transforms: Applying transforms during decoding
=======================================================

In this example, we will demonstrate how to use the ``transforms`` parameter of
the :class:`~torchcodec.decoders.VideoDecoder` class. This parameter allows us
to specify a list of :class:`~torchcodec.transforms.DecoderTransform` or
:class:`~torchvision.transforms.v2.Transform` objects. These objects serve as
transform specificiations that the :class:`~torchcodec.decoders.VideoDecoder`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: specifications

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably start asking Claude to do spell check on the comments. 🤔

will apply during the decoding process.
"""

# %%
# First, a bit of boilerplate and definitions that we will use later:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding point 1 in the PR description about the demonstration starting a quarter down the page - we have a pattern of having a link to skip past the boiler plate section, that might help this gap feel smaller.

# %%
# First, a bit of boilerplate: we'll download a video from the web, and define a
# plotting utility. You can ignore that part and jump right below to
# :ref:`sampling_tuto_start`.



import torch
import requests
import tempfile
from pathlib import Path
import shutil
from time import perf_counter_ns


def store_video_to(url: str, local_video_path: Path):
response = requests.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
raise RuntimeError(f"Failed to download video. {response.status_code = }.")

with open(local_video_path, 'wb') as f:
for chunk in response.iter_content():
f.write(chunk)


def plot(frames: torch.Tensor, title : str | None = None):
try:
from torchvision.utils import make_grid
from torchvision.transforms.v2.functional import to_pil_image
import matplotlib.pyplot as plt
except ImportError:
print("Cannot plot, please run `pip install torchvision matplotlib`")
return

plt.rcParams["savefig.bbox"] = "tight"
dpi = 300
fig, ax = plt.subplots(figsize=(800 / dpi, 600 / dpi), dpi=dpi)
ax.imshow(to_pil_image(make_grid(frames)))
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if title is not None:
ax.set_title(title, fontsize=6)
plt.tight_layout()

# %%
# Our example video
# -----------------
#
# We'll download a video from the internet and store it locally. We're
# purposefully retrieving a high resolution video to demonstrate using
# transforms to reduce the dimensions.


# Video source: https://www.pexels.com/video/an-african-penguin-at-the-beach-9140346/
# Author: Taryn Elliott.
url = "https://videos.pexels.com/video-files/9140346/9140346-uhd_3840_2160_25fps.mp4"

temp_dir = tempfile.mkdtemp()
penguin_video_path = Path(temp_dir) / "penguin.mp4"
store_video_to(url, penguin_video_path)

from torchcodec.decoders import VideoDecoder
print(f"Penguin video metadata: {VideoDecoder(penguin_video_path).metadata}")

# %%
# As shown above, the video is 37 seconds long and has a height of 2160 pixels
# and a width of 3840 pixels.
#
# .. note::
#
# The colloquial way to report the dimensions of this video would be as
# 3840x2160; that is, (`width`, `height`). In the PyTorch ecosystem, image
# dimensions are typically expressed as (`height`, `width`). The remainder
# of this tutorial uses the PyTorch convention of (`height`, `width`) to
# specify image dimensions.

# %%
# Applying transforms during pre-processing
# -----------------------------------------
#
# A pre-processing pipeline for videos during training will typically apply a
# set of transforms for three main reasons:
#
# 1. **Normalization**: Videos can have many different lengths, resolutions,
# and frame rates. Normalizing all videos to the same characteristics
# leads to better model performance.
# 2. **Data reduction**: Training on higher resolution frames may lead to better
# model performance, but it will be more expensive both at training and
# inference time. As a consequence, many video pre-processing pipelines reduce
# frame dimensions through resizing and cropping.
# 3. **Variety**: Applying random transforms (flips, crops, perspective shifts)
# to the same frames during training can improve model performance.
#
# Below is a simple example of applying the
# :class:`~torchvision.transforms.v2.Resize` transform to a single frame:

from torchvision.transforms import v2

full_decoder = VideoDecoder(penguin_video_path)
frame = full_decoder[5]
resized_after = v2.Resize(size=(480, 640))(frame)

plot(resized_after, title="Resized to 480x640 after decoding")

# %%
# In the example above, ``full_decoder`` returns a video frame that has the
# dimensions (2160, 3840) which is then resized down to (480, 640). But with the
# ``transforms`` parameter of :class:`~torchcodec.decoders.VideoDecoder` we can
# specify for the resize to happen during decoding:

resize_decoder = VideoDecoder(
penguin_video_path,
transforms=[v2.Resize(size=(480, 640))]
)
resized_during = resize_decoder[5]

plot(resized_during, title="Resized to 480x640 during decoding")

# %%
# TorchCodec's relationship to TorchVision transforms
# -----------------------------------------------------
# Notably, in our examples we are passing in TorchVision
# :class:`~torchvision.transforms.v2.Transform` objects as our transforms. We
# would have gotten equivalent behavior if we had passed in the
# :class:`~torchcodec.transforms.Resize` object that is a part of TorchCodec.
# :class:`~torchcodec.decoders.VideoDecoder` accepts both objects as a matter of
# convenience and to clarify the relationship between the transforms that TorchCodec
# applies and the transforms that TorchVision offers.
#
# Importantly, the two frames are not identical, even though we can see they
# *look* very similar:

abs_diff = (resized_after.float() - resized_during.float()).abs()
(abs_diff == 0).all()

# %%
# But they're close enough that models won't be able to tell a difference:
(abs_diff <= 1).float().mean() >= 0.998

# %%
# While :class:`~torchcodec.decoders.VideoDecoder` accepts TorchVision transforms as
# *specifications*, it is not actually using the TorchVision implementation of these
# transforms. Instead, it is mapping them to equivalent
# `FFmpeg filters <https://ffmpeg.org/ffmpeg-filters.html>`_. That is,
# :class:`torchvision.transforms.v2.Resize` is mapped to
# `scale <https://ffmpeg.org/ffmpeg-filters.html#scale-1>`_ and
# :class:`torchvision.transforms.v2.CenterCrop` is mapped to
# `crop <https://ffmpeg.org/ffmpeg-filters.html#crop>`_.
#
# The relationships we ensure between TorchCodec :class:`~torchcodec.transforms.DecoderTransform` objects
# and TorchVision :class:`~torchvision.transforms.v2.Transform` objects are:
#
# 1. The names are the same.
# 2. Default behaviors are the same.
# 3. The parameters for the :class:`~torchcodec.transforms.DecoderTransform` object are a subset of the
# TorchVision :class:`~torchvision.transforms.v2.Transform` object.
# 4. Parameters with the same name control the same behavior and accept a
# subset of the same types.
# 5. The difference between the frames returned by a decoder transform and
# the complementary TorchVision transform are such that a model should
# not be able to tell the difference.
#
# .. note::
#
# We do not encourage *intentionally* mixing usage of TorchCodec's decoder
# transforms and TorchVision transforms. That is, if you use TorchCodec's
# decoder transforms during training, you should also use them during
# inference. And if you decode full frames and apply TorchVision's
# transforms to those fully decoded frames during training, you should also
# do the same during inference. We provide the similarity guarantees to mitigate
# the harm when the two techniques are *unintentionally* mixed.

# %%
# Decoder transform pipelines
# ---------------------------
# So far, we've only provided a single transform to the `transform` parameter to
# :class:`~torchcodec.decoders.VideoDecoder`. But it
# actually accepts a list of transforms, which become a pipeline of transforms.
# The order of the list matters: the first transform in the list will receive
# the originally decoded frame. The output of that transform becomes the input
# to the next transform in the list, and so on.
#
# A simple example:

crop_resize_decoder = VideoDecoder(
penguin_video_path,
transforms = [
v2.Resize(size=(480, 640)),
v2.CenterCrop(size=(315, 220))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It usually makes more sense to first crop and then resize, because resize will then work on a smaller surface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it does, and curiously, it actually makes decoder transforms faster than the TorchVision version now (at least on my dev machine).

Results with the old way:

0:
decoder transforms:    times_med = 1474.17ms +- 79.85
torchvision transform: times_med = 4683.55ms +- 28.71

1:
decoder transforms:    times_med = 18486.50ms +- 165.66
torchvision transform: times_med = 16066.02ms +- 164.19

Results with the new way:

0:
decoder transforms:    times_med = 1352.46ms +- 34.86
torchvision transform: times_med = 4077.44ms +- 45.63

1:
decoder transforms:    times_med = 14771.99ms +- 148.83
torchvision transform: times_med = 16112.88ms +- 62.15

]
)
crop_resized_during = crop_resize_decoder[5]
plot(crop_resized_during, title="Resized to 480x640 during decoding then center cropped")

# %%
# Performance: memory efficiency and speed
# ----------------------------------------
#
# The main motivation for decoder transforms is *memory efficiency*,
# particularly when applying transforms that reduce the size of a frame, such
# as resize and crop. Because the transforms are applied during decoding, the
# full frame is never returned to the Python layer. As a result, there is
# significantly less pressure on the Python gargabe collector.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's another core reason why that's more memory efficient: the decompressed RGB frame is never materialized in its original resolution.

Without decoder-native transform we have:

YUV compressed frame in original res -> RGB decompressed frame in original res -> RGB decompressed frame in final (smaller) res

WIth the decoder-native transform we have:

YUV compressed frame in original res -> RGB decompressed frame in final (smaller) res

i.e. we can skip the "RGB decompressed frame in original res" materialization, which is the most memory-expensive bit.

The garbage collector being less pressure is a consequence of that.

Copy link
Contributor Author

@scotts scotts Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not entirely accurate - we definitely never get the "RGB decompressed frame in original res" in the Python layer, but it exists in FFmpeg. This is because we ensure that the FFmpeg filters get applied in the output color space. So without decoder transforms we have (parenthesis to indicate where it happens, TC or TV):

YUV compressed, original res (TC) -> 
RGB decompressed , original res (TC) -> 
RGB decompressed, smaller res (TV)

With decoder transforms it's:

YUV compressed, original res (TC) -> 
RGB decompressed, original res (TC) -> 
RGB decompressed, smaller res (TC)

So we really do go through the same steps in decoder transforms. That middle step - getting the RGB image in the original resolution - is because of this line:

filters_ = "format=rgb24," + filters.str();

Eliminating the explicit "format=rgb24" does improve performance a lot, but at the cost of similarity with using TorchVision transforms on full frames.

Since the filtergraph inputs and outputs are known statically, I suspect they're able to optimize things and reuse memory. That is, it's possible for them to allocate exactly the memory they need for each step and reuse it every time. But I don't know that's the case. I'll try to say something about all this.

#
# In `benchmarks <https://github.com/meta-pytorch/torchcodec/blob/f6a816190cbcac417338c29d5e6fac99311d054f/benchmarks/decoders/benchmark_transforms.py>`_
# reducing frames from (1080, 1920) down to (135, 240), we have observed a
# reduction in peak resident set size from 4.3 GB to 0.4 MB.
#
# There is sometimes a runtime benefit, but it is dependent on the number of
# threads that the :class:`~torchcodec.decoders.VideoDecoder` tells FFmpeg
# to use. We define the following benchmark function, as well as the functions
# to benchmark:


def bench(f, average_over=3, warmup=1, **f_kwargs):
for _ in range(warmup):
f(**f_kwargs)

times = []
for _ in range(average_over):
start_time = perf_counter_ns()
f(**f_kwargs)
end_time = perf_counter_ns()
times.append(end_time - start_time)

times = torch.tensor(times) * 1e-6 # ns to ms
times_std = times.std().item()
times_med = times.median().item()
return f"{times_med = :.2f}ms +- {times_std:.2f}"


from torchcodec import samplers


def sample_decoder_transforms(num_threads: int):
decoder = VideoDecoder(
penguin_video_path,
transforms = [
v2.Resize(size=(480, 640)),
v2.CenterCrop(size=(315, 220))
],
seek_mode="approximate",
num_ffmpeg_threads=num_threads,
)
transformed_frames = samplers.clips_at_regular_indices(
decoder,
num_clips=1,
num_frames_per_clip=200
)
assert len(transformed_frames.data[0]) == 200


def sample_torchvision_transforms(num_threads: int):
decoder = VideoDecoder(
penguin_video_path,
seek_mode="approximate",
num_ffmpeg_threads=num_threads,
)
frames = samplers.clips_at_regular_indices(
decoder,
num_clips=1,
num_frames_per_clip=200
)
transformed_frames = []
for frame in frames.data[0]:
frame = v2.Resize(size=(480, 640))(frame)
frame = v2.CenterCrop(size=(315, 220))(frame)
transformed_frames.append(frame)
assert len(transformed_frames) == 200

# %%
# When the :class:`~torchcodec.decoders.VideoDecoder` object sets the number of
# FFmpeg threads to 0, that tells FFmpeg to determine how many threads to use
# based on what is available on the current system. In such cases, decoder transforms
# will tend to outperform getting back a full frame and applying TorchVision transforms
# sequentially:


print(f"decoder transforms: {bench(sample_decoder_transforms, num_threads=0)}")
print(f"torchvision transform: {bench(sample_torchvision_transforms, num_threads=0)}")

# %%
# The reason is that FFmpeg is applying the decoder transforms in parallel.
# However, if the number of threads is 1 (as is the default), then there often is no
# runtime benefit to using decoder transforms. Using the TorchVision transforms may
# even be faster!

print(f"decoder transforms: {bench(sample_decoder_transforms, num_threads=1)}")
print(f"torchvision transform: {bench(sample_torchvision_transforms, num_threads=1)}")

# %%
# In brief, our performance guidance is:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth mentioning decoder native transforms in the performance tips docs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mollyxu, yes, absolutely. I'd like to do that in a follow-up PR.

#
# 1. If you are applying a transform pipeline that signficantly reduces
# the dimensions of your input frames and memory efficiency matters, use
# decoder transforms.
# 2. If you are using multiple FFmpeg threads, decoder transforms may be
# faster. Experiment with your setup to verify.
# 3. If you are using a single FFmpeg thread, then decoder transforms may
# be slower. Experiment with your setup to verify.

shutil.rmtree(temp_dir)
# %%
Loading