Skip to content

[BUG] Running copy-related operations in instances of used defined TensordDict subclasses, returns a TensordDict object #1184

Open
@alex-bene

Description

Describe the bug

I have created a TensorDict subclass named AutoTensorDict that fits my use case. However, trying to run .to(), .clone(), and other copy-related operations to instances of this class returns a TensorDict object and not an AutoTensorDict object.

To Reproduce

Here's an example to reproduce this:

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
from tensordict.tensordict import TensorDict

if TYPE_CHECKING:
    from collections.abc import Sequence

    from tensordict._nestedkey import NestedKey
    from tensordict.base import CompatibleType, T
    from tensordict.utils import DeviceType, IndexType
    from torch import Size


class AutoTensorDict(TensorDict):
    def __init__(
        self,
        source: T | dict[NestedKey, CompatibleType] = None,
        batch_size: Sequence[int] | Size | int | None = None,
        device: DeviceType | None = None,
        names: Sequence[str] | None = None,
        non_blocking: bool | None = None,
        lock: bool = False,
        **kwargs: dict[str, Any] | None,
    ) -> None:
        super().__init__(source, batch_size, device, names, non_blocking, lock, **kwargs)
        self.auto_batch_size_(1)
        if self.device is None:
            self.auto_device_()

    def __setitem__(self, key: IndexType, value: Any) -> None:
        super().__setitem__(key, value)
        if self.device is None:
            self.auto_device_()
        if not self.batch_size:
            self.auto_batch_size_(1)


if __name__ == "__main__":
    tt = AutoTensorDict()
    tt["a"] = torch.rand(3, 4)
    print(tt.to("mps"))

This snippet prints:

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 4]), device=mps:0, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=mps,
    is_shared=False)

Expected behavior

I would expect an instance of the newly defined class to be returned.

Reason and Possible fixes

At first, I thought this was related to the function to_tensordict being called around inside TensorDict and TensorDictBase. However, I tried to monkey-patch it but nothing came from this.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions