Open
Description
Describe the bug
To Reproduce
import torch
import tensordict
class Test(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return tensordict.TensorDict({
'x': x,
'y': y,
},
batch_size=x.shape[0] # comment this line out and it works, but batch_size = [] and not x.shape[0]
)
test = Test()
result = torch.export.export(
test,
args=(torch.zeros(2,100), torch.zeros(2,100)),
strict=False,
dynamic_shapes={
'x': {0: torch.export.Dim('batch'), 1: torch.export.Dim('time')},
'y': {0: torch.export.Dim('batch'), 1: torch.export.Dim('time')}
}
)
print(result.module()(torch.zeros(5,100), torch.zeros(5,100)))
(myenv) egaznep@...@volta: $ python scripts/mwe.py
Traceback (most recent call last):
File "./myenv/lib/python3.10/site-packages/tensordict/_td.py", line 2082, in _parse_batch_size
return torch.Size(batch_size)
TypeError: 'SymInt' object is not iterable
During handling of the above exception, another exception occurred:
Traceback (without unrelated parts):
File "./scripts/mwe.py", line 6, in forward
return tensordict.TensorDict({
File "./myenv/lib/python3.10/site-packages/tensordict/_td.py", line 285, in __init__
self._batch_size = self._parse_batch_size(source, batch_size)
File "./myenv/lib/python3.10/site-packages/tensordict/_td.py", line 2090, in _parse_batch_size
raise ValueError(ERR)
ValueError: batch size was not specified when creating the TensorDict instance and it could not be retrieved from source.
Expected behavior
I would expect the module to be successfully converted into an ExportedProgram
.
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...)
- Python version
- Versions of any other relevant libraries
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
2024.09.19 2.0.2 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] linux 2.6.0.dev20240919
Reason and Possible fixes
If you know or suspect the reason for this bug, paste the code lines and suggest modifications.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)