Skip to content

Async Checkpointing with DCP and Stateful Dataloader #1502

@conceptofmind

Description

@conceptofmind

Hello,

I was wondering what the recommended way was to use Async Checkpointing with the Stateful Dataloader?

Does this seem correct:

from torchdata.stateful_dataloader import StatefulDataLoader
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
import torch.distributed.checkpoint as dcp

class AsyncCheckpointer(Stateful):
    def __init__(self, model, optimizer, dataloader):
        self.model = model
        self.optimizer = optimizer
        self.dataloader = dataloader

    def state_dict(self):
        model_state_dict, optimizer_state_dict, dataloader_state_dict = get_state_dict(
            self.model, self.optimizer, self.dataloader
        )
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict,
            "dataloader": dataloader_state_dict
        }
    
    def load_state_dict(self, state_dict):
        set_state_dict(
            self.model,
            self.optimizer,
            self.dataloader,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"],
            dataloader_state_dict=state_dict["dataloader"]
        )

...
sampler = DistributedSampler(
    num_replicas=world_size, 
    rank=rank, 
    shuffle=True, 
)

trainloader = StatefulDataLoader(
    batch_size=64,
    sampler=sampler,
    num_workers=2, 
    collate_fn=data_collator
)
...

checkpoint_future = None

trainloader.load_state_dict(state_dict)
for step, batch in enumerate(trainloader):
    ...
    if checkpoint_future is not None:
        checkpoint_future.result()

    dataloader_state_dict = trainloader.state_dict()
    state_dict = { "app": AsyncCheckpointer(model, optimizer, dataloader_state_dict) }
    checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")

It is unclear to me from the documentation how these two should be combined.

Thank you,

Enrico

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions