Skip to content

Commit c84a75f

Browse files
authored
refactor(nccl/test): share NCCL context across unit tests (#239)
* refactor(nccl/test): make the `Ctx` a singleton This change allows to initialize the NCCL context only once instead of per test. This will significantly reduce the wall time of NCCL-based tests. * refactor(test): update main to initialize the NCCL ctx once * refactor(test): update tests to retrieve the global NCCL ctx Don't create a NCCL context per test, reuse the global instance. * fix(test): undefined NCCL/CUDA check macros Signed-off-by: Gabriel Dos Santos <gabriel.dossantos@cea.fr>
1 parent 31e873b commit c84a75f

19 files changed

Lines changed: 252 additions & 190 deletions

cmake/kc-test.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ function(kc_add_unit_test name)
1515
if(UT_CORE)
1616
target_link_libraries(${name} PRIVATE KokkosComm::KokkosComm MPI::MPI_CXX)
1717
if(KokkosComm_ENABLE_NCCL)
18+
target_sources(${name} PRIVATE nccl/utils.cpp)
1819
target_link_libraries(${name} PRIVATE NCCL::NCCL)
1920
endif()
2021
elseif(UT_MPI)

unit_tests/CMakeLists.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,42 +179,42 @@ if(KokkosComm_ENABLE_NCCL)
179179
test.nccl.p2p
180180
NCCL
181181
NUM_PES 2
182-
FILES test_main.cpp nccl/test_point_to_point.cpp
182+
FILES test_main.cpp nccl/test_point_to_point.cpp nccl/utils.cpp
183183
LIBRARIES KokkosComm::KokkosComm MPI::MPI_CXX
184184
)
185185
kc_add_unit_test(
186186
test.nccl.broadcast
187187
NCCL
188188
NUM_PES 2
189-
FILES test_main.cpp nccl/test_broadcast.cpp
189+
FILES test_main.cpp nccl/test_broadcast.cpp nccl/utils.cpp
190190
LIBRARIES KokkosComm::KokkosComm MPI::MPI_CXX
191191
)
192192
kc_add_unit_test(
193193
test.nccl.all-gather
194194
NCCL
195195
NUM_PES 2
196-
FILES test_main.cpp nccl/test_allgather.cpp
196+
FILES test_main.cpp nccl/test_allgather.cpp nccl/utils.cpp
197197
LIBRARIES KokkosComm::KokkosComm MPI::MPI_CXX
198198
)
199199
kc_add_unit_test(
200200
test.nccl.all-to-all
201201
NCCL
202202
NUM_PES 2
203-
FILES test_main.cpp nccl/test_alltoall.cpp
203+
FILES test_main.cpp nccl/test_alltoall.cpp nccl/utils.cpp
204204
LIBRARIES KokkosComm::KokkosComm MPI::MPI_CXX
205205
)
206206
kc_add_unit_test(
207207
test.nccl.all-reduce
208208
NCCL
209209
NUM_PES 2
210-
FILES test_main.cpp nccl/test_allreduce.cpp
210+
FILES test_main.cpp nccl/test_allreduce.cpp nccl/utils.cpp
211211
LIBRARIES KokkosComm::KokkosComm MPI::MPI_CXX
212212
)
213213
kc_add_unit_test(
214214
test.nccl.reduce
215215
NCCL
216216
NUM_PES 2
217-
FILES test_main.cpp nccl/test_reduce.cpp
217+
FILES test_main.cpp nccl/test_reduce.cpp nccl/utils.cpp
218218
LIBRARIES KokkosComm::KokkosComm MPI::MPI_CXX
219219
)
220220
endif()

unit_tests/logging.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
#include <cstdlib>
88
#include <string_view>
99

10-
#include <cuda.h>
10+
#include <KokkosComm/config.hpp>
1111
#include <mpi.h>
1212
#if defined(KOKKOSCOMM_ENABLE_NCCL)
1313
#include <nccl.h>
14+
#include <cuda_runtime.h>
1415
#endif
15-
1616
#include <fmt/core.h>
1717

1818
namespace logging {
@@ -47,24 +47,24 @@ constexpr std::array level_txt{"FATAL"sv, "ERROR"sv, "WARNING"sv, "INFO"sv, "TRA
4747

4848
#define KC_CHECK(expr, ...) ((expr) ? void(0) : KC_FATAL(__VA_ARGS__))
4949

50-
#define KC_CUDA_CHECK(expr) \
51-
([&]() { \
52-
cudaError_t kc_res_ = (expr); \
53-
return kc_res_ == cudaSuccess ? void(0) \
54-
: KC_FATAL("CUDA check failed: `" #expr "`: {}", cudaGetErrorString(kc_res_)); \
55-
}())
56-
5750
#define KC_MPI_CHECK(expr) \
5851
([&]() { \
5952
int kc_res_ = (expr); \
6053
return kc_res_ == MPI_SUCCESS ? void(0) : KC_FATAL("MPI check failed: `" #expr "`: {}", kc_res_); \
6154
}())
6255

63-
#if defined(KOKKOSCOMM_ENABLE_NCCL)
56+
#ifdef KOKKOSCOMM_ENABLE_NCCL
6457
#define KC_NCCL_CHECK(expr) \
6558
([&]() { \
6659
ncclResult_t kc_res_ = (expr); \
6760
return kc_res_ == ncclSuccess ? void(0) \
6861
: KC_FATAL("NCCL check failed: `" #expr "`: {}", ncclGetErrorString(kc_res_)); \
6962
}())
63+
64+
#define KC_CUDA_CHECK(expr) \
65+
([&]() { \
66+
cudaError_t kc_res_ = (expr); \
67+
return kc_res_ == cudaSuccess ? void(0) \
68+
: KC_FATAL("CUDA check failed: `" #expr "`: {}", cudaGetErrorString(kc_res_)); \
69+
}())
7070
#endif

unit_tests/nccl/test_allgather.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ TYPED_TEST_SUITE(AllGather, ScalarTypes);
2121

2222
template <typename Scalar>
2323
auto allgather_0d() -> void {
24-
auto nccl_ctx = test_utils::nccl::Ctx::init();
24+
auto& nccl_ctx = test_utils::NcclCtx::get();
2525
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
2626
const auto comm = nccl_ctx.comm();
2727
const int size = nccl_ctx.size();
@@ -48,7 +48,7 @@ auto allgather_0d() -> void {
4848

4949
template <typename Scalar>
5050
auto allgather_contig_1d() -> void {
51-
auto nccl_ctx = test_utils::nccl::Ctx::init();
51+
auto& nccl_ctx = test_utils::NcclCtx::get();
5252
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
5353
const auto comm = nccl_ctx.comm();
5454
const int size = nccl_ctx.size();

unit_tests/nccl/test_allreduce.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ TYPED_TEST_SUITE(AllReduce, ScalarTypes);
2121

2222
template <typename Scalar>
2323
auto allreduce_0d() -> void {
24-
auto nccl_ctx = test_utils::nccl::Ctx::init();
24+
auto& nccl_ctx = test_utils::NcclCtx::get();
2525
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
2626
const auto comm = nccl_ctx.comm();
2727
const int size = nccl_ctx.size();
@@ -41,23 +41,23 @@ auto allreduce_0d() -> void {
4141

4242
int errs;
4343
Kokkos::parallel_reduce(
44-
rv.extent(0), KOKKOS_LAMBDA(const int, int &lsum) { lsum += (rv() != size * (size - 1) / 2); }, errs
44+
rv.extent(0), KOKKOS_LAMBDA(const int, int& lsum) { lsum += (rv() != size * (size - 1) / 2); }, errs
4545
);
4646
EXPECT_EQ(errs, 0);
4747
}
4848

4949
template <typename Scalar>
5050
auto allreduce_contig_1d() -> void {
51-
auto nccl_ctx = test_utils::nccl::Ctx::init();
51+
auto& nccl_ctx = test_utils::NcclCtx::get();
5252
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
5353
const auto comm = nccl_ctx.comm();
5454
const int size = nccl_ctx.size();
5555
const int rank = nccl_ctx.rank();
5656
const int root = 0;
5757

5858
const int n_contrib = 10;
59-
Kokkos::View<Scalar *> sv("sv", n_contrib);
60-
Kokkos::View<Scalar *> rv("rv", n_contrib);
59+
Kokkos::View<Scalar*> sv("sv", n_contrib);
60+
Kokkos::View<Scalar*> rv("rv", n_contrib);
6161

6262
// Prepare send buffer
6363
Kokkos::parallel_for(
@@ -69,7 +69,7 @@ auto allreduce_contig_1d() -> void {
6969

7070
int errs;
7171
Kokkos::parallel_reduce(
72-
rv.extent(0), KOKKOS_LAMBDA(const int i, int &lsum) { lsum += (rv(i) != size * (size - 1) / 2 + size * i); }, errs
72+
rv.extent(0), KOKKOS_LAMBDA(const int i, int& lsum) { lsum += (rv(i) != size * (size - 1) / 2 + size * i); }, errs
7373
);
7474
EXPECT_EQ(errs, 0);
7575
}

unit_tests/nccl/test_alltoall.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ TYPED_TEST_SUITE(AllToAll, ScalarTypes);
2121

2222
template <typename Scalar>
2323
auto alltoall_contig_1d() -> void {
24-
auto nccl_ctx = test_utils::nccl::Ctx::init();
24+
auto& nccl_ctx = test_utils::NcclCtx::get();
2525
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
2626
const auto comm = nccl_ctx.comm();
2727
const int size = nccl_ctx.size();
2828
const int rank = nccl_ctx.rank();
2929
const int root = 0;
3030

3131
const int n_contrib = 100;
32-
Kokkos::View<Scalar *> sv("sv", size * n_contrib);
33-
Kokkos::View<Scalar *> rv("rv", size * n_contrib);
32+
Kokkos::View<Scalar*> sv("sv", size * n_contrib);
33+
Kokkos::View<Scalar*> rv("rv", size * n_contrib);
3434

3535
// Prepare send view
3636
Kokkos::parallel_for(
@@ -43,7 +43,7 @@ auto alltoall_contig_1d() -> void {
4343
int errs;
4444
Kokkos::parallel_reduce(
4545
rv.extent(0),
46-
KOKKOS_LAMBDA(const int i, int &lsum) {
46+
KOKKOS_LAMBDA(const int i, int& lsum) {
4747
const int src = i / n_contrib; // who sent this data
4848
const int j = rank * n_contrib + (i % n_contrib); // what index i was at the source
4949
lsum += rv(i) != src + j;

unit_tests/nccl/test_broadcast.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ TYPED_TEST_SUITE(Broadcast, ScalarTypes);
2020

2121
template <typename Scalar>
2222
auto broadcast_0d() -> void {
23-
auto nccl_ctx = test_utils::nccl::Ctx::init();
23+
auto& nccl_ctx = test_utils::NcclCtx::get();
2424
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
2525
const auto comm = nccl_ctx.comm();
2626
const int size = nccl_ctx.size();
@@ -47,7 +47,7 @@ auto broadcast_0d() -> void {
4747

4848
template <typename Scalar>
4949
auto broadcast_contig_1d() -> void {
50-
auto nccl_ctx = test_utils::nccl::Ctx::init();
50+
auto& nccl_ctx = test_utils::NcclCtx::get();
5151
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
5252
const auto comm = nccl_ctx.comm();
5353
const int size = nccl_ctx.size();

unit_tests/nccl/test_point_to_point.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ TYPED_TEST_SUITE(PointToPoint, ScalarTypes);
2727

2828
template <typename Scalar>
2929
auto p2p_contig_1d() -> void {
30-
auto nccl_ctx = test_utils::nccl::Ctx::init();
30+
auto& nccl_ctx = test_utils::NcclCtx::get();
3131
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
3232
const auto comm = nccl_ctx.comm();
3333
const int size = nccl_ctx.size();
@@ -60,7 +60,7 @@ auto p2p_contig_1d() -> void {
6060

6161
template <typename Scalar>
6262
auto p2p_noncontig_1d() -> void {
63-
auto nccl_ctx = test_utils::nccl::Ctx::init();
63+
auto& nccl_ctx = test_utils::NcclCtx::get();
6464
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
6565
const auto comm = nccl_ctx.comm();
6666
const int size = nccl_ctx.size();

unit_tests/nccl/test_reduce.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ TYPED_TEST_SUITE(Reduce, ScalarTypes);
2323
/// operation is sum, so recvbuf[i] should be sum(0..size) + i * size
2424
template <typename Scalar>
2525
auto reduce_contig_1d() -> void {
26-
auto nccl_ctx = test_utils::nccl::Ctx::init();
26+
auto& nccl_ctx = test_utils::NcclCtx::get();
2727
const auto exec = Kokkos::Cuda(nccl_ctx.stream());
2828
const auto comm = nccl_ctx.comm();
2929
const int size = nccl_ctx.size();

unit_tests/nccl/utils.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#include <memory>
2+
#include <mutex>
3+
4+
#include <mpi.h>
5+
#include <nccl.h>
6+
#include <cuda_runtime.h>
7+
8+
#include "utils.hpp"
9+
#include "../logging.hpp"
10+
11+
namespace {
12+
13+
[[nodiscard]] auto get_local_rank(MPI_Comm comm, int my_rank) -> int {
14+
MPI_Comm node_comm;
15+
MPI_Comm_split_type(comm, MPI_COMM_TYPE_SHARED, my_rank, MPI_INFO_NULL, &node_comm);
16+
17+
int node_rank;
18+
MPI_Comm_rank(node_comm, &node_rank);
19+
20+
MPI_Comm_free(&node_comm);
21+
return node_rank;
22+
}
23+
24+
} // namespace
25+
26+
namespace test_utils {
27+
28+
std::unique_ptr<NcclCtx> NcclCtx::instance_{};
29+
std::once_flag NcclCtx::init_flag_{};
30+
31+
NcclCtx::NcclCtx(ncclComm_t comm, cudaStream_t stream, int dev, int size, int rank)
32+
: comm_(comm), stream_(stream), dev_(dev), size_(size), rank_(rank) {}
33+
34+
NcclCtx::~NcclCtx() {
35+
if (stream_ != nullptr) {
36+
cudaStreamDestroy(stream_);
37+
}
38+
if (comm_ != nullptr) {
39+
ncclCommDestroy(comm_);
40+
}
41+
}
42+
43+
auto NcclCtx::init(bool verbose) -> void {
44+
std::call_once(init_flag_, [verbose]() {
45+
int flag = 0;
46+
KC_MPI_CHECK(MPI_Initialized(&flag));
47+
KC_CHECK(flag != 0, "MPI is not initialized");
48+
49+
MPI_Comm mpi_comm = MPI_COMM_WORLD;
50+
51+
int size = 0;
52+
KC_MPI_CHECK(MPI_Comm_size(mpi_comm, &size));
53+
int rank = 0;
54+
KC_MPI_CHECK(MPI_Comm_rank(mpi_comm, &rank));
55+
56+
int local_rank = get_local_rank(mpi_comm, rank);
57+
58+
int devs = 0;
59+
KC_CUDA_CHECK(cudaGetDeviceCount(&devs));
60+
61+
if (verbose) {
62+
KC_INFO("P{} found {} CUDA devices", rank, devs);
63+
}
64+
65+
KC_CHECK(local_rank < devs, "P{} needs device #{} but only {} devices available", rank, local_rank, devs);
66+
67+
KC_CUDA_CHECK(cudaSetDevice(local_rank));
68+
69+
if (verbose) {
70+
KC_INFO("P{} assigned to CUDA device #{}", rank, local_rank);
71+
}
72+
73+
ncclUniqueId nccl_id{};
74+
if (rank == 0) {
75+
KC_NCCL_CHECK(ncclGetUniqueId(&nccl_id));
76+
}
77+
78+
KC_MPI_CHECK(MPI_Bcast(&nccl_id, NCCL_UNIQUE_ID_BYTES, MPI_CHAR, 0, mpi_comm));
79+
80+
ncclComm_t nccl_comm = nullptr;
81+
KC_NCCL_CHECK(ncclCommInitRank(&nccl_comm, size, nccl_id, rank));
82+
83+
cudaStream_t stream = nullptr;
84+
KC_CUDA_CHECK(cudaStreamCreate(&stream));
85+
86+
instance_ = std::unique_ptr<NcclCtx>(new NcclCtx(nccl_comm, stream, local_rank, size, rank));
87+
});
88+
}
89+
90+
auto NcclCtx::fini() -> void { instance_.reset(); }
91+
92+
auto NcclCtx::get() -> NcclCtx& {
93+
KC_CHECK(instance_ != nullptr, "NCCL context not initialized");
94+
return *instance_;
95+
}
96+
97+
auto NcclCtx::comm() const -> ncclComm_t { return comm_; }
98+
99+
auto NcclCtx::stream() const -> cudaStream_t { return stream_; }
100+
101+
auto NcclCtx::size() const -> int { return size_; }
102+
103+
auto NcclCtx::rank() const -> int { return rank_; }
104+
105+
auto NcclCtx::device() const -> int { return dev_; }
106+
107+
} // namespace test_utils

0 commit comments

Comments
 (0)