Open
Description
When you pass in a torchtune.modules.tokenizers.ModelTokenizer
to torchtune.datasets.SFTDataset
as the model_transform
, pyre will complain because it is not typed as a Transform
. We don't run into any runtime errors because all our tokenizers usually double inherit from ModelTokenizer
and Transform
(e.g., see Llama3 Tokenizer), but if I were to put together a custom dataset, the typing would be incorrect:
def python_code_instructions_alpaca(tokenizer: ModelTokenizer) -> PackedDataset:
"""
Python code instruction-input-output pairs from iamtarun/python_code_instructions_18k_alpaca templated with Alpaca.
"""
ds = SFTDataset(
model_transform=tokenizer,
source="iamtarun/python_code_instructions_18k_alpaca",
message_transform=AlpacaToMessages(
train_on_input=False,
),
# pyre-ignore[6]: Incompatible parameter type
split="train",
)
if tokenizer.max_seq_len is None:
raise ValueError(
"PackedDataset requires a max_seq_len to be set on the tokenizer."
)
return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len, split_across_pack=False)
This is not an issue on our linters as we don't run pyre (I believe?), but this typing is technically incorrect and should be adjusted as some point.