Skip to content

[GRPO] adds experimental support for the SSR replay buffer #3325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions tests/test_repad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

import torch

from trl.trainer.grpo_replay_buffer import repad


PAD_TOKEN_ID = 123


def test_repad_basic_padding():
sample = [
{
"prompt_ids": torch.LongTensor([1, 2, 3]),
"prompt_mask": torch.LongTensor([1, 1, 0]),
"completion_ids": torch.LongTensor([5, 6, 7, 8]),
"completion_mask": torch.LongTensor([1, 1, 1, 0]),
"old_per_token_logps": torch.tensor([0.1, 0.2, 0.3, 0.4]),
"ref_per_token_logps": torch.tensor([0.0, -0.1, -0.2, -0.3]),
},
{
"prompt_ids": torch.LongTensor([4, 5]),
"prompt_mask": torch.LongTensor([1, 1]),
"completion_ids": torch.LongTensor([9, 10]),
"completion_mask": torch.LongTensor([1, 1]),
"old_per_token_logps": torch.tensor([-0.5, -0.6]),
"ref_per_token_logps": torch.tensor([0.5, 0.6]),
},
]

padded = repad(deepcopy(sample), padding_value=PAD_TOKEN_ID)

assert len(padded[0]["prompt_ids"]) == 2
assert len(padded[0]["completion_ids"]) == 3

for ex in padded:
# All sequences in same batch should have same length
assert len(ex["prompt_ids"]) == len(padded[0]["prompt_ids"])
assert len(ex["prompt_mask"]) == len(padded[0]["prompt_mask"])
assert len(ex["completion_ids"]) == len(padded[0]["completion_ids"])
assert len(ex["completion_mask"]) == len(padded[0]["completion_mask"])

# Mask and ids should match in shape
assert ex["prompt_ids"].shape == ex["prompt_mask"].shape
assert ex["completion_ids"].shape == ex["completion_mask"].shape


def test_repad_logps_padding():
sample = [
{
"prompt_ids": torch.LongTensor([1]),
"prompt_mask": torch.LongTensor([1]),
"completion_ids": torch.LongTensor([2, 3, 4]),
"completion_mask": torch.LongTensor([1, 1, 0]),
"old_per_token_logps": torch.tensor([-0.1, -0.2, -0.3]),
"ref_per_token_logps": torch.tensor([-0.5, -0.6, -0.7]),
},
{
"prompt_ids": torch.LongTensor([5, 6]),
"prompt_mask": torch.LongTensor([1, 1]),
"completion_ids": torch.LongTensor([7, 8]),
"completion_mask": torch.LongTensor([1, 1]),
"old_per_token_logps": torch.tensor([0.4, 0.5]),
"ref_per_token_logps": torch.tensor([0.6, 0.7]),
},
]

padded = repad(deepcopy(sample), padding_value=PAD_TOKEN_ID)

for logps in ["old_per_token_logps", "ref_per_token_logps"]:
for ex in padded:
assert len(ex[logps]) == len(padded[0][logps])
assert isinstance(ex[logps], torch.Tensor)


def test_repad_empty_masks():
sample = [
{
"prompt_ids": torch.tensor([0]),
"prompt_mask": torch.tensor([0]),
"completion_ids": torch.tensor([0]),
"completion_mask": torch.tensor([0]),
"old_per_token_logps": torch.tensor([0.0]),
"ref_per_token_logps": torch.tensor([0.0]),
},
{
"prompt_ids": torch.tensor([1]),
"prompt_mask": torch.tensor([0]),
"completion_ids": torch.tensor([1]),
"completion_mask": torch.tensor([0]),
"old_per_token_logps": torch.tensor([0.0]),
"ref_per_token_logps": torch.tensor([0.0]),
},
{
"prompt_ids": torch.tensor([1, 1]),
"prompt_mask": torch.tensor([0, 1]),
"completion_ids": torch.tensor([1, 2]),
"completion_mask": torch.tensor([1, 0]),
"old_per_token_logps": torch.tensor([0.0, 1.0]),
"ref_per_token_logps": torch.tensor([0.0, 1.0]),
},
{
"prompt_ids": torch.tensor([1, 1]),
"prompt_mask": torch.tensor([1, 1]),
"completion_ids": torch.tensor([1, 2]),
"completion_mask": torch.tensor([1, 0]),
"old_per_token_logps": torch.tensor([0.0, 1.0]),
"ref_per_token_logps": torch.tensor([0.0, 1.0]),
},
]
padded = repad(deepcopy(sample), padding_value=999)

assert len(padded[0]["prompt_ids"]) == 2
assert len(padded[0]["completion_ids"]) == 1

assert padded[0]["prompt_ids"].eq(999).all()
assert padded[0]["completion_ids"].eq(999).all()
35 changes: 35 additions & 0 deletions train_grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datasets import load_dataset

from trl import GRPOConfig, GRPOTrainer


dataset = load_dataset("trl-lib/tldr", split="train")


# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]


training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=1, replay_buffer_class="SSRReplayBuffer")
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train()
22 changes: 22 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class GRPOConfig(TrainingArguments):
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use the Liger GRPO loss.

replay_buffer_class: (`str`, defaults to `ReplayBuffer`):

> Parameters that control the logging

log_completions (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -387,6 +389,26 @@ class GRPOConfig(TrainingArguments):
metadata={"help": "Whether to use the Liger GRPO loss."},
)

replay_buffer_class: str = field(
default="ReplayBuffer",
metadata={
"help": "Replay buffer class to use, Options [ReplayBuffer, SSRReplayBuffer] The default is `ReplayBuffer`, that randomly samples without replacement."
},
)
ssr_capacity_scalar: int = field(
default=4,
metadata={
"help": "Scalar to multiply the replay buffer capacity. The default is 1, which means the capacity is "
"equal to the number of training samples in the effective batch."
},
)
ssr_alpha: float = field(
default=1.0,
metadata={
"help": "Alpha parameter for controlling the probablity distribution of the replay buffer. The default is 1.0, "
},
)

# Parameters that control the logging
log_completions: bool = field(
default=False,
Expand Down
154 changes: 154 additions & 0 deletions trl/trainer/grpo_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random

import numpy as np

from .utils import pad


def repad(list_of_tensor_dicts, padding_value):
p_ids, p_attn_masks = remove_and_pad(
[tensor_dict["prompt_ids"] for tensor_dict in list_of_tensor_dicts],
[tensor_dict["prompt_mask"] for tensor_dict in list_of_tensor_dicts],
pad_token_id=padding_value,
padding_side="left",
)
c_ids, c_attn_masks = remove_and_pad(
[tensor_dict["completion_ids"] for tensor_dict in list_of_tensor_dicts],
[tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts],
pad_token_id=padding_value,
)
old_logps, _ = remove_and_pad(
[tensor_dict["old_per_token_logps"] for tensor_dict in list_of_tensor_dicts],
[tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts],
pad_token_id=-10000.0, # ignored so can be anything
)
ref_logps, _ = remove_and_pad(
[tensor_dict["ref_per_token_logps"] for tensor_dict in list_of_tensor_dicts],
[tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts],
pad_token_id=-10000.0, # ignored so can be anything
)

for i, (p_id, p_mask, c_id, c_mask, o_logp, r_logp) in enumerate(
zip(p_ids, p_attn_masks, c_ids, c_attn_masks, old_logps, ref_logps)
):
list_of_tensor_dicts[i]["prompt_ids"] = p_id
list_of_tensor_dicts[i]["prompt_mask"] = p_mask
list_of_tensor_dicts[i]["completion_ids"] = c_id
list_of_tensor_dicts[i]["completion_mask"] = c_mask
list_of_tensor_dicts[i]["old_per_token_logps"] = o_logp
list_of_tensor_dicts[i]["ref_per_token_logps"] = r_logp

return list_of_tensor_dicts


def remove_and_pad(list_of_ids, list_of_masks, pad_token_id=0, padding_side="right"):
"""
Remove padding from list_of_ids and list_of_masks, and then pad them to the same length.
"""
num_samples = len(list_of_ids)
if list_of_ids[0] is None:
# we are not using old_per_token_logps / ref_per_token_logps
return [None] * num_samples, [None] * num_samples
# Remove padding
list_of_ids = [ids[mask == 1] for ids, mask in zip(list_of_ids, list_of_masks)]
list_of_masks = [mask[mask == 1] for mask in list_of_masks]

ids = pad(list_of_ids, padding_value=pad_token_id, padding_side=padding_side)
masks = pad(list_of_masks, padding_value=0, padding_side=padding_side)

return ids, masks


def remove_padding(input_ids, attn_mask):
"""
Remove padding from input_ids and attn_mask.
"""
if attn_mask is not None:
input_ids = input_ids[attn_mask == 1]
attn_mask = attn_mask[attn_mask == 1]
return input_ids, attn_mask


class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.sample_indices = []

def add(self, experience):
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
else:
self.buffer.pop(0)
self.buffer.append(experience)

# Clear index queue when buffer changes
self.sample_indices.clear()

def _init_sampling_queue(self):
self.sample_indices = list(range(len(self.buffer)))
random.shuffle(self.sample_indices)

def sample(self, batch_size):
if not self.sample_indices:
self._init_sampling_queue()

batch = []
while len(batch) < batch_size and self.sample_indices:
idx = self.sample_indices.pop(0)
batch.append(self.buffer[idx])

if len(batch) != batch_size:
raise ValueError("Not enough samples in the buffer to fill the batch.")

return batch

def __len__(self):
return len(self.buffer)


class SSRReplayBuffer(ReplayBuffer):
# implementation of the SSR replay buffer from https://arxiv.org/pdf/2504.08837
def __init__(self, capacity, alpha=1.0):
super().__init__(capacity)
self.alpha = alpha
self.advantages = []

def add(self, experience):
EPS = 0.0001 # ensures we get non-zero advs when the buffer contains all 0 advantages
advantage = experience["advantages"].item()
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
self.advantages.append(abs(advantage) + EPS) # Store absolute advantage
else:
# Replace the oldest entry if the buffer is full
self.buffer.pop(0)
self.advantages.pop(0)
self.buffer.append(experience)
self.advantages.append(abs(advantage))

def sample(self, batch_size):
if not self.buffer:
raise ValueError("Buffer is empty. Cannot sample from an empty buffer.")

# Convert advantages to priorities
scaled_priorities = np.power(self.advantages, self.alpha)
total_priority = np.sum(scaled_priorities)
probabilities = scaled_priorities / total_priority

indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
return [self.buffer[i] for i in indices]
Loading