Skip to content

[BUG] MemoryMappedTensor Loading #1051

Open
@suessmann

Description

@suessmann

Describe the bug

I collected a memmap tensordict similar to the guide provided [1] on the cluster in a jupyter notebook. When loading the same memmap on my local machine (with TensorDict.load_memmap(path), I get the error RuntimeError: Could not find name <class '__main__.ImageNetData'>, since I'm not loading the memmap from in __main__. I suspect the issue is in meta.json file of the memmap, where the type is specified as <class '__main__.ImageNetData'>, but I do not run load_memmap(path) from __main__.

To Reproduce

Follow [1] and save the path to memmap. Then create main.py:

from data import Dataset

def main(path):
    data = Dataset(path)

if __name__ == '__main__':
    main('path/to/memmap')

in data.py

from tensordict import MemoryMappedTensor, tensorclass, TensorDict

@tensorclass
class ImageNetData:
    images: torch.Tensor
    targets: torch.Tensor

class Dataset:
    def __init__(path):
        self.data = TensorDict.load_memmap(path)

and you will get

RuntimeError: Could not find name <class '__main__.ImageNetData'>

Expected behavior

A slick load of memmap.

System info

import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)

0.5.0 1.26.4 3.9.19 (main, May 6 2024, 19:43:03)
[GCC 11.2.0] linux 2.4.1+cu121

Reason and Possible fixes

I manually changed meta.json to

{"_type":"<class 'data.ImageNetData'>"}

but it is not the most consistent way. There is also an option to make use of snapshots, but from the example [2] I see that to load a snapshot, one needs to initialize memmap each time, which is super time consuming in my case (my data is >500GB of size).

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

[1] https://pytorch.org/tensordict/main/tutorials/tensorclass_imagenet.html
[2]

def load(cls, dataset, path):

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions