Skip to content

DataLoader shuffle sequence replays from epoch 0 after resuming from a checkpoint #3996

@XiangY-Q

Description

@XiangY-Q

System Info

- `Accelerate` version: 1.13.0
- Platform: Linux-6.1.134-152.225.amzn2023.x86_64-x86_64-with-glibc2.39
- `accelerate` bash location: /mnt/.venv/bin/accelerate
- Python version: 3.12.3
- Numpy version: 2.4.4
- PyTorch version: 2.11.0+cu128
- PyTorch accelerator: CUDA
- System RAM: 1999.96 GB
- GPU type: NVIDIA H100 80GB HBM3
- `Accelerate` config passed:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: FSDP
        - mixed_precision: bf16
        - use_cpu: False
        - debug: True
        - num_processes: 2
        - machine_rank: 0
        - num_machines: 1
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - enable_cpu_affinity: False
        - fsdp_config: {'fsdp_activation_checkpointing': False, 'fsdp_auto_wrap_policy': 'TRANSFORMER_BASED_WRAP', 'fsdp_cpu_ram_efficient_loading': False, 'fsdp_offload_params': False, 'fsdp_reshard_after_forward': True, 'fsdp_state_dict_type': 'SHARDED_STATE_DICT', 'fsdp_version': 2}
        - parallelism_config: {'parallelism_config_cp_size': 1, 'parallelism_config_dp_replicate_size': 1, 'parallelism_config_dp_shard_size': 2, 'parallelism_config_tp_size': 1}
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Bug description

After saving a checkpoint with accelerator.save_state() and resuming with accelerator.load_state(), the DataLoader produces the same shuffle order as epochs 0, 1, 2, ... instead of continuing from the correct epoch. This means the model re-trains on the same data ordering it already saw, rather than getting fresh shuffles.

There are two related bugs depending on configuration:

Bug 1: use_seedable_sampler=True

DataLoaderShard.iteration is initialized to 0 in __init__ (data_loader.py:558) and incremented after each completed epoch (data_loader.py:601). At the start of each epoch, __iter__ calls self.set_epoch(self.iteration) (data_loader.py:573), which propagates to SeedableRandomSampler.set_epoch(), setting its self.epoch. The sampler then computes its seed as self.initial_seed + self.epoch (data_loader.py:99).

The problem: DataLoaderShard.iteration is not saved by save_state() and not restored by load_state(). On resume, it resets to 0, so SeedableRandomSampler seeds with initial_seed + 0, initial_seed + 1, etc. — replaying the exact same shuffle sequence from the start of training.

Bug 2: use_seedable_sampler=False (default) + multi-GPU

In prepare_data_loader() with num_processes > 1, if the sampler's generator is None, accelerate creates a private torch.Generator and assigns it to the sampler (data_loader.py:1235-1241). This generator is seeded deterministically from the global torch RNG (which was set by set_seed()).

The problem: this private generator is not saved or restored by save_state()/load_state() for map-style datasets. The checkpointing code only saves the sampler when isinstance(dataloader.dataset, IterableDatasetShard) (checkpointing.py:137-140). On resume, prepare() re-creates the generator with the same initial seed, replaying the same shuffle sequence.

(In single-process mode this case happens to work because the RandomSampler falls back to the global torch RNG, which is restored by load_state.)

Reproduction

Bug 1: use_seedable_sampler=True (single-process)

Save as reproduce_bug1.py and run with python reproduce_bug1.py:

"""Reproduces Bug 1 (use_seedable_sampler=True) in single-process mode."""

import os, shutil, tempfile
import torch
from torch.utils.data import TensorDataset
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.utils import set_seed

SEED, NUM_ITEMS, BATCH_SIZE = 42, 16, 4
CKPT_EPOCH, TOTAL = 2, 5

def epoch_order(dl):
    return [x for batch in dl for x in batch[0].tolist()]

def run_original(tmpdir):
    set_seed(SEED)
    acc = Accelerator(dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True))
    dl = acc.prepare(torch.utils.data.DataLoader(TensorDataset(torch.arange(NUM_ITEMS)), batch_size=BATCH_SIZE, shuffle=True))
    orders = {}
    for e in range(TOTAL):
        orders[e] = epoch_order(dl)
        if e == CKPT_EPOCH - 1:
            acc.save_state(os.path.join(tmpdir, "ckpt"))
    return orders

def run_resumed(tmpdir):
    set_seed(SEED)
    acc = Accelerator(dataloader_config=DataLoaderConfiguration(use_seedable_sampler=True))
    dl = acc.prepare(torch.utils.data.DataLoader(TensorDataset(torch.arange(NUM_ITEMS)), batch_size=BATCH_SIZE, shuffle=True))
    acc.load_state(os.path.join(tmpdir, "ckpt"))
    # BUG: dl.iteration is 0 here, not CKPT_EPOCH
    # Workaround: dl.set_epoch(CKPT_EPOCH)
    return {e: epoch_order(dl) for e in range(CKPT_EPOCH, TOTAL)}

tmpdir = tempfile.mkdtemp()
orig = run_original(tmpdir)
resumed = run_resumed(tmpdir)
shutil.rmtree(tmpdir)

print("Original:")
for e, o in sorted(orig.items()): print(f"  epoch {e}: {o}")
print("\nResumed:")
for e, o in sorted(resumed.items()): print(f"  epoch {e}: {o}")
print("\nComparison:")
for e in range(CKPT_EPOCH, TOTAL):
    restart_e = e - CKPT_EPOCH
    print(f"  epoch {e}: matches original epoch {e}? {orig[e] == resumed[e]}  |  replays epoch {restart_e}? {orig[restart_e] == resumed[e]}")

Output:

Original:
  epoch 0: [6, 3, 0, 7, 10, 15, 2, 12, 14, 9, 8, 5, 4, 13, 11, 1]
  epoch 1: [4, 0, 9, 5, 3, 2, 6, 7, 1, 12, 11, 10, 14, 8, 15, 13]
  epoch 2: [4, 9, 3, 2, 7, 8, 12, 13, 11, 15, 6, 5, 1, 10, 14, 0]
  epoch 3: [11, 12, 9, 1, 4, 6, 13, 14, 0, 15, 5, 3, 2, 8, 7, 10]
  epoch 4: [13, 2, 1, 11, 6, 12, 14, 9, 0, 7, 15, 8, 10, 5, 4, 3]

Resumed:
  epoch 2: [6, 3, 0, 7, 10, 15, 2, 12, 14, 9, 8, 5, 4, 13, 11, 1]   <-- same as epoch 0
  epoch 3: [4, 0, 9, 5, 3, 2, 6, 7, 1, 12, 11, 10, 14, 8, 15, 13]   <-- same as epoch 1
  epoch 4: [4, 9, 3, 2, 7, 8, 12, 13, 11, 15, 6, 5, 1, 10, 14, 0]   <-- same as epoch 2

Comparison:
  epoch 2: matches original epoch 2? False  |  replays epoch 0? True
  epoch 3: matches original epoch 3? False  |  replays epoch 1? True
  epoch 4: matches original epoch 4? False  |  replays epoch 2? True

Bug 2: use_seedable_sampler=False (default) + multi-GPU

Save as reproduce_bug2.py and launch with accelerate on 2+ GPUs:

accelerate launch --config_file accelerate-config-2gpu.yaml reproduce_bug2.py
"""Reproduces Bug 2 (default sampler + multi-GPU)."""

import os, shutil, tempfile
import torch
from torch.utils.data import TensorDataset
from accelerate import Accelerator
from accelerate.utils import set_seed

SEED, NUM_ITEMS, BATCH_SIZE = 42, 64, 4
CKPT_EPOCH, TOTAL = 2, 5

def epoch_order(dl):
    return [x for batch in dl for x in batch[0].tolist()]

