@@ -415,7 +415,9 @@ def compute_policy_loss(
415415 clip_ratio_low : float ,
416416 clip_ratio_high : float ,
417417 clip_ratio_dual : float ,
418- loss_type : Literal ["default" , "gspo" , "gspo_token" , "cispo" ],
418+ tau_positive : float ,
419+ tau_negative : float ,
420+ loss_type : Literal ["default" , "gspo" , "gspo_token" , "cispo" , "sapo" ],
419421 loss_avg_mode : Literal ["token" , "seq" ],
420422 ** kwargs ,
421423) -> tuple [torch .Tensor , dict [str , float ]]:
@@ -438,6 +440,10 @@ def compute_policy_loss(
438440 The higher clip range used in DAPO. See https://arxiv.org/pdf/2503.14476
439441 clip_ratio_dual: (float)
440442 The dual clip range used in Dual-clip PPO. See https://arxiv.org/pdf/1912.09729
443+ tau_positive: (float)
444+ The temperature for control the positive tokens' clipping in SAPO. See https://arxiv.org/pdf/2511.20347
445+ tau_negative: (float)
446+ The temperature for control the negative tokens' clipping in SAPO. See https://arxiv.org/pdf/2511.20347
441447 loss_avg_mode: (Literal["token", "seq"])
442448 "token": average the loss in the whole batch
443449 "seq": average the loss in each sequence then average the mean of the means
@@ -481,6 +487,12 @@ def compute_policy_loss(
481487
482488 if loss_type == "cispo" :
483489 final_pg_loss = - advantages * log_probs * clipped_ratio .detach ()
490+ elif loss_type == "sapo" :
491+ positive_token_mask = (advantages >= 0 ).float ()
492+ negative_token_mask = (advantages < 0 ).float ()
493+ gate_negative = 4.0 / tau_negative * torch .sigmoid (tau_negative * (ratio - 1.0 ))
494+ gate_positive = 4.0 / tau_positive * torch .sigmoid (tau_positive * (ratio - 1.0 ))
495+ final_pg_loss = - advantages * (positive_token_mask * gate_positive + negative_token_mask * gate_negative )
484496 else :
485497 pg_loss = - advantages * ratio # -ratio * A
486498 pg_loss2 = - advantages * clipped_ratio # -clip(ratio, 1-clip_low, 1+clip_high) * A
0 commit comments