Skip to content

Commit f8f700d

Browse files
committed
auto-generating sphinx docs
1 parent 73159a6 commit f8f700d

11 files changed

+5
-951
lines changed

main/_modules/torchtune/rlhf/loss/dpo.html

-75
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,6 @@ <h1>Source code for torchtune.rlhf.loss.dpo</h1><div class="highlight"><pre>
457457
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
458458
<span class="kn">import</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
459459
<span class="kn">import</span><span class="w"> </span><span class="nn">torch.nn.functional</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">F</span>
460-
<span class="kn">from</span><span class="w"> </span><span class="nn">torchtune.utils._logging</span><span class="w"> </span><span class="kn">import</span> <span class="n">deprecated</span>
461460

462461

463462
<div class="viewcode-block" id="DPOLoss"><a class="viewcode-back" href="../../../../generated/torchtune.rlhf.loss.DPOLoss.html#torchtune.rlhf.loss.DPOLoss">[docs]</a><span class="k">class</span><span class="w"> </span><span class="nc">DPOLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
@@ -607,80 +606,6 @@ <h1>Source code for torchtune.rlhf.loss.dpo</h1><div class="highlight"><pre>
607606
<span class="p">)</span>
608607

609608
<span class="k">return</span> <span class="n">losses</span><span class="p">,</span> <span class="n">chosen_rewards</span><span class="p">,</span> <span class="n">rejected_rewards</span></div></div>
610-
611-
612-
<div class="viewcode-block" id="SimPOLoss"><a class="viewcode-back" href="../../../../generated/torchtune.rlhf.loss.SimPOLoss.html#torchtune.rlhf.loss.SimPOLoss">[docs]</a><span class="nd">@deprecated</span><span class="p">(</span><span class="n">msg</span><span class="o">=</span><span class="s2">&quot;SimPOLoss will be deprecated in an upcoming release.&quot;</span><span class="p">)</span>
613-
<span class="k">class</span><span class="w"> </span><span class="nc">SimPOLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
614-
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
615-
<span class="sd"> SimPO: Simple Preference Optimization with a Reference-Free Reward: https://arxiv.org/abs/2405.14734.</span>
616-
<span class="sd"> Intuition from the paper:</span>
617-
618-
<span class="sd"> The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as</span>
619-
<span class="sd"> the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to</span>
620-
<span class="sd"> encourage a larger margin between the winning and losing responses, further enhancing the algorithm&#39;s performance.</span>
621-
622-
<span class="sd"> Based on the TRL implementation:</span>
623-
<span class="sd"> https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/cpo_trainer.py#L603</span>
624-
625-
<span class="sd"> SimPO is pretty much identitcal to DPO but uses average logprobs to eliminate the need for a reference model to regularize</span>
626-
<span class="sd"> the policy during training. It also uses a target reward margin to guide the policy towards better responses.</span>
627-
<span class="sd"> This is kind of the same intuition as in :class:`~torchtune.rlhf.loss.IPOLoss`, but instead of optimizing against</span>
628-
<span class="sd"> a margin between the reference policy and policy models, we&#39;re optimizing against a margin between the chosen and</span>
629-
<span class="sd"> rejected responses.</span>
630-
631-
<span class="sd"> Args:</span>
632-
<span class="sd"> beta (float): Equivalent temperature scaling parameter to DPO loss, typically in the range of 2.0 to 2.5. Default is 2.0.</span>
633-
<span class="sd"> gamma (float): Target reward margin hyperparameter, typically we have ``gamma in (0, 1.5]``.</span>
634-
<span class="sd"> Default is 0.5.</span>
635-
<span class="sd"> label_smoothing (float): Parameter encoding uncertainty about the labels. Default is 0.</span>
636-
<span class="sd"> &quot;&quot;&quot;</span>
637-
638-
<span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span>
639-
<span class="bp">self</span><span class="p">,</span>
640-
<span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">2.0</span><span class="p">,</span>
641-
<span class="n">gamma</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
642-
<span class="n">label_smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
643-
<span class="p">):</span>
644-
<span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
645-
<span class="bp">self</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
646-
<span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">=</span> <span class="n">gamma</span>
647-
<span class="bp">self</span><span class="o">.</span><span class="n">label_smoothing</span> <span class="o">=</span> <span class="n">label_smoothing</span>
648-
649-
<span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span>
650-
<span class="bp">self</span><span class="p">,</span>
651-
<span class="n">policy_chosen_logps</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
652-
<span class="n">policy_rejected_logps</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span>
653-
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span>
654-
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
655-
<span class="sd"> Compute the SimPO loss for a batch chosen and rejected average log probabilities.</span>
656-
657-
<span class="sd"> Args:</span>
658-
<span class="sd"> policy_chosen_logps (torch.Tensor): Average log probabilities of the policy model</span>
659-
<span class="sd"> for the chosen responses with shape [b,].</span>
660-
<span class="sd"> policy_rejected_logps (torch.Tensor): Average log probabilities of the policy model</span>
661-
<span class="sd"> for the rejected responses with shape [b,].</span>
662-
663-
<span class="sd"> Returns:</span>
664-
<span class="sd"> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; A tuple of three tensors with shape [b,]:</span>
665-
<span class="sd"> - losses: The SimPO loss for each example in the batch.</span>
666-
<span class="sd"> - chosen_rewards: Rewards for the chosen responses.</span>
667-
<span class="sd"> - rejected_rewards: Rewards for the rejected responses.</span>
668-
<span class="sd"> &quot;&quot;&quot;</span>
669-
670-
<span class="n">pi_logratios</span> <span class="o">=</span> <span class="n">policy_chosen_logps</span> <span class="o">-</span> <span class="n">policy_rejected_logps</span>
671-
672-
<span class="n">gamma_logratios</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span>
673-
<span class="n">logits</span> <span class="o">=</span> <span class="n">pi_logratios</span> <span class="o">-</span> <span class="n">gamma_logratios</span>
674-
675-
<span class="n">losses</span> <span class="o">=</span> <span class="p">(</span>
676-
<span class="o">-</span><span class="n">F</span><span class="o">.</span><span class="n">logsigmoid</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">beta</span> <span class="o">*</span> <span class="n">logits</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">label_smoothing</span><span class="p">)</span>
677-
<span class="o">-</span> <span class="n">F</span><span class="o">.</span><span class="n">logsigmoid</span><span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">beta</span> <span class="o">*</span> <span class="n">logits</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">label_smoothing</span>
678-
<span class="p">)</span>
679-
680-
<span class="n">chosen_rewards</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span> <span class="o">*</span> <span class="p">(</span><span class="n">policy_chosen_logps</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
681-
<span class="n">rejected_rewards</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span> <span class="o">*</span> <span class="p">(</span><span class="n">policy_rejected_logps</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
682-
683-
<span class="k">return</span> <span class="n">losses</span><span class="p">,</span> <span class="n">chosen_rewards</span><span class="p">,</span> <span class="n">rejected_rewards</span></div>
684609
</pre></div>
685610

686611
</article>

main/_sources/api_ref_rlhf.rst.txt

-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,3 @@ Components and losses for RLHF algorithms like PPO and DPO.
1616
loss.PPOLoss
1717
loss.DPOLoss
1818
loss.RSOLoss
19-
loss.SimPOLoss

main/_sources/generated/torchtune.rlhf.loss.SimPOLoss.rst.txt

-6
This file was deleted.

main/_sources/recipes/dpo.rst.txt

-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ To use any of these, simply use the ``loss`` config entry or flag through the :r
5656
loss=torchtune.modules.loss.RSOLoss \
5757
gamma=0.5
5858
59-
.. todo (@SalmanMohammadi) point to an example repo for SimPO
60-
6159
For a deeper understanding of the different levers you can pull when using this recipe,
6260
see our documentation for the different PEFT training paradigms we support:
6361

main/api_ref_rlhf.html

-3
Original file line numberDiff line numberDiff line change
@@ -473,9 +473,6 @@ <h1>torchtune.rlhf<a class="headerlink" href="#torchtune-rlhf" title="Permalink
473473
<tr class="row-even"><td><p><a class="reference internal" href="generated/torchtune.rlhf.loss.RSOLoss.html#torchtune.rlhf.loss.RSOLoss" title="torchtune.rlhf.loss.RSOLoss"><code class="xref py py-obj docutils literal notranslate"><span class="pre">loss.RSOLoss</span></code></a></p></td>
474474
<td><p>Statistical Rejection Sampling Optimization (RSO) or &quot;hinge&quot; loss module: <a class="reference external" href="https://arxiv.org/abs/2309.06657">https://arxiv.org/abs/2309.06657</a>.</p></td>
475475
</tr>
476-
<tr class="row-odd"><td><p><a class="reference internal" href="generated/torchtune.rlhf.loss.SimPOLoss.html#torchtune.rlhf.loss.SimPOLoss" title="torchtune.rlhf.loss.SimPOLoss"><code class="xref py py-obj docutils literal notranslate"><span class="pre">loss.SimPOLoss</span></code></a></p></td>
477-
<td><p>Simple Preference Optimization with a Reference-Free Reward: <a class="reference external" href="https://arxiv.org/abs/2405.14734">https://arxiv.org/abs/2405.14734</a>.</p></td>
478-
</tr>
479476
</tbody>
480477
</table>
481478
</section>

main/api_ref_training.html

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
<link rel="index" title="Index" href="genindex.html" />
4242
<link rel="search" title="Search" href="search.html" />
4343
<link rel="next" title="FullModelHFCheckpointer" href="generated/torchtune.training.FullModelHFCheckpointer.html" />
44-
<link rel="prev" title="torchtune.rlhf.loss.SimPOLoss" href="generated/torchtune.rlhf.loss.SimPOLoss.html" />
44+
<link rel="prev" title="RSOLoss" href="generated/torchtune.rlhf.loss.RSOLoss.html" />
4545
<!-- Google Tag Manager -->
4646
<script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start':
4747
new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0],
@@ -620,7 +620,7 @@ <h2>Miscellaneous<a class="headerlink" href="#miscellaneous" title="Permalink to
620620
<a href="generated/torchtune.training.FullModelHFCheckpointer.html" class="btn btn-neutral float-right" title="FullModelHFCheckpointer" accesskey="n" rel="next">Next <img src="_static/images/chevron-right-orange.svg" class="next-page"></a>
621621

622622

623-
<a href="generated/torchtune.rlhf.loss.SimPOLoss.html" class="btn btn-neutral" title="torchtune.rlhf.loss.SimPOLoss" accesskey="p" rel="prev"><img src="_static/images/chevron-right-orange.svg" class="previous-page"> Previous</a>
623+
<a href="generated/torchtune.rlhf.loss.RSOLoss.html" class="btn btn-neutral" title="RSOLoss" accesskey="p" rel="prev"><img src="_static/images/chevron-right-orange.svg" class="previous-page"> Previous</a>
624624

625625
</div>
626626

main/generated/torchtune.rlhf.loss.RSOLoss.html

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
<link rel="stylesheet" href="../_static/css/custom_torchtune.css" type="text/css" />
4141
<link rel="index" title="Index" href="../genindex.html" />
4242
<link rel="search" title="Search" href="../search.html" />
43-
<link rel="next" title="torchtune.rlhf.loss.SimPOLoss" href="torchtune.rlhf.loss.SimPOLoss.html" />
43+
<link rel="next" title="torchtune.training" href="../api_ref_training.html" />
4444
<link rel="prev" title="DPOLoss" href="torchtune.rlhf.loss.DPOLoss.html" />
4545
<!-- Google Tag Manager -->
4646
<script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start':
@@ -516,7 +516,7 @@ <h1>RSOLoss<a class="headerlink" href="#rsoloss" title="Permalink to this headin
516516

517517
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
518518

519-
<a href="torchtune.rlhf.loss.SimPOLoss.html" class="btn btn-neutral float-right" title="torchtune.rlhf.loss.SimPOLoss" accesskey="n" rel="next">Next <img src="../_static/images/chevron-right-orange.svg" class="next-page"></a>
519+
<a href="../api_ref_training.html" class="btn btn-neutral float-right" title="torchtune.training" accesskey="n" rel="next">Next <img src="../_static/images/chevron-right-orange.svg" class="next-page"></a>
520520

521521

522522
<a href="torchtune.rlhf.loss.DPOLoss.html" class="btn btn-neutral" title="DPOLoss" accesskey="p" rel="prev"><img src="../_static/images/chevron-right-orange.svg" class="previous-page"> Previous</a>

0 commit comments

Comments
 (0)