This repo draws from the excellently written HALOs repo and DPO repo. We have preserved many design choices from the orignal.
This repo is to provide a generali framework for aligning large language models(LLMs) with the Transformers and Datasets from Huggingface. Unlike the TRL framework from Huggingface, we hereby incorporate the following features:
- Support for modifying the weights of training samples.
- Support for generating responses from the LLM policy.
- Support for getting feedback with online reward model or language model.
A diagram of the data flow from a high level is shown below:
In the following, we introduce the major components of the framework, but not by the order of the data flow.
The BatchFactory is a class that generates batches of data for training.
It takes a DatasLoader object as input and generates train/test batches to the Trainer of the following format:
{
'prompt': str, # prompt for the generated texts
'prompt_token_ids': tensor, # token ids of the prompt
'prompt_attention_mask': tensor, # attention mask of the prompt
'generation1': str, # text of prompt + 1st generation
'generation1_response_only': str, # text of the 1st generation only
'generation1_token_ids': tensor, # token ids of the 1st generation
'generation1_attention_mask': tensor, # attention mask of the 1st generation
'generation1_reward': float, # reward of the 1st generation
'generation1_weight': float, # weight of the 1st generation
'generation2': str, # text of prompt + 2nd generation
'generation2_response_only': str, # text of the 2nd generation only
'generation2_token_ids': tensor, # token ids of the 2nd generation
'generation2_attention_mask': tensor, # attention mask of the 2nd generation
'generation2_reward': float, # reward of the 2nd generation
'generation2_weight': float, # weight of the 2nd generation
}Note that the above items are not necessarily all included in the batch.
Below is a diagram of data in BatchFactory. Note that the final output batches
We hereby list the learning tasks and the corresponding batch items as well as the source of them:
-
Supervised fine-tuning: only
promptandgeneration1in the batch. Moreover, thegeneration1_rewardisNoneand thegeneration1_weightis always 1.0. -
Reward modelling in RLHF:
prompt,generation1, andgeneration2are all included. However, thegeneration1_rewardandgeneration2_rewardare bothNone. Thegeneration1_weightandgeneration2_weightare always 1.0. -
Reinforcement learning: only
promptandgeneration1are included, and thegeneration1_rewardis obtained from the online Reward Model. -
Offline Pointwise preference learning: only
promptandgeneration1are included. Thegeneration1_rewardis 1.0 if thegeneration1is a desired response, otherwise 0.0 to indicate thatgeneration1is undesired. Thegeneration1_weightis always 1.0. (Check out HALOs repo for the details of training models with pointwise desirable/undesirable feedback.)Both the generations and annotations are from the precollected and fixed
DatasLoader, thus this is an OFFLINE learning setup. -
Online Pointwise preference learning: same to the above offline pointwise preference learning, except that the
generation1is sampled from the LLM policy being training and thegeneration1_rewardis obtained from the online Annotator.The generations are from the LLM policy being trained and the feedbacks from online annotator, thus this is an ONLINE learning setup.*
-
Offline Pairwise preference learning:
prompt,generation1, andgeneration2are all included. Thegeneration1_rewardis 1.0 andgeneration2_rewardis 0.0 to indicate thatgeneration1is preferred overgeneration2. Thegeneration1_weightandgeneration2_weightare always 1.0.Like the offline pointwise preference learning setup, the generations and annotations are from the precollected and fixed
DatasLoader, thus this is an OFFLINE learning setup. -
Online Pairwise preference learning: same to the above offline pairwise preference learning, except that the
generation1andgeneration2are sampled from the LLM policy being training.generation1_rewardis 1.0 ifgeneration1is preferred overgeneration2by the online annotator, otherwise 0.0.The generations are from the LLM policy being trained and the feedbacks from online annotator, thus this is an ONLINE learning setup.*
The DatasLoader is a class that loads the original data from Huggingface hub.
Note that there might be a DISTRIBUTION SHIFT problem between the responses in DatasLoader to the responses generated by the LLM policy being trained.
To be more specific, suppose that the responses in DatasLoader were generated by a language model
Difference between DatasLoader and BatchFactory: In short, DatasLoader is a component of BatchFactory.
The DatasLoader yields pre-collected and pre-annotated responses from BatchFactory can either keep the responses and preferences from Annotator.
In the later case, only the prompts from DatasLoader are kept by BatchFactory.

