|
37 | 37 | from trl.models import prepare_deepspeed |
38 | 38 | from trl.trainer import grpo_trainer |
39 | 39 | from trl.trainer.callbacks import SyncRefModelCallback |
40 | | -from trl.trainer.grpo_trainer import RepeatSampler, nanmax, nanmin, nanstd |
| 40 | +from trl.trainer.grpo_trainer import RepeatSampler, nanmax, nanmin |
41 | 41 | from trl.trainer.utils import selective_log_softmax |
42 | 42 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
43 | 43 |
|
|
53 | 53 | from .arguments import GRPOConfig |
54 | 54 | from .rollout_mixin import DataType, RolloutTrainerMixin |
55 | 55 | from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, identity_data_collator, |
56 | | - load_pil_img, make_chord_sft_dataset, pad_logps_back_to_batch, patch_save_last_checkpoint, |
| 56 | + load_pil_img, make_chord_sft_dataset, nanstd, pad_logps_back_to_batch, patch_save_last_checkpoint, |
57 | 57 | profiling_context, profiling_decorator, replace_assistant_response_with_ids) |
58 | 58 |
|
59 | 59 | try: |
@@ -545,20 +545,15 @@ def log_rewards_all(rewards_per_func: torch.Tensor): |
545 | 545 | else: # edge case: num_generations_eval=1 |
546 | 546 | rewards_std = torch.zeros_like(rewards) |
547 | 547 | elif self.scale_rewards == 'gdpo': |
548 | | - num_reward_funcs = rewards_per_func.shape[1] |
549 | | - normalized_advantages_list = [] |
550 | | - for i in range(num_reward_funcs): |
551 | | - reward_i = rewards_per_func[:, i] |
552 | | - grouped_reward_i = reward_i.view(-1, K) |
553 | | - group_mean = grouped_reward_i.mean(dim=1, keepdim=True) |
554 | | - group_std = grouped_reward_i.std(dim=1, keepdim=True) + 1e-8 |
555 | | - normalized_i = (grouped_reward_i - group_mean) / group_std |
556 | | - normalized_i = normalized_i.view(-1) |
557 | | - normalized_advantages_list.append(self.reward_weights[i] * normalized_i) |
558 | | - summed_advantages = sum(normalized_advantages_list) |
559 | | - batch_mean = summed_advantages.mean() |
560 | | - batch_std = summed_advantages.std() + 1e-8 |
561 | | - advantages = (summed_advantages - batch_mean) / batch_std |
| 548 | + grouped = rewards_per_func.view(-1, K, rewards_per_func.shape[1]) |
| 549 | + group_mean = torch.nanmean(grouped, dim=1, keepdim=True) |
| 550 | + group_std = nanstd(grouped, dim=1, keepdim=True) if K > 1 else torch.zeros_like(group_mean) |
| 551 | + normalized = (grouped - group_mean) / (group_std + 1e-8) |
| 552 | + normalized = torch.nan_to_num(normalized, nan=0.0) |
| 553 | + normalized = normalized.view(-1, rewards_per_func.shape[1]) |
| 554 | + advantages = (normalized * self.reward_weights.unsqueeze(0)).sum(dim=1) |
| 555 | + batch_std = advantages.std() + 1e-8 |
| 556 | + advantages = (advantages - advantages.mean()) / batch_std |
562 | 557 | rewards_std = None |
563 | 558 | else: # 'none' |
564 | 559 | rewards_std = None |
|
0 commit comments