Skip to content

Commit a1a1070

Browse files
authored
[Feature] VideoClipRef.decode: honor the container device (#3839)
1 parent b01bdc4 commit a1a1070

2 files changed

Lines changed: 22 additions & 3 deletions

File tree

test/transforms/test_video_transforms.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,18 @@ def test_transform_device(self, video_path):
572572
)(td)
573573
assert out["pixels"].device.type == "cuda"
574574

575+
def test_tensordict_to_device_decodes_on_device(self, video_path):
576+
# Moving the tensordict that holds the reference makes decode() materialize
577+
# frames on that device (decode falls back to the container device).
578+
td = TensorDict({"frame": VideoClipRef.from_file(video_path)}, batch_size=[20])
579+
td = td.to("cpu")
580+
assert td["frame"].decode().device.type == "cpu"
581+
td = td.to("cuda")
582+
assert td["frame"].decode().device.type == "cuda"
583+
assert td["frame"][3:7].decode().device.type == "cuda"
584+
# an explicit override still wins
585+
assert td["frame"].decode(device="cpu").device.type == "cpu"
586+
575587

576588
if __name__ == "__main__":
577589
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/video.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -726,9 +726,11 @@ def decode(self, *, device: Any = None, dtype: Any = None) -> torch.Tensor:
726726
727727
Keyword Args:
728728
device (torch.device or str, optional): output device for the decoded
729-
frames, overriding ``out_device``. A CUDA device uses GPU (NVDEC)
730-
decoding when the torchcodec build supports it, and otherwise decodes
731-
on CPU and moves the frames to the device.
729+
frames, overriding ``out_device``. Defaults to ``out_device`` if set,
730+
else the reference's own device (so ``td.to("cuda")`` on a tensordict
731+
holding the reference decodes onto CUDA), else CPU. A CUDA device uses
732+
GPU (NVDEC) decoding when the torchcodec build supports it, and
733+
otherwise decodes on CPU and moves the frames to the device.
732734
dtype (torch.dtype, optional): dtype for the decoded frames, overriding
733735
``out_dtype``. Defaults to ``uint8``.
734736
@@ -739,6 +741,11 @@ def decode(self, *, device: Any = None, dtype: Any = None) -> torch.Tensor:
739741
raise ModuleNotFoundError(_TORCHCODEC_ERROR)
740742
if device is None:
741743
device = _first(self.out_device)
744+
if device is None:
745+
# Fall back to the reference's own (container) device, so moving the
746+
# tensordict that holds it -- e.g. ``td.to("cuda")`` -- makes
747+
# ``decode()`` materialize frames on that device.
748+
device = self.device
742749
if dtype is None:
743750
dtype = _first(self.out_dtype)
744751
stream = _first(self.stream)

0 commit comments

Comments
 (0)