Skip to content

Commit aa58444

Browse files
Merge branch 'develop' into refactor/reduction-ops
2 parents 37001ee + 785c85a commit aa58444

File tree

8 files changed

+104
-47
lines changed

8 files changed

+104
-47
lines changed

src/KokkosComm/collective.hpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,11 @@
2222

2323
namespace KokkosComm::Experimental {
2424

25-
/// Copy the `sv` view on the `root` rank to all ranks' `rv` view.
26-
///
27-
/// The `sv` view is only used on the `root` rank and ignored for all other ranks.
28-
template <KokkosView SendView, KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
29-
CommunicationSpace CommSpace = DefaultCommunicationSpace>
30-
auto broadcast(Handle<ExecSpace, CommSpace>& h, const SendView sv, RecvView rv, int root) -> Req<CommSpace> {
31-
return Impl::Broadcast<SendView, RecvView, ExecSpace, CommSpace>::execute(h, sv, rv, root);
32-
}
33-
34-
/// In-place variant of `broadcast`. Copy the `v` view from the `root` rank to all ranks' `v` view.
25+
/// Copy the `v` view from the `root` rank to all ranks' `v` view.
3526
template <KokkosView View, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
3627
CommunicationSpace CommSpace = DefaultCommunicationSpace>
3728
auto broadcast(Handle<ExecSpace, CommSpace>& h, View v, int root) -> Req<CommSpace> {
38-
return Impl::Broadcast<View, View, ExecSpace, CommSpace>::execute(h, v, v, root);
29+
return Impl::Broadcast<View, ExecSpace, CommSpace>::execute(h, v, root);
3930
}
4031

4132
/// Copy the `sv` view from each rank to the `rv` view, receiving data from rank `i` at offset

src/KokkosComm/fwd.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct Send;
4848
// Collectives are currently experimental functions
4949
namespace Experimental::Impl {
5050

51-
template <KokkosView SendView, KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
51+
template <KokkosView View, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
5252
CommunicationSpace CommSpace = DefaultCommunicationSpace>
5353
struct Broadcast;
5454

src/KokkosComm/mpi/allgather.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,38 @@
88

99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
11+
#include "mpi_space.hpp"
12+
#include "req.hpp"
1113

1214
#include "impl/types.hpp"
1315
#include "impl/error_handling.hpp"
1416

1517
namespace KokkosComm::mpi {
1618

19+
template <KokkosExecutionSpace ExecSpace, KokkosView SView, KokkosView RView>
20+
auto iallgather(const ExecSpace &space, const SView sv, RView rv, MPI_Comm comm) -> Req<Mpi> {
21+
using ST = typename SView::non_const_value_type;
22+
using RT = typename RView::non_const_value_type;
23+
static_assert(std::is_same_v<ST, RT>, "KokkosComm::mpi::iallgather: View value types must be identical");
24+
Kokkos::Tools::pushRegion("KokkosComm::mpi::iallgather");
25+
26+
fail_if(!is_contiguous(sv) || !is_contiguous(rv),
27+
"KokkosComm::mpi::iallgather: unimplemented for non-contiguous views");
28+
29+
// Sync: Work in space may have been used to produce view data.
30+
space.fence("fence before non-blocking all-gather");
31+
32+
Req<Mpi> req;
33+
// All ranks send/recv same count
34+
MPI_Iallgather(data_handle(sv), span(sv), Impl::mpi_type_v<ST>, data_handle(rv), span(sv), Impl::mpi_type_v<RT>, comm,
35+
&req.mpi_request());
36+
req.extend_view_lifetime(sv);
37+
req.extend_view_lifetime(rv);
38+
39+
Kokkos::Tools::popRegion();
40+
return req;
41+
}
42+
1743
template <KokkosView SendView, KokkosView RecvView>
1844
void allgather(const SendView &sv, const RecvView &rv, MPI_Comm comm) {
1945
Kokkos::Tools::pushRegion("KokkosComm::Mpi::allgather");

src/KokkosComm/mpi/alltoall.hpp

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,43 @@
88

99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
11+
#include "mpi_space.hpp"
12+
#include "req.hpp"
1113

1214
#include "impl/pack_traits.hpp"
1315
#include "impl/types.hpp"
1416
#include "impl/error_handling.hpp"
1517

16-
namespace KokkosComm::Impl {
18+
namespace KokkosComm::mpi {
19+
20+
template <KokkosExecutionSpace ExecSpace, KokkosView SView, KokkosView RView>
21+
auto ialltoall(const ExecSpace &space, const SView sv, RView rv, int count, MPI_Comm comm) -> Req<Mpi> {
22+
using ST = typename SView::non_const_value_type;
23+
using RT = typename RView::non_const_value_type;
24+
static_assert(std::is_same_v<ST, RT>, "KokkosComm::mpi::ialltoall: View value types must be identical");
25+
Kokkos::Tools::pushRegion("KokkosComm::mpi::ialltoall");
26+
27+
fail_if(!is_contiguous(sv) || !is_contiguous(rv),
28+
"KokkosComm::mpi::ialltoall: unimplemented for non-contiguous views");
29+
30+
// Sync: Work in space may have been used to produce view data.
31+
space.fence("fence before non-blocking all-gather");
32+
33+
Req<Mpi> req;
34+
// All ranks send/recv same count
35+
MPI_Ialltoall(data_handle(sv), count, Impl::mpi_type_v<ST>, data_handle(rv), count, Impl::mpi_type_v<RT>, comm,
36+
&req.mpi_request());
37+
req.extend_view_lifetime(sv);
38+
req.extend_view_lifetime(rv);
39+
40+
Kokkos::Tools::popRegion();
41+
return req;
42+
}
1743

1844
template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
1945
void alltoall(const ExecSpace &space, const SendView &sv, const size_t sendCount, const RecvView &rv,
2046
const size_t recvCount, MPI_Comm comm) {
21-
Kokkos::Tools::pushRegion("KokkosComm::Impl::alltoall");
47+
Kokkos::Tools::pushRegion("KokkosComm::mpi::alltoall");
2248

2349
using SendScalar = typename SendView::value_type;
2450
using RecvScalar = typename RecvView::value_type;
@@ -27,7 +53,7 @@ void alltoall(const ExecSpace &space, const SendView &sv, const size_t sendCount
2753
static_assert(KokkosComm::rank<RecvView>() <= 1, "alltoall for RecvView::rank > 1 not supported");
2854

2955
// Make sure views are ready
30-
space.fence("KokkosComm::Impl::alltoall");
56+
space.fence("KokkosComm::mpi::alltoall");
3157

3258
KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(sv) || !KokkosComm::is_contiguous(rv),
3359
"alltoall for non-contiguous views not implemented");
@@ -48,23 +74,23 @@ void alltoall(const ExecSpace &space, const SendView &sv, const size_t sendCount
4874
KokkosComm::mpi::fail_if(true, ss.str().data());
4975
}
5076

51-
MPI_Alltoall(KokkosComm::data_handle(sv), sendCount, mpi_type_v<SendScalar>, KokkosComm::data_handle(rv), recvCount,
52-
mpi_type_v<RecvScalar>, comm);
77+
MPI_Alltoall(KokkosComm::data_handle(sv), sendCount, Impl::mpi_type_v<SendScalar>, KokkosComm::data_handle(rv),
78+
recvCount, Impl::mpi_type_v<RecvScalar>, comm);
5379

5480
Kokkos::Tools::popRegion();
5581
}
5682

5783
// in-place alltoall
5884
template <KokkosExecutionSpace ExecSpace, KokkosView RecvView>
5985
void alltoall(const ExecSpace &space, const RecvView &rv, const size_t recvCount, MPI_Comm comm) {
60-
Kokkos::Tools::pushRegion("KokkosComm::Impl::alltoall");
86+
Kokkos::Tools::pushRegion("KokkosComm::mpi::alltoall");
6187

6288
using RecvScalar = typename RecvView::value_type;
6389

6490
static_assert(RecvView::rank <= 1, "alltoall for RecvView::rank > 1 not supported");
6591

6692
// Make sure views are ready
67-
space.fence("KokkosComm::Impl::alltoall");
93+
space.fence("KokkosComm::mpi::alltoall");
6894

6995
KokkosComm::mpi::fail_if(!KokkosComm::is_contiguous(rv), "alltoall for non-contiguous views not implemented");
7096

@@ -79,9 +105,9 @@ void alltoall(const ExecSpace &space, const RecvView &rv, const size_t recvCount
79105
}
80106

81107
MPI_Alltoall(MPI_IN_PLACE, 0 /*ignored*/, MPI_BYTE /*ignored*/, KokkosComm::data_handle(rv), recvCount,
82-
mpi_type_v<RecvScalar>, comm);
108+
Impl::mpi_type_v<RecvScalar>, comm);
83109

84110
Kokkos::Tools::popRegion();
85111
}
86112

87-
} // namespace KokkosComm::Impl
113+
} // namespace KokkosComm::mpi

src/KokkosComm/mpi/broadcast.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,31 @@
88

99
#include <KokkosComm/concepts.hpp>
1010
#include <KokkosComm/traits.hpp>
11+
#include "mpi_space.hpp"
12+
#include "req.hpp"
1113

1214
#include "impl/types.hpp"
15+
#include "impl/error_handling.hpp"
1316

1417
namespace KokkosComm::mpi {
1518

19+
template <KokkosExecutionSpace ExecSpace, KokkosView View>
20+
auto ibroadcast(const ExecSpace& space, View& v, int root, MPI_Comm comm) -> Req<Mpi> {
21+
using T = typename View::non_const_value_type;
22+
Kokkos::Tools::pushRegion("KokkosComm::mpi::ibroadcast");
23+
fail_if(!is_contiguous(v), "KokkosComm::mpi::ibroadcast: unimplemented for non-contiguous views");
24+
25+
// Sync: Work in space may have been used to produce view data.
26+
space.fence("fence before non-blocking broadcast");
27+
28+
Req<Mpi> req;
29+
MPI_Ibcast(data_handle(v), span(v), Impl::mpi_type_v<T>, root, comm, &req.mpi_request());
30+
req.extend_view_lifetime(v);
31+
32+
Kokkos::Tools::popRegion();
33+
return req;
34+
}
35+
1636
template <KokkosView View>
1737
void broadcast(View const& v, int root, MPI_Comm comm) {
1838
Kokkos::Tools::pushRegion("KokkosComm::mpi::broadcast");

src/KokkosComm/nccl/broadcast.hpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,20 @@ namespace nccl {
1717

1818
namespace KC = KokkosComm;
1919

20-
template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
21-
auto broadcast(const ExecSpace &space, const SendView &sv, const RecvView &rv, int root, ncclComm_t comm)
22-
-> Req<NcclSpace> {
23-
using ST = typename SendView::non_const_value_type;
24-
using RT = typename RecvView::non_const_value_type;
25-
static_assert(std::is_same_v<ST, RT>,
26-
"KokkosComm::Experimental::nccl::broadcast: View value types must be identical");
27-
static_assert(KC::rank<SendView>() <= 1 and KC::rank<RecvView>() <= 1,
20+
template <KokkosView View>
21+
auto broadcast(const Kokkos::Cuda& space, View& v, int root, ncclComm_t comm) -> Req<Nccl> {
22+
using T = typename View::non_const_value_type;
23+
static_assert(KC::rank<View>() <= 1,
2824
"KokkosComm::Experimental::nccl::broadcast: Views with rank higher than 1 are not supported");
2925
Kokkos::Tools::pushRegion("KokkosComm::Experimental::nccl::broadcast");
3026

31-
Req<NcclSpace> req{space.cuda_stream()};
32-
if (KC::is_contiguous(sv) and KC::is_contiguous(rv)) {
33-
ncclBroadcast(KC::data_handle(sv), KC::data_handle(rv), KC::span(sv), Impl::datatype_v<ST>, root, comm,
34-
space.cuda_stream());
27+
Req<Nccl> req{space.cuda_stream()};
28+
if (KC::is_contiguous(v)) {
29+
ncclBcast(KC::data_handle(v), KC::span(v), Impl::datatype_v<T>, root, comm, space.cuda_stream());
3530
} else {
3631
Kokkos::abort("KokkosComm::Experimental::nccl::broadcast: unimplemented for non-contiguous views");
3732
}
38-
req.extend_view_lifetime(sv);
39-
req.extend_view_lifetime(rv);
33+
req.extend_view_lifetime(v);
4034

4135
Kokkos::Tools::popRegion();
4236
return req;
@@ -45,10 +39,10 @@ auto broadcast(const ExecSpace &space, const SendView &sv, const RecvView &rv, i
4539
} // namespace nccl
4640
namespace Impl {
4741

48-
template <KokkosView SendView, KokkosView RecvView>
49-
struct Broadcast<SendView, RecvView, Kokkos::Cuda, NcclSpace> {
50-
static auto execute(Handle<Kokkos::Cuda, NcclSpace> &h, const SendView sv, RecvView rv, int root) -> Req<NcclSpace> {
51-
return nccl::broadcast(h.space(), sv, rv, root, h.comm());
42+
template <KokkosView View>
43+
struct Broadcast<View, Kokkos::Cuda, Nccl> {
44+
static auto execute(Handle<Kokkos::Cuda, Nccl>& h, View v, int root) -> Req<Nccl> {
45+
return nccl::broadcast(h.space(), v, root, h.comm());
5246
}
5347
};
5448

unit_tests/mpi/test_alltoall.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void test_alltoall_1d_contig() {
3131
Kokkos::parallel_for(
3232
sv.extent(0), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; });
3333

34-
KokkosComm::Impl::alltoall(Kokkos::DefaultExecutionSpace(), sv, nContrib, rv, nContrib, MPI_COMM_WORLD);
34+
KokkosComm::mpi::alltoall(Kokkos::DefaultExecutionSpace(), sv, nContrib, rv, nContrib, MPI_COMM_WORLD);
3535

3636
int errs;
3737
Kokkos::parallel_reduce(
@@ -61,7 +61,7 @@ void test_alltoall_1d_inplace_contig() {
6161
Kokkos::parallel_for(
6262
rv.extent(0), KOKKOS_LAMBDA(const int i) { rv(i) = rank + i; });
6363

64-
KokkosComm::Impl::alltoall(Kokkos::DefaultExecutionSpace(), rv, nContrib, MPI_COMM_WORLD);
64+
KokkosComm::mpi::alltoall(Kokkos::DefaultExecutionSpace(), rv, nContrib, MPI_COMM_WORLD);
6565

6666
int errs;
6767
Kokkos::parallel_reduce(

unit_tests/nccl/test_broadcast.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,20 @@ auto broadcast_0d() -> void {
4242

4343
int errs;
4444
Kokkos::parallel_reduce(
45-
v.extent(0), KOKKOS_LAMBDA(const int, int &lsum) { lsum += v() != size; }, errs);
45+
v.extent(0), KOKKOS_LAMBDA(const int, int& lsum) { lsum += v() != size; }, errs);
4646
EXPECT_EQ(errs, 0);
4747
}
4848

4949
template <typename Scalar>
50-
auto broadcast_inplace_contig_1d() -> void {
50+
auto broadcast_contig_1d() -> void {
5151
auto nccl_ctx = test_utils::nccl::Ctx::init();
5252
ExecSpace space(nccl_ctx.stream());
5353
KokkosComm::Handle<ExecSpace, CommSpace> h(space, nccl_ctx.comm());
5454
int rank = h.rank();
5555
int size = h.size();
5656
int root = 0;
5757

58-
Kokkos::View<Scalar *> v("v", 100);
58+
Kokkos::View<Scalar*> v("v", 100);
5959
if (rank == root) {
6060
// Prepare broadcast view
6161
Kokkos::parallel_for(
@@ -67,11 +67,11 @@ auto broadcast_inplace_contig_1d() -> void {
6767

6868
int errs;
6969
Kokkos::parallel_reduce(
70-
v.extent(0), KOKKOS_LAMBDA(const int i, int &lsum) { lsum += (v(i) != size + i); }, errs);
70+
v.extent(0), KOKKOS_LAMBDA(const int i, int& lsum) { lsum += (v(i) != size + i); }, errs);
7171
EXPECT_EQ(errs, 0);
7272
}
7373

74-
TYPED_TEST(Broadcast, InPlace0D) { broadcast_0d<typename TestFixture::Scalar>(); }
75-
TYPED_TEST(Broadcast, InPlaceContiguous1D) { broadcast_inplace_contig_1d<typename TestFixture::Scalar>(); }
74+
TYPED_TEST(Broadcast, 0D) { broadcast_0d<typename TestFixture::Scalar>(); }
75+
TYPED_TEST(Broadcast, Contiguous1D) { broadcast_contig_1d<typename TestFixture::Scalar>(); }
7676

7777
} // namespace

0 commit comments

Comments
 (0)