Skip to content

Commit 225c7d6

Browse files
dssgabrielcedricchevalier19
authored andcommitted
feat(core): expose MPI-based broadcast
Also add associated unit test. Signed-off-by: Gabriel Dos Santos <[email protected]> Signed-off-by: Cedric Chevalier <[email protected]>
1 parent 7c3d596 commit 225c7d6

File tree

3 files changed

+110
-2
lines changed

3 files changed

+110
-2
lines changed

src/KokkosComm/collective.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010
#include "fwd.hpp"
1111
#include "reduction_op.hpp"
12+
#if defined(KOKKOSCOMM_ENABLE_MPI)
13+
#include "mpi/mpi_space.hpp"
14+
#include "mpi/handle.hpp"
15+
#include "mpi/req.hpp"
16+
#include "mpi/broadcast.hpp"
17+
#endif
1218
#if defined(KOKKOSCOMM_ENABLE_NCCL)
1319
#include "nccl/nccl_space.hpp"
1420
#include "nccl/handle.hpp"

src/KokkosComm/mpi/broadcast.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
#include "impl/error_handling.hpp"
1616

17-
namespace KokkosComm::mpi {
17+
namespace KokkosComm {
18+
namespace mpi {
1819

1920
template <KokkosExecutionSpace ExecSpace, KokkosView View>
2021
auto ibroadcast(const ExecSpace& space, View& v, int root, MPI_Comm comm) -> Req<MpiSpace> {
@@ -56,4 +57,15 @@ void broadcast(ExecSpace const& space, View const& v, int root, MPI_Comm comm) {
5657
Kokkos::Tools::popRegion();
5758
}
5859

59-
} // namespace KokkosComm::mpi
60+
} // namespace mpi
61+
namespace Experimental::Impl {
62+
63+
template <KokkosView View, KokkosExecutionSpace ExecSpace>
64+
struct Broadcast<View, ExecSpace, MpiSpace> {
65+
static auto execute(Handle<ExecSpace, MpiSpace>& h, View& v, int root) -> Req<MpiSpace> {
66+
return KokkosComm::mpi::ibroadcast(h.space(), v, root, h.comm());
67+
}
68+
};
69+
70+
} // namespace Experimental::Impl
71+
} // namespace KokkosComm

unit_tests/test_broadcast.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2+
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
3+
4+
#include <gtest/gtest.h>
5+
#include <Kokkos_Core.hpp>
6+
#include <KokkosComm/KokkosComm.hpp>
7+
8+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
9+
#include "nccl/utils.hpp"
10+
#endif
11+
12+
namespace {
13+
14+
template <typename T>
15+
class Broadcast : public testing::Test {
16+
public:
17+
using Scalar = T;
18+
};
19+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
20+
using ScalarTypes = testing::Types<float, double, int, int64_t>;
21+
#else
22+
using ScalarTypes =
23+
testing::Types<float, double, Kokkos::complex<float>, Kokkos::complex<double>, int, unsigned, int64_t, size_t>;
24+
#endif
25+
TYPED_TEST_SUITE(Broadcast, ScalarTypes);
26+
27+
template <typename Scalar>
28+
auto broadcast_0d() -> void {
29+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
30+
using ExecSpace = Kokkos::Cuda;
31+
auto nccl_ctx = test_utils::nccl::Ctx::init();
32+
KokkosComm::Handle<ExecSpace, KokkosComm::Nccl> h(ExecSpace(), nccl_ctx.comm());
33+
#else
34+
using ExecSpace = Kokkos::DefaultExecutionSpace;
35+
KokkosComm::Handle<Kokkos::DefaultExecutionSpace, KokkosComm::Mpi> h{};
36+
#endif
37+
int rank = h.rank();
38+
int size = h.size();
39+
int root = 0;
40+
41+
Kokkos::View<Scalar> v("v");
42+
if (rank == root) {
43+
// Prepare broadcast view
44+
Kokkos::parallel_for(
45+
Kokkos::RangePolicy(ExecSpace(), 0, v.extent(0)), KOKKOS_LAMBDA(const int) { v() = size; });
46+
}
47+
// Using the same execution space for both operations lets us not need an explicit `fence`
48+
auto req = KokkosComm::Experimental::broadcast(h, v, root);
49+
KokkosComm::wait(req);
50+
51+
int errs;
52+
Kokkos::parallel_reduce(
53+
v.extent(0), KOKKOS_LAMBDA(const int, int& lsum) { lsum += v() != size; }, errs);
54+
EXPECT_EQ(errs, 0);
55+
}
56+
57+
template <typename Scalar>
58+
auto broadcast_contig_1d() -> void {
59+
#if defined(KOKKOSCOMM_ENABLE_NCCL)
60+
using ExecSpace = Kokkos::Cuda;
61+
auto nccl_ctx = test_utils::nccl::Ctx::init();
62+
KokkosComm::Handle<ExecSpace, KokkosComm::Nccl> h(ExecSpace(), nccl_ctx.comm());
63+
#else
64+
using ExecSpace = Kokkos::DefaultExecutionSpace;
65+
KokkosComm::Handle<Kokkos::DefaultExecutionSpace, KokkosComm::Mpi> h{};
66+
#endif
67+
int rank = h.rank();
68+
int size = h.size();
69+
int root = 0;
70+
71+
Kokkos::View<Scalar*> v("v", 100);
72+
if (rank == root) {
73+
// Prepare broadcast view
74+
Kokkos::parallel_for(
75+
Kokkos::RangePolicy(ExecSpace(), 0, v.extent(0)), KOKKOS_LAMBDA(const int i) { v(i) = size + i; });
76+
}
77+
// Using the same execution space for both operations lets us not need an explicit `fence`
78+
auto req = KokkosComm::Experimental::broadcast(h, v, root);
79+
KokkosComm::wait(req);
80+
81+
int errs;
82+
Kokkos::parallel_reduce(
83+
v.extent(0), KOKKOS_LAMBDA(const int i, int& lsum) { lsum += (v(i) != size + i); }, errs);
84+
EXPECT_EQ(errs, 0);
85+
}
86+
87+
TYPED_TEST(Broadcast, 0D) { broadcast_0d<typename TestFixture::Scalar>(); }
88+
TYPED_TEST(Broadcast, Contiguous1D) { broadcast_contig_1d<typename TestFixture::Scalar>(); }
89+
90+
} // namespace

0 commit comments

Comments
 (0)