Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4a8bf02
Add DGPO (Difficulty-Aware Group Policy Optimization, ICLR 2026) supp…
YanqiDai Feb 15, 2026
0383a57
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 18, 2026
b0f72ef
Revise DGPO description and usage instructions
YanqiDai Feb 20, 2026
90fc3f5
Remove DGPO section from grpo_trainer.md
YanqiDai Feb 20, 2026
34a69eb
Remove ICLR 2026
YanqiDai Feb 20, 2026
da8c445
Apply all other suggestions from code review
YanqiDai Feb 20, 2026
721c3eb
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 20, 2026
25def33
Polish the description of accuracy handling logic in DQW
YanqiDai Feb 20, 2026
1ecdae6
Rewrite the DGPO code
YanqiDai Feb 20, 2026
6dd54db
Merge branch 'huggingface:main' into grpo-dgpo
YanqiDai Feb 20, 2026
6291ae3
Remove ICLR 2026
YanqiDai Feb 20, 2026
df1fb48
Recover the code position of is_std_zero
YanqiDai Feb 20, 2026
e2b254d
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 23, 2026
fb95c87
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 25, 2026
89e1384
Apply suggestions from code review
YanqiDai Feb 25, 2026
e07db90
Remove repeated use_bias_correction_kl in suggestions from grpo_config
YanqiDai Feb 25, 2026
f34d3f1
Modify _compute_advantages_with_dgae and implement it directly in _ge…
YanqiDai Feb 25, 2026
1535983
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 26, 2026
2137c9d
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 27, 2026
7c5acea
Merge branch 'main' into grpo-dgpo
YanqiDai Feb 27, 2026
ecdbe7e
Merge branch 'main' into grpo-dgpo
YanqiDai Mar 1, 2026
4a2c30c
Resolve conflicts between the grpo-dgpo branch and the main branch.
YanqiDai Mar 1, 2026
aec3912
Remove an extra blank line in grpo_config.py
YanqiDai Mar 1, 2026
4b864c1
Merge branch 'main' into grpo-dgpo
YanqiDai Mar 5, 2026
4420d51
Remove redundant gather() for is_std_zero found by cursor
YanqiDai Mar 5, 2026
158b3de
Merge branch 'main' into grpo-dgpo
YanqiDai Mar 12, 2026
56e489a
Fix type casting for `global_completion_length_sum`, `local_completio…
YanqiDai Mar 13, 2026
87782d9
Refactor standard deviation calculation in GRPOTrainer to use nanstd …
YanqiDai Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,30 @@ training_args = GRPOConfig(
)
```

### DGPO: Difficulty-Aware Group Policy Optimization

**📜 Paper**: https://huggingface.co/papers/2601.20614

DGPO extends GRPO with difficulty-aware mechanisms to improve training on tasks with varying question difficulty (e.g., math reasoning). It is introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614) (ICLR 2026) and is supported in [`GRPOTrainer`] via [`GRPOConfig`].

- **DGAE (Difficulty-balanced Group Advantage Estimation)**: When `use_dgpo_dgae=True`, advantages are scaled using Mean Absolute Deviation (MAD) instead of standard deviation, i.e. advantage = (reward - mean) / (MAD + eps), which can address the implicit imbalance where the update magnitudes are suppressed for both easier and harder questions and peak for those of moderate difficulty.
- **DQW (Difficulty-aware Question-level Weighting)**: When `use_dgpo_dqw=True`, each question (prompt group) is assigned a weight based on its difficulty (e.g., mean accuracy reward). Harder questions get higher weight, so the policy focuses more on them. Use `dgpo_dqw_temp` to control how sharp the weighting is (lower = more focus on hard questions) and `dgpo_dqw_acc_reward_index` to specify which reward in `reward_funcs` is used as the accuracy/difficulty signal.

To use DGPO in TRL, enable the corresponding options in [`GRPOConfig`]:

```python
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
...,
use_dgpo_dgae=True,
use_dgpo_dqw=True,
dgpo_dqw_temp=2.0,
dgpo_dqw_acc_reward_index=0,
)
trainer = GRPOTrainer(..., args=training_args, reward_funcs=[...], train_dataset=...)
```

### Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)

