@@ -726,9 +726,11 @@ def decode(self, *, device: Any = None, dtype: Any = None) -> torch.Tensor:
726726
727727 Keyword Args:
728728 device (torch.device or str, optional): output device for the decoded
729- frames, overriding ``out_device``. A CUDA device uses GPU (NVDEC)
730- decoding when the torchcodec build supports it, and otherwise decodes
731- on CPU and moves the frames to the device.
729+ frames, overriding ``out_device``. Defaults to ``out_device`` if set,
730+ else the reference's own device (so ``td.to("cuda")`` on a tensordict
731+ holding the reference decodes onto CUDA), else CPU. A CUDA device uses
732+ GPU (NVDEC) decoding when the torchcodec build supports it, and
733+ otherwise decodes on CPU and moves the frames to the device.
732734 dtype (torch.dtype, optional): dtype for the decoded frames, overriding
733735 ``out_dtype``. Defaults to ``uint8``.
734736
@@ -739,6 +741,11 @@ def decode(self, *, device: Any = None, dtype: Any = None) -> torch.Tensor:
739741 raise ModuleNotFoundError (_TORCHCODEC_ERROR )
740742 if device is None :
741743 device = _first (self .out_device )
744+ if device is None :
745+ # Fall back to the reference's own (container) device, so moving the
746+ # tensordict that holds it -- e.g. ``td.to("cuda")`` -- makes
747+ # ``decode()`` materialize frames on that device.
748+ device = self .device
742749 if dtype is None :
743750 dtype = _first (self .out_dtype )
744751 stream = _first (self .stream )
0 commit comments