-
Notifications
You must be signed in to change notification settings - Fork 651
Staging SFT training #2148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Staging SFT training #2148
Conversation
|
Just confirm it works on the SFT dataset in "multiturn" format, with/ and w/o the tool, and on reasoning data. (Be aware that when it turns on llama-3.1 (pretrained) model dont have chat_template, and you can find chat_template for Llama-3.1-8B-Instruct in TO reproduce it on the
|
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, left some initial comments.
| return system_prompt, generate_prompt | ||
|
|
||
|
|
||
| def preprocess_data( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it used in SFTdataset init to infer the "system_prompt" and "generation_prompt", the length of these two prompts will be used to remove the repeated [SYSTEM] prompt in multiturn data.
e.g. if will do tokenzier(message_i) for message_i in row we will get
[SYSTEM …][USER …][SYSTEM …][ASSISTANT …][SYSTEM …][USER …]...
So we want to remove the redundant [SYSTEM …] since the second message.
(condition on index of _process_single_message func)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this done in extract_system_prompt_and_generation? My question is if preprocess_data is called anywhere in this PR.
the length of these two prompts will be used to remove the repeated [SYSTEM] prompt in multiturn data.
why do we want to remove [SYSTEM] since the second message? another way to ask: why do we keep the first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh mb, was accidently copied and did not delete it. no we dont need this preprocess_data.
We want to remove [SYSTEM] because of the way we tokenize the data.
The actual conversation is
conversation=[USER …][ASSISTANT …][USER …][ASSISTANT …][USER …][ASSISTANT …]
but we cannot do apply_chat_template(conversation).
instead we need to loop over each[USER …][ASSISTANT …] paries. So everytime we get lots of [SYSTEM] added unexpected in one conversatoin
[SYSTEM …][USER …][ASSISTANT …][SYSTEM …][USER …][ASSISTANT …][SYSTEM …][USER …][ASSISTANT …]
from second nessage
torchtitan/components/tokenizer.py
Outdated
| self.pad_id = None # only used for SFT | ||
| self.pad_token = None # only used for SFT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have a separate (sub)class for SFT tokenizer, and not depending on transformers.
See related comments #1802 (comment)
torchtitan/models/attention.py
Outdated
| return sliding_window_mod | ||
|
|
||
|
|
||
| def get_sft_padding_mask_mod(attention_masks: torch.Tensor) -> _mask_mod_signature: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to materialize an attention_masks to general mask_mod. It defeats the purpose of why using mask_mod (to represent mask sparsely).
Instead, we should use e.g. pad_id to infer the mask_mod here, similar to how we generate document mask using eos_id.
Similar for varlen metadata generation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made the refactor for flex/var len. Now we manually compute the length of each document segment and indices of the inddics of first token of each segments. We use them to reconstruct the varlen and flex attention data.
(We cannot use eos_id to infer the boundry of documents/sequences because eos_id in SFT maybe have differnt meanining, it might refer to the end of a sentence/query, and it may appear many times espically in muliturn SFT.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot use eos_id to infer the boundry of documents/sequences because eos_id in SFT maybe have differnt meanining, it might refer to the end of a sentence/query, and it may appear many times espically in muliturn SFT.
Then in theory we can do eod_id (end of document) when encoding SFT dataset? I somehow don't see if SFT data loading is worth this extra complexity.
cc @lkhphuc if you have feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly we are talking about the extra get_document_mask_mod_from_sequence_indices function, which I agree is a bit overkill.
For both pretraining or SFT, document masking just need a special token to mark boundary of separate sequences.
I think reusing the get_document_mask_mod is fine, since you/user can explicitly passes in whatever token is "end of sequence" token for your particular dataset.
(We cannot use eos_id to infer the boundry of documents/sequences because eos_id in SFT maybe have differnt meanining, it might refer to the end of a sentence/query, and it may appear many times espically in muliturn SFT.)
I actually never seen this practice because usually people create another token like <|end_of_turn|> or <|im_end|> or something, but I might be wrong.
In any case, since you are in charge of the sft dataloader here, nothing prevent us from enforcing <|eos|> as the actual document separator and uses different token for end of turn, thus reuse the same document masking function for flex attention.
Unrelated P/s: Sorry @tianyu-l for neglecting the vlm experiment for a while. End of year deadlines is a bit hectic but I will get back to continue it real soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean <|eos|> or eos_token?
eos_token will not work because some SFT/PT model eos_token does not means end of "documents"
<|eos|> can work, but we need to ask the user to explicitly provide a segment token, and we need to wrap around tokenizer and get_document_mask_mod to get the correct "segment token."
I actually never seen this practice because usually people create another token like <|end_of_turn|> or <|im_end|> or something, but I might be wrong.
I think that's the argument, if we assume users are well known on SFT so can set it properly, or we wanna make sure the training can be done correctly, even for users who dont know anything. (I am ok for both)
|
thx a lot for the feedback, will fix them later today/tomorrow |
| prompt_message = [{"role": "user", "content": prompt_message}] | ||
| response_message = [{"role": "assistant", "content": response_message}] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe noob question for sft:
Strings like "role" / "content" and "user" / "assistant" / "system" seem dataset specific.
- when we start finetuning a model, do we assume pretrained model knows that they are somewhat special? or do we simply rely on
apply_chat_templatefunction to map them to special str / token_ids? - It seems these strings are specific to this particular data loader -- is this a contract with HF tokenizer's
apply_chat_templatefunction? If so can we document them clearly? - do we ever want to finetune a checkpoint that is already finetuned? say llama3_8b_instruct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am also not an experts on SFT these are based on my understandin:
Strings like
"role"/"content"and"user"/"assistant"/"system"seem dataset-specific.
SFT datasets don’t have a single standard schema. Many modern instruction/chat datasets are represented as a list of messages ({role, content, ...}; sometimes also tools / function calls). Other datasets use task-specific columns (e.g., GSM8K-style question / answer). (both of these two foramts are supported in this PR).
We could optionally provide a “row adapter” / registration function so users can define how to map each dataset row into messages (or prompt/response).
Do we assume the pretrained model knows these are somewhat special?
In "tradition" nope. But in reality people already mixiing sft and reasoninng data in pretraining data, not surprised to see model knows some of these special tokens.
Do we simply rely on
apply_chat_templateto map them to special strings / token IDs?
i think so. and apply_chat_template is an option that the user can turn on or off.
It seems these strings are specific to this particular data loader — is this a contract with HF tokenizer’s
apply_chat_template? If so can we document them clearly?
Not HF' implementation, but these "role" thing depends on chat_template file. What i see from HF's implementation is, one will get an error if these "role" things are not defined in the template.
Do we ever want to fine-tune a checkpoint that is already fine-tuned (e.g.,
llama3_8b_instruct)?
Yes, depending on the goal. Multi-stage SFT (and sometimes model merging) is common: start from an instruct model and further specialize it on domain/style/tool-use data. Examples:
Really depends on what people want to do. Its common thought people do multi-stages SFT and do model mergining.
e.g., Ollmo3:

| return system_prompt, generate_prompt | ||
|
|
||
|
|
||
| def preprocess_data( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this done in extract_system_prompt_and_generation? My question is if preprocess_data is called anywhere in this PR.
the length of these two prompts will be used to remove the repeated [SYSTEM] prompt in multiturn data.
why do we want to remove [SYSTEM] since the second message? another way to ask: why do we keep the first?
|
should we filter out |
769e2e4 to
db9030e
Compare
No IMO. It's "token seens", not "token computed loss on" afterall. We just need to ensure padding is removed from token seen. |
I agree. "tokens seen" is used to measure pure infra. To track num of tokens trained with loss, it's probably worth its own metric, which doesn't sound urgent. |
wwwjn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work! Structurally, I would advocate if we start in an experiment folder, so we could highly increase readability for users. We could have something like
experiement/sft:tokenizer.py: Inherit from torchtitan'sHuggingFaceToenizer, or depend on AutoTokenizerdatasets.pyjob_config.py__init__.py: use TrainSpec to plug in dataloader, tokenizer here
Currently branching in main trainer / job_config would seems confusing to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this file under tests/unit_tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx. I moved all SFT stuff to the experiment folder now. (include test_mask.py) I will refactor the test function to make it as a real unit test if we decide to have this new mask mod
| assert seq_lens is not None and seq_start_indices is not None | ||
| return get_document_mask_mod_using_sequence_indices( | ||
| batch, seq_lens, seq_start_indices | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error-out if both eos_id and extra_inputs are empty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would this be more clear? I think eos_id should never be none?
if extra_inputs is None:
assert eos_id is not None
get_document_mask_mod_using_eos_id()
else:
assert seq_lens is not None and seq_start_indices is not None
get_document_mask_mod_using_sequence_indices()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep I think eos_id will never be none
| # batch is [b, s, h, d] shape | ||
|
|
||
| # this functioned can be called in two ways: | ||
| # 1. with eos_id: (we use this in the pre-training while eos_id is relible to get the document boundaries) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # 1. with eos_id: (we use this in the pre-training while eos_id is relible to get the document boundaries) | |
| # 1. with eos_id: (we use this in the pre-training while eos_id is reliable to get the document boundaries) |
| # 2. with seq_lengths and seq_start_indices: | ||
| # - seq_lengths: the length of each sequence | ||
| # - seq_start_indices: the start index of each sequence | ||
| # - we use this in the post-training while eos_id is not relible to get the document boundaries |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # - we use this in the post-training while eos_id is not relible to get the document boundaries | |
| # - we use this in the post-training while eos_id is not reliable to get the document boundaries |
| Args: | ||
| batch: Input batch tensor with shape [b, s, h, d] | ||
| eos_id: End-of-sequence token ID that marks document boundaries | ||
| extra_inputs: Extra inputs to the mask modifier function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if SFT needs to be compatible with PP. From the comments here: https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L463 The extra_inputs are not passed through PP stages, which means for PP rank 1, the extra_inputs will not contain seq_lens and seq_start_indices , will this function still work? cc @fegin
|
|
||
| def _process_one_row(self, row_dict: dict): | ||
| """Convert a dataset row into model-ready tensors with causal labels.""" | ||
| messages = _build_messages( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_build_message seems to be dataset specific, depending on the dataset's fromat and key names. I would suggest we move current _build_messages function out of the SFTDataset class to make SFTDataset class more general. In torchtitan, we plug-in dataset-specific processor here, see sample_processor: https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/hf_datasets/text_datasets.py#L35-L51
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes thats smart. thx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now, for each "kind" of SFT dataset, we will pass a message_builder to process each row. Three commonly used formats are pre-registered
|
Thanks again for the feedback ❤️. Sft code now moved to To run test on GSM8K: and to run test on |
| - Either tokenize the message["content"] directly, OR | ||
| - Apply the tokenizer's chat template (requires the chat template to be configured). | ||
| 3.2) If there are MULTIPLE messages, we call `_process_single_message` for each message. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering instead of removing system prompt later, would it be possible to add a flag to _process_single_message to control if we want to add the system prompt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a flag to control whether we wanna to apply the chat template or not.
I checked the apply_chattemple function from HF, it does not support "render other prompts except system prompt" if the system prompt was hard coded in the chat template.
But it brings another question, if the chat template does not automatically plug the system prompt, do we wanna to manually add it or not?
(From what I understand, "try to remove" system prompt seems to be a safer way to prevent these weird prompts from being plugged in the middle of conversations)
|
|
||
| class AutoTokenizer(BaseTokenizer): | ||
| def __init__(self, tokenizer_path: str): | ||
| self.tokenizer = HF_AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return document_mask | ||
|
|
||
|
|
||
| def get_document_mask_mod_using_sequence_indices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we could avoid the complexity here.
you mean <|eos|> or eos_token?
eos_token will not work because some SFT/PT model eos_token does not means end of "documents"
<|eos|> can work, but we need to ask the user to explicitly provide a segment token, and we need to wrap around tokenizer and get_document_mask_mod to get the correct "segment token."
I actually never seen this practice because usually people create another token like <|end_of_turn|> or <|im_end|> or something, but I might be wrong.
I think that's the argument, if we assume users are well known on SFT so can set it properly, or we wanna make sure the training can be done correctly, even for users who dont know anything. (I am ok for both)
Since we/you are in charge of the dataloader (packing behavior) and tokenizer (eod_id), what would we expect user to know?
I'd like to take a step back and try to ask about the purpose of this experiment:
- If the purpose of this experiment is to show an example of SFT training in torchtitan, then we should make it simple enough so that people can easily learn and fork. Do you agree that "making an eod_id per document is nothing different from having eos_id to determine the mask"? Or do you think eos_id is a much wider accepted thing than something like eod_id? I personally think they are the same thing, conceptually; e.g., we don't need to have
get_document_mask_mod_using_eos_id, we can just require pretraining data loader to also returnseq_lengthsandseq_start_indices. - If the purpose is to provide out-of-box solution for SFT, then it also depends on what dataset people have. Users may not start from HF datasets (or do they?).
I personally think we should start with 1.
I'm also very curious how other libraries are dealing with the complexity in SFT attention mask creation. Would appreciate if you could educate.
| ) | ||
|
|
||
|
|
||
| def create_varlen_metadata_for_document_using_sequence_indices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see benefit of support varlen attention for SFT right now, so prefer we delay this work to later.
| class SFTConfig: | ||
| split: str = "train" | ||
| """Split to use""" | ||
| dataset_subset: str | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this field?
| """Subset to use""" | ||
| stream_dataset: bool = True | ||
| """Whether to stream the dataset""" | ||
| is_multiturn: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this field used?
| "prompt_response": _build_prompt_response_messages_from_row_dict, | ||
| "question_answer": partial( | ||
| _build_prompt_response_messages_from_row_dict, | ||
| prompt_key="question", | ||
| response_key="answer", | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This separation sounds dataset specific. Without specifying dataset it's hard to tell if they are general.
Maybe consider registering a DATASETS field similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/hf_datasets/text_datasets.py#L36
if you have good candidate
|
|
||
|
|
||
| DATASET_MESSAGE_BUILDERS = { | ||
| "multi_turn": _build_multi_turn_messages_from_row_dict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to this one -- too abstract without seeing an example and hard to verify if it's correct
| build_lr_schedulers_fn=build_lr_schedulers, | ||
| build_dataloader_fn=build_sft_text_dataloader, | ||
| build_tokenizer_fn=build_auto_tokenizer, | ||
| build_loss_fn=build_cross_entropy_loss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may want to use https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/vlm/infra/loss.py#L100
otherwise with DP / CP it won't be correct.
| position_ids = F.pad( | ||
| position_ids, (0, target_length - sequence_length), value=0 | ||
| ) | ||
| elif sequence_length > target_length: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have to support the three modes? Is there a single mainstream approach we should take?
| else: | ||
| raise ValueError(f"Unknown truncation method {self.truncation}") | ||
|
|
||
| elif self.pad_mode == "greedy_packing": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly, do we have to support both packing and padding?





TL;DR
Adds SFT training to Torchtitan plus a small
greedy_packingaddition.Most of code borrowed from Verl and OpenRLHF
Changes
get_attention_masksto support SFT masks (only landed on Llama3 now)HFTokenizer, need to fix this laterinput_ids,labels(user tokens masked),attention_masks,position_idsTODO
Run
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh \ --training.running_sft_training \ --model.flavor=debugmodel_varlen_attn \ --training.dataset_path=openai/gsm8k \ --sft_data_config.dataset_subset=mainmore test
(

torch 2.10.0.dev20251124+cu129and i am using cudnn attention )compile does not work for no-padding because the seq-len for each training step keeps changing. We could pad the buffer to
seqlenwhen turning ongreedy_packing. (itspacking_on++) to make compile happy.