Skip to content

FSDP checkpoints don't load when run is restarted with greater world size #811

Closed
@darkmirage

Description

@darkmirage

A checkpoint is saved from an 8-GPU run with dp_shard set to 8 and all other parallelisms set to 1. My understanding is that this is configured as an FSDP run.

The checkpoint is resumed from 16 GPUs with dp_shard now set to 16. When loading the checkpoint, we get this error:

[rank0]: Traceback (most recent call last): (RANK 15)                                                                                            [rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 164, in reduce_scatter                     [rank0]:     local_data = map_fun()                                                                                                              [rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper                            
[rank0]:     result = func(*args, **kwargs)                                                                                                      
[rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 211, in local_step             
[rank0]:     local_plan = planner.create_local_plan()                                                                                            
[rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 233, in create_local_plan        
[rank0]:     return create_default_local_load_plan(                                                                                              
[rank0]:   File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 354, in create_default_local_load
[rank0]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")                                                                 
[rank0]: RuntimeError: Missing key in checkpoint state_dict: dataloader.dp_rank_15.  

My understanding is that torch distributed checkpoints are supposed to support dynamic resharding at load time. Does this not work with torchtitan?

I was able to successfully resume a checkpoint going down from 32 GPUs to 16.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdocumentationImprovements or additions to documentationenhancementNew feature or requestmodule: fsdp

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions