Skip to content

Issue with GRPO trainer when using a reward_func that is a pretrained model? #3202

Open
@sowmaster

Description

@sowmaster

Hi, I am running into an error when I use the GRPO trainer with a reward_func that is a pretrained reward model instead of the custom reward functions. Its throws the following error:

assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings"
AssertionError: Padding_idx must be within num_embeddings

But I did check and made sure paddind_idx is within weight.size(0) (i.e. num_embeddings). I was able to reproduce this error using the minimal example provided in the GRPO docs by replacing the custom reward func with a pretrained reward model (trl-lib/Qwen2-0.5B-Reward)

# test_grpo.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
from cgrpo_trainer import CustomGRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="data/Qwen2-0.5B-GRPO",
                           logging_steps=10,
                           per_device_train_batch_size=2,
                           bf16=True,
                           num_generations=2,
                           max_prompt_length=128,
                           max_completion_length=128,
                           )
trainer = GRPOTrainer(
    model= "Qwen/Qwen2-0.5B-Instruct",
    reward_funcs='trl-lib/Qwen2-0.5B-Reward', # reward_len,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

I run the script with deepspeed using trl v0.16.0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 GRPORelated to GRPO🏋 RewardRelated to Reward modelling🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions