Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProcessGroupBabyNCCL: support multiple streams and use event on start #91

Merged
merged 1 commit into from
Jan 31, 2025

Conversation

d4l3k
Copy link
Member

@d4l3k d4l3k commented Jan 30, 2025

This overhauls how we handle ProcessGroupBabyNCCL cuda events and streams. It also enables the PG test suite for BabyNCCL similar to #89 which enabled it for BabyGloo.

Key changes:

  • BabyNCCL now is stream aware and will allocate one CUDA stream per stream_id in the parent process. This avoids issues when multiple streams are using NCCL and avoids any unintentional synchronization by using a single stream in the child process.
  • BabyNCCL uses a CUDA event to synchronize the start of the NCCL operation. Previously BabyNCCL would immediately start running the NCCL operation even if the launching stream hadn't completed yet.
  • Added _pg_tests for BabyNCCL
  • Unified the codepath and got rid of _BabyNCCLWork/WORK_CLASS
  • Detect when the subprocess has crashed and error out early
  • Propagate device_id to the child process

Test plan:

pytest torchft/process_group_test.py

test integration with torchtitan #82 with patch P1722192431

@d4l3k d4l3k requested review from fegin and H-Huang January 30, 2025 19:26
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 30, 2025
@d4l3k d4l3k force-pushed the d4l3k/baby_nccl_streams branch 3 times, most recently from 037f346 to 7d48c52 Compare January 30, 2025 23:56
@d4l3k d4l3k requested a review from kwen2501 January 30, 2025 23:56
@d4l3k d4l3k force-pushed the d4l3k/baby_nccl_streams branch from 7d48c52 to 0f8b44a Compare January 30, 2025 23:57
torchft/process_group.py Show resolved Hide resolved
torchft/process_group.py Outdated Show resolved Hide resolved
torchft/process_group.py Outdated Show resolved Hide resolved
if stream_key not in streams:
streams[stream_key] = torch.cuda.Stream(
device=stream_device
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to have zombie stream if there are multiple failures? Will this cause memory leakage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is streams are specific to the cuda context/process so this will be cleaned up just fine when it gets killed

@d4l3k d4l3k force-pushed the d4l3k/baby_nccl_streams branch from 0f8b44a to 30296e9 Compare January 31, 2025 00:33
@d4l3k d4l3k requested a review from fegin January 31, 2025 00:33
@d4l3k d4l3k merged commit 2b23017 into main Jan 31, 2025
6 checks passed
@d4l3k d4l3k deleted the d4l3k/baby_nccl_streams branch January 31, 2025 01:23
fegin added a commit to pytorch/torchtitan that referenced this pull request Feb 3, 2025
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new group joins. 

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
~The byproduct of issue 1 and issue 2: group 1 will continue to print out~
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618>
```

Fixed with pytorch/torchft#91 and several other fixes.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.
***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. 

Seems to be fixed, will need more tests.

**Issue 5:**
Hang will happen if using functional collective.
***How to reproduce?***
Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py`


**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```



[ghstack-poisoned]
fegin added a commit to pytorch/torchtitan that referenced this pull request Feb 3, 2025
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new group joins. 

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
~The byproduct of issue 1 and issue 2: group 1 will continue to print out~
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618>
```

Fixed with pytorch/torchft#91 and several other fixes.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.
***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. 

Seems to be fixed, will need more tests.

**Issue 5:**
Hang will happen if using functional collective.
***How to reproduce?***
Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py`


**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```



[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants