[BUG] Stacking NonTensorData
does not appear to return a NonTensorStack
#1047
Open
Description
Describe the bug
Hi, please let me know if I'm using this feature incorrectly or if this is well known.
I've been unable to get NonTensorStack
to work in various contexts.
The simplest example I can come up with is this one:
from tensordict import *
a = NonTensorData({})
b = NonTensorData({}, batch_size=[1])
a_stack = NonTensorStack.from_nontensordata(a)
b_stack = NonTensorStack.from_nontensordata(b)
I expected all of these examples to produce a NonTensorStack
, yet only b_stack
appears to produce what I was expecting:
>>> torch.stack((a,a), dim=0)
NonTensorData(data={}, batch_size=torch.Size([2]), device=None)
>>> torch.stack((b,b), dim=0)
NonTensorData(data={}, batch_size=torch.Size([2, 1]), device=None)
>>> torch.stack((a_stack,a_stack), dim=0)
NonTensorData(data={}, batch_size=torch.Size([2]), device=None)
>>> torch.stack((b_stack,b_stack), dim=0)
NonTensorStack(
[[{}], [{}]],
batch_size=torch.Size([2, 1]),
device=None)
I think I'd have hoped to see
torch.stack((a,a), dim=0).data == [{}, {}]
torch.stack((b,b), dim=0).data == [[{}], [{}]]
torch.stack((a_stack,a_stack), dim=0).data == [{}, {}]
This may be a separate issue, but even for the final case that appears to somewhat work...
>>> torch.stack((b_stack,b_stack), dim=0).batch_size
torch.Size([2, 1])
>>> torch.stack((b_stack,b_stack), dim=0)[...,0]
NonTensorStack(
[{}, {}],
batch_size=torch.Size([2]),
device=None)
>>> torch.stack((b_stack,b_stack), dim=0)[0,0]
NonTensorData(data={}, batch_size=torch.Size([]), device=None)
there's still a number of issues that make it unusable for even the most basic use cases...
>>> torch.stack((b_stack,b_stack), dim=0).contiguous()
TensorDict(
fields={
},
batch_size=torch.Size([2, 1]),
device=None,
is_shared=False)
>>> torch.stack((b_stack,b_stack), dim=0).reshape(-1)
TensorDict(
fields={
},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
>>> torch.stack((b_stack,b_stack), dim=0).reshape(2)
TensorDict(
fields={
},
batch_size=torch.Size([2]),
device=None,
>>> torch.stack((b_stack,b_stack), dim=0).squeeze(dim=1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/utils.py", line 1255, in new_func
out = func(_self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/base.py", line 2070, in squeeze
result = self._squeeze(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/_lazy.py", line 2927, in _squeeze
[td.squeeze(dim) for td in self.tensordicts],
^^^^^^^^^^^^^^^
File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/utils.py", line 1257, in new_func
out._last_op = (new_func.__name__, (args, kwargs, _self))
^^^^^^^^^^^^
File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 1062, in wrapper
out = self.set(key, value)
^^^^^^^^^^^^^^^^^^^^
File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 1482, in _set
raise AttributeError(
AttributeError: Cannot set the attribute '_last_op', expected attributes are {'_is_non_tensor', '_metadata', 'data'}.
>>> @tensorclass
... class B:
... b: NonTensorStack
>>> B(b=torch.stack((b_stack,b_stack), dim=0))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 679, in wrapper
key: value.data if is_non_tensor(value) else value
^^^^^^^^^^
File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 3095, in data
raise AttributeError
AttributeError. Did you mean: '_data'?
Thanks!
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)