Open
Description
Upon torchelastic restart, let's say with train_ddp.py, I haven't been able to find where, upon restart, the dataloader knows from where to resume, or whether it just starts from the "beginning" assuming that the randomness of the sampler will not duplicate its samples? I expect I am missing something, right?
Activity
d4l3k commentedon Jan 8, 2025
@cjolivier01 this is something we want to improve by automatically tracking the dataloader step and fast forwarding as needed. It's feasible to do in torchft but we haven't gotten around to implementing it yet
The current recommended approach is to checkpoint the dataloader using torchdata's StatefulDataloader frequently. To avoid replaying any data you would need to checkpoint it on every step. https://pytorch.org/data/beta/torchdata.stateful_dataloader.html
To minimize overhead from that you could write a custom checkpoint implementation that calls dataloader
.state_dict()
periodically (say every 10 steps) and then only save the offset to disk (i.e. 5 steps). When reloading you would restore the StatefulDataloader from the checkpoint and then callnext(iter)
on it 5 times.If you're interested in contributing better dataloader management I'm happy to set up some time to chat, otherwise, we'll get to it at some point :)
d4l3k commentedon Jan 8, 2025
also see #37
cjolivier01 commentedon Jan 9, 2025
Thank you for the reply! At the moment, I am still trying to figure out how all the machinery works together. Still inm the torchelastic realm (which I have not used before), am I correct in assuming that no state is currently saved or restored wrt StatefulDataloader when running the train_ddp.py? Since there is no active checkpointing (right?) in train_ddp.py (although I do see the inline load_state and set_state that would seem to handle that as well as the external CheckpointServer test), that's not happenning, right? Just not sure if I am running it incorrectly, trying to hook it all up and induce faults, etc.
d4l3k commentedon Jan 9, 2025
@cjolivier01 the example train_ddp.py doesn't do any persistent checkpointing of model state or dataloader -- there is a TODO where it should occur but we don't have it. For completeness we should probably add checkpointing or at least a dummy hook.
The
load_state_dict
andstate_dict
methods are used for live recovery (i.e. transfer from a healthy replica to a recovering replica) and does not create a persistent checkpoint of the dataloader, model or optimizer.You likely want to add checkpoint logic to the end of the train step where the TODO is. Persistent checkpoints should probably be say every 100-1000 steps but dataloader you likely need to checkpoint on every step for now if you don't want to retrain on the same examples.
cjolivier01 commentedon Jan 10, 2025
I am trying to make torchft kick in, so I run two processes, as if they're on different nodes (they're on the same node, only --node-rank differs):
TORCHFT_MANAGER_PORT=29512 \ TORCHFT_LIGHTHOUSE="http://localhost:29510" \ torchrun \ --master_port=29501 \ --nnodes=1:2 \ --nproc_per_node=1 \ --max-restarts=3 \ --rdzv-id=asdsddsded \ --rdzv-backend=c10d \ --rdzv-endpoint=localhost \ --node_rank=0 \ ./train_ddp.py
..and
TORCHFT_MANAGER_PORT=29512 \ TORCHFT_LIGHTHOUSE="http://localhost:29510" \ torchrun \ --master_port=29501 \ --nnodes=1:2 \ --nproc_per_node=1 \ --max-restarts=3 \ --rdzv-id=asdsddsded \ --rdzv-backend=c10d \ --rdzv-endpoint=localhost \ --node_rank=1 \ ./train_ddp.py
...and the latter, I have it exit during a training step before optimizer.step(), then what I get on rank 0 is a RumtimeError (timeout) error in Manager's self._client.should_commit() and this throws the process out into exit. Is this the expecte dbehavior? does not recover. I must be doing something wrong, correct? is there a granularity I am not understanding? (i.e. a full replica group is expected to fail?)
d4l3k commentedon Jan 10, 2025
@cjolivier01 you want the replica groups to not be part of the same torchelastic instance. These are the commands I use to run locally:
The elastic and manager ports should be different between replica groups.
Once #67 lands you won't need to specify the manager port at all