Skip to content

Commit b06127e

Browse files
committed
feat(core): expose MPI-based all-gather
Also add associated unit test. Signed-off-by: Gabriel Dos Santos <[email protected]>
1 parent 8c95907 commit b06127e

File tree

3 files changed

+114
-1
lines changed

3 files changed

+114
-1
lines changed

src/KokkosComm/collective.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010
#include "fwd.hpp"
1111
#include "reduction_op.hpp"
12+
#if defined(KOKKOSCOMM_ENABLE_MPI)
13+
#include "mpi/mpi_space.hpp"
14+
#include "mpi/handle.hpp"
15+
#include "mpi/req.hpp"
16+
#include "mpi/allgather.hpp"
17+
#endif
1218
#if defined(KOKKOSCOMM_ENABLE_NCCL)
1319
#include "nccl/nccl_space.hpp"
1420
#include "nccl/handle.hpp"

src/KokkosComm/mpi/allgather.hpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
#include "impl/types.hpp"
1515
#include "impl/error_handling.hpp"
1616

17-
namespace KokkosComm::mpi {
17+
namespace KokkosComm {
18+
namespace mpi {
1819

1920
template <KokkosExecutionSpace ExecSpace, KokkosView SView, KokkosView RView>
2021
auto iallgather(const ExecSpace& space, const SView sv, RView rv, MPI_Comm comm) -> Req<Mpi> {
@@ -92,3 +93,14 @@ void allgather(const ExecSpace &space, const SendView &sv, const RecvView &rv, M
9293
}
9394

9495
} // namespace KokkosComm::mpi
96+
namespace Experimental::Impl {
97+
98+
template <KokkosView SendView, KokkosView RecvView, KokkosExecutionSpace ExecSpace>
99+
struct AllGather<SendView, RecvView, ExecSpace, Mpi> {
100+
static auto execute(Handle<ExecSpace, Mpi>& h, const SendView sv, RecvView rv) -> Req<Mpi> {
101+
return mpi::iallgather(h.space(), sv, rv, h.comm());
102+
}
103+
};
104+
105+
} // namespace Experimental::Impl
106+
} // namespace KokkosComm

unit_tests/test_allgather.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2+
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
3+
4+
#include <gtest/gtest.h>
5+
#include <Kokkos_Core.hpp>
6+
#include <KokkosComm/KokkosComm.hpp>
7+
8+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
9+
#include "nccl/utils.hpp"
10+
#endif
11+
12+
namespace {
13+
14+
template <typename T>
15+
class AllGather : public testing::Test {
16+
public:
17+
using Scalar = T;
18+
};
19+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
20+
using ScalarTypes = testing::Types<float, double, int, int64_t>;
21+
#else
22+
using ScalarTypes =
23+
testing::Types<float, double, Kokkos::complex<float>, Kokkos::complex<double>, int, unsigned, int64_t, size_t>;
24+
#endif
25+
TYPED_TEST_SUITE(AllGather, ScalarTypes);
26+
27+
template <typename Scalar>
28+
auto allgather_0d() -> void {
29+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
30+
using ExecSpace = Kokkos::Cuda;
31+
auto nccl_ctx = test_utils::nccl::Ctx::init();
32+
KokkosComm::Handle<ExecSpace, KokkosComm::Nccl> h(ExecSpace(), nccl_ctx.comm());
33+
#else
34+
using ExecSpace = Kokkos::DefaultExecutionSpace;
35+
KokkosComm::Handle<Kokkos::DefaultExecutionSpace, KokkosComm::Mpi> h{};
36+
#endif
37+
int rank = h.rank();
38+
int size = h.size();
39+
40+
Kokkos::View<Scalar> sv("sv");
41+
Kokkos::View<Scalar *> rv("rv", size);
42+
43+
// Prepare send view, 1 element per sender: their rank
44+
Kokkos::parallel_for(
45+
Kokkos::RangePolicy(ExecSpace(), 0, sv.extent(0)), KOKKOS_LAMBDA(const int) { sv() = rank; });
46+
// Using the same execution space for both operations lets us not need an explicit `fence`
47+
auto req = KokkosComm::Experimental::allgather(h, sv, rv);
48+
KokkosComm::wait(req);
49+
50+
int errs;
51+
Kokkos::parallel_reduce(
52+
rv.extent(0), KOKKOS_LAMBDA(const int src, int &lsum) { lsum += rv(src) != src; }, errs);
53+
EXPECT_EQ(errs, 0);
54+
}
55+
56+
template <typename Scalar>
57+
auto allgather_contig_1d() -> void {
58+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
59+
using ExecSpace = Kokkos::Cuda;
60+
auto nccl_ctx = test_utils::nccl::Ctx::init();
61+
KokkosComm::Handle<ExecSpace, KokkosComm::Nccl> h(ExecSpace(), nccl_ctx.comm());
62+
#else
63+
using ExecSpace = Kokkos::DefaultExecutionSpace;
64+
KokkosComm::Handle<Kokkos::DefaultExecutionSpace, KokkosComm::Mpi> h{};
65+
#endif
66+
int rank = h.rank();
67+
int size = h.size();
68+
69+
const int n_contrib = 100;
70+
Kokkos::View<Scalar *> sv("sv", n_contrib);
71+
Kokkos::View<Scalar *> rv("rv", size * n_contrib);
72+
73+
// Prepare send view
74+
Kokkos::parallel_for(
75+
Kokkos::RangePolicy(ExecSpace(), 0, sv.extent(0)), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; });
76+
// Using the same execution space for both operations lets us not need an explicit `fence`
77+
auto req = KokkosComm::Experimental::allgather(h, sv, rv);
78+
KokkosComm::wait(req);
79+
80+
int errs;
81+
Kokkos::parallel_reduce(
82+
rv.extent(0),
83+
KOKKOS_LAMBDA(const int i, int &lsum) {
84+
const int src = i / n_contrib;
85+
const int j = i % n_contrib;
86+
lsum += rv(i) != src + j;
87+
},
88+
errs);
89+
EXPECT_EQ(errs, 0);
90+
}
91+
92+
TYPED_TEST(AllGather, 0D) { allgather_0d<typename TestFixture::Scalar>(); }
93+
TYPED_TEST(AllGather, Contiguous1D) { allgather_contig_1d<typename TestFixture::Scalar>(); }
94+
95+
} // namespace

0 commit comments

Comments
 (0)