Description
Reproduction
Use the example script in repo with its sample commands but with accelerate.
Latest commit: 4871c82
accelerate launch examples/scripts/orpo.py --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --model_name_or_path=gpt2 --per_device_train_batch_size 4 --max_steps 1000 --learning_rate 8e-6 --gradient_accumulation_steps 1 --logging_steps 10 --eval_steps 500 --output_dir="gpt2-aligned-orpo" --warmup_steps 150 --bf16 --logging_first_step --no_remove_unused_columns --log_level detail
The training would hang on step 0.
I have tracked it down to these two lines
trl/trl/trainer/orpo_trainer.py
Lines 847 to 850 in 4871c82
These are really large tensors that are being propagated.
torch.Size([4, 223, 50257])
torch.Size([4, 223, 50257])
Solution: Call .detach().mean()
on them prior to gather.
Happy to make the PR if we decide to average them prior to broadcast or sum like in KTOTrainer
trl/trl/trainer/kto_trainer.py
Lines 1270 to 1272 in 4871c82
Unrelated note: these two lines below should also be .detach()
as I noticed they have graph on them.
trl/trl/trainer/orpo_trainer.py
Lines 852 to 853 in 4871c82
Credit to morphism
in discord who helped track down root PR cause and provided hints.
System Info
- Platform: Linux-6.5.0-45-generic-x86_64-with-glibc2.35
- Python version: 3.11.11
- TRL version: 0.16.0.dev0+4871c82
- PyTorch version: 2.5.1+cu124
- CUDA device(s): NVIDIA A40, NVIDIA A40
- Transformers version: 4.49.0
- Accelerate version: 1.3.0
- Accelerate config: not found
- Datasets version: 3.2.0
- HF Hub version: 0.28.1
- bitsandbytes version: 0.45.2
- DeepSpeed version: 0.16.1
- Diffusers version: not installed
- Liger-Kernel version: 0.5.3
- LLM-Blender version: not installed
- OpenAI version: not installed
- PEFT version: 0.14.0
- vLLM version: not installed
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete