1414)
1515from torch .testing ._internal .common_distributed import (
1616 MultiProcContinuousTest ,
17+ requires_nccl_version ,
1718 skip_if_lt_x_gpu ,
1819)
1920from torch .testing ._internal .common_utils import (
@@ -227,6 +228,7 @@ def device(self) -> torch.device:
227228
228229 @skip_but_pass_in_sandcastle_if (TEST_WITH_ROCM , "Skip NCCL tests for ROCm" )
229230 @skip_but_pass_in_sandcastle_if (IS_WINDOWS , "NCCL doesn't support Windows" )
231+ @requires_nccl_version ((2 , 27 ), "NCCL Symmetric Memory support from nccl 2.27" )
230232 @skip_if_lt_x_gpu (2 )
231233 def test_nccl_symmem_alloc (self ):
232234 symm_mem .set_backend ("NCCL" )
@@ -250,6 +252,114 @@ def foo():
250252 out = symm_mem .empty (numel , dtype = dtype , device = self .device )
251253 symm_mem .rendezvous (out , group = group_name )
252254
255+ @skip_but_pass_in_sandcastle_if (TEST_WITH_ROCM , "Skip NCCL tests for ROCm" )
256+ @skip_but_pass_in_sandcastle_if (IS_WINDOWS , "NCCL doesn't support Windows" )
257+ @requires_nccl_version (
258+ (2 , 28 ), "NCCL Symmetric Memory support device API from nccl 2.28"
259+ )
260+ @skip_if_lt_x_gpu (2 )
261+ def test_nccl_symmem_collective (self ):
262+ symm_mem .set_backend ("NCCL" )
263+ torch .cuda .set_device (self .rank )
264+ # Need this all_reduce to initialize NCCL communicator. Otherwise, the
265+ # test will hang. TODO: investigate how NCCLSymmetricMemory can
266+ # initialize NCCL communicator.
267+ c10d .all_reduce (torch .ones (1 , device = self .device ))
268+ group_name = c10d .group .WORLD .group_name
269+ symm_mem .enable_symm_mem_for_group (group_name )
270+
271+ dtype = torch .float
272+ numel = 1024
273+
274+ out = symm_mem .empty (numel , dtype = dtype , device = self .device ).fill_ (self .rank )
275+ symm_mem .rendezvous (out , group = group_name )
276+ c10d .all_reduce (out )
277+ torch .cuda .synchronize ()
278+ self .assertEqual (
279+ out , torch .full_like (out , (self .world_size - 1 ) * self .world_size / 2 )
280+ )
281+
282+ inp = symm_mem .empty (numel , dtype = dtype , device = self .device ).fill_ (self .rank )
283+ symm_mem .rendezvous (inp , group = group_name )
284+ res = torch .ops .symm_mem .one_shot_all_reduce (inp , "sum" , group_name )
285+ self .assertEqual (out , res )
286+
287+ @skip_but_pass_in_sandcastle_if (TEST_WITH_ROCM , "Skip NCCL tests for ROCm" )
288+ @skip_but_pass_in_sandcastle_if (IS_WINDOWS , "NCCL doesn't support Windows" )
289+ @requires_nccl_version (
290+ (2 , 28 ), "NCCL Symmetric Memory support device API from nccl 2.28"
291+ )
292+ @skip_if_lt_x_gpu (2 )
293+ def test_nccl_symmem_put (self ):
294+ symm_mem .set_backend ("NCCL" )
295+ torch .cuda .set_device (self .rank )
296+ # Need this all_reduce to initialize NCCL communicator. Otherwise, the
297+ # test will hang. TODO: investigate how NCCLSymmetricMemory can
298+ # initialize NCCL communicator.
299+ c10d .all_reduce (torch .ones (1 , device = self .device ))
300+ group_name = c10d .group .WORLD .group_name
301+ symm_mem .enable_symm_mem_for_group (group_name )
302+
303+ dtype = torch .float
304+ numel = 1024
305+ tensor = symm_mem .empty (numel , dtype = dtype , device = self .device ).fill_ (self .rank )
306+ # This is needed to make sure we don't get blocked the second time we call rendezvous
307+ # for the same tensor because it will be cached by that moment.
308+ symm_mem .rendezvous (tensor , group = group_name )
309+ signal_val = 5
310+ c10d .barrier ()
311+
312+ if self .rank == 1 :
313+ torch .ops .symm_mem .nccl_put_with_signal (tensor , signal_val , 0 )
314+ elif self .rank == 0 :
315+ torch .ops .symm_mem .nccl_wait_for_signal (tensor , signal_val )
316+ torch .testing .assert_close (
317+ tensor , torch .ones (numel , dtype = dtype , device = self .device )
318+ )
319+ c10d .barrier ()
320+ if self .rank == 1 :
321+ tensor *= 2
322+ torch .ops .symm_mem .nccl_put (tensor , 0 )
323+ c10d .barrier ()
324+ else :
325+ c10d .barrier ()
326+ if self .rank == 0 :
327+ torch .testing .assert_close (
328+ tensor , torch .ones (numel , dtype = dtype , device = self .device ) * 2
329+ )
330+
331+ @skip_but_pass_in_sandcastle_if (TEST_WITH_ROCM , "Skip NCCL tests for ROCm" )
332+ @skip_but_pass_in_sandcastle_if (IS_WINDOWS , "NCCL doesn't support Windows" )
333+ @skip_if_lt_x_gpu (2 )
334+ def test_nccl_symmem_get (self ):
335+ symm_mem .set_backend ("NCCL" )
336+ torch .cuda .set_device (self .rank )
337+ # Need this all_reduce to initialize NCCL communicator. Otherwise, the
338+ # test will hang. TODO: investigate how NCCLSymmetricMemory can
339+ # initialize NCCL communicator.
340+ c10d .all_reduce (torch .ones (1 , device = self .device ))
341+ group_name = c10d .group .WORLD .group_name
342+ symm_mem .enable_symm_mem_for_group (group_name )
343+
344+ dtype = torch .float
345+ numel = 1024
346+ tensor = symm_mem .empty (numel , dtype = dtype , device = self .device ).fill_ (self .rank )
347+ # This is needed to make sure we don't get blocked the second time we call rendezvous
348+ # for the same tensor because it will be cached by that moment.
349+ symm_mem .rendezvous (tensor , group = group_name )
350+ c10d .barrier ()
351+ if self .rank == 0 :
352+ torch .ops .symm_mem .nccl_get (tensor , 1 )
353+ # TODO: remove after we have wait_signal
354+ c10d .barrier ()
355+ torch .testing .assert_close (
356+ tensor , torch .ones (numel , dtype = dtype , device = self .device )
357+ )
358+ else :
359+ # handle.wait_signal(src_rank=0)
360+ # TODO: remove after we have wait_signal
361+ c10d .barrier ()
362+
253363
254364instantiate_device_type_tests (TestNCCL , globals (), only_for = "cuda" )
255365
0 commit comments