Skip to content

FSDP checkpoints don't load when run is restarted with greater world size #811

Closed
@darkmirage

Description

@darkmirage

A checkpoint is saved from an 8-GPU run with dp_shard set to 8 and all other parallelisms set to 1. My understanding is that this is configured as an FSDP run.

The checkpoint is resumed from 16 GPUs with dp_shard now set to 16. When loading the checkpoint, we get this error:

[rank0]: Traceback (most recent call last): (RANK 15)                                                                                            [rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 164, in reduce_scatter                     [rank0]:     local_data = map_fun()                                                                                                              [rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper                            
[rank0]:     result = func(*args, **kwargs)                                                                                                      
[rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 211, in local_step             
[rank0]:     local_plan = planner.create_local_plan()                                                                                            
[rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 233, in create_local_plan        
[rank0]:     return create_default_local_load_plan(                                                                                              
[rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 354, in create_default_local_load
[rank0]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")                                                                 
[rank0]: RuntimeError: Missing key in checkpoint state_dict: dataloader.dp_rank_15.  

My understanding is that torch distributed checkpoints are supposed to support dynamic resharding at load time. Does this not work with torchtitan?

I was able to successfully resume a checkpoint going down from 32 GPUs to 16.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdocumentationImprovements or additions to documentationenhancementNew feature or requestmodule: fsdp

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions