Skip to content

Commit 14ff5fa

Browse files
author
Yifu Wang
committed
[SymmetricMemory] improve multicast initialization/fallback logic (pytorch#136577)
Fixes pytorch#136494 Currently, CUDASymmetricMemory::rendezvous() initializes a multicast address if multicast support is present. However, if we believe multicast support is present but cuMulticastCreate still fails for some reason, we do not fallback gracefully. - In addition to CUDART and driver version check, query CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED to determine multicast support for a rank/device. - Before initializing multicast for a block, ensure all ranks/devices have multicast support. - This is unlikely, but if cuMulticastCreate still fails on rank 0, print the corresponding driver error message as a warning, and gracefully skip multicast initialization for the block. - Introduced an environment variable (TORCH_SYMM_MEM_DISABLE_MULTICAST) to allow users to explicitly disable multicast support as a workaround. Pull Request resolved: pytorch#136577 Approved by: https://github.com/Chillee, https://github.com/eqy (cherry picked from commit d55eef5)
1 parent bc421d4 commit 14ff5fa

File tree

6 files changed

+141
-51
lines changed

6 files changed

+141
-51
lines changed

c10/cuda/driver_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
} while (0)
2020

2121
#define C10_LIBCUDA_DRIVER_API(_) \
22+
_(cuDeviceGetAttribute) \
2223
_(cuMemAddressReserve) \
2324
_(cuMemRelease) \
2425
_(cuMemMap) \

