Closed
Description
A checkpoint is saved from an 8-GPU run with dp_shard
set to 8 and all other parallelisms set to 1. My understanding is that this is configured as an FSDP run.
The checkpoint is resumed from 16 GPUs with dp_shard
now set to 16. When loading the checkpoint, we get this error:
[rank0]: Traceback (most recent call last): (RANK 15) [rank0]: File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 164, in reduce_scatter [rank0]: local_data = map_fun() [rank0]: File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank0]: result = func(*args, **kwargs)
[rank0]: File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 211, in local_step
[rank0]: local_plan = planner.create_local_plan()
[rank0]: File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 233, in create_local_plan
[rank0]: return create_default_local_load_plan(
[rank0]: File "/app/.venv/lib/python3.10/site-packages/torch/distributed/checkpoint/default_planner.py", line 354, in create_default_local_load
[rank0]: raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank0]: RuntimeError: Missing key in checkpoint state_dict: dataloader.dp_rank_15.
My understanding is that torch distributed checkpoints are supposed to support dynamic resharding at load time. Does this not work with torchtitan?
I was able to successfully resume a checkpoint going down from 32 GPUs to 16.
Activity