def run_original(tmpdir):
    acc = Accelerator()
    set_seed(SEED, device_specific=True)
    dl = acc.prepare(torch.utils.data.DataLoader(
        TensorDataset(torch.arange(NUM_ITEMS)), batch_size=BATCH_SIZE, shuffle=True))
    orders = {}
    for e in range(TOTAL):
        orders[e] = epoch_order(dl)
        if e == CKPT_EPOCH - 1:
            acc.save_state(os.path.join(tmpdir, "ckpt"))
    acc.free_memory()
    return orders

def run_resumed(tmpdir):
    acc = Accelerator()
    set_seed(SEED, device_specific=True)
    dl = acc.prepare(torch.utils.data.DataLoader(
        TensorDataset(torch.arange(NUM_ITEMS)), batch_size=BATCH_SIZE, shuffle=True))
    acc.load_state(os.path.join(tmpdir, "ckpt"))
    # BUG: sampler's private generator was re-created with same initial seed
    orders = {e: epoch_order(dl) for e in range(CKPT_EPOCH, TOTAL)}
    is_main = acc.is_main_process
    acc.free_memory()
    return orders, is_main

torch.distributed.init_process_group(backend="nccl")
tmpdir = os.path.join(tempfile.gettempdir(), "accel_bug2_repro")
if int(os.environ.get("RANK", "0")) == 0:
    shutil.rmtree(tmpdir, ignore_errors=True)
    os.makedirs(tmpdir)
torch.distributed.barrier()

orig = run_original(tmpdir)
torch.distributed.barrier()
resumed, is_main = run_resumed(tmpdir)

if is_main:
    print("Original:"); [print(f"  epoch {e}: {o}") for e, o in sorted(orig.items())]
    print("\nResumed:"); [print(f"  epoch {e}: {o}") for e, o in sorted(resumed.items())]
    print("\nComparison:")
    for e in range(CKPT_EPOCH, TOTAL):
        r = e - CKPT_EPOCH
        print(f"  epoch {e}: matches original epoch {e}? {orig[e] == resumed[e]}  |  replays epoch {r}? {orig[r] == resumed[e]}")
    shutil.rmtree(tmpdir, ignore_errors=True)
torch.distributed.destroy_process_group()

Root cause

In data_loader.py:

  • DataLoaderShard.__init__ sets self.iteration = 0 (line 558)
  • DataLoaderShard.__iter__ calls self.set_epoch(self.iteration) (line 573), then increments self.iteration (line 601)
  • DataLoaderShard.set_epoch propagates to SeedableRandomSampler.set_epoch (line 619-620)
  • SeedableRandomSampler.__iter__ seeds with self.epoch + self.initial_seed (line 99)

But self.iteration is never included in state_dict() / load_state_dict(). The DataLoaderAdapter.state_dict() only delegates to the base dataloader's state dict, which tracks within-epoch position but not the epoch counter.

Potential fix

DataLoaderShard.iteration (and equivalently DataLoaderDispatcher.iteration) should be saved and restored as part of the dataloader state. Possible approaches:

  1. Include iteration in state_dict()/load_state_dict(): Override these in DataLoaderShard to include self.iteration alongside the base dataloader state.
  2. Save/restore SeedableRandomSampler.epoch: The checkpointing code already has a path for saving the sampler (currently gated on IterableDatasetShard). Extending this to also save for map-style datasets would fix bug 2 but not bug 1.

Workaround without a fix

After load_state, manually set the epoch counter:

accelerator.load_state(checkpoint_path)
first_epoch = global_step // steps_per_epoch
train_dataloader.set_epoch(first_epoch)

This fixes bug 1. For bug 2, additionally enable use_seedable_sampler=True so that the shuffle is deterministic per-epoch rather than depending on the generator's running state.

Expected behavior

After accelerator.load_state(), the DataLoader should produce the same shuffle order as the original uninterrupted run would have at the corresponding and following epochs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions