Skip to content

Dataloader question upon restart #58

Open
@cjolivier01

Description

@cjolivier01

Upon torchelastic restart, let's say with train_ddp.py, I haven't been able to find where, upon restart, the dataloader knows from where to resume, or whether it just starts from the "beginning" assuming that the randomness of the sampler will not duplicate its samples? I expect I am missing something, right?

Activity

d4l3k

d4l3k commented on Jan 8, 2025

@d4l3k
Member

@cjolivier01 this is something we want to improve by automatically tracking the dataloader step and fast forwarding as needed. It's feasible to do in torchft but we haven't gotten around to implementing it yet

The current recommended approach is to checkpoint the dataloader using torchdata's StatefulDataloader frequently. To avoid replaying any data you would need to checkpoint it on every step. https://pytorch.org/data/beta/torchdata.stateful_dataloader.html

To minimize overhead from that you could write a custom checkpoint implementation that calls dataloader .state_dict() periodically (say every 10 steps) and then only save the offset to disk (i.e. 5 steps). When reloading you would restore the StatefulDataloader from the checkpoint and then call next(iter) on it 5 times.

If you're interested in contributing better dataloader management I'm happy to set up some time to chat, otherwise, we'll get to it at some point :)

added
questionFurther information is requested
dataRelated to dataloading
enhancementNew feature or request
on Jan 8, 2025
d4l3k

d4l3k commented on Jan 8, 2025

@d4l3k
Member

also see #37

cjolivier01

cjolivier01 commented on Jan 9, 2025

@cjolivier01
Author

Thank you for the reply! At the moment, I am still trying to figure out how all the machinery works together. Still inm the torchelastic realm (which I have not used before), am I correct in assuming that no state is currently saved or restored wrt StatefulDataloader when running the train_ddp.py? Since there is no active checkpointing (right?) in train_ddp.py (although I do see the inline load_state and set_state that would seem to handle that as well as the external CheckpointServer test), that's not happenning, right? Just not sure if I am running it incorrectly, trying to hook it all up and induce faults, etc.

d4l3k

d4l3k commented on Jan 9, 2025

@d4l3k
Member

@cjolivier01 the example train_ddp.py doesn't do any persistent checkpointing of model state or dataloader -- there is a TODO where it should occur but we don't have it. For completeness we should probably add checkpointing or at least a dummy hook.

The load_state_dict and state_dict methods are used for live recovery (i.e. transfer from a healthy replica to a recovering replica) and does not create a persistent checkpoint of the dataloader, model or optimizer.

You likely want to add checkpoint logic to the end of the train step where the TODO is. Persistent checkpoints should probably be say every 100-1000 steps but dataloader you likely need to checkpoint on every step for now if you don't want to retrain on the same examples.

cjolivier01

cjolivier01 commented on Jan 10, 2025

@cjolivier01
Author

I am trying to make torchft kick in, so I run two processes, as if they're on different nodes (they're on the same node, only --node-rank differs):

TORCHFT_MANAGER_PORT=29512 \
    TORCHFT_LIGHTHOUSE="http://localhost:29510" \
    torchrun \
    --master_port=29501 \
    --nnodes=1:2  \
    --nproc_per_node=1 \
    --max-restarts=3 \
    --rdzv-id=asdsddsded \
    --rdzv-backend=c10d \
    --rdzv-endpoint=localhost \
    --node_rank=0 \
    ./train_ddp.py

..and

TORCHFT_MANAGER_PORT=29512 \
    TORCHFT_LIGHTHOUSE="http://localhost:29510" \
    torchrun \
    --master_port=29501 \
    --nnodes=1:2  \
    --nproc_per_node=1 \
    --max-restarts=3 \
    --rdzv-id=asdsddsded \
    --rdzv-backend=c10d \
    --rdzv-endpoint=localhost \
    --node_rank=1 \
    ./train_ddp.py

...and the latter, I have it exit during a training step before optimizer.step(), then what I get on rank 0 is a RumtimeError (timeout) error in Manager's self._client.should_commit() and this throws the process out into exit. Is this the expecte dbehavior? does not recover. I must be doing something wrong, correct? is there a granularity I am not understanding? (i.e. a full replica group is expected to fail?)

d4l3k

d4l3k commented on Jan 10, 2025

@d4l3k
Member

@cjolivier01 you want the replica groups to not be part of the same torchelastic instance. These are the commands I use to run locally:

torchft_lighthouse --min_replicas 2 --join_timeout_ms 1000

CUDA_VISIBLE_DEVICES=0 TORCHFT_MANAGER_PORT=29512 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29502 --nnodes 1 --nproc_per_node 1 --max-restarts 10 train_ddp.py

CUDA_VISIBLE_DEVICES=1 TORCHFT_MANAGER_PORT=29513 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 --max-restarts 10 train_ddp.py

The elastic and manager ports should be different between replica groups.

Once #67 lands you won't need to specify the manager port at all

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    dataRelated to dataloadingenhancementNew feature or requestquestionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Dataloader question upon restart · Issue #58 · pytorch/torchft