@@ -457,7 +457,6 @@ <h1>Source code for torchtune.rlhf.loss.dpo</h1><div class="highlight"><pre>
457
457
< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> torch</ span >
458
458
< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> torch.nn</ span > < span class ="w "> </ span > < span class ="k "> as</ span > < span class ="w "> </ span > < span class ="nn "> nn</ span >
459
459
< span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> torch.nn.functional</ span > < span class ="w "> </ span > < span class ="k "> as</ span > < span class ="w "> </ span > < span class ="nn "> F</ span >
460
- < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> torchtune.utils._logging</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> deprecated</ span >
461
460
462
461
463
462
< div class ="viewcode-block " id ="DPOLoss "> < a class ="viewcode-back " href ="../../../../generated/torchtune.rlhf.loss.DPOLoss.html#torchtune.rlhf.loss.DPOLoss "> [docs]</ a > < span class ="k "> class</ span > < span class ="w "> </ span > < span class ="nc "> DPOLoss</ span > < span class ="p "> (</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Module</ span > < span class ="p "> ):</ span >
@@ -607,80 +606,6 @@ <h1>Source code for torchtune.rlhf.loss.dpo</h1><div class="highlight"><pre>
607
606
< span class ="p "> )</ span >
608
607
609
608
< span class ="k "> return</ span > < span class ="n "> losses</ span > < span class ="p "> ,</ span > < span class ="n "> chosen_rewards</ span > < span class ="p "> ,</ span > < span class ="n "> rejected_rewards</ span > </ div > </ div >
610
-
611
-
612
- < div class ="viewcode-block " id ="SimPOLoss "> < a class ="viewcode-back " href ="../../../../generated/torchtune.rlhf.loss.SimPOLoss.html#torchtune.rlhf.loss.SimPOLoss "> [docs]</ a > < span class ="nd "> @deprecated</ span > < span class ="p "> (</ span > < span class ="n "> msg</ span > < span class ="o "> =</ span > < span class ="s2 "> "SimPOLoss will be deprecated in an upcoming release."</ span > < span class ="p "> )</ span >
613
- < span class ="k "> class</ span > < span class ="w "> </ span > < span class ="nc "> SimPOLoss</ span > < span class ="p "> (</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Module</ span > < span class ="p "> ):</ span >
614
- < span class ="w "> </ span > < span class ="sd "> """</ span >
615
- < span class ="sd "> SimPO: Simple Preference Optimization with a Reference-Free Reward: https://arxiv.org/abs/2405.14734.</ span >
616
- < span class ="sd "> Intuition from the paper:</ span >
617
-
618
- < span class ="sd "> The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as</ span >
619
- < span class ="sd "> the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to</ span >
620
- < span class ="sd "> encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance.</ span >
621
-
622
- < span class ="sd "> Based on the TRL implementation:</ span >
623
- < span class ="sd "> https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/cpo_trainer.py#L603</ span >
624
-
625
- < span class ="sd "> SimPO is pretty much identitcal to DPO but uses average logprobs to eliminate the need for a reference model to regularize</ span >
626
- < span class ="sd "> the policy during training. It also uses a target reward margin to guide the policy towards better responses.</ span >
627
- < span class ="sd "> This is kind of the same intuition as in :class:`~torchtune.rlhf.loss.IPOLoss`, but instead of optimizing against</ span >
628
- < span class ="sd "> a margin between the reference policy and policy models, we're optimizing against a margin between the chosen and</ span >
629
- < span class ="sd "> rejected responses.</ span >
630
-
631
- < span class ="sd "> Args:</ span >
632
- < span class ="sd "> beta (float): Equivalent temperature scaling parameter to DPO loss, typically in the range of 2.0 to 2.5. Default is 2.0.</ span >
633
- < span class ="sd "> gamma (float): Target reward margin hyperparameter, typically we have ``gamma in (0, 1.5]``.</ span >
634
- < span class ="sd "> Default is 0.5.</ span >
635
- < span class ="sd "> label_smoothing (float): Parameter encoding uncertainty about the labels. Default is 0.</ span >
636
- < span class ="sd "> """</ span >
637
-
638
- < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span >
639
- < span class ="bp "> self</ span > < span class ="p "> ,</ span >
640
- < span class ="n "> beta</ span > < span class ="p "> :</ span > < span class ="nb "> float</ span > < span class ="o "> =</ span > < span class ="mf "> 2.0</ span > < span class ="p "> ,</ span >
641
- < span class ="n "> gamma</ span > < span class ="p "> :</ span > < span class ="nb "> float</ span > < span class ="o "> =</ span > < span class ="mf "> 0.5</ span > < span class ="p "> ,</ span >
642
- < span class ="n "> label_smoothing</ span > < span class ="p "> :</ span > < span class ="nb "> float</ span > < span class ="o "> =</ span > < span class ="mf "> 0.0</ span > < span class ="p "> ,</ span >
643
- < span class ="p "> ):</ span >
644
- < span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> ()</ span >
645
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> beta</ span > < span class ="o "> =</ span > < span class ="n "> beta</ span >
646
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> gamma</ span > < span class ="o "> =</ span > < span class ="n "> gamma</ span >
647
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> label_smoothing</ span > < span class ="o "> =</ span > < span class ="n "> label_smoothing</ span >
648
-
649
- < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> forward</ span > < span class ="p "> (</ span >
650
- < span class ="bp "> self</ span > < span class ="p "> ,</ span >
651
- < span class ="n "> policy_chosen_logps</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span >
652
- < span class ="n "> policy_rejected_logps</ span > < span class ="p "> :</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span >
653
- < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Tuple</ span > < span class ="p "> [</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> Tensor</ span > < span class ="p "> ]:</ span >
654
- < span class ="w "> </ span > < span class ="sd "> """</ span >
655
- < span class ="sd "> Compute the SimPO loss for a batch chosen and rejected average log probabilities.</ span >
656
-
657
- < span class ="sd "> Args:</ span >
658
- < span class ="sd "> policy_chosen_logps (torch.Tensor): Average log probabilities of the policy model</ span >
659
- < span class ="sd "> for the chosen responses with shape [b,].</ span >
660
- < span class ="sd "> policy_rejected_logps (torch.Tensor): Average log probabilities of the policy model</ span >
661
- < span class ="sd "> for the rejected responses with shape [b,].</ span >
662
-
663
- < span class ="sd "> Returns:</ span >
664
- < span class ="sd "> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; A tuple of three tensors with shape [b,]:</ span >
665
- < span class ="sd "> - losses: The SimPO loss for each example in the batch.</ span >
666
- < span class ="sd "> - chosen_rewards: Rewards for the chosen responses.</ span >
667
- < span class ="sd "> - rejected_rewards: Rewards for the rejected responses.</ span >
668
- < span class ="sd "> """</ span >
669
-
670
- < span class ="n "> pi_logratios</ span > < span class ="o "> =</ span > < span class ="n "> policy_chosen_logps</ span > < span class ="o "> -</ span > < span class ="n "> policy_rejected_logps</ span >
671
-
672
- < span class ="n "> gamma_logratios</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> gamma</ span > < span class ="o "> /</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> beta</ span >
673
- < span class ="n "> logits</ span > < span class ="o "> =</ span > < span class ="n "> pi_logratios</ span > < span class ="o "> -</ span > < span class ="n "> gamma_logratios</ span >
674
-
675
- < span class ="n "> losses</ span > < span class ="o "> =</ span > < span class ="p "> (</ span >
676
- < span class ="o "> -</ span > < span class ="n "> F</ span > < span class ="o "> .</ span > < span class ="n "> logsigmoid</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> beta</ span > < span class ="o "> *</ span > < span class ="n "> logits</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="p "> (</ span > < span class ="mi "> 1</ span > < span class ="o "> -</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> label_smoothing</ span > < span class ="p "> )</ span >
677
- < span class ="o "> -</ span > < span class ="n "> F</ span > < span class ="o "> .</ span > < span class ="n "> logsigmoid</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> beta</ span > < span class ="o "> *</ span > < span class ="n "> logits</ span > < span class ="p "> )</ span > < span class ="o "> *</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> label_smoothing</ span >
678
- < span class ="p "> )</ span >
679
-
680
- < span class ="n "> chosen_rewards</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> beta</ span > < span class ="o "> *</ span > < span class ="p "> (</ span > < span class ="n "> policy_chosen_logps</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> detach</ span > < span class ="p "> ()</ span >
681
- < span class ="n "> rejected_rewards</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> beta</ span > < span class ="o "> *</ span > < span class ="p "> (</ span > < span class ="n "> policy_rejected_logps</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> detach</ span > < span class ="p "> ()</ span >
682
-
683
- < span class ="k "> return</ span > < span class ="n "> losses</ span > < span class ="p "> ,</ span > < span class ="n "> chosen_rewards</ span > < span class ="p "> ,</ span > < span class ="n "> rejected_rewards</ span > </ div >
684
609
</ pre > </ div >
685
610
686
611
</ article >
0 commit comments