Open
Description
Bug description
I have a sharded checkpoint which was saved via trainer.save_checkpoint("/path/to/cp/dir/", weights_only=False
which I am trying to load during test via trainer.test(dataloaders=test_dataloader, ckpt_path="/path/to/cp/dir")
but the loading part bugs out.
Here is a partial stack trace:
File "virtualenv/lib/python3.10/site-packages/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 226, in read_data
all_reads = storage_reader.read_data(final_local_plan, planner)
File "/tmp/ray/session_2024-10-23_03-46-10_056159_1/runtime_resources/pip/191a23b956de22e15c5c2fb65461774803fa6c56/virtualenv/lib/python3.10/site-packages/site-packages/torch/distributed/checkpoint/filesystem.py", line 665, in read_data
target_tensor.size() == tensor.size()
AssertionError: req MetadataIndex(fqn='model.model.model.embed_tokens.weight', offset=torch.Size([0, 0]), index=0) mismatch sizes torch.Size([1571]) vs torch.Size([1571, 3072]
I was able to load the checkpoint into the model state dict via a notebook without facing an exception.
I have also manually inspected the metadata file and it seems to be in order. The sizes of the shards are also well within expectations.
Would appreciate any help or tips for further debugging. Thanks!
What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m File "/tmp/ray/session_2024-10-23_03-46-10_056159_1/runtime_resources/pip/191a23b956de22e15c5c2fb65461774803fa6c56/virtualenv/lib/python3.10/site-packages/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 226, in read_data
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m all_reads = storage_reader.read_data(final_local_plan, planner)
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m File "/tmp/ray/session_2024-10-23_03-46-10_056159_1/runtime_resources/pip/191a23b956de22e15c5c2fb65461774803fa6c56/virtualenv/lib/python3.10/site-packages/site-packages/torch/distributed/checkpoint/filesystem.py", line 665, in read_data
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m target_tensor.size() == tensor.size()
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m AssertionError: req MetadataIndex(fqn='model.model.model.layers.26.self_attn.v_proj.weight', offset=torch.Size([0, 0]), index=0) mismatch sizes torch.Size([32]) vs torch.Size([32, 3072])
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m Traceback (most recent call last): (RANK 31)
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m File "/tmp/ray/session_2024-10-23_03-46-10_056159_1/runtime_resources/pip/191a23b956de22e15c5c2fb65461774803fa6c56/virtualenv/lib/python3.10/site-packages/site-packages/torch/distributed/checkpoint/utils.py", line 249, in all_gather
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m result = map_fun()
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m File "/tmp/ray/session_2024-10-23_03-46-10_056159_1/runtime_resources/pip/191a23b956de22e15c5c2fb65461774803fa6c56/virtualenv/lib/python3.10/site-packages/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m result = func(*args, **kwargs)
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m File "/tmp/ray/session_2024-10-23_03-46-10_056159_1/runtime_resources/pip/191a23b956de22e15c5c2fb65461774803fa6c56/virtualenv/lib/python3.10/site-packages/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 226, in read_data
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m all_reads = storage_reader.read_data(final_local_plan, planner)
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m File "/tmp/ray/session_2024-10-23_03-46-10_056159_1/runtime_resources/pip/191a23b956de22e15c5c2fb65461774803fa6c56/virtualenv/lib/python3.10/site-packages/site-packages/torch/distributed/checkpoint/filesystem.py", line 665, in read_data
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m target_tensor.size() == tensor.size()
�[36m(RayTrainWorker pid=2657016, ip=192.168.202.113)�[0m AssertionError: req MetadataIndex(fqn='model.model.model.layers.27.self_attn.q_proj.weight', offset=torch.Size([0, 0]), index=0) mismatch sizes torch.Size([96]) vs torch.Size([96, 3072])
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA H100 80GB HBM3
- NVIDIA H100 80GB HBM3
- available: True
- version: 12.1 - Lightning:
- torch: 2.4.0 - Packages:
- appdirs: 1.4.4
- argon2-cffi: 21.1.0
- attrs: 21.2.0
- babel: 2.8.0
- backcall: 0.2.0
- beautifulsoup4: 4.10.0
- beniget: 0.4.1
- bleach: 4.1.0
- blinker: 1.4
- bluorion: 0.0.0
- brotli: 1.0.9
- chardet: 4.0.0
- cryptography: 3.4.8
- cycler: 0.11.0
- dbus-python: 1.2.18
- decorator: 4.4.2
- defusedxml: 0.7.1
- distro: 1.7.0
- distro-info: 1.1+ubuntu0.2
- entrypoints: 0.4
- filelock: 3.16.1
- fonttools: 4.29.1
- fs: 2.4.12
- fsspec: 2024.10.0
- gast: 0.5.2
- git-crecord: 20201025.0
- html5lib: 1.1
- httplib2: 0.20.2
- importlib-metadata: 4.6.4
- ipykernel: 6.7.0
- ipython: 7.31.1
- ipython-genutils: 0.2.0
- ipywidgets: 6.0.0
- jedi: 0.18.0
- jeepney: 0.7.1
- jinja2: 3.0.3
- jsonschema: 3.2.0
- jupyter-client: 7.1.2
- jupyter-core: 4.9.1
- jupyterlab-pygments: 0.1.2
- keyring: 23.5.0
- kiwisolver: 1.3.2
- launchpadlib: 1.10.16
- lazr.restfulclient: 0.14.4
- lazr.uri: 1.0.6
- lxml: 4.8.0
- lz4: 3.1.3+dfsg
- markupsafe: 2.0.1
- matplotlib: 3.5.1
- matplotlib-inline: 0.1.3
- more-itertools: 8.10.0
- mpmath: 0.0.0
- nbclient: 0.5.6
- nbconvert: 6.4.0
- nbformat: 5.1.3
- nest-asyncio: 1.5.4
- networkx: 3.4.2
- notebook: 6.4.8
- numpy: 1.21.5
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.6.77
- nvidia-nvtx-cu12: 12.1.105
- oauthlib: 3.2.0
- olefile: 0.46
- packaging: 21.3
- pandocfilters: 1.5.0
- parso: 0.8.1
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 9.0.1
- pip: 22.0.2
- ply: 3.11
- prometheus-client: 0.9.0
- prompt-toolkit: 3.0.28
- ptyprocess: 0.7.0
- py: 1.10.0
- pygments: 2.11.2
- pygobject: 3.42.1
- pyjwt: 2.3.0
- pyparsing: 2.4.7
- pyrsistent: 0.18.1
- python-apt: 2.4.0+ubuntu4
- python-dateutil: 2.8.1
- pythran: 0.10.0
- pytz: 2022.1
- pyyaml: 5.4.1
- pyzmq: 22.3.0
- scipy: 1.8.0
- secretstorage: 3.3.1
- send2trash: 1.8.1b0
- setuptools: 59.6.0
- six: 1.16.0
- soupsieve: 2.3.1
- sympy: 1.9
- terminado: 0.13.1
- testpath: 0.5.0
- torch: 2.4.0
- tornado: 6.1
- traitlets: 5.1.1
- triton: 3.0.0
- typing-extensions: 4.12.2
- ufolib2: 0.13.1
- unattended-upgrades: 0.1
- unicodedata2: 14.0.0
- wadllib: 1.3.6
- wcwidth: 0.2.5
- webencodings: 0.5.1
- wheel: 0.37.1
- widgetsnbextension: 2.0.0
- zipp: 1.0.0 - System:
- OS: Linux
- architecture:
- 64bit
-
- processor: x86_64
- python: 3.10.12
- release: 5.15.0-122-generic
- version: elaborate on the correlation between overfit_pct and xxx_percent_check #132-Ubuntu SMP Thu Aug 29 13:45:52 UTC 2024
More info
I am using Ray for orchestration.
cc @awaelchli