File tree Expand file tree Collapse file tree
src/liger_kernel/chunked_loss Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -57,7 +57,9 @@ def rlhf_loss_fn(
5757 metrics = []
5858 if beta != 0.0 :
5959 metrics .append (((kl_div * attention_mask ).sum () / torch .clamp (full_attention_mask .sum (), min = 1.0 )))
60- is_clipped = (per_token_loss1 < per_token_loss2 ).float ()
60+ is_clipped = ((coef_1 < 1 - epsilon_low ) & (advantages .unsqueeze (1 ) < 0 )) | (
61+ (coef_1 > 1 + epsilon_high ) & (advantages .unsqueeze (1 ) > 0 )
62+ )
6163 metrics .append ((is_clipped * attention_mask ).sum () / torch .clamp (full_attention_mask .sum (), min = 1.0 ))
6264 return loss , metrics
6365
Original file line number Diff line number Diff line change @@ -97,7 +97,9 @@ def forward(
9797 metrics = []
9898 if self .beta != 0.0 :
9999 metrics .append (((kl_div * attention_mask ).sum () / torch .clamp (attention_mask .sum (), min = 1.0 )))
100- is_clipped = (per_token_loss1 < per_token_loss2 ).float ()
100+ is_clipped = ((coef_1 < 1 - self .epsilon_low ) & (advantages .unsqueeze (1 ) < 0 )) | (
101+ (coef_1 > 1 + self .epsilon_high ) & (advantages .unsqueeze (1 ) > 0 )
102+ )
101103 metrics .append ((is_clipped * attention_mask ).sum () / torch .clamp (attention_mask .sum (), min = 1.0 ))
102104 return loss , metrics
103105
You can’t perform that action at this time.
0 commit comments