Skip to content

Trainer.training_step incorrectly normalizes mean token loss when n_gpu > 1 #37474

Open
@wiwu2390

Description

@wiwu2390

System Info

- `transformers` version: 4.46.0
- Platform: Linux-5.15.0-136-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.29.2
- Safetensors version: 0.5.3
- Accelerate version: 1.4.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.4.1+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: yes
- Using GPU in script?: yes
- GPU type: NVIDIA RTX A5000

Who can help?

@zach-huggingface @SunMarc @ArthurZucker

Information

  • The official example scripts
    My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
    My own task or dataset (give details below)

Reproduction

Full example setup:

config = AutoConfig.from_pretrained('EleutherAI/pythia-14m')
model = GPTNeoXForCausalLM(config=config).to('cuda')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-14m')
tokenizer.pad_token = tokenizer.eos_token
train_data = load_dataset("wiwu2390/minipile-100k", split="train")

def tokenize_function(sample):
    return tokenizer(sample["text"], truncation=True, max_length=512)

tokenized_dataset = train_data.map(tokenize_function, batched=True, remove_columns=["text"])

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
)

training_args = TrainingArguments(
    output_dir="../data/pythia-14m-minipile-100k",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy="no",
    logging_steps=1,
    save_steps=100,
    learning_rate=1e-3,
    weight_decay=0.01,
    warmup_steps=100,
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

With 4 GPUs, the training loss at step 1 is ~2.7. However, the expected value is ~10.8. Indeed, this is what we get if we set CUDA_VISIBLE_DEVICES=0.

Expected behavior

Since the model is being trained from initialization, the training loss at the first few steps should be around ~log(vocab_size)=10.8. However, when using 4 GPUs, the reported loss is 1/4 of that (2.7).

The reason that this is happening is that the DataParallel-wrapped model gets num_items_in_batch as an input kwarg in Trainer.compute_loss; this is equal to the number of tokens in the batch (combined across all devices). Each device gets a 1/4-size per-device batch and returns the sum of token losses divided by num_items_in_batch (see transformers.loss.loss_utils.fixed_cross_entropy). The correct way to aggregate these per-device losses is then to sum them. However, Trainer.training_step takes the mean:

loss = loss.mean() # mean() to average on multi-gpu parallel training

A quick and dirty fix would be:

if self.args.n_gpu > 1:
    loss = loss.mean() if num_items_in_batch is None else loss.sum()

I'm not sure if this is compatible with other workflows though.

Activity

SunMarc

SunMarc commented on Apr 16, 2025

@SunMarc
Member

@wiwu2390 thanks for the report.

The reason that this is happening is that the DataParallel-wrapped model gets num_items_in_batch as an input kwarg in Trainer.compute_loss; this is equal to the number of tokens in the batch (combined across all devices). Each device gets a 1/4-size per-device batch and returns the sum of token losses divided by num_items_in_batch (see transformers.loss.loss_utils.fixed_cross_entropy).

num_items_in_batch shouldn't be equal to the number of tokens in the batch combined across all devices but only on the respective device. We only combine if you set average_tokens_across_devices=True. However, the default for this arg is False.

Did you test both runs with the same per_device_train_batch_size ? You need to divide it by 4 when running with 4 gpus to be comparable. Otherwise, you are actually using a bigger global batch size. However, I don't think this is the real issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Trainer.training_step incorrectly normalizes mean token loss when n_gpu > 1 · Issue #37474 · huggingface/transformers