Skip to content

Commit 8e0853a

Browse files
gdengkpre-commit-ci[bot]
authored andcommitted
Introduce NVSHMEM based communication API for pytorch (#1430)
* add nvshmem based api support Signed-off-by: gdeng <[email protected]> * fix lint and license issue Signed-off-by: gdeng <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove asset Signed-off-by: gdeng <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the lib Signed-off-by: gdeng <[email protected]> * address comments Signed-off-by: gdeng <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: gdeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 02180cd commit 8e0853a

File tree

10 files changed

+312
-1
lines changed

10 files changed

+312
-1
lines changed

build_tools/pytorch.py

+15
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,19 @@ def setup_pytorch_extension(
8989
cxx_flags.append("-DNVTE_UB_WITH_MPI")
9090
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
9191

92+
library_dirs = []
93+
libraries = []
94+
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
95+
assert (
96+
os.getenv("NVSHMEM_HOME") is not None
97+
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
98+
nvshmem_home = Path(os.getenv("NVSHMEM_HOME"))
99+
include_dirs.append(nvshmem_home / "include")
100+
library_dirs.append(nvshmem_home / "lib")
101+
libraries.append("nvshmem_host")
102+
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
103+
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
104+
92105
# Construct PyTorch CUDA extension
93106
sources = [str(path) for path in sources]
94107
include_dirs = [str(path) for path in include_dirs]
@@ -102,4 +115,6 @@ def setup_pytorch_extension(
102115
"cxx": cxx_flags,
103116
"nvcc": nvcc_flags,
104117
},
118+
libraries=[str(lib) for lib in libraries],
119+
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
105120
)

setup.py

+6
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def setup_common_extension() -> CMakeExtension:
6464
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
6565
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")
6666

67+
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
68+
assert (
69+
os.getenv("NVSHMEM_HOME") is not None
70+
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
71+
cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")
72+
6773
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
6874
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
6975

transformer_engine/common/CMakeLists.txt

+9
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
9696
target_include_directories(transformer_engine PUBLIC
9797
"${CMAKE_CURRENT_SOURCE_DIR}/include")
9898

99+
100+
99101
# Configure dependencies
100102
target_link_libraries(transformer_engine PUBLIC
101103
CUDA::cublas
@@ -114,6 +116,13 @@ if (NVTE_UB_WITH_MPI)
114116
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
115117
endif()
116118

119+
option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF)
120+
if (NVTE_ENABLE_NVSHMEM)
121+
add_subdirectory(nvshmem_api)
122+
target_link_libraries(transformer_engine PUBLIC nvshmemapi)
123+
target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR})
124+
endif()
125+
117126
# Hack to enable dynamic loading in cuDNN frontend
118127
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
119128

transformer_engine/common/libtransformer_engine.version

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
transformer_engine::typeToSize*;
1515
*transformer_engine::CommOverlapBase*;
1616
*transformer_engine::CommOverlapP2PBase*;
17-
*transformer_engine::CommOverlapCore*
17+
*transformer_engine::CommOverlapCore*;
18+
*nvshmem_wait_on_stream*;
19+
*nvshmemi_init_thread*
1820
};
1921
local: *;
2022
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
##########################################################################
2+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
##########################################################################
6+
cmake_minimum_required (VERSION 3.18)
7+
project(nvshmemapi LANGUAGES CXX CUDA)
8+
9+
# Configure dependencies
10+
find_package(CUDAToolkit REQUIRED)
11+
# find_package(MPI REQUIRED)
12+
set(NVSHMEM_HOME "$ENV{NVSHMEM_HOME}" CACHE STRING "Location of NVSHMEM installation")
13+
14+
add_library(nvshmemapi STATIC nvshmem_waitkernel.cu)
15+
set(NVSHMEMAPI_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}" PARENT_SCOPE)
16+
target_link_directories(nvshmemapi PUBLIC ${NVSHMEM_HOME}/lib)
17+
target_link_libraries(nvshmemapi PUBLIC -static-libstdc++ nvshmem_device nvshmem_host CUDA::nvml CUDA::cublas CUDA::cuda_driver)
18+
target_include_directories(nvshmemapi PRIVATE
19+
${NVSHMEM_HOME}/include/)
20+
target_include_directories(nvshmemapi PUBLIC
21+
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
22+
"${CMAKE_CURRENT_SOURCE_DIR}")
23+
24+
set_target_properties(nvshmemapi PROPERTIES
25+
CUDA_STANDARD 17
26+
POSITION_INDEPENDENT_CODE ON
27+
CUDA_SEPARABLE_COMPILATION ON)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#include <cuda.h>
8+
#include <cuda_bf16.h>
9+
#include <nvshmem.h>
10+
11+
#include <cstdio>
12+
#include <cstdlib>
13+
#include <functional>
14+
#include <iostream>
15+
#include <sstream>
16+
#include <string>
17+
18+
#include "../util/logging.h"
19+
#include "nvshmem_waitkernel.h"
20+
21+
__global__ void __launch_bounds__(1)
22+
wait_until_on_stream_and_reset(uint64_t* wait_flag, uint64_t wait_value,
23+
uint64_t signal_reset) {
24+
nvshmem_uint64_wait_until(wait_flag, NVSHMEM_CMP_EQ, wait_value);
25+
*wait_flag = signal_reset;
26+
}
27+
void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t stream) {
28+
uint64_t wait_value = 1;
29+
uint64_t signal_reset = 0;
30+
cudaStream_t cur_stream = stream;
31+
32+
NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT && wait_kind <= WaitKind::STREAM_WAIT,
33+
"Invalid wait kind: ", static_cast<int>(wait_kind));
34+
35+
switch (wait_kind) {
36+
case WaitKind::KERNEL_WAIT:
37+
wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset);
38+
break;
39+
case WaitKind::NVSHMEM_WAIT:
40+
nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream);
41+
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
42+
CU_STREAM_WRITE_VALUE_DEFAULT);
43+
break;
44+
case WaitKind::STREAM_WAIT:
45+
cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value,
46+
CU_STREAM_WAIT_VALUE_GEQ);
47+
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
48+
CU_STREAM_WRITE_VALUE_DEFAULT);
49+
break;
50+
}
51+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#ifndef TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
8+
#define TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
9+
10+
#ifdef __cplusplus
11+
#include <cstdint>
12+
extern "C" {
13+
#else
14+
#include <stdint.h>
15+
#endif
16+
17+
/*! \enum WaitKind
18+
* \brief Types of wait operations that can be performed.
19+
*/
20+
enum class WaitKind {
21+
KERNEL_WAIT = 0, /*!< Wait using a CUDA kernel */
22+
NVSHMEM_WAIT = 1, /*!< Wait using NVSHMEM wait operation */
23+
STREAM_WAIT = 2 /*!< Wait using CUDA stream synchronization */
24+
};
25+
26+
/*! \brief Wait on a signal until a certain condition is met.
27+
*
28+
* \param[in] sig_addr The address of the signal to wait on.
29+
* \param[in] wait_kind The kind of wait to perform.
30+
* \param[in] stream The stream to wait on.
31+
*/
32+
void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t stream);
33+
34+
#ifdef __cplusplus
35+
} // extern "C"
36+
#endif
37+
38+
#endif // TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H

transformer_engine/pytorch/csrc/extensions.h

+17
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,23 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
373373
std::vector<size_t> input_row_list,
374374
std::vector<size_t> padded_input_row_list);
375375

376+
/***************************************************************************************************
377+
* NVSHMEM APIs
378+
**************************************************************************************************/
379+
380+
namespace nvshmem_api {
381+
void init_nvshmem_backend(c10d::ProcessGroup *process_group);
382+
383+
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
384+
385+
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
386+
torch::Tensor signal);
387+
388+
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind);
389+
390+
void nvshmem_finalize();
391+
} // namespace nvshmem_api
392+
376393
/***************************************************************************************************
377394
* swizzle
378395
**************************************************************************************************/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#include "../extensions.h"
8+
9+
#ifdef NVTE_ENABLE_NVSHMEM
10+
#include <nvshmem.h>
11+
#include <nvshmem_api/nvshmem_waitkernel.h>
12+
#include <nvshmemx.h>
13+
#endif
14+
15+
#include <cuda.h>
16+
#include <cuda_fp8.h>
17+
#include <torch/cuda.h>
18+
#include <torch/extension.h>
19+
20+
namespace nvshmem_api {
21+
void init_nvshmem_backend(c10d::ProcessGroup *process_group) {
22+
#ifdef NVTE_ENABLE_NVSHMEM
23+
nvshmemx_init_attr_t attr = {};
24+
nvshmemx_uniqueid_t id = {};
25+
26+
int my_rank = process_group->getRank();
27+
int num_ranks = process_group->getSize();
28+
if (my_rank == 0) {
29+
nvshmemx_get_uniqueid(&id);
30+
}
31+
32+
auto backend_is_nccl = (process_group->getBackendType() == c10d::ProcessGroup::BackendType::NCCL);
33+
NVTE_CHECK(backend_is_nccl, "Currently only support NCCL boostrap for NVSHMEM");
34+
auto datatensor =
35+
torch::from_blob(reinterpret_cast<void *>(&id),
36+
{static_cast<int64_t>(sizeof(nvshmemx_uniqueid_t) / sizeof(uint8_t))},
37+
at::device(torch::kCPU).dtype(torch::kUInt8));
38+
auto datatmp = (backend_is_nccl) ? datatensor.cuda() : datatensor;
39+
40+
c10d::BroadcastOptions bcast_opts;
41+
bcast_opts.rootRank = 0;
42+
std::vector<torch::Tensor> datachunk = {datatmp};
43+
auto work = process_group->broadcast(datachunk, bcast_opts);
44+
work->wait();
45+
46+
if (backend_is_nccl) {
47+
datatensor.copy_(datatmp.cpu());
48+
datatmp = torch::Tensor();
49+
}
50+
51+
nvshmemx_set_attr_uniqueid_args(my_rank, num_ranks, &id, &attr);
52+
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
53+
54+
NVTE_CHECK(my_rank == nvshmem_my_pe(), "my_rank: ", my_rank,
55+
" != nvshmem_my_pe(): ", nvshmem_my_pe());
56+
NVTE_CHECK(num_ranks == nvshmem_n_pes(), "num_ranks: ", num_ranks,
57+
" != nvshmem_n_pes(): ", nvshmem_n_pes());
58+
#else
59+
NVTE_ERROR("Internal TE error: init_nvshmem_backend cannot be initialized with valid PyTorch ",
60+
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
61+
#endif
62+
}
63+
64+
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind) {
65+
#ifdef NVTE_ENABLE_NVSHMEM
66+
uint64_t *sig_addr = reinterpret_cast<uint64_t *>(signal.data_ptr());
67+
cudaStream_t cur_stream = (cudaStream_t)at::cuda::getCurrentCUDAStream();
68+
69+
WaitKind wait_kind_enum = WaitKind::STREAM_WAIT;
70+
71+
if (wait_kind == "kernel") {
72+
wait_kind_enum = WaitKind::KERNEL_WAIT;
73+
} else if (wait_kind == "nvshmem") {
74+
wait_kind_enum = WaitKind::NVSHMEM_WAIT;
75+
} else if (wait_kind == "stream") {
76+
wait_kind_enum = WaitKind::STREAM_WAIT;
77+
} else {
78+
NVTE_ERROR("Invalid wait kind: ", wait_kind);
79+
}
80+
nvshmem_wait_on_stream(sig_addr, wait_kind_enum, cur_stream);
81+
82+
#else
83+
NVTE_ERROR(
84+
"Internal TE error: nvshmem_wait_on_current_stream cannot be initialized with valid PyTorch ",
85+
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
86+
#endif
87+
}
88+
89+
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype) {
90+
#ifdef NVTE_ENABLE_NVSHMEM
91+
auto option_gpu =
92+
at::TensorOptions().dtype(dtype).device(at::kCUDA).device_index(c10::cuda::current_device());
93+
auto size = torch::elementSize(dtype) *
94+
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
95+
return at::from_blob(
96+
nvshmem_malloc(size), shape, [](void *ptr) { nvshmem_free(ptr); }, option_gpu);
97+
#else
98+
NVTE_ERROR("Internal TE error: create_nvshmem_tensor cannot be initialized with valid PyTorch ",
99+
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
100+
#endif
101+
}
102+
103+
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
104+
torch::Tensor signal) {
105+
#ifdef NVTE_ENABLE_NVSHMEM
106+
void *src_ptr = reinterpret_cast<void *>(src.data_ptr());
107+
void *dst_ptr = reinterpret_cast<void *>(dst.data_ptr());
108+
uint64_t *sig_addr = reinterpret_cast<uint64_t *>(signal.data_ptr());
109+
auto nelement = src.numel() * src.element_size();
110+
uint64_t sigval = 1;
111+
at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream();
112+
113+
nvshmemx_putmem_signal_on_stream(dst_ptr, src_ptr, nelement, sig_addr, sigval, NVSHMEM_SIGNAL_SET,
114+
peer, (cudaStream_t)cur_stream);
115+
#else
116+
NVTE_ERROR(
117+
"Internal TE error: nvshmem_send_on_current_stream cannot be initialized with valid PyTorch ",
118+
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
119+
#endif
120+
}
121+
void nvshmem_finalize() {
122+
#ifdef NVTE_ENABLE_NVSHMEM
123+
nvshmem_finalize();
124+
#else
125+
NVTE_ERROR("Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch ",
126+
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
127+
#endif
128+
}
129+
} // namespace nvshmem_api

transformer_engine/pytorch/csrc/extensions/pybind.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
234234
"Generate partitioned indices for inputs in THD format",
235235
py::call_guard<py::gil_scoped_release>());
236236

237+
// nvshmem functions
238+
m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend,
239+
"Initialize nvshmem backend with Pytorch distributed process groups",
240+
py::call_guard<py::gil_scoped_release>());
241+
m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor,
242+
"Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
243+
m.def("nvshmem_send_on_current_stream", &nvshmem_api::nvshmem_send_on_current_stream,
244+
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream",
245+
py::call_guard<py::gil_scoped_release>());
246+
m.def("nvshmem_wait_on_current_stream", &nvshmem_api::nvshmem_wait_on_current_stream,
247+
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
248+
"stream",
249+
py::call_guard<py::gil_scoped_release>());
250+
m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize,
251+
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
252+
py::call_guard<py::gil_scoped_release>());
253+
237254
// multi-tensor functions
238255
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
239256
"Fused overflow check + scale for a list of contiguous tensors",

0 commit comments

Comments
 (0)