System Info
- `Accelerate` version: 1.13.0
- Platform: Linux-6.1.134-152.225.amzn2023.x86_64-x86_64-with-glibc2.39
- `accelerate` bash location: /mnt/.venv/bin/accelerate
- Python version: 3.12.3
- Numpy version: 2.4.4
- PyTorch version: 2.11.0+cu128
- PyTorch accelerator: CUDA
- System RAM: 1999.96 GB
- GPU type: NVIDIA H100 80GB HBM3
- `Accelerate` config passed:
- compute_environment: LOCAL_MACHINE
- distributed_type: FSDP
- mixed_precision: bf16
- use_cpu: False
- debug: True
- num_processes: 2
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- fsdp_config: {'fsdp_activation_checkpointing': False, 'fsdp_auto_wrap_policy': 'TRANSFORMER_BASED_WRAP', 'fsdp_cpu_ram_efficient_loading': False, 'fsdp_offload_params': False, 'fsdp_reshard_after_forward': True, 'fsdp_state_dict_type': 'SHARDED_STATE_DICT', 'fsdp_version': 2}
- parallelism_config: {'parallelism_config_cp_size': 1, 'parallelism_config_dp_replicate_size': 1, 'parallelism_config_dp_shard_size': 2, 'parallelism_config_tp_size': 1}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
Information
Tasks
Reproduction
Bug description
After saving a checkpoint with accelerator.save_state() and resuming with accelerator.load_state(), the DataLoader produces the same shuffle order as epochs 0, 1, 2, ... instead of continuing from the correct epoch. This means the model re-trains on the same data ordering it already saw, rather than getting fresh shuffles.
There are two related bugs depending on configuration:
Bug 1: use_seedable_sampler=True
DataLoaderShard.iteration is initialized to 0 in __init__ (data_loader.py:558) and incremented after each completed epoch (data_loader.py:601). At the start of each epoch, __iter__ calls self.set_epoch(self.iteration) (data_loader.py:573), which propagates to SeedableRandomSampler.set_epoch(), setting its self.epoch. The sampler then computes its seed as self.initial_seed + self.epoch (data_loader.py:99).
The problem: DataLoaderShard.iteration is not saved by save_state() and not restored by load_state(). On resume, it resets to 0, so SeedableRandomSampler seeds with initial_seed + 0, initial_seed + 1, etc. — replaying the exact same shuffle sequence from the start of training.
Bug 2: use_seedable_sampler=False (default) + multi-GPU
In prepare_data_loader() with num_processes > 1, if the sampler's generator is None, accelerate creates a private torch.Generator and assigns it to the sampler (data_loader.py:1235-1241). This generator is seeded deterministically from the global torch RNG (which was set by set_seed()).
The problem: this private generator is not saved or restored by save_state()/load_state() for map-style datasets. The checkpointing code only saves the sampler when isinstance(dataloader.dataset, IterableDatasetShard) (checkpointing.py:137-140). On resume, prepare() re-creates the generator with the same initial seed, replaying the same shuffle sequence.
(In single-process mode this case happens to work because the RandomSampler falls back to the global torch RNG, which is restored by load_state.)
Reproduction
Bug 1: use_seedable_sampler=True (single-process)
Save as reproduce_bug1.py and run with python reproduce_bug1.py:
"""Reproduces Bug 1 (use_seedable_sampler=True) in single-process mode."""
import os, shutil, tempfile
import torch
from torch.utils.data import TensorDataset
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.utils import set_seed
SEED, NUM_ITEMS, BATCH_SIZE = 42, 16, 4
CKPT_EPOCH, TOTAL = 2, 5
def epoch_order(dl):
return [x for batch in dl for x in batch[0].tolist()]
def run_original(tmpdir):
set_seed(SEED)
acc = Accelerator(dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True))
dl = acc.prepare(torch.utils.data.DataLoader(TensorDataset(torch.arange(NUM_ITEMS)), batch_size=BATCH_SIZE, shuffle=True))
orders = {}
for e in range(TOTAL):
orders[e] = epoch_order(dl)
if e == CKPT_EPOCH - 1:
acc.save_state(os.path.join(tmpdir, "ckpt"))
return orders
def run_resumed(tmpdir):
set_seed(SEED)
acc = Accelerator(dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True))
dl = acc.prepare(torch.utils.data.DataLoader(TensorDataset(torch.arange(NUM_ITEMS)), batch_size=BATCH_SIZE, shuffle=True))
acc.load_state(os.path.join(tmpdir, "ckpt"))
# BUG: dl.iteration is 0 here, not CKPT_EPOCH
# Workaround: dl.set_epoch(CKPT_EPOCH)
return {e: epoch_order(dl) for e in range(CKPT_EPOCH, TOTAL)}
tmpdir = tempfile.mkdtemp()
orig = run_original(tmpdir)
resumed = run_resumed(tmpdir)
shutil.rmtree(tmpdir)
print("Original:")
for e, o in sorted(orig.items()): print(f" epoch {e}: {o}")
print("\nResumed:")
for e, o in sorted(resumed.items()): print(f" epoch {e}: {o}")
print("\nComparison:")
for e in range(CKPT_EPOCH, TOTAL):
restart_e = e - CKPT_EPOCH
print(f" epoch {e}: matches original epoch {e}? {orig[e] == resumed[e]} | replays epoch {restart_e}? {orig[restart_e] == resumed[e]}")
Output:
Original:
epoch 0: [6, 3, 0, 7, 10, 15, 2, 12, 14, 9, 8, 5, 4, 13, 11, 1]
epoch 1: [4, 0, 9, 5, 3, 2, 6, 7, 1, 12, 11, 10, 14, 8, 15, 13]
epoch 2: [4, 9, 3, 2, 7, 8, 12, 13, 11, 15, 6, 5, 1, 10, 14, 0]
epoch 3: [11, 12, 9, 1, 4, 6, 13, 14, 0, 15, 5, 3, 2, 8, 7, 10]
epoch 4: [13, 2, 1, 11, 6, 12, 14, 9, 0, 7, 15, 8, 10, 5, 4, 3]
Resumed:
epoch 2: [6, 3, 0, 7, 10, 15, 2, 12, 14, 9, 8, 5, 4, 13, 11, 1] <-- same as epoch 0
epoch 3: [4, 0, 9, 5, 3, 2, 6, 7, 1, 12, 11, 10, 14, 8, 15, 13] <-- same as epoch 1
epoch 4: [4, 9, 3, 2, 7, 8, 12, 13, 11, 15, 6, 5, 1, 10, 14, 0] <-- same as epoch 2
Comparison:
epoch 2: matches original epoch 2? False | replays epoch 0? True
epoch 3: matches original epoch 3? False | replays epoch 1? True
epoch 4: matches original epoch 4? False | replays epoch 2? True
Bug 2: use_seedable_sampler=False (default) + multi-GPU
Save as reproduce_bug2.py and launch with accelerate on 2+ GPUs:
accelerate launch --config_file accelerate-config-2gpu.yaml reproduce_bug2.py
"""Reproduces Bug 2 (default sampler + multi-GPU)."""
import os, shutil, tempfile
import torch
from torch.utils.data import TensorDataset
from accelerate import Accelerator
from accelerate.utils import set_seed
SEED, NUM_ITEMS, BATCH_SIZE = 42, 64, 4
CKPT_EPOCH, TOTAL = 2, 5
def epoch_order(dl):
return [x for batch in dl for x in batch[0].tolist()]
def run_original(tmpdir):
acc = Accelerator()
set_seed(SEED, device_specific=True)
dl = acc.prepare(torch.utils.data.DataLoader(
TensorDataset(torch.arange(NUM_ITEMS)), batch_size=BATCH_SIZE, shuffle=True))
orders = {}
for e in range(TOTAL):
orders[e] = epoch_order(dl)
if e == CKPT_EPOCH - 1:
acc.save_state(os.path.join(tmpdir, "ckpt"))
acc.free_memory()
return orders
def run_resumed(tmpdir):
acc = Accelerator()
set_seed(SEED, device_specific=True)
dl = acc.prepare(torch.utils.data.DataLoader(
TensorDataset(torch.arange(NUM_ITEMS)), batch_size=BATCH_SIZE, shuffle=True))
acc.load_state(os.path.join(tmpdir, "ckpt"))
# BUG: sampler's private generator was re-created with same initial seed
orders = {e: epoch_order(dl) for e in range(CKPT_EPOCH, TOTAL)}
is_main = acc.is_main_process
acc.free_memory()
return orders, is_main
torch.distributed.init_process_group(backend="nccl")
tmpdir = os.path.join(tempfile.gettempdir(), "accel_bug2_repro")
if int(os.environ.get("RANK", "0")) == 0:
shutil.rmtree(tmpdir, ignore_errors=True)
os.makedirs(tmpdir)
torch.distributed.barrier()
orig = run_original(tmpdir)
torch.distributed.barrier()
resumed, is_main = run_resumed(tmpdir)
if is_main:
print("Original:"); [print(f" epoch {e}: {o}") for e, o in sorted(orig.items())]
print("\nResumed:"); [print(f" epoch {e}: {o}") for e, o in sorted(resumed.items())]
print("\nComparison:")
for e in range(CKPT_EPOCH, TOTAL):
r = e - CKPT_EPOCH
print(f" epoch {e}: matches original epoch {e}? {orig[e] == resumed[e]} | replays epoch {r}? {orig[r] == resumed[e]}")
shutil.rmtree(tmpdir, ignore_errors=True)
torch.distributed.destroy_process_group()
Root cause
In data_loader.py:
DataLoaderShard.__init__ sets self.iteration = 0 (line 558)
DataLoaderShard.__iter__ calls self.set_epoch(self.iteration) (line 573), then increments self.iteration (line 601)
DataLoaderShard.set_epoch propagates to SeedableRandomSampler.set_epoch (line 619-620)
SeedableRandomSampler.__iter__ seeds with self.epoch + self.initial_seed (line 99)
But self.iteration is never included in state_dict() / load_state_dict(). The DataLoaderAdapter.state_dict() only delegates to the base dataloader's state dict, which tracks within-epoch position but not the epoch counter.
Potential fix
DataLoaderShard.iteration (and equivalently DataLoaderDispatcher.iteration) should be saved and restored as part of the dataloader state. Possible approaches:
- Include
iteration in state_dict()/load_state_dict(): Override these in DataLoaderShard to include self.iteration alongside the base dataloader state.
- Save/restore
SeedableRandomSampler.epoch: The checkpointing code already has a path for saving the sampler (currently gated on IterableDatasetShard). Extending this to also save for map-style datasets would fix bug 2 but not bug 1.
Workaround without a fix
After load_state, manually set the epoch counter:
accelerator.load_state(checkpoint_path)
first_epoch = global_step // steps_per_epoch
train_dataloader.set_epoch(first_epoch)
This fixes bug 1. For bug 2, additionally enable use_seedable_sampler=True so that the shuffle is deterministic per-epoch rather than depending on the generator's running state.
Expected behavior
After accelerator.load_state(), the DataLoader should produce the same shuffle order as the original uninterrupted run would have at the corresponding and following epochs.
System Info
Information
Tasks
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py)Reproduction
Bug description
After saving a checkpoint with
accelerator.save_state()and resuming withaccelerator.load_state(), the DataLoader produces the same shuffle order as epochs 0, 1, 2, ... instead of continuing from the correct epoch. This means the model re-trains on the same data ordering it already saw, rather than getting fresh shuffles.There are two related bugs depending on configuration:
Bug 1:
use_seedable_sampler=TrueDataLoaderShard.iterationis initialized to0in__init__(data_loader.py:558) and incremented after each completed epoch (data_loader.py:601). At the start of each epoch,__iter__callsself.set_epoch(self.iteration)(data_loader.py:573), which propagates toSeedableRandomSampler.set_epoch(), setting itsself.epoch. The sampler then computes its seed asself.initial_seed + self.epoch(data_loader.py:99).The problem:
DataLoaderShard.iterationis not saved bysave_state()and not restored byload_state(). On resume, it resets to0, soSeedableRandomSamplerseeds withinitial_seed + 0,initial_seed + 1, etc. — replaying the exact same shuffle sequence from the start of training.Bug 2:
use_seedable_sampler=False(default) + multi-GPUIn
prepare_data_loader()withnum_processes > 1, if the sampler's generator isNone, accelerate creates a privatetorch.Generatorand assigns it to the sampler (data_loader.py:1235-1241). This generator is seeded deterministically from the global torch RNG (which was set byset_seed()).The problem: this private generator is not saved or restored by
save_state()/load_state()for map-style datasets. The checkpointing code only saves the sampler whenisinstance(dataloader.dataset, IterableDatasetShard)(checkpointing.py:137-140). On resume,prepare()re-creates the generator with the same initial seed, replaying the same shuffle sequence.(In single-process mode this case happens to work because the
RandomSamplerfalls back to the global torch RNG, which is restored byload_state.)Reproduction
Bug 1:
use_seedable_sampler=True(single-process)Save as
reproduce_bug1.pyand run withpython reproduce_bug1.py:Output:
Bug 2:
use_seedable_sampler=False(default) + multi-GPUSave as
reproduce_bug2.pyand launch with accelerate on 2+ GPUs:Root cause
In
data_loader.py:DataLoaderShard.__init__setsself.iteration = 0(line 558)DataLoaderShard.__iter__callsself.set_epoch(self.iteration)(line 573), then incrementsself.iteration(line 601)DataLoaderShard.set_epochpropagates toSeedableRandomSampler.set_epoch(line 619-620)SeedableRandomSampler.__iter__seeds withself.epoch + self.initial_seed(line 99)But
self.iterationis never included instate_dict()/load_state_dict(). TheDataLoaderAdapter.state_dict()only delegates to the base dataloader's state dict, which tracks within-epoch position but not the epoch counter.Potential fix
DataLoaderShard.iteration(and equivalentlyDataLoaderDispatcher.iteration) should be saved and restored as part of the dataloader state. Possible approaches:iterationinstate_dict()/load_state_dict(): Override these inDataLoaderShardto includeself.iterationalongside the base dataloader state.SeedableRandomSampler.epoch: The checkpointing code already has a path for saving the sampler (currently gated onIterableDatasetShard). Extending this to also save for map-style datasets would fix bug 2 but not bug 1.Workaround without a fix
After
load_state, manually set the epoch counter:This fixes bug 1. For bug 2, additionally enable
use_seedable_sampler=Trueso that the shuffle is deterministic per-epoch rather than depending on the generator's running state.Expected behavior
After
accelerator.load_state(), the DataLoader should produce the same shuffle order as the original uninterrupted run would have at the corresponding and following epochs.