Skip to content

Learning to generate EOS tokens Β #1623

Closed
@vwxyzjn

Description

@vwxyzjn

@edbeeching and I noticed sometimes the trained SFT models do not learn to stop generations. In other words, the model never learn to generate EOS tokens.

Upon some digging, I noticed this is mainly an issue with the dataset preprocessing. In particular, if we simply pass a dataset like https://huggingface.co/datasets/timdettmers/openassistant-guanaco to the SFTTrainer, the trainer may not postpend the completion with an EOS token.

If we run for item1, item2 in zip(inputs["input_ids"][1], inputs["attention_mask"][1]): print(item1, item2) at https://github.com/huggingface/transformers/blob/91d155ea92da372b319a79dd4eef69533ee15170/src/transformers/trainer.py#L3207, with our SFT example we get

python examples/scripts/sft.py \
    --model_name_or_path="facebook/opt-350m" \
    --report_to="wandb" \
    --learning_rate=1.41e-5 \
    --per_device_train_batch_size=2 \
    --gradient_accumulation_steps=16 \
    --output_dir="sft_openassistant-guanaco" \
    --logging_steps=1 \
    --num_train_epochs=3 \
    --max_steps=-1 \
    --push_to_hub \
    --gradient_checkpointing \
    --dataset_text_field text
image

Notice how the pad token / eos token corresponds to attention mask = 0.

potential solution

This can be resolved if we add an eos token to the dataset itself. For example,

"{% for message in messages %}{{' ' + message['content']}}{% endfor %}{{ eos_token }}"
always adds an EOS token to the tokenized dataset, and as a result we get

python examples/scripts/minimal/sft.py \
    --learning_rate 3e-6 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 32 \
    --learning_rate 5e-05 \
    --logging_steps 10 \
    --evaluation_strategy epoch \
    --max_seq_length 1024 \
    --num_train_epochs 5 \
    --output_dir models/minimal/sft
image

Notice how the first eos token corresponds to attention mask = 1.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions