Skip to content

Make ModelTokenizer a Transform #1922

Open
@RdoubleA

Description

When you pass in a torchtune.modules.tokenizers.ModelTokenizer to torchtune.datasets.SFTDataset as the model_transform, pyre will complain because it is not typed as a Transform. We don't run into any runtime errors because all our tokenizers usually double inherit from ModelTokenizer and Transform (e.g., see Llama3 Tokenizer), but if I were to put together a custom dataset, the typing would be incorrect:

def python_code_instructions_alpaca(tokenizer: ModelTokenizer) -> PackedDataset:
    """
    Python code instruction-input-output pairs from iamtarun/python_code_instructions_18k_alpaca templated with Alpaca.
    """
    ds = SFTDataset(
        model_transform=tokenizer,
        source="iamtarun/python_code_instructions_18k_alpaca",
        message_transform=AlpacaToMessages(
            train_on_input=False,
        ),
        # pyre-ignore[6]: Incompatible parameter type
        split="train",
    )
    if tokenizer.max_seq_len is None:
        raise ValueError(
            "PackedDataset requires a max_seq_len to be set on the tokenizer."
        )
    return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len, split_across_pack=False)

This is not an issue on our linters as we don't run pyre (I believe?), but this typing is technically incorrect and should be adjusted as some point.

Metadata

Assignees

No one assigned

    Labels

    better engineeringTasks which help improve eng productivity e.g. building tools, cleaning up code, writing docs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions