Skip to content

Commit 185bf5b

Browse files
committed
fix is clipped according to trl
1 parent 31ca95b commit 185bf5b

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

src/liger_kernel/chunked_loss/grpo_loss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def rlhf_loss_fn(
5757
metrics = []
5858
if beta != 0.0:
5959
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
60-
is_clipped = (per_token_loss1 < per_token_loss2).float()
60+
is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
61+
(coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
62+
)
6163
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
6264
return loss, metrics
6365

test/chunked_loss/test_grpo_loss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def forward(
9797
metrics = []
9898
if self.beta != 0.0:
9999
metrics.append(((kl_div * attention_mask).sum() / torch.clamp(attention_mask.sum(), min=1.0)))
100-
is_clipped = (per_token_loss1 < per_token_loss2).float()
100+
is_clipped = ((coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
101+
(coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
102+
)
101103
metrics.append((is_clipped * attention_mask).sum() / torch.clamp(attention_mask.sum(), min=1.0))
102104
return loss, metrics
103105

0 commit comments

Comments
 (0)