Description
Describe the bug
I collected a memmap tensordict similar to the guide provided [1] on the cluster in a jupyter notebook. When loading the same memmap on my local machine (with TensorDict.load_memmap(path)
, I get the error RuntimeError: Could not find name <class '__main__.ImageNetData'>
, since I'm not loading the memmap from in __main__
. I suspect the issue is in meta.json
file of the memmap, where the type is specified as <class '__main__.ImageNetData'>
, but I do not run load_memmap(path)
from __main__
.
To Reproduce
Follow [1] and save the path to memmap. Then create main.py
:
from data import Dataset
def main(path):
data = Dataset(path)
if __name__ == '__main__':
main('path/to/memmap')
in data.py
from tensordict import MemoryMappedTensor, tensorclass, TensorDict
@tensorclass
class ImageNetData:
images: torch.Tensor
targets: torch.Tensor
class Dataset:
def __init__(path):
self.data = TensorDict.load_memmap(path)
and you will get
RuntimeError: Could not find name <class '__main__.ImageNetData'>
Expected behavior
A slick load of memmap.
System info
import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
0.5.0 1.26.4 3.9.19 (main, May 6 2024, 19:43:03)
[GCC 11.2.0] linux 2.4.1+cu121
Reason and Possible fixes
I manually changed meta.json
to
{"_type":"<class 'data.ImageNetData'>"}
but it is not the most consistent way. There is also an option to make use of snapshots, but from the example [2] I see that to load a snapshot, one needs to initialize memmap each time, which is super time consuming in my case (my data is >500GB of size).
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
[1] https://pytorch.org/tensordict/main/tutorials/tensorclass_imagenet.html
[2]