Skip to content

Commit e23a9c8

Browse files
committed
fix None reward
1 parent de6e595 commit e23a9c8

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

swift/rlhf_trainers/grpo_trainer.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from trl.models import prepare_deepspeed
3838
from trl.trainer import grpo_trainer
3939
from trl.trainer.callbacks import SyncRefModelCallback
40-
from trl.trainer.grpo_trainer import RepeatSampler, nanmax, nanmin, nanstd
40+
from trl.trainer.grpo_trainer import RepeatSampler, nanmax, nanmin
4141
from trl.trainer.utils import selective_log_softmax
4242
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4343

@@ -53,7 +53,7 @@
5353
from .arguments import GRPOConfig
5454
from .rollout_mixin import DataType, RolloutTrainerMixin
5555
from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, identity_data_collator,
56-
load_pil_img, make_chord_sft_dataset, pad_logps_back_to_batch, patch_save_last_checkpoint,
56+
load_pil_img, make_chord_sft_dataset, nanstd, pad_logps_back_to_batch, patch_save_last_checkpoint,
5757
profiling_context, profiling_decorator, replace_assistant_response_with_ids)
5858

5959
try:
@@ -545,20 +545,15 @@ def log_rewards_all(rewards_per_func: torch.Tensor):
545545
else: # edge case: num_generations_eval=1
546546
rewards_std = torch.zeros_like(rewards)
547547
elif self.scale_rewards == 'gdpo':
548-
num_reward_funcs = rewards_per_func.shape[1]
549-
normalized_advantages_list = []
550-
for i in range(num_reward_funcs):
551-
reward_i = rewards_per_func[:, i]
552-
grouped_reward_i = reward_i.view(-1, K)
553-
group_mean = grouped_reward_i.mean(dim=1, keepdim=True)
554-
group_std = grouped_reward_i.std(dim=1, keepdim=True) + 1e-8
555-
normalized_i = (grouped_reward_i - group_mean) / group_std
556-
normalized_i = normalized_i.view(-1)
557-
normalized_advantages_list.append(self.reward_weights[i] * normalized_i)
558-
summed_advantages = sum(normalized_advantages_list)
559-
batch_mean = summed_advantages.mean()
560-
batch_std = summed_advantages.std() + 1e-8
561-
advantages = (summed_advantages - batch_mean) / batch_std
548+
grouped = rewards_per_func.view(-1, K, rewards_per_func.shape[1])
549+
group_mean = torch.nanmean(grouped, dim=1, keepdim=True)
550+
group_std = nanstd(grouped, dim=1, keepdim=True) if K > 1 else torch.zeros_like(group_mean)
551+
normalized = (grouped - group_mean) / (group_std + 1e-8)
552+
normalized = torch.nan_to_num(normalized, nan=0.0)
553+
normalized = normalized.view(-1, rewards_per_func.shape[1])
554+
advantages = (normalized * self.reward_weights.unsqueeze(0)).sum(dim=1)
555+
batch_std = advantages.std() + 1e-8
556+
advantages = (advantages - advantages.mean()) / batch_std
562557
rewards_std = None
563558
else: # 'none'
564559
rewards_std = None

swift/rlhf_trainers/utils.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,23 +206,34 @@ def _patched_stateless_pg_create(
206206
patch_stateless_process_group_for_ipv6()
207207

208208

209-
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
209+
def nanstd(tensor: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False) -> torch.Tensor:
210210
"""
211-
refer: trl/trainer/utils
212-
Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.
211+
Compute the standard deviation of a tensor, ignoring NaNs.
212+
213+
Refer: trl/trainer/utils.py
213214
214215
Args:
215216
tensor (`torch.Tensor`):
216-
Input tensor of shape `(N,)`.
217+
Input tensor.
218+
dim (`int`, *optional*):
219+
Dimension to reduce. Defaults to all dimensions.
220+
keepdim (`bool`, *optional*, defaults to `False`):
221+
Whether to keep reduced dimensions.
217222
218223
Returns:
219224
`torch.Tensor`:
220225
Standard deviation of the tensor, ignoring NaNs.
221226
"""
222-
variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True))**2) # Compute variance ignoring NaNs
223-
count = torch.sum(~torch.isnan(tensor)) # Count of non-NaN values
224-
variance *= count / (count - 1) # Bessel's correction
225-
return torch.sqrt(variance)
227+
mean = torch.nanmean(tensor, dim=dim, keepdim=True)
228+
variance = torch.nanmean((tensor - mean)**2, dim=dim, keepdim=True)
229+
count = torch.sum(~torch.isnan(tensor), dim=dim, keepdim=True)
230+
correction = torch.where(count > 1, count / (count - 1), torch.full_like(count, float('nan')))
231+
std = torch.sqrt(variance * correction)
232+
if keepdim:
233+
return std
234+
if dim is None:
235+
return std.squeeze()
236+
return std.squeeze(dim)
226237

227238

228239
# code borrowed from verl/verl/utils/memory_utils.py

0 commit comments

Comments
 (0)