Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4a8bf02
Add DGPO (Difficulty-Aware Group Policy Optimization, ICLR 2026) supp…
YanqiDai Feb 15, 2026
0383a57
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 18, 2026
b0f72ef
Revise DGPO description and usage instructions
YanqiDai Feb 20, 2026
90fc3f5
Remove DGPO section from grpo_trainer.md
YanqiDai Feb 20, 2026
34a69eb
Remove ICLR 2026
YanqiDai Feb 20, 2026
da8c445
Apply all other suggestions from code review
YanqiDai Feb 20, 2026
721c3eb
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 20, 2026
25def33
Polish the description of accuracy handling logic in DQW
YanqiDai Feb 20, 2026
1ecdae6
Rewrite the DGPO code
YanqiDai Feb 20, 2026
6dd54db
Merge branch 'huggingface:main' into grpo-dgpo
YanqiDai Feb 20, 2026
6291ae3
Remove ICLR 2026
YanqiDai Feb 20, 2026
df1fb48
Recover the code position of is_std_zero
YanqiDai Feb 20, 2026
e2b254d
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 23, 2026
fb95c87
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 25, 2026
89e1384
Apply suggestions from code review
YanqiDai Feb 25, 2026
e07db90
Remove repeated use_bias_correction_kl in suggestions from grpo_config
YanqiDai Feb 25, 2026
f34d3f1
Modify _compute_advantages_with_dgae and implement it directly in _ge…
YanqiDai Feb 25, 2026
1535983
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 26, 2026
2137c9d
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 27, 2026
7c5acea
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 27, 2026
ecdbe7e
Merge branch 'main' into grpo-dgpo
YanqiDai Mar 1, 2026
4a2c30c
Resolve conflicts between the grpo-dgpo branch and the main branch.
YanqiDai Mar 1, 2026
aec3912
Remove an extra blank line in grpo_config.py
YanqiDai Mar 1, 2026
4b864c1
Merge branch 'main' into grpo-dgpo
YanqiDai Mar 5, 2026
4420d51
Remove redundant gather() for is_std_zero found by cursor
YanqiDai Mar 5, 2026
158b3de
Merge branch 'main' into grpo-dgpo
YanqiDai Mar 12, 2026
56e489a
Fix type casting for `global_completion_length_sum`, `local_completio…
YanqiDai Mar 13, 2026
87782d9
Refactor standard deviation calculation in GRPOTrainer to use nanstd …
YanqiDai Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,30 @@ training_args = GRPOConfig(
)
```

### DGPO: Difficulty-Aware Group Policy Optimization

**📜 Paper**: https://huggingface.co/papers/2601.20614

DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`].

- **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty.
- **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight, so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal.

To use DGPO in TRL, enable the corresponding options in [`GRPOConfig`]:

```python
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
...,
use_dgpo_dgae=True,
use_dgpo_dqw=True,
dgpo_dqw_temp=2.0,
dgpo_dqw_acc_reward_index=0,
)
trainer = GRPOTrainer(..., args=training_args, reward_funcs=[...], train_dataset=...)
```

### Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)

**📜 Paper**: https://huggingface.co/papers/2508.08221
Expand Down
25 changes: 25 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,31 @@ def test_training_loss_types(self, loss_type):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_dgpo(self):
"""Test DGPO (Difficulty-Aware Group Policy Optimization) runs without error."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=32,
report_to="none",
use_dgpo_dgae=True,
use_dgpo_dqw=True,
dgpo_dqw_temp=2.0,
dgpo_dqw_acc_reward_index=0,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None

def test_training_with_eval(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")

Expand Down
56 changes: 56 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,26 @@ class GRPOConfig(_BaseConfig):
Whether to use the unbiased KL divergence estimator with importance sampling correction. This corrects the
KL divergence estimate by multiplying it with the importance sampling ratio. This is described in the
[DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556).
use_dgpo_dgae (`bool`, *optional*, defaults to `False`):
Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, group-relative
advantages are normalized by the mean absolute deviation (MAD) of rewards (instead of the standard
deviation): `advantage = (reward - mean) / (MAD + eps)`, where `MAD = mean(|reward - mean|)`. Introduced
in the [MathForge paper](https://huggingface.co/papers/2601.20614).
use_dgpo_dqw (`bool`, *optional*, defaults to `False`):
Whether to use difficulty-aware question-level weighting (DQW). When `True`, each question gets a weight
based on its estimated difficulty, and that weight is multiplied into the advantages—so harder questions
produce larger effective updates. Difficulty is computed as a softmax over the negative per-question mean
accuracy reward from `reward_funcs[dgpo_dqw_acc_reward_index]`. Introduced in the [MathForge
paper](https://huggingface.co/papers/2601.20614).
dgpo_dqw_temp (`float`, *optional*, defaults to `2.0`):
Temperature for the DQW softmax over difficulty scores (negative mean accuracy reward). Higher values make
weights more uniform across questions; lower values concentrate weight on the hardest questions.
Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614).
dgpo_dqw_acc_reward_index (`int`, *optional*, defaults to `0`):
Index into `reward_funcs` selecting the reward used as the "accuracy" signal for DQW's difficulty
estimate. For each question, DQW uses the mean reward at this index across the group to compute the
difficulty score (lower mean ⇒ harder), which is then turned into a weight via the DQW softmax.
Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614).

> Parameters that control the logging

Expand Down Expand Up @@ -787,6 +807,42 @@ class GRPOConfig(_BaseConfig):
"This is described in the [DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556)."
},
)
use_dgpo_dgae: bool = field(
default=False,
metadata={
"help": "Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, group-relative "
"advantages are normalized by the mean absolute deviation (MAD) of rewards (instead of the standard "
"deviation): `advantage = (reward - mean) / (MAD + eps)`, where `MAD = mean(|reward - mean|)`. Introduced "
"in the [MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)
use_dgpo_dqw: bool = field(
default=False,
metadata={
"help": "Whether to use difficulty-aware question-level weighting (DQW). When `True`, each question gets a "
"weight based on its estimated difficulty, and that weight is multiplied into the advantages—so harder "
"questions produce larger effective updates. Difficulty is computed as a softmax over the negative "
"per-question mean accuracy reward from `reward_funcs[dgpo_dqw_acc_reward_index]`. Introduced in the "
"[MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)
dgpo_dqw_temp: float = field(
default=2.0,
metadata={
"help": "Temperature for the DQW softmax over difficulty scores (negative mean accuracy reward). Higher "
"values make weights more uniform across questions; lower values concentrate weight on the hardest "
"questions. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)
dgpo_dqw_acc_reward_index: int = field(
default=0,
metadata={
"help": "Index into `reward_funcs` selecting the reward used as the \"accuracy\" signal for DQW's "
"difficulty estimate. For each question, DQW uses the mean reward at this index across the group to "
"compute the difficulty score (lower mean ⇒ harder), which is then turned into a weight via the DQW "
"softmax. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)

# Parameters that control the logging
log_completions: bool = field(
Expand Down
120 changes: 118 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,15 @@ def __init__(
f"Unknown importance sampling level: {self.importance_sampling_level}. "
"Possible values are 'token' and 'sequence'."
)
self.use_dgpo_dgae = args.use_dgpo_dgae
self.use_dgpo_dqw = args.use_dgpo_dqw
self.dgpo_dqw_temp = args.dgpo_dqw_temp
self.dgpo_dqw_acc_reward_index = args.dgpo_dqw_acc_reward_index
if self.use_dgpo_dqw and (self.dgpo_dqw_acc_reward_index < 0 or self.dgpo_dqw_acc_reward_index >= len(self.reward_funcs)):
raise ValueError(
f"dgpo_dqw_acc_reward_index must be in [0, {len(self.reward_funcs)}), got "
f"{self.dgpo_dqw_acc_reward_index}."
)

# Datasets
self.shuffle_dataset = args.shuffle_dataset
Expand Down Expand Up @@ -1606,6 +1615,86 @@ def _generate(self, prompts: list):
extra_fields,
)

def _compute_valid_token_balancing_ratios(
self,
completion_mask: torch.Tensor,
is_std_zero: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute token-level balancing ratios for distributed training with filtered questions.

When zero-variance questions are masked out, the effective number of valid tokens changes across
processes. This method produces two correction factors: `zero_mask_ratio` compensates for the tokens
lost by masking zero-variance questions, and `global_balancing_ratio` corrects for uneven token counts
across processes.

Args:
completion_mask: Boolean tensor of shape `(batch_size, seq_len)` indicating valid completion tokens.
is_std_zero: Boolean tensor of shape `(batch_size,)` indicating zero-variance questions.

Returns:
A tuple `(zero_mask_ratio, global_balancing_ratio)` of scalar tensors.
"""
completion_length_local = completion_mask.sum(dim=1)
completion_length_global = gather(completion_length_local)

global_completion_length_sum = completion_length_global.sum().clamp(min=1e-8)
local_completion_length_sum = completion_length_local.sum()

global_balancing_ratio = (
self.accelerator.num_processes * local_completion_length_sum / global_completion_length_sum
)

valid_mask_global = ~is_std_zero
if valid_mask_global.any():
valid_completion_length_sum = completion_length_global[valid_mask_global].sum().clamp(min=1e-8)
zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum
else:
zero_mask_ratio = torch.tensor(1.0, device=completion_mask.device, dtype=completion_mask.dtype)

return zero_mask_ratio, global_balancing_ratio

def _compute_dqw_weights(
self,
rewards: torch.Tensor,
rewards_per_func: torch.Tensor,
num_generations: int,
) -> torch.Tensor:
"""
Compute question-level difficulty balancing weights (DQW).

Assigns higher weight to harder questions (lower mean accuracy) using a temperature-scaled softmax over
per-question accuracy means. Zero-variance questions receive a neutral weight of 1.
The returned weights sum to `num_questions`.

Args:
rewards: Tensor of shape `(num_questions * num_generations,)` with per-generation rewards.
rewards_per_func: Tensor of shape `(num_questions * num_generations, num_reward_funcs)` with
per-reward-function scores; the column at `dgpo_dqw_acc_reward_index` is used as accuracy.
num_generations: Number of generations per question.

Returns:
Tensor of shape `(num_questions,)` with difficulty balancing weights.
"""
num_questions = rewards.size(0) // num_generations
acc_rewards = rewards_per_func[:, self.dgpo_dqw_acc_reward_index]
mean_per_q_acc = acc_rewards.view(-1, num_generations).nanmean(dim=1) # (num_questions,)
std_per_q_acc = acc_rewards.view(-1, num_generations).std(dim=1) # (num_questions,)
is_std_zero_q = std_per_q_acc < 1e-8
num_zero_variance_questions = is_std_zero_q.sum().item()
difficulty_balancing_weights = torch.ones(
num_questions, device=rewards.device, dtype=rewards.dtype
)
if num_zero_variance_questions < num_questions:
mean_per_q_acc_modified = mean_per_q_acc.clone()
mean_per_q_acc_modified[(mean_per_q_acc == 0) | torch.isnan(mean_per_q_acc)] = 1.0
difficulty_balancing_weights[~is_std_zero_q] = (
num_questions - num_zero_variance_questions
) * torch.nn.functional.softmax(
-mean_per_q_acc_modified[~is_std_zero_q] / self.dgpo_dqw_temp, dim=0
)
return difficulty_balancing_weights

def _generate_and_score_completions(
self, inputs: list[dict[str, torch.Tensor | Any]]
) -> dict[str, torch.Tensor | Any]:
Expand Down Expand Up @@ -1891,7 +1980,12 @@ def _generate_and_score_completions(

advantages = rewards - mean_grouped_rewards
if self.scale_rewards != "none":
advantages = advantages / (std_rewards + 1e-4)
if self.use_dgpo_dgae:
mad_rewards = advantages.abs().view(-1, num_generations).mean(dim=1)
mad_rewards = mad_rewards.repeat_interleave(num_generations, dim=0)
advantages = advantages / (mad_rewards + 1e-4)
else:
advantages = advantages / (std_rewards + 1e-4)
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging

elif self.multi_objective_aggregation == "normalize_then_sum":
Expand All @@ -1902,7 +1996,13 @@ def _generate_and_score_completions(
reward_k = reward_k.view(-1, len(self.reward_funcs))
rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards)
advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4)
advantages = rewards - rewards.mean()
if self.use_dgpo_dgae:
mad_rewards = advantages.abs().view(-1, num_generations).mean(dim=1)
mad_rewards = mad_rewards.repeat_interleave(num_generations, dim=0)
advantages = advantages / (mad_rewards + 1e-4)
else:
advantages = advantages / (std_rewards + 1e-4)
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging

else:
Expand All @@ -1911,6 +2011,19 @@ def _generate_and_score_completions(
"'sum_then_normalize' or 'normalize_then_sum'."
)

# zero_mask_ratio must be applied before the process slice; global_balancing_ratio after
if self.use_dgpo_dgae or self.use_dgpo_dqw:
zero_mask_ratio, global_balancing_ratio = self._compute_valid_token_balancing_ratios(
completion_mask, is_std_zero
)
advantages = advantages * zero_mask_ratio

if self.use_dgpo_dqw:
difficulty_balancing_weights = self._compute_dqw_weights(
rewards, rewards_per_func, num_generations
)
advantages = advantages * difficulty_balancing_weights.repeat_interleave(num_generations)

# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
Expand All @@ -1919,6 +2032,9 @@ def _generate_and_score_completions(
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
advantages = advantages[process_slice]

if self.use_dgpo_dgae or self.use_dgpo_dqw:
advantages = advantages * global_balancing_ratio

# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
Expand Down