Skip to content

Add an example using a library such as Hugging Face datasets #1504

@conceptofmind

Description

@conceptofmind

🚀 The feature

I think it would make sense to provide a "real-world" example for using the StatefulDataloader with a popular library such as Hugging Face datasets.

For example, the below example code uses IterableDatasets, StatefulDataloader, and Hugging Face streaming datasets together:

import os
from typing import Optional

import torch
import torch.distributed as dist
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
)

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert torch.cuda.is_available()
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()


class TokenizedDataset(IterableDataset, Stateful):
    def __init__(
        self,
        path: str,
        tokenizer: AutoTokenizer,
        name: Optional[str] = None,
        split: str = "train",
        streaming: bool = True,
        max_length: int = 2048,
        ddp_rank: int = 0,
        ddp_world_size: int = 1,
    ):
        dataset = load_dataset(path, name, split=split, streaming=streaming)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.train_dataset = split_dataset_by_node(
            dataset=dataset, rank=ddp_rank, world_size=ddp_world_size
        )

    def __iter__(self):
        for sample in iter(self.train_dataset):
            tokenized = self.tokenizer(
                sample["text"],
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_special_tokens_mask=True,
            )
            yield tokenized

    def load_state_dict(self, state_dict):
        assert "data" in state_dict
        self.train_dataset.load_state_dict(state_dict["data"])

    def state_dict(self):
        return {"data": self.train_dataset.state_dict()}


tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm_probability=0.5
)

tokenized_dataset = TokenizedDataset(
    path="Salesforce/wikitext",
    name="wikitext-2-v1",
    tokenizer=tokenizer,
    max_length=2048,
    ddp_rank=rank,
    ddp_world_size=world_size,
)

trainloader = StatefulDataLoader(
    dataset=tokenized_dataset,
    batch_size=64,
    num_workers=1,
    collate_fn=data_collator,
)

for step, batch in enumerate(trainloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    print(step)
    print(batch)
    if step == 2:
        dataloader_state_dict = trainloader.state_dict()
        print(dataloader_state_dict)
        break

print(f"restart from checkpoint")
trainloader.load_state_dict(dataloader_state_dict)
for step, batch in enumerate(trainloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    print(step)
    print(batch)
    if step == 2:
        dataloader_state_dict = trainloader.state_dict()
        print(dataloader_state_dict)
        break

destroy_process_group()

Motivation, pitch

If something like the above is both correct and useful, I would be happy to provide it as an example in the repository.

Alternatives

If not, just leaving it as a closed issue for others to reference in the future.

Additional context

I am willing to add more to this example as well if needed.

Thank you,

Enrico

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions