Skip to content

Conversation

@rakkit
Copy link
Contributor

@rakkit rakkit commented Dec 12, 2025

TL;DR

Adds SFT training to Torchtitan plus a small greedy_packing addition.

Most of code borrowed from Verl and OpenRLHF

Changes

  • Added SFT dataset config to job config
  • Updated attention + modifying get_attention_masks to support SFT masks (only landed on Llama3 now)
  • Temporarily use HFTokenizer, need to fix this later
  • Added SFT dataset/dataloader that returns input_ids, labels (user tokens masked), attention_masks, position_ids

TODO

Run

CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh \
  --training.running_sft_training \
  --model.flavor=debugmodel_varlen_attn \
  --training.dataset_path=openai/gsm8k \
  --sft_data_config.dataset_subset=main
image

more test

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --training.running_sft_training \ --model.flavor=8B_varlen \ --training.dataset_path=openai/gsm8k \
 --sft_data_config.dataset_subset=main \
 --model.hf_assets_path="$Home/Meta-Llama-3.1-8B-Instruct/" \
 --training.local_batch_size=8 \ --activation_checkpoint.mode=full \
 --training.seq_len=2048 \ --training.steps=100 \
 --lr_scheduler.warmup_steps=10 \  --debug.seed 10 \
 --sft_data_config.pad_mode=right_padding \
 --metrics.enable_wandb \

(torch 2.10.0.dev20251124+cu129 and i am using cudnn attention )
image

W B Chart 12_12_2025, 8_32_32 PM

compile does not work for no-padding because the seq-len for each training step keeps changing. We could pad the buffer to seqlen when turning on greedy_packing. (its packing_on++) to make compile happy.

image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 12, 2025
@rakkit
Copy link
Contributor Author

rakkit commented Dec 12, 2025

Just confirm it works on the SFT dataset in "multiturn" format, with/ and w/o the tool, and on reasoning data.

(Be aware that when it turns on apply_chat_template, we are supposed to provide the chat_template.jinja in tokenzier's folder. There is no such file for "Meta-Llama-3.1-8B-Instruct". For test purposes, you can use, e.g,. tokenizer from Olmo-3-7B-Instruct.

llama-3.1 (pretrained) model dont have chat_template, and you can find chat_template for Llama-3.1-8B-Instruct in tokenizer_config.json


TO reproduce it on the multiturn format dataset Dolci-Instruct-SFT, without apply_chat_template

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --training.running_sft_training \
 --model.flavor=8B_flex \
 --model.hf_assets_path="$HOME//torchtitan_assets/Meta-Llama-3.1-8B-Instruct/" \
 --training.dataset_path=$HOME/sft/Dolci-Instruct-SFT/ \
 --sft_data_config.is_multiturn \
 --training.seq_len=2048 \
 --training.local_batch_size=2 \
image

TO reproduce it on the multiturn format dataset Dolci-Instruct-SFT, with apply_chat_template

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --training.running_sft_training \
 --model.flavor=8B_flex \
 --model.hf_assets_path="$HOME//torchtitan_assets/Meta-Llama-3.1-8B-Instruct/" \
 --training.dataset_path=$HOME/sft/Dolci-Instruct-SFT/ \
 --sft_data_config.is_multiturn \
 --training.seq_len=2048 \
 --training.local_batch_size=2 \
--sft_data_config.apply_chat_template \
--sft_data_config.ignore_input_ids_mismatch \

image

To reproduce it on tool datasets e.g, ReTool-SFT-multi-turn dataset

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --training.running_sft_training \
 --model.flavor=8B_flex \
 --model.hf_assets_path="$HOME/Meta-Llama-3.1-8B-Instruct/" \
 --training.dataset_path=$HOME/ReTool-SFT-multi-turn/ \
 --sft_data_config.is_multiturn --sft_data_config.apply_chat_template \
 --training.seq_len=2048 \
 --training.local_batch_size=2 \
 --sft_data_config.ignore_input_ids_mismatch \
  --sft_data_config.tools_key=tools \
image

To reproduce it on the reasoning dataset e.g., nvidia/Puzzle-KD-Nemotron-Post-Training-Dataset-v2/

CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --training.running_sft_training \
 --model.flavor=8B_flex \
 --model.hf_assets_path="$HOME/Meta-Llama-3.1-8B-Instruct/" \
 --training.dataset_path=$HOME/Puzzle-KD-Nemotron-Post-Training-Dataset-v2 \
 --sft_data_config.is_multiturn --sft_data_config.apply_chat_template \
 --sft_data_config.split=validation \
 --training.seq_len=2048 \
 --training.local_batch_size=2 \
 --sft_data_config.thinking_key=reasoning \
image

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left some initial comments.

return system_prompt, generate_prompt


def preprocess_data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it used in SFTdataset init to infer the "system_prompt" and "generation_prompt", the length of these two prompts will be used to remove the repeated [SYSTEM] prompt in multiturn data.
e.g. if will do tokenzier(message_i) for message_i in row we will get
[SYSTEM …][USER …][SYSTEM …][ASSISTANT …][SYSTEM …][USER …]...

So we want to remove the redundant [SYSTEM …] since the second message.
(condition on index of _process_single_message func)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this done in extract_system_prompt_and_generation? My question is if preprocess_data is called anywhere in this PR.

the length of these two prompts will be used to remove the repeated [SYSTEM] prompt in multiturn data.

why do we want to remove [SYSTEM] since the second message? another way to ask: why do we keep the first?

Copy link
Contributor Author

@rakkit rakkit Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh mb, was accidently copied and did not delete it. no we dont need this preprocess_data.

We want to remove [SYSTEM] because of the way we tokenize the data.
The actual conversation is
conversation=[USER …][ASSISTANT …][USER …][ASSISTANT …][USER …][ASSISTANT …]
but we cannot do apply_chat_template(conversation).
instead we need to loop over each[USER …][ASSISTANT …] paries. So everytime we get lots of [SYSTEM] added unexpected in one conversatoin

[SYSTEM …][USER …][ASSISTANT …][SYSTEM …][USER …][ASSISTANT …][SYSTEM …][USER …][ASSISTANT …]

from second nessage

Comment on lines 63 to 64
self.pad_id = None # only used for SFT
self.pad_token = None # only used for SFT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a separate (sub)class for SFT tokenizer, and not depending on transformers.

See related comments #1802 (comment)

return sliding_window_mod


def get_sft_padding_mask_mod(attention_masks: torch.Tensor) -> _mask_mod_signature:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to materialize an attention_masks to general mask_mod. It defeats the purpose of why using mask_mod (to represent mask sparsely).

Instead, we should use e.g. pad_id to infer the mask_mod here, similar to how we generate document mask using eos_id.

Similar for varlen metadata generation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made the refactor for flex/var len. Now we manually compute the length of each document segment and indices of the inddics of first token of each segments. We use them to reconstruct the varlen and flex attention data.

(We cannot use eos_id to infer the boundry of documents/sequences because eos_id in SFT maybe have differnt meanining, it might refer to the end of a sentence/query, and it may appear many times espically in muliturn SFT.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot use eos_id to infer the boundry of documents/sequences because eos_id in SFT maybe have differnt meanining, it might refer to the end of a sentence/query, and it may appear many times espically in muliturn SFT.

Then in theory we can do eod_id (end of document) when encoding SFT dataset? I somehow don't see if SFT data loading is worth this extra complexity.

cc @lkhphuc if you have feedback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly we are talking about the extra get_document_mask_mod_from_sequence_indices function, which I agree is a bit overkill.
For both pretraining or SFT, document masking just need a special token to mark boundary of separate sequences.
I think reusing the get_document_mask_mod is fine, since you/user can explicitly passes in whatever token is "end of sequence" token for your particular dataset.

(We cannot use eos_id to infer the boundry of documents/sequences because eos_id in SFT maybe have differnt meanining, it might refer to the end of a sentence/query, and it may appear many times espically in muliturn SFT.)

I actually never seen this practice because usually people create another token like <|end_of_turn|> or <|im_end|> or something, but I might be wrong.
In any case, since you are in charge of the sft dataloader here, nothing prevent us from enforcing <|eos|> as the actual document separator and uses different token for end of turn, thus reuse the same document masking function for flex attention.

Unrelated P/s: Sorry @tianyu-l for neglecting the vlm experiment for a while. End of year deadlines is a bit hectic but I will get back to continue it real soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean <|eos|> or eos_token?
eos_token will not work because some SFT/PT model eos_token does not means end of "documents"

<|eos|> can work, but we need to ask the user to explicitly provide a segment token, and we need to wrap around tokenizer and get_document_mask_mod to get the correct "segment token."

I actually never seen this practice because usually people create another token like <|end_of_turn|> or <|im_end|> or something, but I might be wrong.

I think that's the argument, if we assume users are well known on SFT so can set it properly, or we wanna make sure the training can be done correctly, even for users who dont know anything. (I am ok for both)

@rakkit
Copy link
Contributor Author

rakkit commented Dec 15, 2025

thx a lot for the feedback, will fix them later today/tomorrow

Comment on lines 86 to 87
prompt_message = [{"role": "user", "content": prompt_message}]
response_message = [{"role": "assistant", "content": response_message}]
Copy link
Contributor

@tianyu-l tianyu-l Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe noob question for sft:

Strings like "role" / "content" and "user" / "assistant" / "system" seem dataset specific.

  1. when we start finetuning a model, do we assume pretrained model knows that they are somewhat special? or do we simply rely on apply_chat_template function to map them to special str / token_ids?
  2. It seems these strings are specific to this particular data loader -- is this a contract with HF tokenizer's apply_chat_template function? If so can we document them clearly?
  3. do we ever want to finetune a checkpoint that is already finetuned? say llama3_8b_instruct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also not an experts on SFT these are based on my understandin:

Strings like "role" / "content" and "user" / "assistant" / "system" seem dataset-specific.

SFT datasets don’t have a single standard schema. Many modern instruction/chat datasets are represented as a list of messages ({role, content, ...}; sometimes also tools / function calls). Other datasets use task-specific columns (e.g., GSM8K-style question / answer). (both of these two foramts are supported in this PR).

We could optionally provide a “row adapter” / registration function so users can define how to map each dataset row into messages (or prompt/response).

Do we assume the pretrained model knows these are somewhat special?

In "tradition" nope. But in reality people already mixiing sft and reasoninng data in pretraining data, not surprised to see model knows some of these special tokens.

Do we simply rely on apply_chat_template to map them to special strings / token IDs?

i think so. and apply_chat_template is an option that the user can turn on or off.

It seems these strings are specific to this particular data loader — is this a contract with HF tokenizer’s apply_chat_template? If so can we document them clearly?

Not HF' implementation, but these "role" thing depends on chat_template file. What i see from HF's implementation is, one will get an error if these "role" things are not defined in the template.

Do we ever want to fine-tune a checkpoint that is already fine-tuned (e.g., llama3_8b_instruct)?

Yes, depending on the goal. Multi-stage SFT (and sometimes model merging) is common: start from an instruct model and further specialize it on domain/style/tool-use data. Examples:

Really depends on what people want to do. Its common thought people do multi-stages SFT and do model mergining.
e.g., Ollmo3:
image

nemotorn nano 2:
image

return system_prompt, generate_prompt


def preprocess_data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this done in extract_system_prompt_and_generation? My question is if preprocess_data is called anywhere in this PR.

the length of these two prompts will be used to remove the repeated [SYSTEM] prompt in multiturn data.

why do we want to remove [SYSTEM] since the second message? another way to ask: why do we keep the first?

@rakkit
Copy link
Contributor Author

rakkit commented Dec 15, 2025

should we filter out ids!= IGNORE_INDEX to get the "correct" ntokens_seen?

@rakkit rakkit force-pushed the staging_sft branch 2 times, most recently from 769e2e4 to db9030e Compare December 15, 2025 16:30
@lkhphuc
Copy link
Contributor

lkhphuc commented Dec 17, 2025

should we filter out ids!= IGNORE_INDEX to get the "correct" ntokens_seen?

No IMO. It's "token seens", not "token computed loss on" afterall. We just need to ensure padding is removed from token seen.

@tianyu-l
Copy link
Contributor

No IMO. It's "token seens", not "token computed loss on" afterall. We just need to ensure padding is removed from token seen.

I agree. "tokens seen" is used to measure pure infra. To track num of tokens trained with loss, it's probably worth its own metric, which doesn't sound urgent.

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! Structurally, I would advocate if we start in an experiment folder, so we could highly increase readability for users. We could have something like

  • experiement/sft:
    • tokenizer.py: Inherit from torchtitan's HuggingFaceToenizer , or depend on AutoTokenizer
    • datasets.py
    • job_config.py
    • __init__.py : use TrainSpec to plug in dataloader, tokenizer here

Currently branching in main trainer / job_config would seems confusing to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move this file under tests/unit_tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx. I moved all SFT stuff to the experiment folder now. (include test_mask.py) I will refactor the test function to make it as a real unit test if we decide to have this new mask mod

assert seq_lens is not None and seq_start_indices is not None
return get_document_mask_mod_using_sequence_indices(
batch, seq_lens, seq_start_indices
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error-out if both eos_id and extra_inputs are empty?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this be more clear? I think eos_id should never be none?

if extra_inputs is None:
    assert eos_id is not None
    get_document_mask_mod_using_eos_id()
else:
    assert seq_lens is not None and seq_start_indices is not None
    get_document_mask_mod_using_sequence_indices()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep I think eos_id will never be none

# batch is [b, s, h, d] shape

# this functioned can be called in two ways:
# 1. with eos_id: (we use this in the pre-training while eos_id is relible to get the document boundaries)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# 1. with eos_id: (we use this in the pre-training while eos_id is relible to get the document boundaries)
# 1. with eos_id: (we use this in the pre-training while eos_id is reliable to get the document boundaries)

# 2. with seq_lengths and seq_start_indices:
# - seq_lengths: the length of each sequence
# - seq_start_indices: the start index of each sequence
# - we use this in the post-training while eos_id is not relible to get the document boundaries
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# - we use this in the post-training while eos_id is not relible to get the document boundaries
# - we use this in the post-training while eos_id is not reliable to get the document boundaries

Args:
batch: Input batch tensor with shape [b, s, h, d]
eos_id: End-of-sequence token ID that marks document boundaries
extra_inputs: Extra inputs to the mask modifier function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if SFT needs to be compatible with PP. From the comments here: https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L463 The extra_inputs are not passed through PP stages, which means for PP rank 1, the extra_inputs will not contain seq_lens and seq_start_indices , will this function still work? cc @fegin


def _process_one_row(self, row_dict: dict):
"""Convert a dataset row into model-ready tensors with causal labels."""
messages = _build_messages(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_build_message seems to be dataset specific, depending on the dataset's fromat and key names. I would suggest we move current _build_messages function out of the SFTDataset class to make SFTDataset class more general. In torchtitan, we plug-in dataset-specific processor here, see sample_processor: https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/hf_datasets/text_datasets.py#L35-L51

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes thats smart. thx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, for each "kind" of SFT dataset, we will pass a message_builder to process each row. Three commonly used formats are pre-registered

@rakkit
Copy link
Contributor Author

rakkit commented Dec 18, 2025

Thanks again for the feedback ❤️. Sft code now moved to experiments folder, also simplified lots of configs.
I will try to edit and improve documents.

To run test on GSM8K:

CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh \
 --model.flavor=debugmodel_varlen_attn \
 --training.dataset question_answer \
 --training.dataset_path=openai/gsm8k \
 --sft_config.dataset_subset=main \
 --model.name=sft.llama3 --job.custom_config_module=torchtitan.experiments.sft.job_config \

and to run test on Dolci-Instruct-SFT with chat template, (llama3-8b-flex)


rm -rf ./outputs/*
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=4 ./run_train.sh \
 --model.name=sft.llama3 --job.custom_config_module=torchtitan.experiments.sft.job_config \
 --model.flavor=8B_flex --training.dataset=multi_turn --sft_config.apply_chat_template \
 --training.dataset_path=$HOME/Dolci-Instruct-SFT/ \
  --model.hf_assets_path=$HOME//Meta-Llama-3.1-8B-Instruct/ \
 --training.local_batch_size=8 \
 --activation_checkpoint.mode=full \
 --metrics.log_freq=10 \
 --training.seq_len=2048 --training.steps=100 --lr_scheduler.warmup_steps=10 \
 --debug.seed 10 \
 --compile.enable \

- Either tokenize the message["content"] directly, OR
- Apply the tokenizer's chat template (requires the chat template to be configured).
3.2) If there are MULTIPLE messages, we call `_process_single_message` for each message.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering instead of removing system prompt later, would it be possible to add a flag to _process_single_message to control if we want to add the system prompt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a flag to control whether we wanna to apply the chat template or not.
I checked the apply_chattemple function from HF, it does not support "render other prompts except system prompt" if the system prompt was hard coded in the chat template.

But it brings another question, if the chat template does not automatically plug the system prompt, do we wanna to manually add it or not?

(From what I understand, "try to remove" system prompt seems to be a safer way to prevent these weird prompts from being plugged in the middle of conversations)


class AutoTokenizer(BaseTokenizer):
def __init__(self, tokenizer_path: str):
self.tokenizer = HF_AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return document_mask


def get_document_mask_mod_using_sequence_indices(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could avoid the complexity here.

you mean <|eos|> or eos_token?
eos_token will not work because some SFT/PT model eos_token does not means end of "documents"

<|eos|> can work, but we need to ask the user to explicitly provide a segment token, and we need to wrap around tokenizer and get_document_mask_mod to get the correct "segment token."

I actually never seen this practice because usually people create another token like <|end_of_turn|> or <|im_end|> or something, but I might be wrong.

I think that's the argument, if we assume users are well known on SFT so can set it properly, or we wanna make sure the training can be done correctly, even for users who dont know anything. (I am ok for both)

Since we/you are in charge of the dataloader (packing behavior) and tokenizer (eod_id), what would we expect user to know?

I'd like to take a step back and try to ask about the purpose of this experiment:

  • If the purpose of this experiment is to show an example of SFT training in torchtitan, then we should make it simple enough so that people can easily learn and fork. Do you agree that "making an eod_id per document is nothing different from having eos_id to determine the mask"? Or do you think eos_id is a much wider accepted thing than something like eod_id? I personally think they are the same thing, conceptually; e.g., we don't need to have get_document_mask_mod_using_eos_id, we can just require pretraining data loader to also return seq_lengths and seq_start_indices.
  • If the purpose is to provide out-of-box solution for SFT, then it also depends on what dataset people have. Users may not start from HF datasets (or do they?).

I personally think we should start with 1.

I'm also very curious how other libraries are dealing with the complexity in SFT attention mask creation. Would appreciate if you could educate.

)


def create_varlen_metadata_for_document_using_sequence_indices(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see benefit of support varlen attention for SFT right now, so prefer we delay this work to later.

class SFTConfig:
split: str = "train"
"""Split to use"""
dataset_subset: str | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this field?

"""Subset to use"""
stream_dataset: bool = True
"""Whether to stream the dataset"""
is_multiturn: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this field used?

Comment on lines +64 to +69
"prompt_response": _build_prompt_response_messages_from_row_dict,
"question_answer": partial(
_build_prompt_response_messages_from_row_dict,
prompt_key="question",
response_key="answer",
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This separation sounds dataset specific. Without specifying dataset it's hard to tell if they are general.

Maybe consider registering a DATASETS field similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/hf_datasets/text_datasets.py#L36
if you have good candidate



DATASET_MESSAGE_BUILDERS = {
"multi_turn": _build_multi_turn_messages_from_row_dict,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to this one -- too abstract without seeing an example and hard to verify if it's correct

build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_sft_text_dataloader,
build_tokenizer_fn=build_auto_tokenizer,
build_loss_fn=build_cross_entropy_loss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to use https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/vlm/infra/loss.py#L100

otherwise with DP / CP it won't be correct.

position_ids = F.pad(
position_ids, (0, target_length - sequence_length), value=0
)
elif sequence_length > target_length:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to support the three modes? Is there a single mainstream approach we should take?

else:
raise ValueError(f"Unknown truncation method {self.truncation}")

elif self.pad_mode == "greedy_packing":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, do we have to support both packing and padding?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants