From 2316830467d54fee613d1c0a6e3f75585216570c Mon Sep 17 00:00:00 2001 From: Sergei Isaev <48261488+s-isaev@users.noreply.github.com> Date: Tue, 25 Apr 2023 12:56:48 +0300 Subject: [PATCH] Fix numerical instability of critic loss --- .../DeepSpeed-Chat/training/utils/model/reward_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/DeepSpeed-Chat/training/utils/model/reward_model.py b/applications/DeepSpeed-Chat/training/utils/model/reward_model.py index c8912de56..f9ad3691e 100644 --- a/applications/DeepSpeed-Chat/training/utils/model/reward_model.py +++ b/applications/DeepSpeed-Chat/training/utils/model/reward_model.py @@ -99,8 +99,8 @@ def forward(self, chosen_reward[c_ind - 1]) #use the end score for reference rejected_mean_scores.append(rejected_reward[r_ind - 1]) - loss += -torch.log( - torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean() + loss += nn.functional.softplus( + r_truncated_reward - c_truncated_reward).mean() loss = loss / bs chosen_mean_scores = torch.stack(chosen_mean_scores)