-
Notifications
You must be signed in to change notification settings - Fork 169
Open
Description
🚀 The feature
I think it would make sense to provide a "real-world" example for using the StatefulDataloader with a popular library such as Hugging Face datasets.
For example, the below example code uses IterableDatasets, StatefulDataloader, and Hugging Face streaming datasets together:
import os
from typing import Optional
import torch
import torch.distributed as dist
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import (
AutoTokenizer,
DataCollatorForLanguageModeling,
)
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert torch.cuda.is_available()
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
class TokenizedDataset(IterableDataset, Stateful):
def __init__(
self,
path: str,
tokenizer: AutoTokenizer,
name: Optional[str] = None,
split: str = "train",
streaming: bool = True,
max_length: int = 2048,
ddp_rank: int = 0,
ddp_world_size: int = 1,
):
dataset = load_dataset(path, name, split=split, streaming=streaming)
self.tokenizer = tokenizer
self.max_length = max_length
self.train_dataset = split_dataset_by_node(
dataset=dataset, rank=ddp_rank, world_size=ddp_world_size
)
def __iter__(self):
for sample in iter(self.train_dataset):
tokenized = self.tokenizer(
sample["text"],
padding="max_length",
truncation=True,
max_length=self.max_length,
return_special_tokens_mask=True,
)
yield tokenized
def load_state_dict(self, state_dict):
assert "data" in state_dict
self.train_dataset.load_state_dict(state_dict["data"])
def state_dict(self):
return {"data": self.train_dataset.state_dict()}
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm_probability=0.5
)
tokenized_dataset = TokenizedDataset(
path="Salesforce/wikitext",
name="wikitext-2-v1",
tokenizer=tokenizer,
max_length=2048,
ddp_rank=rank,
ddp_world_size=world_size,
)
trainloader = StatefulDataLoader(
dataset=tokenized_dataset,
batch_size=64,
num_workers=1,
collate_fn=data_collator,
)
for step, batch in enumerate(trainloader):
batch = {k: v.to(device) for k, v in batch.items()}
print(step)
print(batch)
if step == 2:
dataloader_state_dict = trainloader.state_dict()
print(dataloader_state_dict)
break
print(f"restart from checkpoint")
trainloader.load_state_dict(dataloader_state_dict)
for step, batch in enumerate(trainloader):
batch = {k: v.to(device) for k, v in batch.items()}
print(step)
print(batch)
if step == 2:
dataloader_state_dict = trainloader.state_dict()
print(dataloader_state_dict)
break
destroy_process_group()Motivation, pitch
If something like the above is both correct and useful, I would be happy to provide it as an example in the repository.
Alternatives
If not, just leaving it as a closed issue for others to reference in the future.
Additional context
I am willing to add more to this example as well if needed.
Thank you,
Enrico
Metadata
Metadata
Assignees
Labels
No labels