Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Classification Use Case #1502

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

qqlabs
Copy link

@qqlabs qqlabs commented Sep 5, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

My use case is adapting torchtune to be able to do LlamaForSequenceClassification. This differs from a lot of assumptions made in the current finetuning code which assumes we're finetuning for chat/instruct purposes. Issues: #1249, #1464

A sample dataset that I want this to work on is Amazon's shopping queries dataset, where we try to classify query and products as exact, substitute, complement, and irrelevant. I concat the query text and product title to become the prompt to classify into one of the 4 classes.

Changelog

Key changes:

  • Support multi-class classification
  • Add inference for classification
  • Support lora adapter checkpointing/loading
  • Add validation step during training
  • Address OOM issues

Multi-class classification training:
We essentially want to implement LlamaForSequenceClassification. I added a dataset, recipe, configs, and some other code changes to get this to work.

  • Dataset is structured as prompt and completion. The class label mappings are passed into the dataset in the classes parameter. We will encode the classes in the exact order they are passed in. The dataset tokenizes the prompt as normal and encodes the label to its class mapping.
  • Collate function updated to one hot encode the label
  • Forward pass updated to look at the model output for the eos_token (end of the prompt). Ideally, this step is built into the TransformerDecoder function somehow.
  • Loss updated to BCEWithLogitsLoss

Inference using trained classification model:

  • TBD need to update/add a classify function based off of generate.
  • I'm having trouble getting alignment for the logits when loading the trained model from torchtune and loading the trained model from huggingface (I converted the checkpoints to be loadable with from_pretrained)

Lora Adapter Issues:
The current code saves the lora adapter weights as a separate file from the base model weights. We hit missing key and key mismatch errors when we try to load in the two files using load_checkpoint.

In addition, I often only save the adapter weights to save on disk space. There is no clear guidance on merging the weights and loading it back in to do inference.

I added some functions to be able to load back in the adapter weights and merge it with a separate base model file. Will upload example of this later.

Validation Step:
Torchtune does not support validation step during training. I added it into the recipe in a hacky way. This currently only works at the end of the epoch since I hit deadlock issues when trying to validate in the middle of distributed training.

OOM Issues:
Reserved memory is not cleaned up at each training step and causes OOM issues. Added a cleanup step for each training step. Also found out I needed to reduce context length for my GPU of 40GB memory.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

…ointing, and some memory optimization working
Copy link

pytorch-bot bot commented Sep 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1502

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 5, 2024

self._loss_fn = config.instantiate(cfg.loss)

############################
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 dataloaders right now so that we can do train and validation

# Since this is classification drop the output weights from base llama model
# Otherwise load_state_dict would have key size mismatch
############################
del base_model_state_dict["output.weight"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't directly load in the base llama model weights since they don't match the newly replaced mlp layer for the classification version. Current hack is to drop the weights before loading into the classification model.

partial(
padded_collate,
padding_idx=self._tokenizer.pad_id,
# ignore_idx=self._loss_fn.ignore_index,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BCEWithLogitsLoss doesn't have ignore_index

running_loss = 0
num_tokens = 0

############################
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation only works at end of epoch


logits = self._model(tokens, mask=mask, input_pos=input_pos)

############################
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added step here where we pick the output for the eos token. This should ideally be built into the model itself - not sure what's the best way of doing that.

num_tokens = 0
t0 = time.perf_counter()

############################
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

free up memory at end of training step


self._profiler.stop()

############################
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validation loop

text_column: str = "text",
label_column: str = "label",
split: str = "train",
classes: Optional[List[Any]] = None,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class/label mapping currently passed in as a an ordered list. Not sure what's the best practice for this (do we store a separate label mapping file somewhere?)

@@ -246,6 +246,29 @@ def tune_to_peft_adapter_config(
return adapter_config


# When we only save adapter weights, we need to convert back from peft format to tune format
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use this when trying to load adapter config weights back. Will add sample code for what I got working later

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in, resuming training in torchtune using adapter weights saved from a previous torchtune LoRA run? How come that wasn't working?


# Note: this needs to be set before wrapping with FSDP
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we also need to set the final mlp layer of the base model to be trainable?

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - excellent catch. This is done through the apply_lora_to_output recipe config arg

Copy link
Contributor

@felipemello1 felipemello1 Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! This is not a review yet, but just food for thought: We have been having conversations for a while about removing the output head from the LLM forward, and running it as a second step.

Do you think that if this was implemented, we would still need a different recipe? For example:

hidden_state = model(inputs)
output = model.output_head(hidden_state) #edited
loss = self.loss_fn(logits, label) #edited

If we could make it work without adding a bunch of if/else everywhere in the recipe, then we could leverage the existing ones instead of 2x the number of recipes in the library

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clarify, did you mean loss = _loss_fn(output, label) for the last line?

I think this could work to separate out the model forward logic for the classification use case (where we need to find the last token and pick the hidden state for it). Another option is to build the forward logic into the model itself like how huggingface transformers are implemented.

The other big change for classification/prompt completion use cases is the logic for processing the labels. The chat/instruct use cases shift all the logits by 1 to do next token prediction for the labels. Prompt completion/classification tasks have the label defined in the dataset and processed by the dataloader. We would need to figure out how to pull out that label processing logic as well to avoid duplicating recipes.

Copy link
Contributor

@felipemello1 felipemello1 Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clarify, did you mean loss = _loss_fn(output, label) for the last line?

yes! Thanks!

We would need to figure out how to pull out that label processing logic as well to avoid duplicating recipes

That sounds manageable. It could be in the dataloader or some utility function.

I guess one 3rd issue is the checkpoint. I am not sure if this would require something different for classification vs generation.

So in summary, the main differences are:

  1. output head
  2. Label offset in generation
  3. checkpoint

Assuming you want to follow through, if you could adapt an existing generation recipe to support both generation/classification by making these changes, i think it would be easy to approve and a huge contribution to torchtune.

As a proof-of-concept, its fine to have some if/else and not be polished. And if it looks good, then we can brainstorm about how to polish it. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI there are a lot of changes coming to generation in #1424 which might be worth holding off for. I'm actually not sure inference for seq->label should be in the generation recipe which is beefed up for next-token prediction tasks + kv cacheing + compile. Not really stuff we need for classification tasks right?

I think the code we need for prediction here is pretty minimal - grab a prompt, run the model, then extract the logits at the last valid token (i.e. eos token) and activation function/argmax on those.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the feedback! I have some immediate use cases I need to get working, so will be focused on getting a working proof of concept that can replicate my current finetuned model that I trained with huggingface. Once we reach parity, we can know that the logic is correct and start refactoring into a more generalized design.

My current plan is:

  1. Create an evaluate recipe for classification so that we have some way to batch deploy trained models. I'll use this to check whether my current models are reaching parity with my baseline relevance models.
  2. Fix training logic for classification. My validation metrics are showing increasing validation loss and very low validation accuracy right now, so I think I have a bug somewhere. Trying to hunt this down.
  3. Make checkpoint loadable by HF from_pretrained. There are some format conversion issues that I had to adjust to get from_pretrained working. I'm also getting different logits for some reason when running eval for torchtune loaded checkpoint vs HF loaded checkpoint, so more investigation needed here.

Once the above are good to go, then can start refactoring to cleaner code design.

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I think arriving at a sane baseline first of all is a good idea.

Make checkpoint loadable by HF from_pretrained. There are some format conversion issues that I had to adjust to get from_pretrained working.

I'm curious about this - what changes did you have to make?

I'm also getting different logits for some reason when running eval for torchtune loaded checkpoint vs HF loaded checkpoint, so more investigation needed here.

If you have any code to share here I'd be happy to look. I've tried matching logits between HF before, but also couldn't get it to match. There are several differences in the underlying transformer modelling code, and quite a lot of differences between generation. I haven't found this to be an issue yet, though. When testing against HF baselines, a good sanity check is to use less granular comparisons e.g. similar loss curves/similar evaluation metrics across identical setups.

@SalmanMohammadi
Copy link
Collaborator

Hey @qqlabs. Thanks so much for sharing your work.

Currently I'm a bit full on with RLHF stuff, but I'm planning to revisit this properly in a couple weeks.
Right now I want to help unblock you. It's great to se a concrete example of a classification use case - with a bit of discussion we can try merging some of the ideas you have here upstream.

Firstly, I've put up my draft work on classification datasets (#1424). Please take a look. On that branch, you can do something like:

from torchtune.datasets._imdb import imdb_dataset
from torchtune.models.llama2 import llama2_tokenizer, llama2

tokenizer = llama2_tokenizer("./dummy/tokenizer.model")

ds = imdb_dataset(
    tokenizer=tokenizer
)

Then, to collate properly in your dataloader, you'd use our generic padding collation utils which just landed (

def padded_collate(
). See the docstring for an example of how you'd use it for a classification task.

I see you're converting to one-hot out-of-the-box, but AFAIK you don't actually have to convert here since the relevant torch loss functions will accept class indices (https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html). Let me know if this isn't correct though.

I'll add more comments to address your inference/validation stuff soon. If you need more direct support please don't hesitate to ping me on our Discord.

@SalmanMohammadi
Copy link
Collaborator

@qqlabs check https://pytorch.org/torchtune/main/generated/torchtune.training.update_state_dict_for_classifier.html#torchtune.training.update_state_dict_for_classifier which might help your troubles on checkpointing. Essentially you just need to call this function once before any LoRA state dict validation and it should work. There's an example in the PPO recipe for non-LoRA state dict loading.

@SalmanMohammadi
Copy link
Collaborator

I also haven't looked too closely at how you're predicting the labels. AFAIK the standard way (which is also what HF does) is to grab the logits from the model at the last valid token in the prompt - this will usually be the EOS token if it exists, or the last token otherwise, then use your activation functions / argmaxing etc on the logits for that token(s). I can point you to an example of this in the codebase - we actually used to have a separate utility just for this but it was lost in a refactor.

@qqlabs
Copy link
Author

qqlabs commented Sep 6, 2024

I also haven't looked too closely at how you're predicting the labels. AFAIK the standard way (which is also what HF does) is to grab the logits from the model at the last valid token in the prompt - this will usually be the EOS token if it exists, or the last token otherwise, then use your activation functions / argmaxing etc on the logits for that token(s). I can point you to an example of this in the codebase - we actually used to have a separate utility just for this but it was lost in a refactor.

Great, this is what I did here! This is currently copied from HF's implementation, but using the token before the eos token instead of the padding token which HF does.

@@ -246,6 +246,29 @@ def tune_to_peft_adapter_config(
return adapter_config


# When we only save adapter weights, we need to convert back from peft format to tune format
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in, resuming training in torchtune using adapter weights saved from a previous torchtune LoRA run? How come that wasn't working?

@@ -369,6 +368,83 @@ def lora_llama2_reward_7b(
quantize_base=quantize_base,
)

# Add llama multi classification model builders
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely want to make it easier to use classification models in our codebase, this is high-up on my TODO list - your comment about replacing reward models is on-point. Soon you should just be able to use wrap any existing model builder without having to re-define builders.

… train recipe. this gives us distributed evaluation. numbers are not matching the validation scores during training though so not sure what's happening.
@qqlabs
Copy link
Author

qqlabs commented Sep 7, 2024

@qqlabs check https://pytorch.org/torchtune/main/generated/torchtune.training.update_state_dict_for_classifier.html#torchtune.training.update_state_dict_for_classifier which might help your troubles on checkpointing. Essentially you just need to call this function once before any LoRA state dict validation and it should work. There's an example in the PPO recipe for non-LoRA state dict loading.

@SalmanMohammadi I just realized this function is a bit problematic for lora classification. From my understanding, the lora adapter values are added to the corresponding layers when merge the adapter into the base model. This would mess up the output layer since we're replacing the base model's output layer with random weights before merging in lora adapter values - thus messing up the output weights post training right? We need to keep track of the original output weights, which isn't done if we only save the adapter weights during training now.

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Sep 7, 2024

I'm sorry I'm not 100% sure I follow. If we're fine-tuning a language model for a classification task, here's roughly what I think should be happening:

  1. We should have a corresponding classifier model, which is identical to the base model, but the output projection layer is changed.
  2. Now, when we want to finetune our base model, we load all the weights in, aside from the output projection, which remains randomly initialized as you pointed out. From here on out we have no further need of the base model state dict.
  3. In full fine-tuning, we'll fine-tune the whole output projection weights, and in LoRA we'll fine-tune just an adapter on top of the randomly initialized weights.
  4. If LoRA fine-tuning, we merge the weights back in post training (if we've not set save adapters only). This maybe isn't great if we're merging trained output projection adapters with a randomly initialized output head.

Does that sound roughly right to you? I'm curious what you mean by keeping track of the original output weights.

You've made me think of something important here though. Adapter based fine-tuning might not be suitable for the output projection which is randomly initialized. This means we might want to do something like adapter-based fine-tuning for the rest of the model, don't apply an adapter to the output projection, and just unfreeze the final layer so it trains from scratch.

I've been doing some LoRA classification model fine-tuning, but I'm going from pertained classifier -> finetuned classifier so I haven't hit the same edges cases here. This is really useful.

@qqlabs
Copy link
Author

qqlabs commented Sep 7, 2024

  1. If LoRA fine-tuning, we merge the weights back in post training (if we've not set save adapters only). This maybe isn't great if we're merging trained output projection adapters with a randomly initialized output head.

This is the part that I'm worried about. I think it's ok right now if we have the save_adapter_weights_only set to false since we would merge the weights with the trained model while calling save_checkpoint and can load the whole model back in as is without looking at the adapter weights.

The issue is when we only save adapter weights as checkpoints since we lose the "random initialized output weights" that the adapter weights were trained on.

Current steps to load the model with only adapter checkpoints:

  1. Download original llama2_7b base model weights
  2. Load the base model weights.
  3. We don't have the random initialized weights that we used during finetuning which are associated with the adapter checkpoint. This means we can only call update_state_dict_for_classifier which creates a new random set of weights for the output layer.
  4. Merge the adapter weights with base model weights. At this step, we're basically adding a random number to each of the output layer weights of the adapter.

I think a finetuned llama model with a completely random output layer would produce random results. I'm seeing exactly random accuracy (~25% for my 4 class classification) with my eval script with this setup.

I did have the full model checkpointed for one of my runs and I am able to reproduce something closer to the validation metrics that I was seeing during my training process (which has some other bug since I'm getting like 9% accuracy...).

Overall, we cannot only save the adapter weights for classification tasks - we need to also save the randomly initialized output weights (or trained output weights) somewhere (could be part of the adapter weights).

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Sep 7, 2024

Can you try set apply_lora_to_output: False, and then as a quick hack add, re-enable requires_grad=True on the output weight parameters? This should mean we freeze base model weights -> setup adapter params (not for the output weight) -> unfreeze output weight -> train.

This won't work with save_adapter_weights_only. I think longer term for classification tasks, we need some sensible solution which either trains the output weight from scratch, or, if you're not starting from randomly initialized weights (finetuning an existing classifier), only then adding an adapter.

@qqlabs
Copy link
Author

qqlabs commented Sep 9, 2024

Can you try set apply_lora_to_output: False, and then as a quick hack add, re-enable requires_grad=True on the output weight parameters? This should mean we freeze base model weights -> setup adapter params (not for the output weight) -> unfreeze output weight -> train.

I unfreezed the output and also added lora weights to it (likely redundant), but am not getting very different results compared to keeping the output layer frozen. Main thing I was missing before was only saving adapter weights - we definitely need to figure out some way of storing the output weights so that the model can be recovered from just the adapter weights.

I think my classification code is actually working since train accuracy is increasing. I think I am overfitting to train since validation loss increases.

I went back and reimplemented prompt_completion (predicting next token instead of multi-class classification) and was able to reproduce a model with closer performance to my baseline. This means I'm either overfitting the classification version now that the label output space is so small OR I have some bug/misunderstanding about implementing the sequence classification. Will continue debugging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants