-
Notifications
You must be signed in to change notification settings - Fork 378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Classification Use Case #1502
base: main
Are you sure you want to change the base?
Conversation
…ointing, and some memory optimization working
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1502
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
self._loss_fn = config.instantiate(cfg.loss) | ||
|
||
############################ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 dataloaders right now so that we can do train and validation
# Since this is classification drop the output weights from base llama model | ||
# Otherwise load_state_dict would have key size mismatch | ||
############################ | ||
del base_model_state_dict["output.weight"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't directly load in the base llama model weights since they don't match the newly replaced mlp layer for the classification version. Current hack is to drop the weights before loading into the classification model.
partial( | ||
padded_collate, | ||
padding_idx=self._tokenizer.pad_id, | ||
# ignore_idx=self._loss_fn.ignore_index, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BCEWithLogitsLoss doesn't have ignore_index
running_loss = 0 | ||
num_tokens = 0 | ||
|
||
############################ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validation only works at end of epoch
|
||
logits = self._model(tokens, mask=mask, input_pos=input_pos) | ||
|
||
############################ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added step here where we pick the output for the eos token. This should ideally be built into the model itself - not sure what's the best way of doing that.
num_tokens = 0 | ||
t0 = time.perf_counter() | ||
|
||
############################ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
free up memory at end of training step
|
||
self._profiler.stop() | ||
|
||
############################ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validation loop
text_column: str = "text", | ||
label_column: str = "label", | ||
split: str = "train", | ||
classes: Optional[List[Any]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class/label mapping currently passed in as a an ordered list. Not sure what's the best practice for this (do we store a separate label mapping file somewhere?)
@@ -246,6 +246,29 @@ def tune_to_peft_adapter_config( | |||
return adapter_config | |||
|
|||
|
|||
# When we only save adapter weights, we need to convert back from peft format to tune format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use this when trying to load adapter config weights back. Will add sample code for what I got working later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in, resuming training in torchtune using adapter weights saved from a previous torchtune LoRA run? How come that wasn't working?
|
||
# Note: this needs to be set before wrapping with FSDP | ||
self.adapter_params = get_adapter_params(model) | ||
set_trainable_params(model, self.adapter_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we also need to set the final mlp layer of the base model to be trainable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - excellent catch. This is done through the apply_lora_to_output
recipe config arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! This is not a review yet, but just food for thought: We have been having conversations for a while about removing the output head from the LLM forward, and running it as a second step.
Do you think that if this was implemented, we would still need a different recipe? For example:
hidden_state = model(inputs)
output = model.output_head(hidden_state) #edited
loss = self.loss_fn(logits, label) #edited
If we could make it work without adding a bunch of if/else everywhere in the recipe, then we could leverage the existing ones instead of 2x the number of recipes in the library
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to clarify, did you mean loss = _loss_fn(output, label)
for the last line?
I think this could work to separate out the model forward logic for the classification use case (where we need to find the last token and pick the hidden state for it). Another option is to build the forward logic into the model itself like how huggingface transformers are implemented.
The other big change for classification/prompt completion use cases is the logic for processing the labels. The chat/instruct use cases shift all the logits by 1 to do next token prediction for the labels. Prompt completion/classification tasks have the label defined in the dataset and processed by the dataloader. We would need to figure out how to pull out that label processing logic as well to avoid duplicating recipes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to clarify, did you mean loss = _loss_fn(output, label) for the last line?
yes! Thanks!
We would need to figure out how to pull out that label processing logic as well to avoid duplicating recipes
That sounds manageable. It could be in the dataloader or some utility function.
I guess one 3rd issue is the checkpoint. I am not sure if this would require something different for classification vs generation.
So in summary, the main differences are:
- output head
- Label offset in generation
- checkpoint
Assuming you want to follow through, if you could adapt an existing generation recipe to support both generation/classification by making these changes, i think it would be easy to approve and a huge contribution to torchtune.
As a proof-of-concept, its fine to have some if/else and not be polished. And if it looks good, then we can brainstorm about how to polish it. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI there are a lot of changes coming to generation in #1424 which might be worth holding off for. I'm actually not sure inference for seq->label should be in the generation recipe which is beefed up for next-token prediction tasks + kv cacheing + compile. Not really stuff we need for classification tasks right?
I think the code we need for prediction here is pretty minimal - grab a prompt, run the model, then extract the logits at the last valid token (i.e. eos token) and activation function/argmax on those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the feedback! I have some immediate use cases I need to get working, so will be focused on getting a working proof of concept that can replicate my current finetuned model that I trained with huggingface. Once we reach parity, we can know that the logic is correct and start refactoring into a more generalized design.
My current plan is:
- Create an evaluate recipe for classification so that we have some way to batch deploy trained models. I'll use this to check whether my current models are reaching parity with my baseline relevance models.
- Fix training logic for classification. My validation metrics are showing increasing validation loss and very low validation accuracy right now, so I think I have a bug somewhere. Trying to hunt this down.
- Make checkpoint loadable by HF from_pretrained. There are some format conversion issues that I had to adjust to get from_pretrained working. I'm also getting different logits for some reason when running eval for torchtune loaded checkpoint vs HF loaded checkpoint, so more investigation needed here.
Once the above are good to go, then can start refactoring to cleaner code design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I think arriving at a sane baseline first of all is a good idea.
Make checkpoint loadable by HF from_pretrained. There are some format conversion issues that I had to adjust to get from_pretrained working.
I'm curious about this - what changes did you have to make?
I'm also getting different logits for some reason when running eval for torchtune loaded checkpoint vs HF loaded checkpoint, so more investigation needed here.
If you have any code to share here I'd be happy to look. I've tried matching logits between HF before, but also couldn't get it to match. There are several differences in the underlying transformer modelling code, and quite a lot of differences between generation. I haven't found this to be an issue yet, though. When testing against HF baselines, a good sanity check is to use less granular comparisons e.g. similar loss curves/similar evaluation metrics across identical setups.
Hey @qqlabs. Thanks so much for sharing your work. Currently I'm a bit full on with RLHF stuff, but I'm planning to revisit this properly in a couple weeks. Firstly, I've put up my draft work on classification datasets (#1424). Please take a look. On that branch, you can do something like: from torchtune.datasets._imdb import imdb_dataset
from torchtune.models.llama2 import llama2_tokenizer, llama2
tokenizer = llama2_tokenizer("./dummy/tokenizer.model")
ds = imdb_dataset(
tokenizer=tokenizer
) Then, to collate properly in your dataloader, you'd use our generic padding collation utils which just landed ( torchtune/torchtune/data/_collate.py Line 55 in 277fbf8
I see you're converting to one-hot out-of-the-box, but AFAIK you don't actually have to convert here since the relevant torch loss functions will accept class indices (https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html). Let me know if this isn't correct though. I'll add more comments to address your inference/validation stuff soon. If you need more direct support please don't hesitate to ping me on our Discord. |
@qqlabs check https://pytorch.org/torchtune/main/generated/torchtune.training.update_state_dict_for_classifier.html#torchtune.training.update_state_dict_for_classifier which might help your troubles on checkpointing. Essentially you just need to call this function once before any LoRA state dict validation and it should work. There's an example in the PPO recipe for non-LoRA state dict loading. |
I also haven't looked too closely at how you're predicting the labels. AFAIK the standard way (which is also what HF does) is to grab the logits from the model at the last valid token in the prompt - this will usually be the EOS token if it exists, or the last token otherwise, then use your activation functions / argmaxing etc on the logits for that token(s). I can point you to an example of this in the codebase - we actually used to have a separate utility just for this but it was lost in a refactor. |
Great, this is what I did here! This is currently copied from HF's implementation, but using the token before the eos token instead of the padding token which HF does. |
@@ -246,6 +246,29 @@ def tune_to_peft_adapter_config( | |||
return adapter_config | |||
|
|||
|
|||
# When we only save adapter weights, we need to convert back from peft format to tune format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in, resuming training in torchtune using adapter weights saved from a previous torchtune LoRA run? How come that wasn't working?
@@ -369,6 +368,83 @@ def lora_llama2_reward_7b( | |||
quantize_base=quantize_base, | |||
) | |||
|
|||
# Add llama multi classification model builders |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I definitely want to make it easier to use classification models in our codebase, this is high-up on my TODO list - your comment about replacing reward models is on-point. Soon you should just be able to use wrap any existing model builder without having to re-define builders.
… train recipe. this gives us distributed evaluation. numbers are not matching the validation scores during training though so not sure what's happening.
@SalmanMohammadi I just realized this function is a bit problematic for lora classification. From my understanding, the lora adapter values are added to the corresponding layers when merge the adapter into the base model. This would mess up the output layer since we're replacing the base model's output layer with random weights before merging in lora adapter values - thus messing up the output weights post training right? We need to keep track of the original output weights, which isn't done if we only save the adapter weights during training now. |
I'm sorry I'm not 100% sure I follow. If we're fine-tuning a language model for a classification task, here's roughly what I think should be happening:
Does that sound roughly right to you? I'm curious what you mean by keeping track of the original output weights. You've made me think of something important here though. Adapter based fine-tuning might not be suitable for the output projection which is randomly initialized. This means we might want to do something like adapter-based fine-tuning for the rest of the model, don't apply an adapter to the output projection, and just unfreeze the final layer so it trains from scratch. I've been doing some LoRA classification model fine-tuning, but I'm going from pertained classifier -> finetuned classifier so I haven't hit the same edges cases here. This is really useful. |
This is the part that I'm worried about. I think it's ok right now if we have the save_adapter_weights_only set to false since we would merge the weights with the trained model while calling save_checkpoint and can load the whole model back in as is without looking at the adapter weights. The issue is when we only save adapter weights as checkpoints since we lose the "random initialized output weights" that the adapter weights were trained on. Current steps to load the model with only adapter checkpoints:
I think a finetuned llama model with a completely random output layer would produce random results. I'm seeing exactly random accuracy (~25% for my 4 class classification) with my eval script with this setup. I did have the full model checkpointed for one of my runs and I am able to reproduce something closer to the validation metrics that I was seeing during my training process (which has some other bug since I'm getting like 9% accuracy...). Overall, we cannot only save the adapter weights for classification tasks - we need to also save the randomly initialized output weights (or trained output weights) somewhere (could be part of the adapter weights). |
Can you try set This won't work with |
I unfreezed the output and also added lora weights to it (likely redundant), but am not getting very different results compared to keeping the output layer frozen. Main thing I was missing before was only saving adapter weights - we definitely need to figure out some way of storing the output weights so that the model can be recovered from just the adapter weights. I think my classification code is actually working since train accuracy is increasing. I think I am overfitting to train since validation loss increases. I went back and reimplemented prompt_completion (predicting next token instead of multi-class classification) and was able to reproduce a model with closer performance to my baseline. This means I'm either overfitting the classification version now that the label output space is so small OR I have some bug/misunderstanding about implementing the sequence classification. Will continue debugging. |
Context
What is the purpose of this PR? Is it to
My use case is adapting torchtune to be able to do LlamaForSequenceClassification. This differs from a lot of assumptions made in the current finetuning code which assumes we're finetuning for chat/instruct purposes. Issues: #1249, #1464
A sample dataset that I want this to work on is Amazon's shopping queries dataset, where we try to classify query and products as exact, substitute, complement, and irrelevant. I concat the query text and product title to become the prompt to classify into one of the 4 classes.
Changelog
Key changes:
Multi-class classification training:
We essentially want to implement LlamaForSequenceClassification. I added a dataset, recipe, configs, and some other code changes to get this to work.
Inference using trained classification model:
Lora Adapter Issues:
The current code saves the lora adapter weights as a separate file from the base model weights. We hit missing key and key mismatch errors when we try to load in the two files using load_checkpoint.
In addition, I often only save the adapter weights to save on disk space. There is no clear guidance on merging the weights and loading it back in to do inference.
I added some functions to be able to load back in the adapter weights and merge it with a separate base model file. Will upload example of this later.
Validation Step:
Torchtune does not support validation step during training. I added it into the recipe in a hacky way. This currently only works at the end of the epoch since I hit deadlock issues when trying to validate in the middle of distributed training.
OOM Issues:
Reserved memory is not cleaned up at each training step and causes OOM issues. Added a cleanup step for each training step. Also found out I needed to reduce context length for my GPU of 40GB memory.
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models