Skip to content

Commit 41fb0ac

Browse files
committed
feat(core): expose MPI-based allreduce
Added associated unit test. Signed-off-by: Gabriel Dos Santos <[email protected]>
1 parent e73d027 commit 41fb0ac

File tree

2 files changed

+110
-2
lines changed

2 files changed

+110
-2
lines changed

src/KokkosComm/mpi/allreduce.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
#include "impl/error_handling.hpp"
1616
#include "impl/types.hpp"
1717

18-
namespace KokkosComm::mpi {
18+
namespace KokkosComm {
19+
namespace mpi {
1920

2021
template <KokkosView SView, KokkosView RView, KokkosExecutionSpace ExecSpace>
2122
auto iallreduce(const ExecSpace &space, const SView sv, RView rv, MPI_Op op, MPI_Comm comm) -> Req<MpiSpace> {
@@ -104,4 +105,15 @@ void allreduce(ExecSpace const &space, View const &v, MPI_Op op, MPI_Comm comm)
104105
Kokkos::Tools::popRegion();
105106
}
106107

107-
} // namespace KokkosComm::mpi
108+
} // namespace mpi
109+
namespace Experimental::Impl {
110+
111+
template <KokkosView SendView, KokkosView RecvView, ReductionOperator RedOp, KokkosExecutionSpace ExecSpace>
112+
struct AllReduce<SendView, RecvView, RedOp, ExecSpace, MpiSpace> {
113+
static auto execute(Handle<ExecSpace, MpiSpace> &h, const SendView &sv, RecvView rv) -> Req<MpiSpace> {
114+
return mpi::iallreduce(h.space(), sv, rv, reduction_op<MpiSpace, RedOp>(), h.comm());
115+
}
116+
};
117+
118+
} // namespace Experimental::Impl
119+
} // namespace KokkosComm

unit_tests/test_allreduce.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2+
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
3+
4+
#include <gtest/gtest.h>
5+
#include <mpi.h>
6+
#include <Kokkos_Core.hpp>
7+
#include <KokkosComm/KokkosComm.hpp>
8+
9+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
10+
#include "nccl/utils.hpp"
11+
#endif
12+
13+
namespace {
14+
15+
template <typename T>
16+
class AllReduce : public testing::Test {
17+
public:
18+
using Scalar = T;
19+
};
20+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
21+
using ScalarTypes = testing::Types<float, double, int, int64_t>;
22+
#else
23+
using ScalarTypes =
24+
testing::Types<float, double, Kokkos::complex<float>, Kokkos::complex<double>, int, unsigned, int64_t, size_t>;
25+
#endif
26+
TYPED_TEST_SUITE(AllReduce, ScalarTypes);
27+
28+
template <typename Scalar>
29+
auto allreduce_0d() -> void {
30+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
31+
using ExecSpace = Kokkos::Cuda;
32+
auto nccl_ctx = test_utils::nccl::Ctx::init();
33+
ExecSpace space(nccl_ctx.stream());
34+
KokkosComm::Handle<ExecSpace, KokkosComm::NcclSpace> h(space, nccl_ctx.comm());
35+
#else
36+
using ExecSpace = Kokkos::DefaultExecutionSpace;
37+
ExecSpace space{};
38+
KokkosComm::Handle<ExecSpace, KokkosComm::MpiSpace> h{};
39+
#endif
40+
int rank = h.rank();
41+
int size = h.size();
42+
43+
Kokkos::View<Scalar> sv("sv");
44+
Kokkos::View<Scalar> rv("rv");
45+
46+
// Prepare send buffer
47+
Kokkos::parallel_for(
48+
Kokkos::RangePolicy(space, 0, sv.extent(0)), KOKKOS_LAMBDA(const int) { sv() = rank; });
49+
space.fence();
50+
auto req = KokkosComm::Experimental::allreduce(h, sv, rv, KokkosComm::Sum{});
51+
KokkosComm::wait(req);
52+
53+
int errs;
54+
Kokkos::parallel_reduce(
55+
rv.extent(0), KOKKOS_LAMBDA(const int, int &lsum) { lsum += (rv() != size * (size - 1) / 2); }, errs);
56+
EXPECT_EQ(errs, 0);
57+
}
58+
59+
template <typename Scalar>
60+
auto allreduce_contig_1d() -> void {
61+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
62+
using ExecSpace = Kokkos::Cuda;
63+
auto nccl_ctx = test_utils::nccl::Ctx::init();
64+
auto space = ExecSpace(nccl_ctx.stream());
65+
KokkosComm::Handle<ExecSpace, KokkosComm::NcclSpace> h(space, nccl_ctx.comm());
66+
#else
67+
using ExecSpace = Kokkos::DefaultExecutionSpace;
68+
auto space = ExecSpace();
69+
KokkosComm::Handle<ExecSpace, KokkosComm::MpiSpace> h(space, MPI_COMM_WORLD);
70+
#endif
71+
int rank = h.rank();
72+
int size = h.size();
73+
74+
int n_contrib = 10;
75+
Kokkos::View<Scalar *> sv("sv", n_contrib);
76+
Kokkos::View<Scalar *> rv("rv", n_contrib);
77+
78+
// Prepare send buffer
79+
Kokkos::parallel_for(
80+
Kokkos::RangePolicy(space, 0, sv.extent(0)), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; });
81+
space.fence();
82+
auto req = KokkosComm::Experimental::allreduce(h, sv, rv, KokkosComm::Sum{});
83+
KokkosComm::wait(req);
84+
85+
int errs;
86+
// Fill in this reduction which verifies that each element computed the correct value
87+
Kokkos::parallel_reduce(
88+
rv.extent(0),
89+
KOKKOS_LAMBDA(const int i, int &lsum) { lsum += (rv(i) != ((size * (size - 1)) / 2 + (size * i))); }, errs);
90+
EXPECT_EQ(errs, 0);
91+
}
92+
93+
TYPED_TEST(AllReduce, 0D) { allreduce_0d<typename TestFixture::Scalar>(); }
94+
TYPED_TEST(AllReduce, Contiguous1D) { allreduce_contig_1d<typename TestFixture::Scalar>(); }
95+
96+
} // namespace

0 commit comments

Comments
 (0)