Skip to content

[Bug] ORPO Trainer hangs with multi-gpu on step 0 #3068

Open
@NanoCode012

Description

@NanoCode012

Reproduction

Use the example script in repo with its sample commands but with accelerate.

Latest commit: 4871c82

accelerate launch examples/scripts/orpo.py     --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style     --model_name_or_path=gpt2     --per_device_train_batch_size 4     --max_steps 1000     --learning_rate 8e-6     --gradient_accumulation_steps 1     --logging_steps 10     --eval_steps 500     --output_dir="gpt2-aligned-orpo"     --warmup_steps 150    --bf16     --logging_first_step     --no_remove_unused_columns --log_level detail

The training would hang on step 0.

I have tracked it down to these two lines

metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
)
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()

These are really large tensors that are being propagated.

torch.Size([4, 223, 50257])
torch.Size([4, 223, 50257])

Solution: Call .detach().mean() on them prior to gather.

Happy to make the PR if we decide to average them prior to broadcast or sum like in KTOTrainer

trl/trl/trainer/kto_trainer.py

Lines 1270 to 1272 in 4871c82

metrics["logits/chosen_sum"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
)


Unrelated note: these two lines below should also be .detach() as I noticed they have graph on them.

metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()

Credit to morphism in discord who helped track down root PR cause and provided hints.

System Info

  • Platform: Linux-6.5.0-45-generic-x86_64-with-glibc2.35
  • Python version: 3.11.11
  • TRL version: 0.16.0.dev0+4871c82
  • PyTorch version: 2.5.1+cu124
  • CUDA device(s): NVIDIA A40, NVIDIA A40
  • Transformers version: 4.49.0
  • Accelerate version: 1.3.0
  • Accelerate config: not found
  • Datasets version: 3.2.0
  • HF Hub version: 0.28.1
  • bitsandbytes version: 0.45.2
  • DeepSpeed version: 0.16.1
  • Diffusers version: not installed
  • Liger-Kernel version: 0.5.3
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: 0.14.0
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions