Skip to content

Commit 840f827

Browse files
committed
Correct allgather in place
Signed-off-by: Adrien Taberner <[email protected]>
1 parent 2b04837 commit 840f827

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/KokkosComm/mpi/allgather.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ void allgather(const SendView &sv, const RecvView &rv, MPI_Comm comm) {
5050
}
5151

5252
// in-place allgather
53-
template <KokkosView RecvView>
54-
void allgather(const RecvView &rv, MPI_Comm comm) {
53+
template <KokkosExecutionSpace ExecSpace, KokkosView RecvView>
54+
void allgather(const ExecSpace &space, const RecvView &rv, const size_t recvCount, MPI_Comm comm) {
5555
Kokkos::Tools::pushRegion("KokkosComm::Mpi::allgather");
5656

5757
using RecvScalar = typename RecvView::value_type;
@@ -61,7 +61,8 @@ void allgather(const RecvView &rv, MPI_Comm comm) {
6161
if (!KokkosComm::is_contiguous(rv)) {
6262
throw std::runtime_error("low-level allgather requires contiguous recv view");
6363
}
64-
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, KokkosComm::data_handle(rv), KokkosComm::span(rv),
64+
space.fence("fence before allgather"); // work in space may have been used to produce send view data
65+
MPI_Allgather(MPI_IN_PLACE, 0 /*ignored*/, MPI_DATATYPE_NULL /*ignored*/, KokkosComm::data_handle(rv), recvCount,
6566
KokkosComm::Impl::mpi_type_v<RecvScalar>, comm);
6667

6768
Kokkos::Tools::popRegion();

unit_tests/mpi/test_allgather.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ void test_allgather_1d_inplace_contig() {
9595

9696
// fill send buffer
9797
Kokkos::parallel_for(
98-
nContrib, KOKKOS_LAMBDA(const int i) { rv(i + rank * nContrib) = rank + i; });
98+
nContrib, KOKKOS_LAMBDA(const int i) { rv(rank * nContrib + i) = rank + i; });
9999

100-
KokkosComm::mpi::allgather(rv, MPI_COMM_WORLD);
100+
KokkosComm::mpi::allgather(Kokkos::DefaultExecutionSpace(), rv, nContrib, MPI_COMM_WORLD);
101101

102102
int errs;
103103
Kokkos::parallel_reduce(

0 commit comments

Comments
 (0)