[BUG] Running copy-related operations in instances of used defined TensordDict
subclasses, returns a TensordDict
object #1184
Open
Description
Describe the bug
I have created a TensorDict
subclass named AutoTensorDict
that fits my use case. However, trying to run .to()
, .clone()
, and other copy-related operations to instances of this class returns a TensorDict
object and not an AutoTensorDict
object.
To Reproduce
Here's an example to reproduce this:
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import torch
from tensordict.tensordict import TensorDict
if TYPE_CHECKING:
from collections.abc import Sequence
from tensordict._nestedkey import NestedKey
from tensordict.base import CompatibleType, T
from tensordict.utils import DeviceType, IndexType
from torch import Size
class AutoTensorDict(TensorDict):
def __init__(
self,
source: T | dict[NestedKey, CompatibleType] = None,
batch_size: Sequence[int] | Size | int | None = None,
device: DeviceType | None = None,
names: Sequence[str] | None = None,
non_blocking: bool | None = None,
lock: bool = False,
**kwargs: dict[str, Any] | None,
) -> None:
super().__init__(source, batch_size, device, names, non_blocking, lock, **kwargs)
self.auto_batch_size_(1)
if self.device is None:
self.auto_device_()
def __setitem__(self, key: IndexType, value: Any) -> None:
super().__setitem__(key, value)
if self.device is None:
self.auto_device_()
if not self.batch_size:
self.auto_batch_size_(1)
if __name__ == "__main__":
tt = AutoTensorDict()
tt["a"] = torch.rand(3, 4)
print(tt.to("mps"))
This snippet prints:
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 4]), device=mps:0, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=mps,
is_shared=False)
Expected behavior
I would expect an instance of the newly defined class to be returned.
Reason and Possible fixes
At first, I thought this was related to the function to_tensordict
being called around inside TensorDict
and TensorDictBase
. However, I tried to monkey-patch it but nothing came from this.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)