Skip to content

[RFC] Additional chat loss masking strategies #2261

Open
@RdoubleA

Description

Context

Continuing the discussion from #2214.

When LLMs are trained on chat data, you typically need control over which messages the model is actually trained on and which ones are masked from the loss, and thereby do not affect loss calculation. The most common setting is to only train on the model responses in the dataset:

messages = [
	Message(role="user", content="What is the greatest TV show of all time?", masked=True),
	Message(role="assistant", content="The Sopranos, c'mon Tony", masked=False),
]

In this example, we don't need to train to model to predict the next tokens in the user prompt, only the answer to the actual prompt. Thus, the user message is masked. Another alternative is to train the model on all messages including the user prompt. This is the approach taken by Stanford Alpaca.

In torchtune, we provide this control to the user via train_on_input, which is supported by all our built-in message transforms. This lets you easily configure the masking strategies from the config or command line:

dataset:
  _component_: torchtune.datasets.alpaca_dataset
  train_on_input: True

or add dataset.train_on_input=True to your launch command: tune run full_finetune_single_device --config my_config.yaml dataset.train_on_input=True

However, as brought up in #2214, there may be more common masking strategies that we should provide strong support for. The most requested one is to mask out all messages except the last assistant message in a multi-turn conversation. Currently, a custom message transform needs to be created to support this, which is not intuitive to beginner users.

Proposal

Instead of the binary train_on_input parameter, we can replace this with a string parameter that specifies the masking strategy. We can add mask all but last since this was requested, and any other masking strategies in the future if the community requires it.

masking_strategy (str): string parameter indicating the masking strategy to use. Must be one of ``train_on_all``, ``train_on_assistant``, ``train_on_last``
	- ``train_on_all``: both user and assistant messages are unmasked and included in loss calculation
	- ``train_on_assistant``: user messages are masked, only assistant messages are unmasked and included in loss calculation
	- ``train_on_last``: only the last assistant message is included in loss calculation

The names of the strategies can be different if folks feel strongly (like none, user, except_last if that's more clear to what is being masked, but I do like being explicit about what is being trained).

We will need to do the following:

  • Create a mask_messages utility in torchtune/data/_messages.py. This should take in a list of messages and the string masking strategy and set the masked attributes on each message accordingly in place, returning nothing.
  • Update all of our built-in message transforms to take in a new parameter masking_strategy which will call mask_messages after all the messages are processed. Make sure to update the docstrings. (Also don't forget to update AlpacaToMessages which is strangely not documented here...)
  • Keep the original behavior of train_on_input for BC, but under the hood map train_on_input=True or train_on_input=False to the correct strategy. Then file an issue to deprecate in the next release.
  • Update the documentation, replacing all references of train_on_input to masking_strategy. See here for a starting point: https://pytorch.org/torchtune/main/basics/message_transforms.html#configuring-message-transforms

Ask

Please share any other masking strategies you feel is needed for your use case. Also welcome any feedback on the approach. We welcome any willing contributors to work on this.

cc @EugenHotaj @tginart @ebsmothers @pbontrager

Metadata

Assignees

No one assigned

    Labels

    community help wantedWe would love the community's help completing this issuediscussionStart a discussionenhancementNew feature or requestgood first issueGood for newcomersrfcRequest for comments

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions