Skip to content

Commit 70c1d4d

Browse files
committed
add ipc cache
1 parent 51a8663 commit 70c1d4d

2 files changed

Lines changed: 50 additions & 9 deletions

File tree

src/include/registered_memory.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct RegisteredMemory::Impl {
5555
bool isCuMemMapAlloc;
5656
TransportFlags transports;
5757
std::vector<TransportInfo> transportInfos;
58+
std::shared_ptr<void> peerHandle;
5859

5960
// Only used for IB transport
6061
std::unordered_map<Transport, std::unique_ptr<const IbMr>> ibMrMap;

src/registered_memory.cc

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <unistd.h>
88

99
#include <algorithm>
10+
#include <cstring>
1011
#include <mscclpp/gpu_utils.hpp>
1112

1213
#include "api.h"
@@ -28,7 +29,22 @@
2829
} \
2930
} while (false)
3031

32+
namespace std {
33+
template <>
34+
struct hash<cudaIpcMemHandle_t> {
35+
size_t operator()(const cudaIpcMemHandle_t& handle) const {
36+
std::string_view view(handle.reserved, sizeof(handle.reserved));
37+
return std::hash<std::string_view>{}(view);
38+
}
39+
};
40+
} // namespace std
41+
42+
inline bool operator==(const cudaIpcMemHandle_t& lhs, const cudaIpcMemHandle_t& rhs) {
43+
return std::memcmp(lhs.reserved, rhs.reserved, sizeof(lhs.reserved)) == 0;
44+
}
45+
3146
namespace {
47+
3248
CUmemAllocationHandleType getNvlsMemHandleType() {
3349
#if (CUDA_NVLS_API_AVAILABLE)
3450
if (mscclpp::detail::nvlsCompatibleMemHandleType & CU_MEM_HANDLE_TYPE_FABRIC) {
@@ -41,6 +57,37 @@ CUmemAllocationHandleType getNvlsMemHandleType() {
4157
#endif
4258
}
4359

60+
std::shared_ptr<void> getPeerMemoryHandle(cudaIpcMemHandle_t ipcHandle) {
61+
void* addr;
62+
auto deleter = [](void* p) {
63+
cudaError_t err = cudaIpcCloseMemHandle(p);
64+
if (err != cudaSuccess) {
65+
WARN("Failed to close CUDA IPC handle at pointer %p: %s", p, cudaGetErrorString(err));
66+
} else {
67+
INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", p);
68+
}
69+
};
70+
#if defined(__HIP_PLATFORM_AMD__)
71+
static std::unordered_map<cudaIpcMemHandle_t, std::weak_ptr<void>> peerMemoryHandleMap;
72+
std::mutex mutex;
73+
std::lock_guard<std::mutex> lock(mutex);
74+
auto it = peerMemoryHandleMap.find(ipcHandle);
75+
if (it != peerMemoryHandleMap.end()) {
76+
if (auto ptr = it->second.lock()) {
77+
return ptr;
78+
}
79+
throw mscclpp::Error("Failed to get peer memory handle, may already be closed", mscclpp::ErrorCode::InvalidUsage);
80+
}
81+
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&addr, ipcHandle, cudaIpcMemLazyEnablePeerAccess));
82+
std::shared_ptr<void> ptr = std::shared_ptr<void>(addr, deleter);
83+
peerMemoryHandleMap[ipcHandle] = ptr;
84+
return ptr;
85+
#else
86+
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&addr, ipcHandle, cudaIpcMemLazyEnablePeerAccess));
87+
return std::shared_ptr<void>(addr, deleter);
88+
#endif
89+
}
90+
4491
} // namespace
4592

4693
namespace mscclpp {
@@ -256,8 +303,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>::const_iterator& begin,
256303
throw Error("Unexpected error", ErrorCode::InternalError);
257304
#endif // !(CUDA_NVLS_API_AVAILABLE)
258305
} else if (getHostHash() == this->hostHash) {
259-
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
260-
this->data = static_cast<char*>(base) + entry.cudaIpcOffsetFromBase;
306+
this->peerHandle = getPeerMemoryHandle(entry.cudaIpcBaseHandle);
307+
this->data = static_cast<char*>(this->peerHandle.get()) + entry.cudaIpcOffsetFromBase;
261308
}
262309
}
263310
if (this->data != nullptr) {
@@ -291,13 +338,6 @@ RegisteredMemory::Impl::~Impl() {
291338
MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size));
292339
MSCCLPP_CULOG_WARN(cuMemRelease(handle));
293340
MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size));
294-
} else {
295-
cudaError_t err = cudaIpcCloseMemHandle(base);
296-
if (err != cudaSuccess) {
297-
WARN("Failed to close CUDA IPC handle at pointer %p: %s", base, cudaGetErrorString(err));
298-
} else {
299-
INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", base);
300-
}
301341
}
302342
data = nullptr;
303343
fileDesc = -1;

0 commit comments

Comments
 (0)