Description
Context
Continuing the discussion from #2214.
When LLMs are trained on chat data, you typically need control over which messages the model is actually trained on and which ones are masked from the loss, and thereby do not affect loss calculation. The most common setting is to only train on the model responses in the dataset:
messages = [
Message(role="user", content="What is the greatest TV show of all time?", masked=True),
Message(role="assistant", content="The Sopranos, c'mon Tony", masked=False),
]
In this example, we don't need to train to model to predict the next tokens in the user prompt, only the answer to the actual prompt. Thus, the user message is masked. Another alternative is to train the model on all messages including the user prompt. This is the approach taken by Stanford Alpaca.
In torchtune, we provide this control to the user via train_on_input
, which is supported by all our built-in message transforms. This lets you easily configure the masking strategies from the config or command line:
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
or add dataset.train_on_input=True
to your launch command: tune run full_finetune_single_device --config my_config.yaml dataset.train_on_input=True
However, as brought up in #2214, there may be more common masking strategies that we should provide strong support for. The most requested one is to mask out all messages except the last assistant message in a multi-turn conversation. Currently, a custom message transform needs to be created to support this, which is not intuitive to beginner users.
Proposal
Instead of the binary train_on_input
parameter, we can replace this with a string parameter that specifies the masking strategy. We can add mask all but last since this was requested, and any other masking strategies in the future if the community requires it.
masking_strategy (str): string parameter indicating the masking strategy to use. Must be one of ``train_on_all``, ``train_on_assistant``, ``train_on_last``
- ``train_on_all``: both user and assistant messages are unmasked and included in loss calculation
- ``train_on_assistant``: user messages are masked, only assistant messages are unmasked and included in loss calculation
- ``train_on_last``: only the last assistant message is included in loss calculation
The names of the strategies can be different if folks feel strongly (like none, user, except_last if that's more clear to what is being masked, but I do like being explicit about what is being trained).
We will need to do the following:
- Create a
mask_messages
utility intorchtune/data/_messages.py
. This should take in a list of messages and the string masking strategy and set themasked
attributes on each message accordingly in place, returning nothing. - Update all of our built-in message transforms to take in a new parameter
masking_strategy
which will callmask_messages
after all the messages are processed. Make sure to update the docstrings. (Also don't forget to updateAlpacaToMessages
which is strangely not documented here...) - Keep the original behavior of
train_on_input
for BC, but under the hood maptrain_on_input=True
ortrain_on_input=False
to the correct strategy. Then file an issue to deprecate in the next release. - Update the documentation, replacing all references of
train_on_input
tomasking_strategy
. See here for a starting point: https://pytorch.org/torchtune/main/basics/message_transforms.html#configuring-message-transforms
Ask
Please share any other masking strategies you feel is needed for your use case. Also welcome any feedback on the approach. We welcome any willing contributors to work on this.