-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Add support for DGPO (ICLR 2026) to GRPO #5102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
4a8bf02
0383a57
b0f72ef
90fc3f5
34a69eb
da8c445
721c3eb
25def33
1ecdae6
6dd54db
6291ae3
df1fb48
e2b254d
fb95c87
89e1384
e07db90
f34d3f1
1535983
2137c9d
7c5acea
ecdbe7e
4a2c30c
aec3912
4b864c1
4420d51
158b3de
56e489a
87782d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -546,6 +546,15 @@ def __init__( | |
| raise NotImplementedError( | ||
| "Liger Kernels don't currently support masking token positions based on entropy." | ||
| ) | ||
| self.use_dgpo_dgae = args.use_dgpo_dgae | ||
| self.use_dgpo_dqw = args.use_dgpo_dqw | ||
| self.dgpo_dqw_temp = args.dgpo_dqw_temp | ||
| self.dgpo_dqw_acc_reward_index = args.dgpo_dqw_acc_reward_index | ||
| if self.use_dgpo_dqw and (self.dgpo_dqw_acc_reward_index < 0 or self.dgpo_dqw_acc_reward_index >= len(self.reward_funcs)): | ||
| raise ValueError( | ||
| f"dgpo_dqw_acc_reward_index must be in [0, {len(self.reward_funcs)}), got " | ||
| f"{self.dgpo_dqw_acc_reward_index}." | ||
| ) | ||
| if self.use_liger_kernel and not self.importance_sampling_level == "token": | ||
| raise NotImplementedError( | ||
| "Liger Kernels currently only support token-level importance sampling. Please set" | ||
|
|
@@ -1583,6 +1592,90 @@ def _generate(self, prompts: list): | |
| extra_fields, | ||
| ) | ||
|
|
||
| def _compute_advantages_with_dgae( | ||
| self, | ||
| rewards: torch.Tensor, | ||
| num_generations: int, | ||
| *, | ||
| use_group_mad: bool | None = None, | ||
| ) -> torch.Tensor: | ||
| """Compute advantages using MAD (DGAE) as denominator. Call only when use_dgpo_dgae is True.""" | ||
| advantages = rewards - rewards.mean() | ||
| if self.scale_rewards != "none": | ||
| if use_group_mad is None: | ||
| use_group_mad = self.scale_rewards == "group" and num_generations > 1 | ||
| if use_group_mad: | ||
| mad_rewards = ( | ||
| advantages.abs() | ||
| .view(-1, num_generations) | ||
| .mean(dim=1) | ||
| .repeat_interleave(num_generations, dim=0) | ||
| ) | ||
| else: | ||
| mad_rewards = advantages.abs().mean().expand_as(rewards) | ||
| advantages = advantages / (mad_rewards + 1e-4) | ||
| return advantages | ||
|
|
||
| def _compute_valid_token_balancing_ratios( | ||
| self, | ||
| completion_mask: torch.Tensor, | ||
| is_std_zero: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Compute valid token-level balancing ratios (zero_mask_ratio and global_balancing_ratio). | ||
| Returns (zero_mask_ratio, global_balancing_ratio). Apply zero_mask_ratio to advantages before slice, | ||
| global_balancing_ratio after slice. Call only when use_dgpo_dgae or use_dgpo_dqw is True. | ||
| """ | ||
YanqiDai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| completion_length_local = completion_mask.sum(dim=1) | ||
| completion_length_global = gather(completion_length_local) | ||
|
|
||
| global_completion_length_sum = completion_length_global.sum().clamp(min=1e-8) | ||
| local_completion_length_sum = completion_length_local.sum() | ||
|
|
||
| global_balancing_ratio = ( | ||
| self.accelerator.num_processes * local_completion_length_sum / global_completion_length_sum | ||
| ) | ||
|
|
||
| valid_mask_global = ~gather(is_std_zero) | ||
| if valid_mask_global.any(): | ||
| valid_completion_length_sum = completion_length_global[valid_mask_global].sum().clamp(min=1e-8) | ||
cursor[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
cursor[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| zero_mask_ratio = global_completion_length_sum / valid_completion_length_sum | ||
| else: | ||
| zero_mask_ratio = torch.tensor(1.0, device=completion_mask.device, dtype=completion_mask.dtype) | ||
|
|
||
| return zero_mask_ratio, global_balancing_ratio | ||
|
|
||
| def _compute_dqw_weights( | ||
| self, | ||
| rewards: torch.Tensor, | ||
| rewards_per_func: torch.Tensor, | ||
| num_generations: int, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Compute question-level difficulty balancing weights (DQW). | ||
| Returns difficulty_balancing_weights (num_questions,); expand with repeat_interleave at call site. | ||
| Weights sum to num_questions; zero-variance questions get weight 1. | ||
| Call only when use_dgpo_dqw is True. | ||
| """ | ||
YanqiDai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| num_questions = rewards.size(0) // num_generations | ||
| acc_rewards = rewards_per_func[:, self.dgpo_dqw_acc_reward_index] # (N,) | ||
YanqiDai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| mean_per_q_acc = acc_rewards.view(-1, num_generations).nanmean(dim=1) # (num_questions,) | ||
| std_per_q_acc = acc_rewards.view(-1, num_generations).std(dim=1) # (num_questions,) | ||
| is_std_zero_q = std_per_q_acc < 1e-8 | ||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| num_zero_variance_questions = is_std_zero_q.sum().item() | ||
| difficulty_balancing_weights = torch.ones( | ||
| num_questions, device=rewards.device, dtype=rewards.dtype | ||
| ) | ||
| if num_zero_variance_questions < num_questions: | ||
| mean_per_q_acc_modified = mean_per_q_acc.clone() | ||
| mean_per_q_acc_modified[(mean_per_q_acc == 0) | torch.isnan(mean_per_q_acc)] = 1.0 | ||
| difficulty_balancing_weights[~is_std_zero_q] = ( | ||
| num_questions - num_zero_variance_questions | ||
| ) * torch.nn.functional.softmax( | ||
| -mean_per_q_acc_modified[~is_std_zero_q] / self.dgpo_dqw_temp, dim=0 | ||
| ) | ||
| return difficulty_balancing_weights | ||
|
|
||
| def _generate_and_score_completions( | ||
| self, inputs: list[dict[str, torch.Tensor | Any]] | ||
| ) -> dict[str, torch.Tensor | Any]: | ||
|
|
@@ -1824,9 +1917,14 @@ def _generate_and_score_completions( | |
| f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." | ||
| ) | ||
|
|
||
| advantages = rewards - mean_grouped_rewards | ||
| if self.scale_rewards != "none": | ||
| advantages = advantages / (std_rewards + 1e-4) | ||
| if self.use_dgpo_dgae: | ||
| advantages = self._compute_advantages_with_dgae( | ||
| rewards, num_generations | ||
| ) | ||
|
||
| else: | ||
| advantages = rewards - mean_grouped_rewards | ||
| if self.scale_rewards != "none": | ||
| advantages = advantages / (std_rewards + 1e-4) | ||
| is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging | ||
|
|
||
| elif self.multi_objective_aggregation == "normalize_then_sum": | ||
|
|
@@ -1837,7 +1935,12 @@ def _generate_and_score_completions( | |
| reward_k = reward_k.view(-1, len(self.reward_funcs)) | ||
| rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) | ||
| std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) | ||
| advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) | ||
| if self.use_dgpo_dgae: | ||
| advantages = self._compute_advantages_with_dgae( | ||
| rewards, num_generations, use_group_mad=False | ||
| ) | ||
| else: | ||
| advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) | ||
| is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging | ||
|
|
||
| else: | ||
|
|
@@ -1846,6 +1949,20 @@ def _generate_and_score_completions( | |
| "'sum_then_normalize' or 'normalize_then_sum'." | ||
| ) | ||
|
|
||
| # Valid token-level loss averaging: zero_mask_ratio before slice, global_balancing_ratio after slice | ||
| if self.use_dgpo_dgae or self.use_dgpo_dqw: | ||
| zero_mask_ratio, global_balancing_ratio = self._compute_valid_token_balancing_ratios( | ||
| completion_mask, is_std_zero | ||
| ) | ||
YanqiDai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| advantages = advantages * zero_mask_ratio | ||
|
|
||
| # DQW: multiply advantages by question-level weights; weights sum to num_questions, zero-variance questions get 1 | ||
YanqiDai marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.use_dgpo_dqw: | ||
| difficulty_balancing_weights = self._compute_dqw_weights( | ||
| rewards, rewards_per_func, num_generations | ||
| ) | ||
| advantages = advantages * difficulty_balancing_weights.repeat_interleave(num_generations) | ||
|
|
||
| # Slice to keep only the local part of the data | ||
| process_slice = slice( | ||
| self.accelerator.process_index * len(prompts), | ||
|
|
@@ -1854,6 +1971,9 @@ def _generate_and_score_completions( | |
| all_process_advantages = advantages.clone() # keep the aggregated advantages for logging | ||
| advantages = advantages[process_slice] | ||
|
|
||
| if self.use_dgpo_dgae or self.use_dgpo_dqw: | ||
| advantages = advantages * global_balancing_ratio | ||
|
|
||
| # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) | ||
| for i, reward_func_name in enumerate(self.reward_func_names): | ||
| mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.