Skip to content

Commit c907c77

Browse files
fduwjjpytorchmergebot
authored andcommitted
[c10d][Sym mem] Make nccl backend full fledged with nccl 2.28.9-1 (pytorch#168129)
(This PR will be rebased on pytorch#166174) (There are other PR which updates NCCL version: pytorch#168091) We did the following thing: 1. To add exchange of buffer ptr and signal pad ptr via NCCL device API introduced in nccl 2.28. 2. With #1, we showed that the symmem from nccl backend works with existing one_shot_all_reduce kernel (Add a UT for it) 3. Add a simple put, put with signal, wait for signal and get. So that symmem's one side API works. 4. Show that symmem from nccl backend works with traditional c10d collective as well in UT. 5. Stored DevComm inside symmetric memory so that users can access to it for customized kernel. Resolves pytorch#167682 Pull Request resolved: pytorch#168129 Approved by: https://github.com/kwen2501, https://github.com/ngimel, https://github.com/atalman
1 parent b626cc1 commit c907c77

File tree

8 files changed

+592
-27
lines changed

8 files changed

+592
-27
lines changed

.ci/pytorch/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ _run_symm_mem_tests() {
373373
time python test/run_test.py --include distributed/test_symmetric_memory.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
374374
time python test/run_test.py --include distributed/test_nvshmem.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
375375
time python test/run_test.py --include distributed/test_nvshmem_triton.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
376-
time python test/run_test.py --include distributed/test_nccl.py $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
376+
time python test/run_test.py --include distributed/test_nccl.py -k NCCLSymmetricMemoryTest $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
377377
assert_git_not_dirty
378378
}
379379

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,7 @@ libtorch_cuda_distributed_extra_sources = [
768768
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp",
769769
"torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp",
770770
"torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu",
771+
"torch/csrc/distributed/c10d/symm_mem/nccl_extension.cu",
771772
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp",
772773
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
773774
"torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp",

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ if(USE_CUDA)
593593
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu
594594
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp
595595
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu
596+
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/nccl_extension.cu
596597
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp
597598
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
598599
)

test/distributed/test_nccl.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from torch.testing._internal.common_distributed import (
1616
MultiProcContinuousTest,
17+
requires_nccl_version,
1718
skip_if_lt_x_gpu,
1819
)
1920
from torch.testing._internal.common_utils import (
@@ -227,6 +228,7 @@ def device(self) -> torch.device:
227228

228229
@skip_but_pass_in_sandcastle_if(TEST_WITH_ROCM, "Skip NCCL tests for ROCm")
229230
@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
231+
@requires_nccl_version((2, 27), "NCCL Symmetric Memory support from nccl 2.27")
230232
@skip_if_lt_x_gpu(2)
231233
def test_nccl_symmem_alloc(self):
232234
symm_mem.set_backend("NCCL")
@@ -250,6 +252,114 @@ def foo():
250252
out = symm_mem.empty(numel, dtype=dtype, device=self.device)
251253
symm_mem.rendezvous(out, group=group_name)
252254

255+
@skip_but_pass_in_sandcastle_if(TEST_WITH_ROCM, "Skip NCCL tests for ROCm")
256+
@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
257+
@requires_nccl_version(
258+
(2, 28), "NCCL Symmetric Memory support device API from nccl 2.28"
259+
)
260+
@skip_if_lt_x_gpu(2)
261+
def test_nccl_symmem_collective(self):
262+
symm_mem.set_backend("NCCL")
263+
torch.cuda.set_device(self.rank)
264+
# Need this all_reduce to initialize NCCL communicator. Otherwise, the
265+
# test will hang. TODO: investigate how NCCLSymmetricMemory can
266+
# initialize NCCL communicator.
267+
c10d.all_reduce(torch.ones(1, device=self.device))
268+
group_name = c10d.group.WORLD.group_name
269+
symm_mem.enable_symm_mem_for_group(group_name)
270+
271+
dtype = torch.float
272+
numel = 1024
273+
274+
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
275+
symm_mem.rendezvous(out, group=group_name)
276+
c10d.all_reduce(out)
277+
torch.cuda.synchronize()
278+
self.assertEqual(
279+
out, torch.full_like(out, (self.world_size - 1) * self.world_size / 2)
280+
)
281+
282+
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
283+
symm_mem.rendezvous(inp, group=group_name)
284+
res = torch.ops.symm_mem.one_shot_all_reduce(inp, "sum", group_name)
285+
self.assertEqual(out, res)
286+
287+
@skip_but_pass_in_sandcastle_if(TEST_WITH_ROCM, "Skip NCCL tests for ROCm")
288+
@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
289+
@requires_nccl_version(
290+
(2, 28), "NCCL Symmetric Memory support device API from nccl 2.28"
291+
)
292+
@skip_if_lt_x_gpu(2)
293+
def test_nccl_symmem_put(self):
294+
symm_mem.set_backend("NCCL")
295+
torch.cuda.set_device(self.rank)
296+
# Need this all_reduce to initialize NCCL communicator. Otherwise, the
297+
# test will hang. TODO: investigate how NCCLSymmetricMemory can
298+
# initialize NCCL communicator.
299+
c10d.all_reduce(torch.ones(1, device=self.device))
300+
group_name = c10d.group.WORLD.group_name
301+
symm_mem.enable_symm_mem_for_group(group_name)
302+
303+
dtype = torch.float
304+
numel = 1024
305+
tensor = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
306+
# This is needed to make sure we don't get blocked the second time we call rendezvous
307+
# for the same tensor because it will be cached by that moment.
308+
symm_mem.rendezvous(tensor, group=group_name)
309+
signal_val = 5
310+
c10d.barrier()
311+
312+
if self.rank == 1:
313+
torch.ops.symm_mem.nccl_put_with_signal(tensor, signal_val, 0)
314+
elif self.rank == 0:
315+
torch.ops.symm_mem.nccl_wait_for_signal(tensor, signal_val)
316+
torch.testing.assert_close(
317+
tensor, torch.ones(numel, dtype=dtype, device=self.device)
318+
)
319+
c10d.barrier()
320+
if self.rank == 1:
321+
tensor *= 2
322+
torch.ops.symm_mem.nccl_put(tensor, 0)
323+
c10d.barrier()
324+
else:
325+
c10d.barrier()
326+
if self.rank == 0:
327+
torch.testing.assert_close(
328+
tensor, torch.ones(numel, dtype=dtype, device=self.device) * 2
329+
)
330+
331+
@skip_but_pass_in_sandcastle_if(TEST_WITH_ROCM, "Skip NCCL tests for ROCm")
332+
@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
333+
@skip_if_lt_x_gpu(2)
334+
def test_nccl_symmem_get(self):
335+
symm_mem.set_backend("NCCL")
336+
torch.cuda.set_device(self.rank)
337+
# Need this all_reduce to initialize NCCL communicator. Otherwise, the
338+
# test will hang. TODO: investigate how NCCLSymmetricMemory can
339+
# initialize NCCL communicator.
340+
c10d.all_reduce(torch.ones(1, device=self.device))
341+
group_name = c10d.group.WORLD.group_name
342+
symm_mem.enable_symm_mem_for_group(group_name)
343+
344+
dtype = torch.float
345+
numel = 1024
346+
tensor = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
347+
# This is needed to make sure we don't get blocked the second time we call rendezvous
348+
# for the same tensor because it will be cached by that moment.
349+
symm_mem.rendezvous(tensor, group=group_name)
350+
c10d.barrier()
351+
if self.rank == 0:
352+
torch.ops.symm_mem.nccl_get(tensor, 1)
353+
# TODO: remove after we have wait_signal
354+
c10d.barrier()
355+
torch.testing.assert_close(
356+
tensor, torch.ones(numel, dtype=dtype, device=self.device)
357+
)
358+
else:
359+
# handle.wait_signal(src_rank=0)
360+
# TODO: remove after we have wait_signal
361+
c10d.barrier()
362+
253363

254364
instantiate_device_type_tests(TestNCCL, globals(), only_for="cuda")
255365

0 commit comments

Comments
 (0)