Skip to content

FSDP2 + SHARDED_STATE_DICT: optimizer checkpoint load fails with missing step key. (RuntimeError: Missing key in checkpoint state_dict: optimizer.state.0.step.) #3971

@iavinas

Description

@iavinas

System Info

- `Accelerate` version: 1.12.0
- Platform: Linux-6.6.113+-x86_64-with-glibc2.35
- `accelerate` bash location: /usr/local/bin/accelerate
- Python version: 3.12.12
- Numpy version: 2.0.2
- PyTorch version: 2.9.0+cu126
- PyTorch accelerator: CUDA
- System RAM: 31.35 GB
- GPU type: Tesla T4

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Bug description

When using FSDP2 (fsdp_version: 2) with SHARDED_STATE_DICT, saving an optimizer checkpoint succeeds, but loading it back fails with:

RuntimeError: Missing key in checkpoint state_dict: optimizer.state.0.step.

Root cause

In load_fsdp_optimizer (fsdp_utils.py:L322), dist_cp.load() is called without a planner argument. When no planner is provided, dist_cp uses its default resolution logic which cannot resolve nested optimizer state keys like optimizer.state.0.step (the AdamW step counter, stored as a 0-dimensional tensor in PyTorch 2.2+).

The fix is to explicitly pass planner=DefaultLoadPlanner() to the dist_cp.load() call, matching what dist_cp.save() already does with DefaultSavePlanner().

Note: The save_fsdp_optimizer function already uses DefaultSavePlanner() — only the load side is missing the corresponding planner.

Reproduction

Config (config.yaml):

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_cpu_ram_efficient_loading: true
  fsdp_offload_params: false
  fsdp_reshard_after_forward: true
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Script (repro.py):

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator

class BigMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(4096, 4096), nn.GELU(),
            nn.Linear(4096, 4096), nn.GELU(),
            nn.Linear(4096, 1),
        )
    def forward(self, x):
        return self.layers(x)

class RandomDataset(Dataset):
    def __len__(self): return 64
    def __getitem__(self, i):
        return torch.randn(4096), torch.randn(1)

def main():
    accelerator = Accelerator()
    model = BigMLP()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    loader = DataLoader(RandomDataset(), batch_size=8)
    model, optimizer, loader = accelerator.prepare(model, optimizer, loader)

    for step, (x, y) in enumerate(loader):
        if step >= 3: break
        accelerator.backward(model(x).mean())
        optimizer.step()
        optimizer.zero_grad()

    accelerator.wait_for_everyone()
    accelerator.save_state("./checkpoint")
    accelerator.wait_for_everyone()
    accelerator.load_state("./checkpoint")  # <-- crashes here

if __name__ == "__main__":
    main()

Run:

accelerate launch --config_file config.yaml repro.py

Expected behavior

Expected behavior

accelerator.load_state("./checkpoint") should succeed and restore the optimizer state (including the step tensor) from the sharded checkpoint.

Actual behavior

RuntimeError: Missing key in checkpoint state_dict: optimizer.state.0.step.

The checkpoint files on disk are correct — inspecting the metadata with FileSystemReader confirms optimizer.state.0.step is present. The issue is purely on the load side where the missing DefaultLoadPlanner prevents dist_cp.load() from resolving the key.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions