Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds validation loss to LoRA fine tune single device #2238

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

MaxFrax
Copy link

@MaxFrax MaxFrax commented Jan 8, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.
#1042

Changelog

What are the changes made in this PR?
Adds support to a validation dataset and computes the loss on it after each epoch.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Jan 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2238

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit df8cd1e with merge base 27fd3a1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @MaxFrax!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@MaxFrax
Copy link
Author

MaxFrax commented Jan 8, 2025

@felipemello1 Finally I have been able to work on this. I'll make my way through the testing plan, but feel free to share any comment you might already have.

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 8, 2025
@felipemello1 felipemello1 self-requested a review January 9, 2025 03:02
@felipemello1
Copy link
Contributor

hey @MaxFrax , thank you! I am on PTO this week. I will get to it next week if someone doesnt do it before me.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @MaxFrax, thanks for this PR! The validation loop itself looks pretty reasonable, but I think we should figure out the right way to integrate it. E.g. right now it seems like we perform validation after every training epoch inside of the train method. Personally I would be in favor of splitting out into multiple methods to make things clearer. That will be a bit more work, but I want to make sure we expose this as clearly as possible. What about something like this?

def validate():
	# Should be roughly the code you added

def train():
	# Keep this mostly as it is, but add something like:
	if self.global_step % self.run_val_every_n_steps == 0:
		self.validate()

Then we can expose run_val_every_n_steps via config. A couple other things to think about would be a maximum number of batches in the val loop and early stopping. I don't think we need to worry about the latter for this PR, but should make sure it's something we're able to support later on.

Also cc @joecummings @felipemello1 @calvinpelletier for anything I've missed here.

for idx, batch in enumerate(self._dataloader_val):
utils.batch_to_device(batch, self._device)

current_loss = self._loss_step(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we will need to toggle eval <-> train mode for the model, right? (Another reason having a separate method will probably be cleanest)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ebsmothers ! It makes sense. I will look into it. Do you have any pointers on how to do that?

@MaxFrax
Copy link
Author

MaxFrax commented Jan 15, 2025

Hi @ebsmothers ! I have updated the pr with the following edits, as per your recommendation:

  • Created a stand alone method for the validation loop
  • Parameter run_val_every_n_steps to invoke validate in specific points of the training epoch
  • I also added max_validation_batches to cap the amount of batches run in each validation step

If there's any other feedback or comment, just let me know!

@felipemello1
Copy link
Contributor

Thanks for making the changes! I will take a look at this PR later today.

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, although I'm getting worried about the proliferation of cfg.get and cfg validation logic in the recipes. There's nothing inherently wrong about the cfg.get, but it encourages the use of hidden parameters not exposed to the user. I don't have a good long term solution for this, but since we are only modifying one recipe, maybe we could at least update the lora single device configs to expose these fields so users know that it exists and we can check if the cfg field is None directly?

dataset_validation: null
run_val_every_n_steps: null
max_validation_batches: -1

I know this will affect a lot of files, so open to thoughts. We could also do this in a follow-up.

step=(curr_epoch + 1) * idx + idx,
)

if self.run_validation:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need this check since you only call validate() if self.run_validation is True

)

self.run_val_every_n_steps = cfg.get("run_val_every_n_steps", None)
if self.run_validation:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could remove this if statement and keep all the logic below under the first if self.run_validation check

@@ -335,6 +335,29 @@ def setup(self, cfg: DictConfig) -> None:
last_epoch=self.global_step - 1,
)

# Setup the validation dataset
self.run_validation = "dataset_validation" in cfg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would prefer the name validation_dataset

In fact, what do you think about organizing the validation arguments in the configs like so:

validation:
  dataset:
    ...
  run_every_n_steps: null
  max_batches: -1

that way you can just set validation: null and query that for self.run_validation

@felipemello1
Copy link
Contributor

maybe we could at least update the lora single device configs to expose these fields so users know that it exists and we can check if the cfg field is None directly?

Lets do this as a follow up. I can use my script to bulk update. But lets make sure that we all agree on how it should like in the config.

@codecov-commenter
Copy link

codecov-commenter commented Jan 15, 2025

Codecov Report

Attention: Patch coverage is 0% with 28 lines in your changes missing coverage. Please review.

Project coverage is 23.93%. Comparing base (213f386) to head (df8cd1e).
Report is 18 commits behind head on main.

Files with missing lines Patch % Lines
recipes/lora_finetune_single_device.py 0.00% 28 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (213f386) and HEAD (df8cd1e). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (213f386) HEAD (df8cd1e)
9 3
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2238       +/-   ##
===========================================
- Coverage   65.41%   23.93%   -41.49%     
===========================================
  Files         344      357       +13     
  Lines       20658    21153      +495     
===========================================
- Hits        13514     5062     -8452     
- Misses       7144    16091     +8947     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR! It looks simple and the functions make sense!

I added a few comments/ideas. Please push back on what you disagree.

IMO, to approve this, we would need two things:

  1. Testing, like i suggested in one of the comments. Let me know if you are comfortable running it, otherwise we can help you out
  2. An example of how the config should look like. The UI should play a big factor on this PR.

@@ -652,6 +675,43 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:

return loss

def validate(self, curr_epoch) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Do we have model.eval() somewhere?

usually we want to set the model to .eval mode, because some layers have different behavior, like dropout.

By doing that, we then require less memory, because we only need the forward pass, which allows us to have a higher batch_size --> faster validation step.

I am not sure about the implications it may have to compile/FSDP. For example, compile will have to create a new graph that doesnt require grad, so compile time will have to increase. If the number of graph breaks increase, we may have to manually change the threshold of maximum number of graph breaks. (there is an example of that in one of our RL recipes)

  1. not all recipes have self._loss_step. We would have to standardize and make sure that they all do, but this requires a different PR,.

IMO, if you have access to >1 GPU, I would encourage you to implement it in lora_distributed with QLoRA config, add .eval(), run it:

  • with eval + compile + opt_in_bwd + activation ckpt + activation offloading
  • without eval + compile + opt_in_bwd + activation ckpt + activation offloading

If nothing breaks, I would feel more confident in approving it

Ps: we would also have to add mode.train() in the training loop

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felipemello1 Thanks for the detailed breakdown and suggestions. Should we also unload the model being trained before loading the eval one? Having just one in memory would allow for bigger batch sizes.

That said, I’m currently constrained on time and not very familiar with the implementation details for this. If I were to take this on, it would likely take me a significant amount of time to get it done properly.

Would you be able to take the lead on this?

Copy link
Contributor

@felipemello1 felipemello1 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @MaxFrax , completely understandable. Thanks for sharing it.

I dont think that I will have bandwidth soon, but if i do, this PR is a good start.

@Ankur-singh , cc'ing you in case you are looking for more issues to contribute to! :D

Thank you guys!


# This bit allows to see the loss for each batch. Not sure about step indexing.
log_dict = {
"val_loss": current_loss.item(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should be logging memory/TPS too. If memory is very low, this would show to the user that they can increase bsz. What do you think?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense. I will definitely look into it.
By the way, when is the previous training batch deallocated from memory? Do I have to deallocate manually? It would be handy to do so before staring the validation step to have more memory available.

@@ -779,6 +839,12 @@ def train(self) -> None:
)
)

if (
self.run_validation
and self.global_step % self.run_val_every_n_steps == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it makes a lot of sense to eval every N steps, but currently a lot of our training logic is based on epochs. I wonder if we should honor this and keep it based on epoch. Maybe users could pass a float, e.g. every 0.5 epochs.

Not 100% sure about this, just brainstorming.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to change it as you suggest, I like it better too. I just want to point out that the scheduler has the parameter num_warmup_steps which contradicts your statement:

our training logic is based on epochs

As user, I'd love the num_warmup_steps to be based on epochs as well.

@MaxFrax
Copy link
Author

MaxFrax commented Jan 16, 2025

Thanks @felipemello1 ! Some help on the testing side would be much appreciated.
When you say:

  1. An example of how the config should look like. The UI should play a big factor on this PR.
    what do you exactly mean?

Should I provide a recipe using the validation dataset? Are we talking about the docs?
Let me know more precisely what I should do, and I'll be happy to look into it.

@felipemello1
Copy link
Contributor

felipemello1 commented Jan 16, 2025

@MaxFrax

Should I provide a recipe using the validation dataset? Are we talking about the docs?
Let me know more precisely what I should do, and I'll be happy to look into it.

Your PR only contains changes to the recipe. I would encourage you to:

  • make changes to one of the configs too, to illustrate how users would use it
  • Put the command to launch this config in the description of the PR, under the testing section
  • share an image of the logs generated in weights and biases under the testing section

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants