-
Notifications
You must be signed in to change notification settings - Fork 169
Open
Description
Hello,
I was wondering what the recommended way was to use Async Checkpointing with the Stateful Dataloader?
Does this seem correct:
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
import torch.distributed.checkpoint as dcp
class AsyncCheckpointer(Stateful):
def __init__(self, model, optimizer, dataloader):
self.model = model
self.optimizer = optimizer
self.dataloader = dataloader
def state_dict(self):
model_state_dict, optimizer_state_dict, dataloader_state_dict = get_state_dict(
self.model, self.optimizer, self.dataloader
)
return {
"model": model_state_dict,
"optim": optimizer_state_dict,
"dataloader": dataloader_state_dict
}
def load_state_dict(self, state_dict):
set_state_dict(
self.model,
self.optimizer,
self.dataloader,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
dataloader_state_dict=state_dict["dataloader"]
)
...
sampler = DistributedSampler(
num_replicas=world_size,
rank=rank,
shuffle=True,
)
trainloader = StatefulDataLoader(
batch_size=64,
sampler=sampler,
num_workers=2,
collate_fn=data_collator
)
...
checkpoint_future = None
trainloader.load_state_dict(state_dict)
for step, batch in enumerate(trainloader):
...
if checkpoint_future is not None:
checkpoint_future.result()
dataloader_state_dict = trainloader.state_dict()
state_dict = { "app": AsyncCheckpointer(model, optimizer, dataloader_state_dict) }
checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")It is unclear to me from the documentation how these two should be combined.
Thank you,
Enrico
Metadata
Metadata
Assignees
Labels
No labels