Skip to content

[RFC] Classification task fine-tuning #1464

Open
@SalmanMohammadi

Description

There has been some community appetite for classification tasks #1249 #1124. Incidentally, due to the use of classification models for RLHF, we already have some of the necessary components to support classification tasks. I think we're not too far off of supporting this. Concretely:

In-progress

  1. Land Re-organizing collation utils #1463 which will provide generic collation utils for classification datasets.
  2. Add support for classification datasets (I've mostly completed this, will put a PR up soon)

TODO - if this sounds interesting to you and you'd like to help out here please don't hesitate to comment!

Add support for a classification loss

I don't think we need a new recipe for this task, we mainly just need to refactor the label-slicing logic in _loss_step. Instead, we want to grab the scores from the classification model like we do in the RLHF recipe.

cc @felipemello1 for some thoughts here on compatibility with chunked CE.

Huggingface models also provide a reference for classification loss calculation https://github.com/huggingface/transformers/blob/38d58a4427c7c5093dc7bde45613d2bb0a5dea2c/src/transformers/models/llama/modeling_llama.py#L1409.

This should enable fine-tuning with our already defined classification models, our classification datasets (coming soon!) and generic collation utils.

Testing this will be the most complex step here. I haven't found many existing examples to benchmark against, so if you're reading this and know of some, please chime in! At the very least, we'll need to verify correctness by seeing sensible loss scores and eval outputs.

Optional - Generalize classification models
We currently only support Mistral and Llama classifier models. Model builders for these classifiers are only provided for binary classification (or regression) tasks. If you'd like to use another model for a classification task, we should discuss some sensible way to add generic support for converting an existing model to a classification model, without needing to define builders for each model.

Thoughts, comments, criticisms, appreciation, all welcome here.

Metadata

Assignees

No one assigned

    Labels

    community help wantedWe would love the community's help completing this issuediscussionStart a discussion

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions