-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Description
System Info
- `Accelerate` version: 1.12.0
- Platform: Linux-6.6.113+-x86_64-with-glibc2.35
- `accelerate` bash location: /usr/local/bin/accelerate
- Python version: 3.12.12
- Numpy version: 2.0.2
- PyTorch version: 2.9.0+cu126
- PyTorch accelerator: CUDA
- System RAM: 31.35 GB
- GPU type: Tesla T4Information
- The official example scripts
- My own modified scripts
Tasks
- One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - My own task or dataset (give details below)
Reproduction
Bug description
When using FSDP2 (fsdp_version: 2) with SHARDED_STATE_DICT, saving an optimizer checkpoint succeeds, but loading it back fails with:
RuntimeError: Missing key in checkpoint state_dict: optimizer.state.0.step.
Root cause
In load_fsdp_optimizer (fsdp_utils.py:L322), dist_cp.load() is called without a planner argument. When no planner is provided, dist_cp uses its default resolution logic which cannot resolve nested optimizer state keys like optimizer.state.0.step (the AdamW step counter, stored as a 0-dimensional tensor in PyTorch 2.2+).
The fix is to explicitly pass planner=DefaultLoadPlanner() to the dist_cp.load() call, matching what dist_cp.save() already does with DefaultSavePlanner().
Note: The save_fsdp_optimizer function already uses DefaultSavePlanner() — only the load side is missing the corresponding planner.
Reproduction
Config (config.yaml):
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: falseScript (repro.py):
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
class BigMLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(4096, 4096), nn.GELU(),
nn.Linear(4096, 4096), nn.GELU(),
nn.Linear(4096, 1),
)
def forward(self, x):
return self.layers(x)
class RandomDataset(Dataset):
def __len__(self): return 64
def __getitem__(self, i):
return torch.randn(4096), torch.randn(1)
def main():
accelerator = Accelerator()
model = BigMLP()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
loader = DataLoader(RandomDataset(), batch_size=8)
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
for step, (x, y) in enumerate(loader):
if step >= 3: break
accelerator.backward(model(x).mean())
optimizer.step()
optimizer.zero_grad()
accelerator.wait_for_everyone()
accelerator.save_state("./checkpoint")
accelerator.wait_for_everyone()
accelerator.load_state("./checkpoint") # <-- crashes here
if __name__ == "__main__":
main()Run:
accelerate launch --config_file config.yaml repro.pyExpected behavior
Expected behavior
accelerator.load_state("./checkpoint") should succeed and restore the optimizer state (including the step tensor) from the sharded checkpoint.
Actual behavior
RuntimeError: Missing key in checkpoint state_dict: optimizer.state.0.step.
The checkpoint files on disk are correct — inspecting the metadata with FileSystemReader confirms optimizer.state.0.step is present. The issue is purely on the load side where the missing DefaultLoadPlanner prevents dist_cp.load() from resolving the key.