Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PGTransport: add inplace transport (3x faster) #119

Merged
merged 1 commit into from
Feb 25, 2025
Merged

Conversation

d4l3k
Copy link
Member

@d4l3k d4l3k commented Feb 24, 2025

This adds a new state_dict argument to PGTransport that when provided will give a state_dict to use for doing in-place tensor operations. This has been measured at ~3x faster.

Test plan:

Correctness:

pytest torchft/checkpointing/pg_transport_test.py

The improvements have been measured via pg_transport_bench.

For inplace operation we see it at ~3x faster (15s -> 4-5s) for 12GB with 3MB tensors size. The remaining overhead is primarily from torchft ProcessGroupBaby queue communication and not proportional to the size of the tensors.

Reducing this overhead requires some careful consideration and will be addressed in a follow up PR.

python torchft/checkpointing/pg_transport_bench.py --device cuda
python torchft/checkpointing/pg_transport_bench.py --device cuda --inplace

inplace

12GB/3MB (4k tensors)

INFO:torchft.checkpointing.pg_transport:send_checkpoint took 5.05398303642869s
INFO:torchft.checkpointing.pg_transport:recv_checkpoint took 5.637796577066183s

16KB/4B (4k tensors)

INFO:torchft.checkpointing.pg_transport:send_checkpoint took 4.909562937915325s
INFO:torchft.checkpointing.pg_transport:recv_checkpoint took 4.766054484993219s

48GB/12MB (4k tensors)

INFO:torchft.checkpointing.pg_transport:send_checkpoint took 18.53099210932851s
INFO:torchft.checkpointing.pg_transport:recv_checkpoint took 18.847779247909784s

not inplace

12GB/3MB (4k tensors)

INFO:torchft.checkpointing.pg_transport:send_checkpoint took 15.791493758559227s
INFO:torchft.checkpointing.pg_transport:recv_checkpoint took 17.16875096037984s

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 24, 2025
@d4l3k d4l3k marked this pull request as ready for review February 24, 2025 21:55
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


for w in work:
w.wait(timeout)

def recv_checkpoint(
self, src_rank: int, metadata: str, step: int, timeout: timedelta
) -> T:
state_dict = self._state_dict() if self._state_dict else {}
state_dict_leaves, _ = tree_flatten_with_path(state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is tree_flatten_with_path a new one? Is it going to give you the FQN?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been there for a few versions of torch -- it gives a path like:

(MappingKey(key='user'), MappingKey(key='optimizer'), MappingKey(key='state.layers.7.feed_forward.w2.weight.step'))

@d4l3k d4l3k merged commit 6fe4c8e into main Feb 25, 2025
6 checks passed
@d4l3k d4l3k deleted the d4l3k/pg_inplace branch February 25, 2025 00:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants