Skip to content

Commit a1f4b97

Browse files
committed
Address PR #771 review feedback, add pip install and docs
Review feedback (chhwang): - TorchCommMSCCLPP::init(): replace raw cudaSetDevice with RAII CudaDeviceGuard to restore previous device on return/exception - TorchCommMSCCLPP::init(): remove redundant cudaGetDevice call, use device_.index() directly for compute capability queries - Add pip install support via separate mscclpp-torchcomms package with pyproject.toml, scikit-build-core, and auto-discovery of backend .so - docs/quickstart.md: add tested version table Review feedback (Copilot bot): - TorchCommMSCCLPPBootstrap: add "_" delimiter between name and counter in store key to prevent collisions, make counter_ std::atomic<int> - TorchCommMSCCLPP::finalize(): wrap cudaStreamSynchronize and cudaStreamDestroy with MSCCLPP_CUDATHROW for error surfacing - All 4 supported collectives: replace tensor.contiguous() with TORCH_CHECK(tensor.is_contiguous()) to prevent silently dropping results for non-contiguous tensors - CMakeLists.txt: replace manual glog search with find_package(glog REQUIRED) for consistency with codebase conventions Rename and documentation: - Rename python/mscclpp_torchcomm to python/mscclpp_torchcomms for consistency with the torchcomms library naming - Add docs/torchcomms.md: standalone doc covering architecture, algorithm selection, user-defined algorithms, testing, benchmarks, limitations, and troubleshooting - Slim down quickstart.md TorchComms section to brief snippet + link - Add torchcomms entry to docs/index.rst - Add import mscclpp_torchcomms to all test/benchmark files for automatic backend .so discovery (no env var needed)
1 parent db92aee commit a1f4b97

23 files changed

Lines changed: 463 additions & 115 deletions

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,5 +276,5 @@ endif()
276276

277277
# TorchComms MSCCL++ backend
278278
if(MSCCLPP_BUILD_EXT_TORCHCOMMS)
279-
add_subdirectory(python/mscclpp_torchcomm)
279+
add_subdirectory(python/mscclpp_torchcomms)
280280
endif()

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ You can find the followings from this documentation.
1010

1111
- **Overview:** An overview of MSCCL++ and its features. :doc:`🔗 <overview>`
1212
- **Quick Start:** A guide to build, install, and run MSCCL++. :doc:`🔗 <quickstart>`
13+
- **TorchComms:** Using MSCCL++ as a TorchComms backend for PyTorch training. :doc:`🔗 <torchcomms>`
1314
- **MSCCL++ DSL:** A guide to get started with the MSCCL++ DSL. :doc:`🔗 <dsl>`
1415
- **Tutorials:** A step-by-step guide for GPU communication using MSCCL++. :doc:`🔗 <tutorials>`
1516
- **Programming Guide:** Advanced topics and best practices for using MSCCL++. :doc:`🔗 <programming_guide>`
@@ -22,6 +23,7 @@ You can find the followings from this documentation.
2223

2324
overview
2425
quickstart
26+
torchcomms
2527
dsl
2628
tutorials
2729
programming_guide

docs/quickstart.md

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -232,72 +232,20 @@ torchrun --nnodes=1 --nproc_per_node=8 your_script.py
232232

233233
MSCCL++ integrates with [TorchComms](https://github.com/meta-pytorch/torchcomms), enabling PyTorch users to use MSCCL++ collectives through the TorchComms API. This is the recommended way to use MSCCL++ in PyTorch training for mixed-backend setups (e.g., MSCCL++ for allreduce, NCCL for broadcast/barrier).
234234

235-
#### Building
236-
237-
Prerequisites: PyTorch, pybind11, and [torchcomms](https://github.com/meta-pytorch/torchcomms) (`pip install --pre torchcomms`).
238-
239235
```bash
240-
$ mkdir -p build && cd build
241-
$ cmake -DCMAKE_BUILD_TYPE=Release \
242-
-DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON \
243-
..
244-
$ make -j$(nproc)
245-
$ cd ..
246-
```
247-
248-
This produces `_comms_mscclpp.*.so` in the build output. TorchComms discovers MSCCL++ via the `TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP` environment variable, where `MSCCLPP_BUILD` is your MSCCL++ build directory.
249-
250-
#### Usage
251-
252-
```bash
253-
$ export TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP=$MSCCLPP_BUILD/lib/_comms_mscclpp.cpython-*.so
254-
$ torchrun --nproc_per_node=8 your_script.py
236+
$ python -m pip install ./python/mscclpp_torchcomms
255237
```
256238

257239
```python
258-
import torch
259240
import torchcomms
241+
import mscclpp_torchcomms # auto-registers the backend
260242
261-
# Create an MSCCL++ communicator
262-
comm = torchcomms.new_comm("mscclpp", torch.device(f"cuda:{local_rank}"), name="my_comm")
263-
264-
# Run allreduce (MSCCL++ automatically selects the best algorithm)
243+
comm = torchcomms.new_comm("mscclpp", device, name="my_comm")
265244
comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False)
266-
267-
# Cleanup
268245
comm.finalize()
269246
```
270247

271-
#### Supported Collectives
272-
273-
| Collective | Status | Notes |
274-
|---|---|---|
275-
| AllReduce | Supported | SUM, MIN. Auto-selects from ~10 native algorithms by message size and topology |
276-
| AllGather | Supported | Fullmesh algorithms |
277-
| ReduceScatter | Dispatched | Requires a registered DSL algorithm |
278-
| AllToAll | Dispatched | Requires a registered DSL algorithm |
279-
| All others | Not supported | Throws with guidance to use a separate NCCL/RCCL communicator |
280-
281-
#### Environment Variables
282-
283-
| Variable | Description |
284-
|---|---|
285-
| `TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP` | **Required.** Path to the built `_comms_mscclpp.*.so` module |
286-
287-
#### Running Tests
288-
289-
```bash
290-
$ export TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP=$MSCCLPP_BUILD/lib/_comms_mscclpp.cpython-*.so
291-
$ torchrun --nproc_per_node=8 test/torchcomms/test_correctness.py --all
292-
```
293-
294-
#### Running Benchmarks
295-
296-
```bash
297-
$ export TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP=$MSCCLPP_BUILD/lib/_comms_mscclpp.cpython-*.so
298-
$ torchrun --nproc_per_node=8 test/torchcomms/bench_torchcomms.py --collective allreduce --warmup 100 --iters 200
299-
$ torchrun --nproc_per_node=8 test/torchcomms/bench_torchcomms.py --collective allgather --warmup 100 --iters 200
300-
```
248+
See [TorchComms Integration](torchcomms.md) for full documentation including architecture, algorithm selection, user-defined algorithms, testing, benchmarks, and troubleshooting.
301249

302250
## Version Tracking
303251

0 commit comments

Comments
 (0)