Description
@edbeeching and I noticed sometimes the trained SFT models do not learn to stop generations. In other words, the model never learn to generate EOS tokens.
Upon some digging, I noticed this is mainly an issue with the dataset preprocessing. In particular, if we simply pass a dataset like https://huggingface.co/datasets/timdettmers/openassistant-guanaco to the SFTTrainer
, the trainer may not postpend the completion with an EOS token.
If we run for item1, item2 in zip(inputs["input_ids"][1], inputs["attention_mask"][1]): print(item1, item2)
at https://github.com/huggingface/transformers/blob/91d155ea92da372b319a79dd4eef69533ee15170/src/transformers/trainer.py#L3207, with our SFT example we get
python examples/scripts/sft.py \
--model_name_or_path="facebook/opt-350m" \
--report_to="wandb" \
--learning_rate=1.41e-5 \
--per_device_train_batch_size=2 \
--gradient_accumulation_steps=16 \
--output_dir="sft_openassistant-guanaco" \
--logging_steps=1 \
--num_train_epochs=3 \
--max_steps=-1 \
--push_to_hub \
--gradient_checkpointing \
--dataset_text_field text

Notice how the pad token / eos token corresponds to attention mask = 0.
potential solution
This can be resolved if we add an eos token to the dataset itself. For example,
trl/examples/scripts/minimal/sft.py
Line 57 in dc012ea
python examples/scripts/minimal/sft.py \
--learning_rate 3e-6 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 32 \
--learning_rate 5e-05 \
--logging_steps 10 \
--evaluation_strategy epoch \
--max_seq_length 1024 \
--num_train_epochs 5 \
--output_dir models/minimal/sft

Notice how the first eos token corresponds to attention mask = 1.