Skip to content

Commit a52e117

Browse files
authored
Merge pull request #156 from nicoleavans/persistence
2 parents c9da285 + ba6aac7 commit a52e117

File tree

7 files changed

+328
-6
lines changed

7 files changed

+328
-6
lines changed

perf_tests/mpi/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ target_sources(
2020
PRIVATE
2121
test_main.cpp
2222
test_sendrecv.cpp
23+
test_channel_sendrecv.cpp
2324
test_2dhalo.cpp
2425
test_osu_latency.cpp
2526
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//@HEADER
2+
// ************************************************************************
3+
//
4+
// Kokkos v. 4.0
5+
// Copyright (2022) National Technology & Engineering
6+
// Solutions of Sandia, LLC (NTESS).
7+
//
8+
// Under the terms of Contract DE-NA0003525 with NTESS,
9+
// the U.S. Government retains certain rights in this software.
10+
//
11+
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://kokkos.org/LICENSE for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//@HEADER
16+
17+
#include "test_utils.hpp"
18+
19+
#include <KokkosComm/KokkosComm.hpp>
20+
21+
template <typename View>
22+
void channel_send_recv(benchmark::State &, MPI_Comm comm, int rank, const View &v) {
23+
int size;
24+
MPI_Comm_size(MPI_COMM_WORLD, &size);
25+
const int dest_rank = (rank + 1) % size; // send to next rank
26+
const int src_rank = (rank - 1 + size) % size; // recv from prev rank
27+
KokkosComm::Channel<> channel(dest_rank, src_rank, 42, comm);
28+
channel.sendinit(v);
29+
channel.recvinit(v);
30+
channel.start();
31+
channel.wait();
32+
}
33+
34+
void benchmark_channel_sendrecv(benchmark::State &state) {
35+
int rank, size;
36+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
37+
MPI_Comm_size(MPI_COMM_WORLD, &size);
38+
if (size < 2) {
39+
state.SkipWithError("benchmark_sendrecv needs at least 2 ranks");
40+
}
41+
42+
using Scalar = double;
43+
using view_type = Kokkos::View<Scalar *>;
44+
view_type a("", 1000000);
45+
46+
while (state.KeepRunning()) {
47+
do_iteration(state, MPI_COMM_WORLD, channel_send_recv<view_type>, rank, a);
48+
}
49+
50+
state.SetBytesProcessed(sizeof(Scalar) * state.iterations() * a.size() * 2);
51+
}
52+
53+
BENCHMARK(benchmark_channel_sendrecv)->UseManualTime()->Unit(benchmark::kMillisecond);

src/KokkosComm/CMakeLists.txt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,20 @@ if(KOKKOSCOMM_ENABLE_MPI)
5252
TYPE HEADERS
5353
BASE_DIRS ${PROJECT_SOURCE_DIR}/src
5454
FILES
55-
mpi/mpi_space.hpp
55+
mpi/allgather.hpp
56+
mpi/alltoall.hpp
57+
mpi/barrier.hpp
58+
mpi/channel.hpp
5659
mpi/comm_mode.hpp
5760
mpi/handle.hpp
58-
mpi/req.hpp
5961
mpi/irecv.hpp
6062
mpi/isend.hpp
63+
mpi/mpi_space.hpp
6164
mpi/recv.hpp
62-
mpi/send.hpp
63-
mpi/allgather.hpp
64-
mpi/alltoall.hpp
6565
mpi/allreduce.hpp
6666
mpi/reduce.hpp
67-
mpi/barrier.hpp
67+
mpi/req.hpp
68+
mpi/send.hpp
6869
mpi/broadcast.hpp
6970
mpi/scan.hpp
7071
)

src/KokkosComm/KokkosComm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#if defined(KOKKOSCOMM_ENABLE_MPI)
2626
#include "mpi/mpi_space.hpp"
2727

28+
#include "mpi/channel.hpp"
2829
#include "mpi/comm_mode.hpp"
2930
#include "mpi/handle.hpp"
3031
#include "mpi/req.hpp"

src/KokkosComm/mpi/channel.hpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//@HEADER
2+
// ************************************************************************
3+
//
4+
// Kokkos v. 4.0
5+
// Copyright (2025) National Technology & Engineering
6+
// Solutions of Sandia, LLC (NTESS).
7+
//
8+
// Under the terms of Contract DE-NA0003525 with NTESS,
9+
// the U.S. Government retains certain rights in this software.
10+
//
11+
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://kokkos.org/LICENSE for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//@HEADER
16+
17+
#pragma once
18+
19+
#include <Kokkos_Core.hpp>
20+
21+
#include <KokkosComm/traits.hpp>
22+
#include "req.hpp"
23+
24+
#include "impl/types.hpp"
25+
26+
namespace KokkosComm {
27+
28+
template <typename CommSpace = DefaultCommunicationSpace>
29+
class Channel {
30+
public:
31+
explicit Channel(int dest_rank, int src_rank, int tag, MPI_Comm comm)
32+
: dest_rank_(dest_rank), src_rank_(src_rank), tag_(tag), comm_(comm) {}
33+
34+
// Send initialization - dynamically adds a send request to the send queue
35+
template <class SendView>
36+
void sendinit(SendView view) {
37+
Kokkos::Tools::pushRegion("KokkosComm::Channel::sendinit");
38+
using value_type = typename SendView::value_type;
39+
// Add a new request to the send_reqs_ vector
40+
send_reqs_.emplace_back();
41+
MPI_Send_init(KokkosComm::data_handle(view), KokkosComm::span(view), KokkosComm::Impl::mpi_type_v<value_type>,
42+
dest_rank_, tag_, comm_, &(send_reqs_.back().mpi_request()));
43+
Kokkos::Tools::popRegion();
44+
}
45+
46+
// Receive initialization - dynamically adds a receive request to the recv queue
47+
template <class RecvView>
48+
void recvinit(RecvView view) {
49+
Kokkos::Tools::pushRegion("KokkosComm::Channel::recvinit");
50+
using value_type = typename RecvView::value_type;
51+
recv_reqs_.emplace_back();
52+
MPI_Recv_init(KokkosComm::data_handle(view), KokkosComm::span(view), KokkosComm::Impl::mpi_type_v<value_type>,
53+
src_rank_, tag_, comm_, &(recv_reqs_.back().mpi_request()));
54+
Kokkos::Tools::popRegion();
55+
}
56+
57+
void start() {
58+
Kokkos::Tools::pushRegion("KokkosComm::Channel::start");
59+
std::vector<MPI_Request> mpi_reqs;
60+
for (auto& req : send_reqs_) {
61+
mpi_reqs.push_back(req.mpi_request());
62+
}
63+
for (auto& req : recv_reqs_) {
64+
mpi_reqs.push_back(req.mpi_request());
65+
}
66+
Kokkos::fence();
67+
MPI_Startall(mpi_reqs.size(), mpi_reqs.data());
68+
Kokkos::Tools::popRegion();
69+
}
70+
71+
void wait() {
72+
Kokkos::Tools::pushRegion("KokkosComm::Channel::wait");
73+
std::vector<Req<Mpi>> reqs;
74+
reqs.reserve(send_reqs_.size() + recv_reqs_.size());
75+
reqs.insert(reqs.end(), send_reqs_.begin(), send_reqs_.end());
76+
reqs.insert(reqs.end(), recv_reqs_.begin(), recv_reqs_.end());
77+
wait_all(reqs);
78+
Kokkos::Tools::popRegion();
79+
}
80+
81+
private:
82+
std::vector<Req<Mpi>> send_reqs_; // Queue for send requests
83+
std::vector<Req<Mpi>> recv_reqs_; // Queue for receive requests
84+
int dest_rank_; // Destination rank for send
85+
int src_rank_; // Source rank for receive
86+
int tag_; // MPI tag
87+
MPI_Comm comm_; // MPI communicator
88+
};
89+
90+
} // namespace KokkosComm

unit_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,4 @@ add_mpi_test(test-allreduce 2 mpi/test_allreduce.cpp)
149149
add_mpi_test(test-alltoall 2 mpi/test_alltoall.cpp)
150150
add_mpi_test(test-allgather 2 mpi/test_allgather.cpp)
151151
add_mpi_test(test-scan 2 mpi/test_scan.cpp)
152+
add_mpi_test(test-channel 2 mpi/test_channel.cpp)

unit_tests/mpi/test_channel.cpp

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
//@HEADER
2+
// ************************************************************************
3+
//
4+
// Kokkos v. 4.0
5+
// Copyright (2025) National Technology & Engineering
6+
// Solutions of Sandia, LLC (NTESS).
7+
//
8+
// Under the terms of Contract DE-NA0003525 with NTESS,
9+
// the U.S. Government retains certain rights in this software.
10+
//
11+
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://kokkos.org/LICENSE for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//@HEADER
16+
17+
#include <gtest/gtest.h>
18+
#include <type_traits>
19+
20+
#include <KokkosComm/KokkosComm.hpp>
21+
22+
namespace {
23+
24+
using namespace KokkosComm::mpi;
25+
26+
template <typename T>
27+
class ChannelSendRecv : public testing::Test {
28+
public:
29+
using Scalar = T;
30+
};
31+
32+
using ScalarTypes = ::testing::Types<int, int64_t, float, double, Kokkos::complex<float>, Kokkos::complex<double>>;
33+
TYPED_TEST_SUITE(ChannelSendRecv, ScalarTypes);
34+
35+
template <typename Scalar>
36+
void test_channel_hostspace() {
37+
int rank, size;
38+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
39+
MPI_Comm_size(MPI_COMM_WORLD, &size);
40+
41+
if (size < 2) {
42+
GTEST_SKIP() << "This test requires at least 2 MPI processes";
43+
}
44+
45+
const int dest_rank = (rank + 1) % size; // send to next rank
46+
const int src_rank = (rank - 1 + size) % size; // recv from prev rank
47+
const int tag = 42;
48+
49+
KokkosComm::Channel<> channel(dest_rank, src_rank, tag, MPI_COMM_WORLD);
50+
51+
const int N = 10;
52+
53+
// Create host views
54+
Kokkos::View<Scalar*, Kokkos::HostSpace> send_host("send_host", N);
55+
Kokkos::View<Scalar*, Kokkos::HostSpace> recv_host("recv_host", N);
56+
57+
for (int i = 0; i < N; i++) send_host(i) = static_cast<Scalar>(rank * N + i);
58+
59+
channel.sendinit(send_host);
60+
channel.recvinit(recv_host);
61+
62+
channel.start();
63+
channel.wait();
64+
65+
int errs = 0;
66+
for (int i = 0; i < N; i++) {
67+
const Scalar expected = static_cast<Scalar>(src_rank * N + i);
68+
if (recv_host(i) != expected) {
69+
errs++;
70+
}
71+
}
72+
EXPECT_EQ(errs, 0);
73+
}
74+
75+
template <typename Scalar>
76+
void test_channel_execspace() {
77+
int rank, size;
78+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
79+
MPI_Comm_size(MPI_COMM_WORLD, &size);
80+
81+
if (size < 2) {
82+
GTEST_SKIP() << "This test requires at least 2 MPI processes";
83+
}
84+
85+
const int dest_rank = (rank + 1) % size; // send to next rank
86+
const int src_rank = (rank - 1 + size) % size; // recv from prev rank
87+
const int tag = 42;
88+
89+
KokkosComm::Channel<> channel(dest_rank, src_rank, tag, MPI_COMM_WORLD);
90+
91+
const int N = 10;
92+
93+
// Create host view
94+
Kokkos::View<Scalar*, Kokkos::HostSpace> recv_host("recv_host", N);
95+
// Create device views
96+
Kokkos::View<Scalar*, Kokkos::DefaultExecutionSpace> send_dev("send_dev", N);
97+
Kokkos::View<Scalar*, Kokkos::DefaultExecutionSpace> recv_dev("recv_dev", N);
98+
99+
GTEST_LOG_(INFO) << "Kokkos Execution Space: " << Kokkos::DefaultExecutionSpace::name();
100+
Kokkos::parallel_for(
101+
"init_send_dev", N, KOKKOS_LAMBDA(int i) { send_dev(i) = static_cast<Scalar>(rank * N + i); });
102+
Kokkos::fence();
103+
104+
channel.sendinit(send_dev);
105+
channel.recvinit(recv_dev);
106+
107+
channel.start();
108+
channel.wait();
109+
110+
Kokkos::deep_copy(recv_host, recv_dev);
111+
112+
int errs = 0;
113+
for (int i = 0; i < N; i++) {
114+
const Scalar expected = static_cast<Scalar>(src_rank * N + i);
115+
if (recv_host(i) != expected) {
116+
errs++;
117+
}
118+
}
119+
EXPECT_EQ(errs, 0);
120+
}
121+
122+
template <typename Scalar>
123+
void test_channel_reuse() {
124+
int rank, size;
125+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
126+
MPI_Comm_size(MPI_COMM_WORLD, &size);
127+
128+
if (size < 2) {
129+
GTEST_SKIP() << "This test requires at least 2 MPI processes";
130+
}
131+
132+
const int dest_rank = (rank + 1) % size; // send to next rank
133+
const int src_rank = (rank - 1 + size) % size; // recv from prev rank
134+
const int tag = 42;
135+
136+
KokkosComm::Channel<> channel(dest_rank, src_rank, tag, MPI_COMM_WORLD);
137+
138+
const int N = 10;
139+
140+
// Create host view
141+
Kokkos::View<Scalar*, Kokkos::HostSpace> recv_host("recv_host", N);
142+
// Create device views
143+
Kokkos::View<Scalar*, Kokkos::DefaultExecutionSpace> send_dev("send_dev", N);
144+
Kokkos::View<Scalar*, Kokkos::DefaultExecutionSpace> recv_dev("recv_dev", N);
145+
146+
GTEST_LOG_(INFO) << "Kokkos Execution Space: " << Kokkos::DefaultExecutionSpace::name();
147+
int errs = 0;
148+
for (int i = 0; i < 3; i++) {
149+
Kokkos::parallel_for(
150+
"init_send_dev", N, KOKKOS_LAMBDA(int i) { send_dev(i) = static_cast<Scalar>(rank * N + i); });
151+
Kokkos::fence();
152+
153+
channel.sendinit(send_dev);
154+
channel.recvinit(recv_dev);
155+
156+
channel.start();
157+
channel.wait();
158+
159+
Kokkos::deep_copy(recv_host, recv_dev);
160+
161+
for (int j = 0; j < N; j++) {
162+
const Scalar expected = static_cast<Scalar>(src_rank * N + j);
163+
if (recv_host(j) != expected) {
164+
errs++;
165+
}
166+
}
167+
}
168+
EXPECT_EQ(errs, 0);
169+
}
170+
171+
TYPED_TEST(ChannelSendRecv, 1D_contig_sendrecv_hostspace) { test_channel_hostspace<typename TestFixture::Scalar>(); }
172+
TYPED_TEST(ChannelSendRecv, 1D_contig_sendrecv_execspace) { test_channel_execspace<typename TestFixture::Scalar>(); }
173+
TYPED_TEST(ChannelSendRecv, 1D_contig_sendrecv_reuse) { test_channel_reuse<typename TestFixture::Scalar>(); }
174+
175+
} // namespace

0 commit comments

Comments
 (0)