20
20
21
21
namespace {
22
22
23
- bool has_multicast_support ( ) {
23
+ bool device_has_multicast_support ( int device_idx ) {
24
24
#if defined(CUDART_SUPPORTS_MULTICAST)
25
- return c10::cuda::DriverAPI::get ()->cuMulticastCreate_ != nullptr ;
25
+ if (c10::utils::check_env (" TORCH_SYMM_MEM_DISABLE_MULTICAST" ) == true ) {
26
+ return false ;
27
+ }
28
+ // Multicast support requirements:
29
+ // - CUDA Runtime version >= 12030: Checked at compile time using
30
+ // CUDART_VERSION.
31
+ // - Driver version >= 535: Checked at runtime by verifying the existence of
32
+ // cuMulticastCreate_.
33
+ // - Device support: Determined by querying
34
+ // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED at runtime.
35
+ auto driver_api = c10::cuda::DriverAPI::get ();
36
+ int multicast_supported;
37
+ C10_CUDA_DRIVER_CHECK (driver_api->cuDeviceGetAttribute_ (
38
+ &multicast_supported,
39
+ CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,
40
+ device_idx));
41
+ return driver_api->cuMulticastCreate_ != nullptr && multicast_supported;
26
42
#else
27
43
return false ;
28
44
#endif
@@ -70,7 +86,16 @@ class IpcChannel {
70
86
cmsg->cmsg_len = CMSG_LEN (sizeof (int ));
71
87
cmsg->cmsg_level = SOL_SOCKET;
72
88
cmsg->cmsg_type = SCM_RIGHTS;
73
- memcpy (CMSG_DATA (cmsg), &fd, sizeof (fd));
89
+
90
+ if (fd != -1 ) {
91
+ // memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
92
+ std::copy (
93
+ reinterpret_cast <const char *>(&fd),
94
+ reinterpret_cast <const char *>(&fd) + sizeof (fd),
95
+ reinterpret_cast <char *>(CMSG_DATA (cmsg)));
96
+ } else {
97
+ msg.msg_controllen = 0 ;
98
+ }
74
99
75
100
TORCH_CHECK (
76
101
sendmsg (socket_, &msg, 0 ) > 0 , " Failed to send fd: " , strerror (errno));
@@ -94,6 +119,10 @@ class IpcChannel {
94
119
" Failed to receive fd: " ,
95
120
strerror (errno));
96
121
122
+ if (msg.msg_controllen == 0 ) {
123
+ return -1 ;
124
+ }
125
+
97
126
auto cmsg = CMSG_FIRSTHDR (&msg);
98
127
TORCH_CHECK (cmsg != NULL );
99
128
TORCH_CHECK (cmsg->cmsg_len == CMSG_LEN (sizeof (int )));
@@ -319,7 +348,7 @@ size_t CUDASymmetricMemory::get_signal_pad_size() {
319
348
}
320
349
321
350
bool CUDASymmetricMemory::has_multicast_support () {
322
- return :: has_multicast_support () ;
351
+ return mc_addr_ != nullptr ;
323
352
}
324
353
325
354
void * CUDASymmetricMemory::get_multicast_ptr () {
@@ -555,10 +584,11 @@ struct RendezvousRequest {
555
584
size_t block_size;
556
585
size_t buffer_size;
557
586
size_t signal_pad_offset;
587
+ bool has_multicast_support;
558
588
};
559
589
560
590
void validate_rendezvous_requests (
561
- const std::vector<RendezvousRequest> reqs,
591
+ const std::vector<RendezvousRequest>& reqs,
562
592
int world_size) {
563
593
TORCH_CHECK (reqs.size () == (size_t )world_size);
564
594
@@ -582,6 +612,92 @@ void validate_rendezvous_requests(
582
612
}
583
613
}
584
614
615
+ static bool check_group_multicast_support (
616
+ const std::vector<RendezvousRequest>& reqs) {
617
+ std::vector<size_t > ranks_with_multicast_support;
618
+ for (size_t r = 0 ; r < reqs.size (); ++r) {
619
+ if (reqs[r].has_multicast_support ) {
620
+ ranks_with_multicast_support.push_back (r);
621
+ }
622
+ }
623
+ if (ranks_with_multicast_support.size () == reqs.size ()) {
624
+ return true ;
625
+ } else {
626
+ // We don't expect this to happen. But we want to let the user to know if
627
+ // this happens.
628
+ if (ranks_with_multicast_support.size () != 0 ) {
629
+ LOG (WARNING)
630
+ << " Only a subset of ranks in the group has multicast support: "
631
+ << ranks_with_multicast_support << " (world_size=" << reqs.size ()
632
+ << " ). Skipping multicast initialization because this is unexpected." ;
633
+ }
634
+ return false ;
635
+ }
636
+ }
637
+
638
+ static void init_multicast_for_block (
639
+ HandleType& mc_handle,
640
+ void *& mc_addr,
641
+ const c10::intrusive_ptr<Block>& block,
642
+ IpcChannel& ipc_channel,
643
+ const std::vector<int >& pids,
644
+ const c10::intrusive_ptr<c10d::Store>& store,
645
+ int rank,
646
+ int world_size) {
647
+ #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) && \
648
+ defined (CUDART_SUPPORTS_MULTICAST)
649
+ auto driver_api = c10::cuda::DriverAPI::get ();
650
+ if (rank == 0 ) {
651
+ CUmulticastObjectProp mc_prop{};
652
+ mc_prop.numDevices = world_size;
653
+ mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
654
+ mc_prop.size = block->block_size ;
655
+
656
+ auto err = driver_api->cuMulticastCreate_ (&mc_handle, &mc_prop);
657
+ if (err != CUDA_SUCCESS) {
658
+ const char * err_str;
659
+ CUresult get_error_str_err = driver_api->cuGetErrorString_ (err, &err_str);
660
+ if (get_error_str_err != CUDA_SUCCESS) {
661
+ err_str = " unknown cuda driver error" ;
662
+ }
663
+ LOG (WARNING)
664
+ << " SymmetricMemory: cuMulticastCreate failed with: \" " << err_str
665
+ << " \" . Gracefully skipping multicast initialization. "
666
+ << " However, this is unexpected. Please report the issue on GitHub." ;
667
+ // Allow peers gracefully skip multicast initialization by sending -1
668
+ ipc_channel.broadcast_fds (rank, 0 , pids, -1 );
669
+ return ;
670
+ }
671
+
672
+ int mc_fd;
673
+ C10_CUDA_DRIVER_CHECK (driver_api->cuMemExportToShareableHandle_ (
674
+ &mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 ));
675
+ ipc_channel.broadcast_fds (rank, 0 , pids, mc_fd);
676
+ // Ref count is incremented as soon as SCM_RIGHTS send happens
677
+ close (mc_fd);
678
+ } else {
679
+ int mc_fd = ipc_channel.broadcast_fds (rank, 0 , pids, -1 );
680
+ if (mc_fd == -1 ) {
681
+ return ;
682
+ }
683
+ C10_CUDA_DRIVER_CHECK (driver_api->cuMemImportFromShareableHandle_ (
684
+ &mc_handle,
685
+ (void *)(uintptr_t )mc_fd,
686
+ CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
687
+ close (mc_fd);
688
+ }
689
+
690
+ // All rank adds their physical allocation to the multicast object
691
+ C10_CUDA_DRIVER_CHECK (
692
+ driver_api->cuMulticastAddDevice_ (mc_handle, block->device_idx ));
693
+ C10_CUDA_DRIVER_CHECK (driver_api->cuMulticastBindMem_ (
694
+ mc_handle, 0 , block->handle , 0 , block->block_size , 0 ));
695
+
696
+ map_block (&mc_addr, mc_handle, block->block_size , block->device_idx );
697
+ store_barrier (store, rank, world_size);
698
+ #endif
699
+ }
700
+
585
701
c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous (
586
702
void * ptr) {
587
703
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
@@ -610,7 +726,8 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
610
726
.pid = getpid (),
611
727
.block_size = block->block_size ,
612
728
.buffer_size = block->buffer_size ,
613
- .signal_pad_offset = block->signal_pad_offset };
729
+ .signal_pad_offset = block->signal_pad_offset ,
730
+ .has_multicast_support = device_has_multicast_support (block->device_idx )};
614
731
auto reqs = store_all_gather (store, rank, world_size, local_req);
615
732
validate_rendezvous_requests (reqs, world_size);
616
733
@@ -642,45 +759,13 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
642
759
store_barrier (store, rank, world_size);
643
760
close (block_fd);
644
761
645
- CUmemGenericAllocationHandle mc_handle{};
762
+ HandleType mc_handle{};
646
763
void * mc_addr = nullptr ;
647
- #if defined(CUDART_SUPPORTS_MULTICAST)
648
- // We have to further check if the driver supports multicast
649
- if (has_multicast_support ()) {
650
- // Rank 0 creates a multicast object and share it with peers
651
- if (rank == 0 ) {
652
- CUmulticastObjectProp mc_prop{};
653
- mc_prop.numDevices = world_size;
654
- mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
655
- mc_prop.size = block->block_size ;
656
-
657
- CUresult res = driver_api->cuMulticastCreate_ (&mc_handle, &mc_prop);
658
- TORCH_CHECK (res == CUDA_SUCCESS);
659
-
660
- int mc_fd;
661
- C10_CUDA_DRIVER_CHECK (driver_api->cuMemExportToShareableHandle_ (
662
- &mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 ));
663
- ipc_channel.broadcast_fds (rank, 0 , pids, mc_fd);
664
- // Ref count is incremented as soon as SCM_RIGHTS send happens
665
- close (mc_fd);
666
- } else {
667
- int mc_fd = ipc_channel.broadcast_fds (rank, 0 , pids, -1 );
668
- C10_CUDA_DRIVER_CHECK (driver_api->cuMemImportFromShareableHandle_ (
669
- &mc_handle,
670
- (void *)(uintptr_t )mc_fd,
671
- CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
672
- close (mc_fd);
673
- }
674
- // All rank adds their physical allocation to the multicast object
675
- C10_CUDA_DRIVER_CHECK (
676
- driver_api->cuMulticastAddDevice_ (mc_handle, block->device_idx ));
677
- C10_CUDA_DRIVER_CHECK (driver_api->cuMulticastBindMem_ (
678
- mc_handle, 0 , block->handle , 0 , block->block_size , 0 ));
679
-
680
- map_block (&mc_addr, mc_handle, block->block_size , block->device_idx );
681
- store_barrier (store, rank, world_size);
764
+ bool group_has_multicast_support = check_group_multicast_support (reqs);
765
+ if (group_has_multicast_support) {
766
+ init_multicast_for_block (
767
+ mc_handle, mc_addr, block, ipc_channel, pids, store, rank, world_size);
682
768
}
683
- #endif
684
769
685
770
// Initializing CUDASymmetricMemory with an allocation transfers its
686
771
// ownership to the CUDASymmetricMemory object. So that outstanding
@@ -713,8 +798,8 @@ bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
713
798
return block->symm_mem != nullptr ;
714
799
}
715
800
716
- bool CUDASymmetricMemoryAllocator::has_multicast_support () {
717
- return :: has_multicast_support ( );
801
+ bool CUDASymmetricMemoryAllocator::has_multicast_support (int device_idx ) {
802
+ return device_has_multicast_support (device_idx );
718
803
}
719
804
720
805
c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block (void * ptr) {
0 commit comments