**📜 Paper**: https://huggingface.co/papers/2508.08221
Expand Down
25 changes: 25 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,31 @@ def test_training_loss_types(self, loss_type):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_dgpo(self):
"""Test DGPO (Difficulty-Aware Group Policy Optimization) runs without error."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=32,
report_to="none",
use_dgpo_dgae=True,
use_dgpo_dqw=True,
dgpo_dqw_temp=2.0,
dgpo_dqw_acc_reward_index=0,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None

def test_training_with_eval(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")

Expand Down
49 changes: 49 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,24 @@ class GRPOConfig(BaseConfig):
Whether to use the unbiased KL divergence estimator with importance sampling correction. This corrects the
KL divergence estimate by multiplying it with the importance sampling ratio. This is described in the
[DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556).
use_dgpo_dgae (`bool`, *optional*, defaults to `False`):
Whether to use difficulty-balanced group advantage estimation (DGAE). When `True`, the denominator when
scaling advantages uses the Mean Absolute Deviation (MAD) of rewards instead of the standard deviation, i.e.
advantage = (reward - mean) / (MAD + eps) with MAD = mean(|reward - mean|). Introduced in the [MathForge
paper](https://huggingface.co/papers/2601.20614).
use_dgpo_dqw (`bool`, *optional*, defaults to `False`):
Whether to use difficulty-aware question-level weighting (DQW). When `True`, question weights (softmax over
negative mean accuracy reward at `dgpo_dqw_acc_reward_index`) are multiplied directly onto the advantages,
so harder questions get larger effective advantages. Introduced in the [MathForge
paper](https://huggingface.co/papers/2601.20614).
dgpo_dqw_temp (`float`, *optional*, defaults to `2.0`):
Temperature for the DQW softmax over negative mean (accuracy) reward. Higher values make the weighting more
uniform; lower values concentrate weight on harder questions. Introduced in the [MathForge
paper](https://huggingface.co/papers/2601.20614).
dgpo_dqw_acc_reward_index (`int`, *optional*, defaults to `0`):
Index of the accuracy reward in `reward_funcs` used by DQW for difficulty measure. The mean reward at this
index (per question) is used to compute question weights: lower mean accuracy means harder question.
Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614).

> Parameters that control the logging

Expand Down Expand Up @@ -776,6 +794,37 @@ class GRPOConfig(BaseConfig):
},
)

use_dgpo_dgae: bool = field(
default=False,
metadata={
"help": "Whether to use difficulty-balanced group advantage estimation (DGAE). When True, the denominator "
"when scaling advantages uses the Mean Absolute Deviation (MAD) of rewards instead of the standard "
"deviation. Introduced in the [MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)
use_dgpo_dqw: bool = field(
default=False,
metadata={
"help": "Whether to use difficulty-aware question-level weighting (DQW). When True, question weights are "
"multiplied directly onto the advantages. Introduced in the [MathForge paper](https://huggingface.co/"
"papers/2601.20614)."
},
)
dgpo_dqw_temp: float = field(
default=2.0,
metadata={
"help": "Temperature for the DQW softmax over negative mean (accuracy) reward. Introduced in the "
"[MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)
dgpo_dqw_acc_reward_index: int = field(
default=0,
metadata={
"help": "Index of the accuracy reward in reward_funcs used by DQW for difficulty measure. Introduced in "
"the [MathForge paper](https://huggingface.co/papers/2601.20614)."
},
)

# Parameters that control the logging
log_completions: bool = field(
default=False,
Expand Down
128 changes: 124 additions & 4 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,15 @@ def __init__(
raise NotImplementedError(
"Liger Kernels don't currently support masking token positions based on entropy."
)
self.use_dgpo_dgae = args.use_dgpo_dgae
self.use_dgpo_dqw = args.use_dgpo_dqw
self.dgpo_dqw_temp = args.dgpo_dqw_temp
self.dgpo_dqw_acc_reward_index = args.dgpo_dqw_acc_reward_index
if self.use_dgpo_dqw and (self.dgpo_dqw_acc_reward_index < 0 or self.dgpo_dqw_acc_reward_index >= len(self.reward_funcs)):
raise ValueError(
f"dgpo_dqw_acc_reward_index must be in [0, {len(self.reward_funcs)}), got "
f"{self.dgpo_dqw_acc_reward_index}."
)
if self.use_liger_kernel and not self.importance_sampling_level == "token":
raise NotImplementedError(
"Liger Kernels currently only support token-level importance sampling. Please set"
Expand Down Expand Up @@ -1583,6 +1592,90 @@ def _generate(self, prompts: list):
extra_fields,
)

def _compute_advantages_with_dgae(
self,
rewards: torch.Tensor,
num_generations: int,
*,
use_group_mad: bool | None = None,
) -> torch.Tensor:
"""Compute advantages using MAD (DGAE) as denominator. Call only when use_dgpo_dgae is True."""
advantages = rewards - rewards.mean()
if self.scale_rewards != "none":
if use_group_mad is None:
use_group_mad = self.scale_rewards == "group" and num_generations > 1
if use_group_mad:
mad_rewards = (
advantages.abs()
.view(-1, num_generations)
.mean(dim=1)
.repeat_interleave(num_generations, dim=0)
)
else:
mad_rewards = advantages.abs().mean().expand_as(rewards)
advantages = advantages / (mad_rewards + 1e-4)
return advantages

def _compute_valid_token_balancing_ratios(
self,
completion_mask: torch.Tensor,
is_std_zero: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute valid token-level balancing ratios (zero_mask_ratio and global_balancing_ratio).
Returns (zero_mask_ratio, global_balancing_ratio). Apply zero_mask_ratio to advantages before slice,
global_balancing_ratio after slice. Call only when use_dgpo_dgae or use_dgpo_dqw is True.
"""
completion_length_local = completion_mask.sum(dim=1)
completion_length_global = gather(completion_length_local)

global_completion_length_sum = completion_length_global.sum().clamp(min=1e-8)
local_completion_length_sum = completion_length_local.sum()

