Open
Description
This is related to #116 but focused on Manager times
We don't have any profiling for manager operations and it would be great to add record_function
annotations to Manager so we can track torchft overhead via the PyTorch profiler.
There's a few key points we want to track:
- _async_quorum overall execution
https://github.com/pytorch/torchft/blob/main/torchft/manager.py#L436 - manager.quorum -- this is where we compute a new quorum
https://github.com/pytorch/torchft/blob/main/torchft/manager.py#L445 - send_checkpoint/recv_checkpoint -- this is called when a new replica group is recovering and can be significant time spent
https://github.com/pytorch/torchft/blob/main/torchft/manager.py#L493-L540 - process group reconfiguration -- this is called after a new quorum is found and requires P2P communication to setup the new processgroup
https://github.com/pytorch/torchft/blob/main/torchft/manager.py#L490 - should_commit -- this is called prior to the optimizer step to check whether all workers were successful in exchanging gradients
https://github.com/pytorch/torchft/blob/main/torchft/manager.py#L559
Relevant code in PT:
- record_function in PyTorch https://github.com/pytorch/pytorch/blob/b1a81a4a650a17988ef9289b1a7697a384c48b26/torch/autograd/profiler.py#L703
- https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
Testing:
To test we should add a new mocked test in manager_test.py where we enable the profiler and run through a step and make sure we have all the relevant areas logged.
Example manager mocked test: https://github.com/pytorch/torchft/blob/main/torchft/manager_test.py#L130-L164
You can also run with torchx with train_ddp.py
example:
torchx run -- --replicas 2
Metadata
Metadata
Assignees
Labels
No labels