Skip to content

Commit 59e95c7

Browse files
mscclpp_torchcomms: build _comms_nccl.so with bad_weak_ptr fix
Adds a separate _comms_nccl pybind11 module alongside the existing _comms_mscclpp module. The .so wraps the upstream torchcomms NCCL backend sources from build-torchcomm/_deps/torchcomms-src/comms/ torchcomms/nccl/, paired with a new csrc/NcclDynamicLoader.cpp that publishes the create_dynamic_loader_nccl entry point torchcomms's TorchCommFactory dlopen path requires. Why this exists: TorchCommFactory::create_generic_backend (TorchCommFactory.cpp) wraps the raw pointer returned by loader.new_comm() in a std::shared_ptr<TorchCommBackend>(rawBackendPtr, deleter). std::enable_shared_from_this<Y>'s internal weak_ptr is only initialized when the shared_ptr is constructed from a pointer *statically typed* as the derived class. Constructing shared_ptr<TorchCommBackend> from a pointer typed as TorchCommBackend* skips that machinery, so when TorchCommNCCL::createWork() later calls shared_from_this() it throws std::bad_weak_ptr — the very first all_reduce crashes, before any user code runs. Verified with a 6-line minimal repro (torchcomms.new_comm("nccl", ...).all_reduce(...)) that crashes with the upstream-only build and now succeeds with this fix. How the fix works: NcclDynamicLoader.cpp::new_comm_impl creates TorchCommNCCL via std::make_shared so the weak_ptr is properly populated, then stashes the shared_ptr in a static keep-alive map keyed by the TorchCommBackend* it returns. destroy_comm_impl drops the keep-alive entry. While the entry lives, shared_from_this() inside the NCCL backend constructs a new shared_ptr that aliases our keep-alive one — no upstream changes required. CMakeLists.txt: - Exclude NcclDynamicLoader.cpp from _comms_mscclpp source glob (it's the entry point for the separate _comms_nccl target). - Add _comms_nccl pybind11_add_module target that compiles the upstream torchcomms NCCL backend sources + the framework set shared with _comms_mscclpp + our loader. Links against PyTorch's bundled libnccl.so.2, torch libs, GPU libs, and glog. - Compile with FMT_HEADER_ONLY to sidestep the fmt v11/v12 ABI mismatch between conda's libfmt.so.11 and PyTorch's bundled fmt v12 headers (otherwise the .so dlopen fails with 'undefined symbol: fmt::v12::vformat...'). - Define USE_NVSHMEM to match the historical build's compile flags (NcclApi.cpp checks the macro). - torchcomm_nccl_copy custom target installs the .so into the package source tree alongside _comms_mscclpp. Build with: ./build_torchcomm.sh config cmake --build build-torchcomm --target _comms_nccl -j Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent e729d11 commit 59e95c7

2 files changed

Lines changed: 175 additions & 1 deletion

File tree

python/mscclpp_torchcomms/CMakeLists.txt

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ if(PYBIND11_FIND_RESULT EQUAL 0 AND PYBIND11_CMAKE_DIR)
2929
endif()
3030
find_package(pybind11 REQUIRED)
3131

32-
# Gather our C++ sources
32+
# Gather our C++ sources. NcclDynamicLoader.cpp is the entry point for the
33+
# separate _comms_nccl target — exclude it from the _comms_mscclpp module.
3334
file(GLOB_RECURSE TORCHCOMM_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp)
35+
list(REMOVE_ITEM TORCHCOMM_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/NcclDynamicLoader.cpp)
3436

3537
# Torchcomms framework sources we need to compile in directly.
3638
# Our module inherits from TorchWork, TorchCommBackend, and registers with
@@ -109,3 +111,93 @@ add_custom_target(torchcomm_lib_copy ALL
109111
${CMAKE_CURRENT_SOURCE_DIR}
110112
DEPENDS _comms_mscclpp
111113
)
114+
115+
# -----------------------------------------------------------------------------
116+
# Second target: _comms_nccl
117+
#
118+
# This builds the upstream torchcomms NCCL backend sources from
119+
# build-torchcomm/_deps/torchcomms-src/comms/torchcomms/nccl/ into a separate
120+
# .so, paired with our own NcclDynamicLoader.cpp that publishes the
121+
# create_dynamic_loader_nccl entry point torchcomms's TorchCommFactory dlopen
122+
# path requires.
123+
#
124+
# Why we don't reuse the upstream NCCL CMakeLists: it depends on the
125+
# torchcomms top-level CMake project that defines the `torchcomms` static
126+
# library and the ROOT, CONDA_INCLUDE, etc. variables. Mirroring the small
127+
# bit we need here keeps the dependency tree shallow and ensures the .so we
128+
# produce loads correctly through TorchCommFactory.
129+
# -----------------------------------------------------------------------------
130+
131+
set(TORCHCOMMS_NCCL_BACKEND_SOURCES
132+
${torchcomms_SOURCE_DIR}/comms/torchcomms/nccl/NcclApi.cpp
133+
${torchcomms_SOURCE_DIR}/comms/torchcomms/nccl/TorchCommNCCL.cpp
134+
${torchcomms_SOURCE_DIR}/comms/torchcomms/nccl/TorchCommNCCLBootstrap.cpp
135+
${torchcomms_SOURCE_DIR}/comms/torchcomms/nccl/TorchCommNCCLCCA.cpp
136+
${torchcomms_SOURCE_DIR}/comms/torchcomms/nccl/TorchCommNCCLPy.cpp
137+
${torchcomms_SOURCE_DIR}/comms/torchcomms/nccl/TorchCommNCCLUtils.cpp
138+
${torchcomms_SOURCE_DIR}/comms/torchcomms/nccl/TorchWorkNCCL.cpp
139+
${torchcomms_SOURCE_DIR}/comms/torchcomms/nccl/TorchWorkNCCLQueue.cpp
140+
${torchcomms_SOURCE_DIR}/comms/torchcomms/device/cuda/CudaApi.cpp
141+
)
142+
143+
# Framework sources are shared between _comms_mscclpp and _comms_nccl. Each
144+
# .so gets its own copy of TorchCommFactory's singleton (RTLD_LOCAL load) so
145+
# the registrations don't leak across backends.
146+
pybind11_add_module(_comms_nccl
147+
${TORCHCOMMS_NCCL_BACKEND_SOURCES}
148+
${TORCHCOMMS_FRAMEWORK_SOURCES}
149+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/NcclDynamicLoader.cpp
150+
# TracingGuard.cpp lives outside the framework set used by mscclpp because
151+
# only the NCCL backend references it (via TracingGuard.hpp).
152+
${torchcomms_SOURCE_DIR}/comms/torchcomms/utils/TracingGuard.cpp
153+
)
154+
155+
# Locate libnccl.so.2 shipped with PyTorch (the same one PyTorch's c10d uses).
156+
get_filename_component(_TORCH_LIB_DIR "${TORCH_INSTALL_PREFIX}/lib" ABSOLUTE)
157+
set(_NCCL_PYTORCH_INCLUDE "${TORCH_INSTALL_PREFIX}/../nvidia/nccl/include")
158+
set(_NCCL_PYTORCH_LIB "${TORCH_INSTALL_PREFIX}/../nvidia/nccl/lib/libnccl.so.2")
159+
if(NOT EXISTS "${_NCCL_PYTORCH_LIB}")
160+
message(FATAL_ERROR
161+
"Could not find PyTorch's bundled libnccl.so.2 at ${_NCCL_PYTORCH_LIB}. "
162+
"Install PyTorch with the nvidia-nccl-cu12 wheel.")
163+
endif()
164+
165+
target_include_directories(_comms_nccl SYSTEM PRIVATE
166+
${torchcomms_SOURCE_DIR}
167+
${_NCCL_PYTORCH_INCLUDE}
168+
${GPU_INCLUDE_DIRS}
169+
)
170+
171+
target_link_libraries(_comms_nccl PRIVATE
172+
${TORCH_LIBRARIES}
173+
${GPU_LIBRARIES}
174+
${_NCCL_PYTORCH_LIB}
175+
glog::glog
176+
)
177+
178+
if(EXISTS "${TORCH_PYTHON_LIB}")
179+
target_link_libraries(_comms_nccl PRIVATE "${TORCH_PYTHON_LIB}")
180+
endif()
181+
182+
# Match the NVSHMEM macro used in the historical build (NcclApi.cpp checks it).
183+
target_compile_definitions(_comms_nccl PRIVATE
184+
USE_NVSHMEM
185+
# fmt v12 ships in PyTorch's bundled headers (torch/include/fmt/), but only
186+
# fmt v11 is available as a linkable shared object (libfmt.so.11 in conda).
187+
# Compiling against v12 headers and linking against v11 yields an
188+
# ``undefined symbol: fmt::v12::vformat...`` at dlopen time. Force the
189+
# header-only build of whichever fmt headers are picked so all fmt code
190+
# is inlined into _comms_nccl.so itself, eliminating the runtime dep.
191+
FMT_HEADER_ONLY
192+
)
193+
194+
target_compile_features(_comms_nccl PRIVATE cxx_std_20)
195+
196+
install(TARGETS _comms_nccl LIBRARY DESTINATION mscclpp_torchcomms COMPONENT torchcomm)
197+
198+
add_custom_target(torchcomm_nccl_copy ALL
199+
COMMAND ${CMAKE_COMMAND} -E copy_if_different
200+
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/_comms_nccl*.so
201+
${CMAKE_CURRENT_SOURCE_DIR}
202+
DEPENDS _comms_nccl
203+
)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
// Dynamic-loader entry point for the upstream torchcomms NCCL backend.
5+
//
6+
// Why this file exists separately from the upstream torchcomms tree:
7+
//
8+
// The torchcomms TorchCommFactory dlopen-based loader path
9+
// (TorchCommFactory::create_generic_backend in TorchCommFactory.cpp) wraps the
10+
// raw pointer returned by `loader.new_comm()` in a
11+
// `std::shared_ptr<TorchCommBackend>(rawBackendPtr, deleter)`. The
12+
// `enable_shared_from_this<Y>` mechanism only initializes its internal
13+
// weak_ptr when the shared_ptr is constructed from a pointer to the *derived*
14+
// type `Y`. Constructing `shared_ptr<TorchCommBackend>` from a pointer
15+
// statically typed as `TorchCommBackend*` skips that machinery, so when
16+
// `TorchCommNCCL::createWork()` later calls `shared_from_this()` it throws
17+
// `std::bad_weak_ptr` (the very first all_reduce crashes).
18+
//
19+
// To work around this without patching torchcomms, this loader keeps a
20+
// keep-alive `shared_ptr<TorchCommNCCL>` (created via `std::make_shared` so
21+
// the weak_ptr is set up correctly) alive in a static map keyed by the raw
22+
// `TorchCommBackend*` we hand back. The factory still wraps our pointer in
23+
// its own `shared_ptr<TorchCommBackend>` for ownership semantics, and
24+
// destroy_comm_impl drops the keep-alive entry — but as long as the entry
25+
// lives, `shared_from_this()` inside the NCCL backend successfully
26+
// constructs a new shared_ptr that aliases our keep-alive one.
27+
28+
#include <comms/torchcomms/TorchCommBackend.hpp>
29+
#include <comms/torchcomms/nccl/TorchCommNCCL.hpp>
30+
31+
#include <memory>
32+
#include <mutex>
33+
#include <unordered_map>
34+
35+
namespace {
36+
37+
std::mutex& keepaliveMutex() {
38+
static std::mutex m;
39+
return m;
40+
}
41+
42+
std::unordered_map<torch::comms::TorchCommBackend*, std::shared_ptr<torch::comms::TorchCommNCCL>>&
43+
keepaliveMap() {
44+
static std::unordered_map<torch::comms::TorchCommBackend*, std::shared_ptr<torch::comms::TorchCommNCCL>>
45+
m;
46+
return m;
47+
}
48+
49+
torch::comms::TorchCommBackend* new_comm_impl() {
50+
auto sp = std::make_shared<torch::comms::TorchCommNCCL>();
51+
auto* base = static_cast<torch::comms::TorchCommBackend*>(sp.get());
52+
{
53+
std::lock_guard<std::mutex> guard(keepaliveMutex());
54+
keepaliveMap().emplace(base, std::move(sp));
55+
}
56+
return base;
57+
}
58+
59+
void destroy_comm_impl(torch::comms::TorchCommBackend* comm) {
60+
std::lock_guard<std::mutex> guard(keepaliveMutex());
61+
auto it = keepaliveMap().find(comm);
62+
if (it != keepaliveMap().end()) {
63+
keepaliveMap().erase(it);
64+
} else {
65+
delete comm;
66+
}
67+
}
68+
69+
const char* get_supported_version_impl() {
70+
return torch::comms::TORCHCOMM_BACKEND_ABI_VERSION;
71+
}
72+
73+
} // namespace
74+
75+
extern "C" __attribute__((visibility("default"))) torch::comms::DynamicLoaderInterface
76+
create_dynamic_loader_nccl() {
77+
return torch::comms::DynamicLoaderInterface{
78+
.new_comm = new_comm_impl,
79+
.destroy_comm = destroy_comm_impl,
80+
.get_supported_version = get_supported_version_impl,
81+
};
82+
}

0 commit comments

Comments
 (0)