Skip to content

Commit bba7521

Browse files
dssgabrielcedricchevalier19
authored andcommitted
feat(core): expose MPI-based all-to-all
Also, add associated unit test. Signed-off-by: Gabriel Dos Santos <[email protected]>
1 parent 7c3d596 commit bba7521

File tree

3 files changed

+84
-2
lines changed

3 files changed

+84
-2
lines changed

src/KokkosComm/collective.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010
#include "fwd.hpp"
1111
#include "reduction_op.hpp"
12+
#if defined(KOKKOSCOMM_ENABLE_MPI)
13+
#include "mpi/mpi_space.hpp"
14+
#include "mpi/handle.hpp"
15+
#include "mpi/req.hpp"
16+
#include "mpi/alltoall.hpp"
17+
#endif
1218
#if defined(KOKKOSCOMM_ENABLE_NCCL)
1319
#include "nccl/nccl_space.hpp"
1420
#include "nccl/handle.hpp"

src/KokkosComm/mpi/alltoall.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
#include "impl/pack_traits.hpp"
1616
#include "impl/error_handling.hpp"
1717

18-
namespace KokkosComm::mpi {
18+
namespace KokkosComm {
19+
namespace mpi {
1920

2021
template <KokkosExecutionSpace ExecSpace, KokkosView SView, KokkosView RView>
2122
auto ialltoall(const ExecSpace &space, const SView sv, RView rv, int count, MPI_Comm comm) -> Req<MpiSpace> {
@@ -110,4 +111,15 @@ void alltoall(const ExecSpace &space, const RecvView &rv, const size_t recvCount
110111
Kokkos::Tools::popRegion();
111112
}
112113

113-
} // namespace KokkosComm::mpi
114+
} // namespace mpi
115+
namespace Experimental::Impl {
116+
117+
template <KokkosView SendView, KokkosView RecvView, KokkosExecutionSpace ExecSpace>
118+
struct AllToAll<SendView, RecvView, ExecSpace, Mpi> {
119+
static auto execute(Handle<ExecSpace, Mpi> &h, const SendView sv, RecvView rv, int count) -> Req<Mpi> {
120+
return mpi::ialltoall(h.space(), sv, rv, count, h.comm());
121+
}
122+
};
123+
124+
} // namespace Experimental::Impl
125+
} // namespace KokkosComm

unit_tests/test_alltoall.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2+
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
3+
4+
#include <gtest/gtest.h>
5+
#include <Kokkos_Core.hpp>
6+
#include <KokkosComm/KokkosComm.hpp>
7+
8+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
9+
#include "nccl/utils.hpp"
10+
#endif
11+
12+
namespace {
13+
14+
template <typename T>
15+
class AllToAll : public testing::Test {
16+
public:
17+
using Scalar = T;
18+
};
19+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
20+
using ScalarTypes = testing::Types<float, double, int, int64_t>;
21+
#else
22+
using ScalarTypes =
23+
testing::Types<float, double, Kokkos::complex<float>, Kokkos::complex<double>, int, unsigned, int64_t, size_t>;
24+
#endif
25+
TYPED_TEST_SUITE(AllToAll, ScalarTypes);
26+
27+
template <typename Scalar>
28+
auto alltoall_contig_1d() -> void {
29+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
30+
using ExecSpace = Kokkos::Cuda;
31+
auto nccl_ctx = test_utils::nccl::Ctx::init();
32+
KokkosComm::Handle<ExecSpace, KokkosComm::Nccl> h(ExecSpace(), nccl_ctx.comm());
33+
#else
34+
using ExecSpace = Kokkos::DefaultExecutionSpace;
35+
KokkosComm::Handle<Kokkos::DefaultExecutionSpace, KokkosComm::Mpi> h{};
36+
#endif
37+
int rank = h.rank();
38+
int size = h.size();
39+
40+
int n_contrib = 100;
41+
Kokkos::View<Scalar *> sv("sv", size * n_contrib);
42+
Kokkos::View<Scalar *> rv("rv", size * n_contrib);
43+
44+
// Prepare send view
45+
Kokkos::parallel_for(
46+
Kokkos::RangePolicy(ExecSpace(), 0, sv.extent(0)), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; });
47+
// Using the same execution space for both operations lets us not need an explicit `fence`
48+
KokkosComm::Experimental::alltoall(h, sv, rv, n_contrib);
49+
50+
int errs;
51+
Kokkos::parallel_reduce(
52+
rv.extent(0),
53+
KOKKOS_LAMBDA(const int i, int &lsum) {
54+
const int src = i / n_contrib; // who sent this data
55+
const int j = rank * n_contrib + (i % n_contrib); // what index i was at the source
56+
lsum += rv(i) != src + j;
57+
},
58+
errs);
59+
EXPECT_EQ(errs, 0);
60+
}
61+
62+
TYPED_TEST(AllToAll, Contiguous1D) { alltoall_contig_1d<typename TestFixture::Scalar>(); }
63+
64+
} // namespace

0 commit comments

Comments
 (0)