-
Notifications
You must be signed in to change notification settings - Fork 459
Description
** Environment **
composer version 0.26.0
torch 2.4.0
** To reproduce
Steps to reproduce the behavior:
$ git clone https://github.com/maxrousseau/rafale.git
$ cd rafale
$ uv venv
$ . .venv/bin/activate
$ uv pip install -r cuda-requirements.txt
$ uv pip install -e .
$ rafale-run test/pythia_tinystories.yaml
$ # cancel the current run
$ rafale-run test/pythia_tinystories.yaml # resumes from the "latest" checkpoint
Expected behavior
Near exact continuation of the training loss curve compared to the uninterrupted run. After the second or third resumptions, the loss begins to diverge (see plot below). I suspect that maybe gradient accumulation is causing an issue where the gradients are not stored in the checkpoint or that we are restarting mid-batch (and the accumulated gradients are lost) ?
Note: purple is the uninterrupted run which has lower training loss.

Additional context
I am using device_microbatch_size="auto" for my training run the configuration of the run is the following:
run:
name: "pythia14m-tinystories" # name of your experiment, used for checkpointing
seed: 42
n_epochs: 1
max_lr: 6e-04
warmup_pct: 0.01
schedule: "cosine-warmup" # linear, linear-warmup, cosine, cosine-warmup
optimizer: "AdamW"
eval_interval: "100ba"
clip_type: "norm"
clip_value: 1.0
device_bs: "auto"
save_interval: "50ba"
train_key: "train"
eval_key: "validation"
model:
config: "pythia14m" # config key
type: "decoder"
use_pretrained: True
# mode: None
# n_classes: None
data:
pipeline: "tinystories_neox" # the preprocessing/tokenization pipeline
config:
name: "tinystories"
num_processes: 8
tokenizer_name: "neox"
shuffle_dataset: True # this will shufflle the whole training dataset once
input_id_key: "input_ids"
train_batch_size: 1024
eval_batch_size: 16
shuffle_train: False
dataset_path: "~/code/data/TinyStories"
tokenizer_path: "EleutherAI/pythia-14m"
max_sequence_length: 512
pad_token_id: -100
pad_inputs: True
is_prepared: False
subset_key_mappings: { "train": "train", "validation": "validation" } # (source: target)