-
Notifications
You must be signed in to change notification settings - Fork 494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding reverse and symmetric KLD losses #2094
base: main
Are you sure you want to change the base?
Conversation
- Adding KLD losses based on [link](https://github.com/jongwooko/distillm/blob/17c0f98bc263b1861a02d5df578c84aea652ee65/distillm/losses.py)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2094
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @insop! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
@ebsmothers , @lindawangg, PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @insop for the PR! I left a few comments but no major concerns. One thing you'll need to fix is the failing linter job -- if you haven't already you can set up and run pre-commit on all your modified files by following this section of our contributing guide (assuming you already performed a dev install). If you have any trouble do let me know and we can help out.
@@ -138,3 +237,164 @@ def forward( | |||
) | |||
|
|||
return total_fkl_loss / torch.sum(mask.view(-1), dim=0) | |||
|
|||
class ReverseKLWithChunkedOutputLoss(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary for this PR but as we are starting to have a proliferation of chunked loss implementations I wonder whether it'd be worth investing in a general utility to wrap an arbitrary loss with chunking operation @felipemello1
Thank you for the review and comments, @ebsmothers. |
@insop do you have any results training with the losses that you could add to the test plan? |
My apologies for the long delay. Addresses pre-commit check and review comments. |
My apologies for the long delay. Here is one example of running a small-scale test using Llama 3.1 8B as a teacher and Llama 3.2 1B as a student with the Code Alpaca dataset. I'm not entirely sure about 'adding to the test plan' in this case. Could you please clarify your suggestion?
|
Let me know if you have any comments. |
@insop sorry for the delayed response. Re "adding to the test plan", I think it's just referring to updating the PR summary to show your results. Let me quickly update it for you based on your previous comment. Also, where possible it's helpful to provide repro commands as part of the test plan to make it easier for others to verify. (Don't worry about it in this case, as your results are several weeks old now there is no need to go and dig them up.) Otherwise it appears that your linter job is still failing, lmk if you need any help with this. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2094 +/- ##
==========================================
+ Coverage 65.41% 67.04% +1.62%
==========================================
Files 344 352 +8
Lines 20658 20698 +40
==========================================
+ Hits 13514 13877 +363
+ Misses 7144 6821 -323 ☔ View full report in Codecov by Sentry. |
Thank you @ebsmothers, I have updated the lint issue. I used default config files and argument overriding for my training. I will put them together for others to use shortly. |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example