test/distributed/test_symmetric_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def requires_cuda_p2p_access():
5050
def requires_multicast_support():
5151
has_multicast_support = (
5252
torch.cuda.is_available()
53-
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA)
53+
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0)
5454
)
5555
return skip_but_pass_in_sandcastle_if(
5656
not has_multicast_support,

torch/csrc/distributed/c10d/CUDASymmetricMemory.cu

Lines changed: 130 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,25 @@
2020

2121
namespace {
2222

23-
bool has_multicast_support() {
23+
bool device_has_multicast_support(int device_idx) {
2424
#if defined(CUDART_SUPPORTS_MULTICAST)
25-
return c10::cuda::DriverAPI::get()->cuMulticastCreate_ != nullptr;
25+
if (c10::utils::check_env("TORCH_SYMM_MEM_DISABLE_MULTICAST") == true) {
26+
return false;
27+
}
28+
// Multicast support requirements:
29+
// - CUDA Runtime version >= 12030: Checked at compile time using
30+
// CUDART_VERSION.
31+
// - Driver version >= 535: Checked at runtime by verifying the existence of
32+
// cuMulticastCreate_.
33+
// - Device support: Determined by querying
34+
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED at runtime.
35+
auto driver_api = c10::cuda::DriverAPI::get();
36+
int multicast_supported;
37+
C10_CUDA_DRIVER_CHECK(driver_api->cuDeviceGetAttribute_(
38+
&multicast_supported,
39+
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,
40+
device_idx));
41+
return driver_api->cuMulticastCreate_ != nullptr && multicast_supported;
2642
#else
2743
return false;
2844
#endif
@@ -70,7 +86,16 @@ class IpcChannel {
7086
cmsg->cmsg_len = CMSG_LEN(sizeof(int));
7187
cmsg->cmsg_level = SOL_SOCKET;
7288
cmsg->cmsg_type = SCM_RIGHTS;
73-
memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
89+
90+
if (fd != -1) {
91+
// memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
92+
std::copy(
93+
reinterpret_cast<const char*>(&fd),
94+
reinterpret_cast<const char*>(&fd) + sizeof(fd),
95+
reinterpret_cast<char*>(CMSG_DATA(cmsg)));
96+
} else {
97+
msg.msg_controllen = 0;
98+
}
7499

75100
TORCH_CHECK(
76101
sendmsg(socket_, &msg, 0) > 0, "Failed to send fd: ", strerror(errno));
@@ -94,6 +119,10 @@ class IpcChannel {
94119
"Failed to receive fd: ",
95120
strerror(errno));
96121

122+
if (msg.msg_controllen == 0) {
123+
return -1;
124+
}
125+
97126
auto cmsg = CMSG_FIRSTHDR(&msg);
98127
TORCH_CHECK(cmsg != NULL);
99128
TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
@@ -319,7 +348,7 @@ size_t CUDASymmetricMemory::get_signal_pad_size() {
319348
}
320349

321350
bool CUDASymmetricMemory::has_multicast_support() {
322-
return ::has_multicast_support();
351+
return mc_addr_ != nullptr;
323352
}
324353

325354
void* CUDASymmetricMemory::get_multicast_ptr() {
@@ -555,10 +584,11 @@ struct RendezvousRequest {
555584
size_t block_size;
556585
size_t buffer_size;
557586
size_t signal_pad_offset;
587+
bool has_multicast_support;
558588
};
559589

560590
void validate_rendezvous_requests(
561-
const std::vector<RendezvousRequest> reqs,
591+
const std::vector<RendezvousRequest>& reqs,
562592
int world_size) {
563593
TORCH_CHECK(reqs.size() == (size_t)world_size);
564594

@@ -582,6 +612,92 @@ void validate_rendezvous_requests(
582612
}
583613
}
584614

615+
static bool check_group_multicast_support(
616+
const std::vector<RendezvousRequest>& reqs) {
617+
std::vector<size_t> ranks_with_multicast_support;
618+
for (size_t r = 0; r < reqs.size(); ++r) {
619+
if (reqs[r].has_multicast_support) {
620+
ranks_with_multicast_support.push_back(r);
621+
}
622+
}
623+
if (ranks_with_multicast_support.size() == reqs.size()) {
624+
return true;
625+
} else {
626+
// We don't expect this to happen. But we want to let the user to know if
627+
// this happens.
628+
if (ranks_with_multicast_support.size() != 0) {
629+
LOG(WARNING)
630+
<< "Only a subset of ranks in the group has multicast support: "
631+
<< ranks_with_multicast_support << " (world_size=" << reqs.size()
632+
<< "). Skipping multicast initialization because this is unexpected.";
633+
}
634+
return false;
635+
}
636+
}
637+
638+
static void init_multicast_for_block(
639+
HandleType& mc_handle,
640+
void*& mc_addr,
641+
const c10::intrusive_ptr<Block>& block,
642+
IpcChannel& ipc_channel,
643+
const std::vector<int>& pids,
644+
const c10::intrusive_ptr<c10d::Store>& store,
645+
int rank,
646+
int world_size) {
647+
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) && \
648+
defined(CUDART_SUPPORTS_MULTICAST)
649+
auto driver_api = c10::cuda::DriverAPI::get();
650+
if (rank == 0) {
651+
CUmulticastObjectProp mc_prop{};
652+
mc_prop.numDevices = world_size;
653+
mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
654+
mc_prop.size = block->block_size;
655+
656+
auto err = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
657+
if (err != CUDA_SUCCESS) {
658+
const char* err_str;
659+
CUresult get_error_str_err = driver_api->cuGetErrorString_(err, &err_str);
660+
if (get_error_str_err != CUDA_SUCCESS) {
661+
err_str = "unknown cuda driver error";
662+
}
663+
LOG(WARNING)
664+
<< "SymmetricMemory: cuMulticastCreate failed with: \"" << err_str
665+
<< "\". Gracefully skipping multicast initialization. "
666+
<< "However, this is unexpected. Please report the issue on GitHub.";
667+
// Allow peers gracefully skip multicast initialization by sending -1
668+
ipc_channel.broadcast_fds(rank, 0, pids, -1);
669+
return;
670+
}
671+
672+
int mc_fd;
673+
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
674+
&mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
675+
ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
676+
// Ref count is incremented as soon as SCM_RIGHTS send happens
677+
close(mc_fd);
678+
} else {
679+
int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
680+
if (mc_fd == -1) {
681+
return;
682+
}
683+
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
684+
&mc_handle,
685+
(void*)(uintptr_t)mc_fd,
686+
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
687+
close(mc_fd);
688+
}
689+
690+
// All rank adds their physical allocation to the multicast object
691+
C10_CUDA_DRIVER_CHECK(
692+
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
693+
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
694+
mc_handle, 0, block->handle, 0, block->block_size, 0));
695+
696+
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
697+
store_barrier(store, rank, world_size);
698+
#endif
699+
}
700+
585701
c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
586702
void* ptr) {
587703
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
@@ -610,7 +726,8 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
610726
.pid = getpid(),
611727
.block_size = block->block_size,
612728
.buffer_size = block->buffer_size,
613-
.signal_pad_offset = block->signal_pad_offset};
729+
.signal_pad_offset = block->signal_pad_offset,
730+
.has_multicast_support = device_has_multicast_support(block->device_idx)};
614731
auto reqs = store_all_gather(store, rank, world_size, local_req);
615732
validate_rendezvous_requests(reqs, world_size);
616733

@@ -642,45 +759,13 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
642759
store_barrier(store, rank, world_size);
643760
close(block_fd);
644761

645-
CUmemGenericAllocationHandle mc_handle{};
762+
HandleType mc_handle{};
646763
void* mc_addr = nullptr;
647-
#if defined(CUDART_SUPPORTS_MULTICAST)
648-
// We have to further check if the driver supports multicast
649-
if (has_multicast_support()) {
650-
// Rank 0 creates a multicast object and share it with peers
651-
if (rank == 0) {
652-
CUmulticastObjectProp mc_prop{};
653-
mc_prop.numDevices = world_size;
654-
mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
655-
mc_prop.size = block->block_size;
656-
657-
CUresult res = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
658-
TORCH_CHECK(res == CUDA_SUCCESS);
659-
660-
int mc_fd;
661-
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
662-
&mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
663-
ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
664-
// Ref count is incremented as soon as SCM_RIGHTS send happens
665-
close(mc_fd);
666-
} else {
667-
int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
668-
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
669-
&mc_handle,
670-
(void*)(uintptr_t)mc_fd,
671-
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
672-
close(mc_fd);
673-
}
674-
// All rank adds their physical allocation to the multicast object
675-
C10_CUDA_DRIVER_CHECK(
676-
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
677-
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
678-
mc_handle, 0, block->handle, 0, block->block_size, 0));
679-
680-
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
681-
store_barrier(store, rank, world_size);
764+
bool group_has_multicast_support = check_group_multicast_support(reqs);
765+
if (group_has_multicast_support) {
766+
init_multicast_for_block(
767+
mc_handle, mc_addr, block, ipc_channel, pids, store, rank, world_size);
682768
}
683-
#endif
684769

685770
// Initializing CUDASymmetricMemory with an allocation transfers its
686771
// ownership to the CUDASymmetricMemory object. So that outstanding
@@ -713,8 +798,8 @@ bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
713798
return block->symm_mem != nullptr;
714799
}
715800

716-
bool CUDASymmetricMemoryAllocator::has_multicast_support() {
717-
return ::has_multicast_support();
801+
bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) {
802+
return device_has_multicast_support(device_idx);
718803
}
719804

720805
c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {

torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
102102
size_t get_alloc_size(void* ptr) override;
103103
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
104104
bool is_rendezvous_completed(void* ptr) override;
105-
bool has_multicast_support() override;
105+
bool has_multicast_support(int device_idx) override;
106106

107107
private:
108108
c10::intrusive_ptr<Block> find_block(void* ptr);

torch/csrc/distributed/c10d/SymmetricMemory.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,11 @@ c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
189189
return allocator->rendezvous(tensor.data_ptr());
190190
}
191191

192-
TORCH_API bool has_multicast_support(c10::DeviceType device_type) {
192+
TORCH_API bool has_multicast_support(
193+
c10::DeviceType device_type,
194+
int device_idx) {
193195
auto allocator = get_allocator(device_type);
194-
return allocator->has_multicast_support();
196+
return allocator->has_multicast_support(device_idx);
195197
}
196198
} // namespace symmetric_memory
197199
} // namespace c10d

torch/csrc/distributed/c10d/SymmetricMemory.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
8181
virtual size_t get_alloc_size(void* ptr) = 0;
8282
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
8383
virtual bool is_rendezvous_completed(void* ptr) = 0;
84-
virtual bool has_multicast_support() = 0;
84+
virtual bool has_multicast_support(int device_idx) = 0;
8585
};
8686

8787
C10_EXPORT bool is_finalizing();
@@ -154,6 +154,8 @@ TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
154154
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
155155
const at::Tensor& tensor);
156156

157-
TORCH_API bool has_multicast_support(c10::DeviceType device_type);
157+
TORCH_API bool has_multicast_support(
158+
c10::DeviceType device_type,
159+
int device_idx);
158160
} // namespace symmetric_memory
159161
} // namespace c10d

0 commit comments

Comments
 (0)