global_balancing_ratio = (
self.accelerator.num_processes * local_completion_length_sum / global_completion_length_sum
)

valid_mask_global = ~gather(is_std_zero)
if valid_mask_global.any():
valid_completion_length_sum = completion_length_global[valid_mask_global].sum().clamp(min=1e-8)
zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum
else:
zero_mask_ratio = torch.tensor(1.0, device=completion_mask.device, dtype=completion_mask.dtype)

return zero_mask_ratio, global_balancing_ratio

def _compute_dqw_weights(
self,
rewards: torch.Tensor,
rewards_per_func: torch.Tensor,
num_generations: int,
) -> torch.Tensor:
"""
Compute question-level difficulty balancing weights (DQW).
Returns difficulty_balancing_weights (num_questions,); expand with repeat_interleave at call site.
Weights sum to num_questions; zero-variance questions get weight 1.
Call only when use_dgpo_dqw is True.
"""
num_questions = rewards.size(0) // num_generations
acc_rewards = rewards_per_func[:, self.dgpo_dqw_acc_reward_index] # (N,)
mean_per_q_acc = acc_rewards.view(-1, num_generations).nanmean(dim=1) # (num_questions,)
std_per_q_acc = acc_rewards.view(-1, num_generations).std(dim=1) # (num_questions,)
is_std_zero_q = std_per_q_acc < 1e-8
num_zero_variance_questions = is_std_zero_q.sum().item()
difficulty_balancing_weights = torch.ones(
num_questions, device=rewards.device, dtype=rewards.dtype
)
if num_zero_variance_questions < num_questions:
mean_per_q_acc_modified = mean_per_q_acc.clone()
mean_per_q_acc_modified[(mean_per_q_acc == 0) | torch.isnan(mean_per_q_acc)] = 1.0
difficulty_balancing_weights[~is_std_zero_q] = (
num_questions - num_zero_variance_questions
) * torch.nn.functional.softmax(
-mean_per_q_acc_modified[~is_std_zero_q] / self.dgpo_dqw_temp, dim=0
)
return difficulty_balancing_weights

def _generate_and_score_completions(
self, inputs: list[dict[str, torch.Tensor | Any]]
) -> dict[str, torch.Tensor | Any]:
Expand Down Expand Up @@ -1824,9 +1917,14 @@ def _generate_and_score_completions(
f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'."
)

advantages = rewards - mean_grouped_rewards
if self.scale_rewards != "none":
advantages = advantages / (std_rewards + 1e-4)
if self.use_dgpo_dgae:
advantages = self._compute_advantages_with_dgae(
rewards, num_generations
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this split feels a bit asymmetric: the DGAE path goes into _compute_advantages_with_dgae while the standard advantage computation (center by mean, divide by std) stays inline. When reading the code you now have to jump to a separate method for one path but not the other, even though they're doing the same conceptual thing.

I think there's motivation for a larger refactoring of the advantage calculations LOC 1844-1890, but I'd like a maintainers thoughts on this. My suggestion would be to at least pull out std_rewards calculation into a helper, but open to a larger refactor as well.

Copy link
Author

@YanqiDai YanqiDai Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion. We have found that the implementation of _compute_advantages_with_dgae can be simplified and requires some minor adjustments. After making these corrections, we implemented it directly in _generate_and_score_completions (it only takes 3 lines of code, just as concise as using std_rewards).

else:
advantages = rewards - mean_grouped_rewards
if self.scale_rewards != "none":
advantages = advantages / (std_rewards + 1e-4)
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging

elif self.multi_objective_aggregation == "normalize_then_sum":
Expand All @@ -1837,7 +1935,12 @@ def _generate_and_score_completions(
reward_k = reward_k.view(-1, len(self.reward_funcs))
rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards)
advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4)
if self.use_dgpo_dgae:
advantages = self._compute_advantages_with_dgae(
rewards, num_generations, use_group_mad=False
)
else:
advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4)
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging

else:
Expand All @@ -1846,6 +1949,20 @@ def _generate_and_score_completions(
"'sum_then_normalize' or 'normalize_then_sum'."
)

# Valid token-level loss averaging: zero_mask_ratio before slice, global_balancing_ratio after slice
if self.use_dgpo_dgae or self.use_dgpo_dqw:
zero_mask_ratio, global_balancing_ratio = self._compute_valid_token_balancing_ratios(
completion_mask, is_std_zero
)
advantages = advantages * zero_mask_ratio

# DQW: multiply advantages by question-level weights; weights sum to num_questions, zero-variance questions get 1
if self.use_dgpo_dqw:
difficulty_balancing_weights = self._compute_dqw_weights(
rewards, rewards_per_func, num_generations
)
advantages = advantages * difficulty_balancing_weights.repeat_interleave(num_generations)

# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
Expand All @@ -1854,6 +1971,9 @@ def _generate_and_score_completions(
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
advantages = advantages[process_slice]

if self.use_dgpo_dgae or self.use_dgpo_dqw:
advantages = advantages * global_balancing_ratio

# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
Expand Down