Skip to content

Commit dbafc95

Browse files
authored
Merge pull request #164 from AdRi1t/inplace_all_gather
2 parents ae4a704 + 840f827 commit dbafc95

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

src/KokkosComm/mpi/allgather.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,19 @@ void allgather(const SendView &sv, const RecvView &rv, MPI_Comm comm) {
5050
}
5151

5252
// in-place allgather
53-
template <KokkosView RecvView>
54-
void allgather(const RecvView &rv, MPI_Comm comm) {
53+
template <KokkosExecutionSpace ExecSpace, KokkosView RecvView>
54+
void allgather(const ExecSpace &space, const RecvView &rv, const size_t recvCount, MPI_Comm comm) {
5555
Kokkos::Tools::pushRegion("KokkosComm::Mpi::allgather");
5656

57-
using RT = KokkosComm::Traits<RecvView>;
5857
using RecvScalar = typename RecvView::value_type;
5958

60-
static_assert(RT::rank() <= 1, "allgather for RecvView::rank > 1 not supported");
59+
static_assert(KokkosComm::rank<RecvView>() <= 1, "allgather for RecvView::rank > 1 not supported");
6160

62-
if (!RT::is_contiguous(rv)) {
61+
if (!KokkosComm::is_contiguous(rv)) {
6362
throw std::runtime_error("low-level allgather requires contiguous recv view");
6463
}
65-
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, RT::data_handle(rv), RT::span(rv),
64+
space.fence("fence before allgather"); // work in space may have been used to produce send view data
65+
MPI_Allgather(MPI_IN_PLACE, 0 /*ignored*/, MPI_DATATYPE_NULL /*ignored*/, KokkosComm::data_handle(rv), recvCount,
6666
KokkosComm::Impl::mpi_type_v<RecvScalar>, comm);
6767

6868
Kokkos::Tools::popRegion();

unit_tests/mpi/test_allgather.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,34 @@ void test_allgather_1d_contig() {
8383

8484
TYPED_TEST(Allgather, 1D_contig) { test_allgather_1d_contig<typename TestFixture::Scalar>(); }
8585

86+
template <typename Scalar>
87+
void test_allgather_1d_inplace_contig() {
88+
int rank, size;
89+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
90+
MPI_Comm_size(MPI_COMM_WORLD, &size);
91+
92+
const int nContrib = 10;
93+
94+
Kokkos::View<Scalar *> rv("rv", size * nContrib);
95+
96+
// fill send buffer
97+
Kokkos::parallel_for(
98+
nContrib, KOKKOS_LAMBDA(const int i) { rv(rank * nContrib + i) = rank + i; });
99+
100+
KokkosComm::mpi::allgather(Kokkos::DefaultExecutionSpace(), rv, nContrib, MPI_COMM_WORLD);
101+
102+
int errs;
103+
Kokkos::parallel_reduce(
104+
rv.extent(0),
105+
KOKKOS_LAMBDA(const int &i, int &lsum) {
106+
const int src = i / nContrib;
107+
const int j = i % nContrib;
108+
lsum += rv(i) != src + j;
109+
},
110+
errs);
111+
EXPECT_EQ(errs, 0);
112+
}
113+
114+
TYPED_TEST(Allgather, 1D_inplace_contig) { test_allgather_1d_inplace_contig<typename TestFixture::Scalar>(); }
115+
86116
} // namespace

0 commit comments

Comments
 (0)