Skip to content

[BUG] TensorDict with dynamic, input-dependent batch_size is not torch.export.exportable #1003

Open
@egaznep

Description

@egaznep

Describe the bug

To Reproduce

import torch
import tensordict

class Test(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor):
        return tensordict.TensorDict({
            'x': x,
            'y': y,
            },
            batch_size=x.shape[0] # comment this line out and it works, but batch_size = [] and not x.shape[0]
         )
    
test = Test()
result = torch.export.export(
    test, 
    args=(torch.zeros(2,100), torch.zeros(2,100)),
    strict=False,
    dynamic_shapes={
        'x': {0: torch.export.Dim('batch'), 1: torch.export.Dim('time')},
        'y': {0: torch.export.Dim('batch'), 1: torch.export.Dim('time')}
        }
    )
print(result.module()(torch.zeros(5,100), torch.zeros(5,100)))
(myenv) egaznep@...@volta: $ python scripts/mwe.py                                                                                                                                  
Traceback (most recent call last):                                                                                                                                                                                              
  File "./myenv/lib/python3.10/site-packages/tensordict/_td.py", line 2082, in _parse_batch_size                                                                                               
    return torch.Size(batch_size)                                                                                                                                                                                               
TypeError: 'SymInt' object is not iterable

During handling of the above exception, another exception occurred:

Traceback (without unrelated parts):
  File "./scripts/mwe.py", line 6, in forward
    return tensordict.TensorDict({
  File "./myenv/lib/python3.10/site-packages/tensordict/_td.py", line 285, in __init__
    self._batch_size = self._parse_batch_size(source, batch_size)
  File "./myenv/lib/python3.10/site-packages/tensordict/_td.py", line 2090, in _parse_batch_size
    raise ValueError(ERR)
ValueError: batch size was not specified when creating the TensorDict instance and it could not be retrieved from source.

Expected behavior

I would expect the module to be successfully converted into an ExportedProgram.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
2024.09.19 2.0.2 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] linux 2.6.0.dev20240919

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions