Skip to content

Commit c1f1b67

Browse files
committed
rework to support waitall
1 parent 0823dc8 commit c1f1b67

12 files changed

+107
-110
lines changed

perf_tests/test_2dhalo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ void send_recv(benchmark::State &, MPI_Comm comm, const Space &space, int nx, in
4646
auto ym1_s = Kokkos::subview(v, make_pair(1, nx + 1), 1, Kokkos::ALL);
4747
auto ym1_r = Kokkos::subview(v, make_pair(1, nx + 1), 0, Kokkos::ALL);
4848

49-
KokkosComm::Handle<Space> h = KokkosComm::plan(space, comm, [=](KokkosComm::Handle<Space> &handle) {
49+
std::vector<KokkosComm::Req<>> reqs = KokkosComm::plan(space, comm, [=](KokkosComm::Handle<Space> &handle) {
5050
KokkosComm::isend(handle, xp1_s, get_rank(xp1, ry), 0);
5151
KokkosComm::isend(handle, xm1_s, get_rank(xm1, ry), 1);
5252
KokkosComm::isend(handle, yp1_s, get_rank(rx, yp1), 2);
@@ -59,7 +59,7 @@ void send_recv(benchmark::State &, MPI_Comm comm, const Space &space, int nx, in
5959
KokkosComm::recv(space, yp1_r, get_rank(rx, yp1), 3, comm);
6060

6161
// wait for comm
62-
KokkosComm::wait(h);
62+
KokkosComm::wait_all(reqs);
6363
}
6464

6565
void benchmark_2dhalo(benchmark::State &state) {

src/KokkosComm_fwd.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
#pragma once
1818

19+
#include <vector>
20+
1921
#include "KokkosComm_concepts.hpp"
2022
#include "KokkosComm_config.hpp"
2123

@@ -28,15 +30,18 @@ using FallbackTransport = Mpi;
2830
#error at least one transport must be defined
2931
#endif
3032

33+
template <Transport TRANSPORT = DefaultTransport>
34+
class Req;
35+
3136
template <KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, Transport TRANSPORT = DefaultTransport>
3237
class Handle;
3338

34-
template <Dispatch DISPATCH, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
35-
Transport TRANSPORT = DefaultTransport>
36-
class Plan;
37-
3839
namespace Impl {
3940

41+
template <Dispatch TDispatch, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
42+
Transport TRANSPORT = DefaultTransport>
43+
struct Plan;
44+
4045
template <KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
4146
Transport TRANSPORT = DefaultTransport>
4247
struct Irecv;

src/KokkosComm_plan.hpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,18 @@
2222

2323
namespace KokkosComm {
2424

25-
template <Dispatch DISPATCH, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
25+
template <Dispatch TDispatch, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
2626
Transport TRANSPORT = DefaultTransport>
27-
Handle<ExecSpace, TRANSPORT> plan(const ExecSpace &space, MPI_Comm comm, DISPATCH d) {
28-
return Plan<DISPATCH, ExecSpace, TRANSPORT>(space, comm, d).handle();
27+
std::vector<Req<TRANSPORT>> plan(Handle<ExecSpace, TRANSPORT> &handle, TDispatch d) {
28+
return Impl::Plan<TDispatch, ExecSpace, TRANSPORT>(handle, d).reqs;
2929
}
3030

31-
template <Dispatch DISPATCH, KokkosExecutionSpace ExecSpace, Transport TRANSPORT>
32-
void plan(Handle<ExecSpace, TRANSPORT> &handle, DISPATCH d) {
33-
Plan<DISPATCH, ExecSpace, TRANSPORT>(handle, d);
31+
template <Dispatch TDispatch, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
32+
Transport TRANSPORT = DefaultTransport>
33+
std::vector<Req<TRANSPORT>> plan(const ExecSpace &space, MPI_Comm comm, TDispatch d) {
34+
Handle<ExecSpace, TRANSPORT> handle(space, comm);
35+
auto ret = plan<TDispatch, ExecSpace, TRANSPORT>(handle, d);
36+
return ret;
3437
}
3538

3639
} // namespace KokkosComm

src/KokkosComm_point_to_point.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,19 @@ void isend(Handle<ExecSpace, TRANSPORT> &h, SendView &sv, int dest, int tag) {
3939
// TODO: can these go in MPI somewhere?
4040
#if defined(KOKKOSCOMM_TRANSPORT_MPI)
4141
template <KokkosView SendView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace>
42-
KokkosComm::Handle<ExecSpace, Mpi> isend(const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Comm comm) {
43-
return KokkosComm::plan(space, comm,
44-
[=](Handle<ExecSpace, Mpi> &handle) { KokkosComm::isend(handle, sv, dest, tag); });
42+
Req<Mpi> isend(const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Comm comm) {
43+
auto reqs =
44+
KokkosComm::plan(space, comm, [=](Handle<ExecSpace, Mpi> &handle) { KokkosComm::isend(handle, sv, dest, tag); });
45+
assert(reqs.size() == 1 && "Internal KokkosComm developer error");
46+
return reqs[0];
4547
}
4648

4749
template <KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace>
48-
KokkosComm::Handle<ExecSpace, Mpi> irecv(const ExecSpace &space, const RecvView &rv, int dest, int tag, MPI_Comm comm) {
49-
return KokkosComm::plan(space, comm,
50-
[=](Handle<ExecSpace, Mpi> &handle) { KokkosComm::irecv(handle, rv, dest, tag); });
50+
Req<Mpi> irecv(const ExecSpace &space, const RecvView &rv, int dest, int tag, MPI_Comm comm) {
51+
auto reqs =
52+
KokkosComm::plan(space, comm, [=](Handle<ExecSpace, Mpi> &handle) { KokkosComm::irecv(handle, rv, dest, tag); });
53+
assert(reqs.size() == 1 && "Internal KokkosComm developer error");
54+
return reqs[0];
5155
}
5256

5357
#endif

src/impl/KokkosComm_irecv.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include <Kokkos_Core.hpp>
2222

2323
#include "KokkosComm_pack_traits.hpp"
24-
#include "KokkosComm_request.hpp"
2524
#include "KokkosComm_traits.hpp"
2625

2726
// impl
@@ -45,9 +44,9 @@ void irecv(RecvView &rv, int src, int tag, MPI_Comm comm, MPI_Request &req) {
4544
}
4645

4746
template <KokkosView RecvView>
48-
KokkosComm::Req irecv(RecvView &rv, int src, int tag, MPI_Comm comm) {
47+
KokkosComm::Req<KokkosComm::Mpi> irecv(RecvView &rv, int src, int tag, MPI_Comm comm) {
4948
Kokkos::Tools::pushRegion("KokkosComm::Impl::irecv");
50-
KokkosComm::Req req;
49+
KokkosComm::Req<KokkosComm::Mpi> req;
5150
irecv(rv, src, tag, comm, req.mpi_req());
5251
return req;
5352
}

src/impl/KokkosComm_isend.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
#include "KokkosComm_concepts.hpp"
2424
#include "KokkosComm_pack_traits.hpp"
25-
#include "KokkosComm_request.hpp"
2625
#include "KokkosComm_traits.hpp"
2726
#include "KokkosComm_comm_mode.hpp"
2827

@@ -46,10 +45,10 @@ void isend(const SendView &sv, int dest, int tag, MPI_Comm comm, MPI_Request &re
4645
}
4746

4847
template <CommMode SendMode = CommMode::Default, KokkosExecutionSpace ExecSpace, KokkosView SendView>
49-
KokkosComm::Req isend(const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Comm comm) {
48+
KokkosComm::Req<KokkosComm::Mpi> isend(const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Comm comm) {
5049
Kokkos::Tools::pushRegion("KokkosComm::Impl::isend");
5150

52-
KokkosComm::Req req;
51+
KokkosComm::Req<KokkosComm::Mpi> req;
5352

5453
using KCT = KokkosComm::Traits<SendView>;
5554
using KCPT = KokkosComm::PackTraits<SendView>;

src/mpi/KokkosComm_mpi_handle.hpp

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
#include "KokkosComm_fwd.hpp"
2020

21+
#include "KokkosComm_mpi_req.hpp"
22+
2123
namespace KokkosComm {
2224

2325
/*
@@ -35,41 +37,23 @@ class Handle<ExecSpace, Mpi> {
3537
public:
3638
using execution_space = ExecSpace;
3739

38-
// a type-erased view. Request uses these to keep temporary views alive for
39-
// the lifetime of "Immediate" MPI operations
40-
struct ViewHolderBase {
41-
virtual ~ViewHolderBase() {}
42-
};
43-
template <typename V>
44-
struct ViewHolder : ViewHolderBase {
45-
ViewHolder(const V &v) : v_(v) {}
46-
V v_;
47-
};
48-
4940
Handle(const execution_space &space, MPI_Comm comm) : space_(space), comm_(comm), preCommFence_(false) {}
5041

51-
// template <KokkosView View>
52-
// void impl_track_view(const View &v) {
53-
// views_.push_back(std::make_shared<ViewHolder<View>>(v));
54-
// }
42+
MPI_Comm &mpi_comm() { return comm_; }
43+
const execution_space &space() const { return space_; }
5544

5645
void impl_add_pre_comm_fence() { preCommFence_ = true; }
5746

58-
void impl_track_mpi_request(MPI_Request req) { reqs_.push_back(req); }
59-
6047
void impl_add_alloc(std::function<void()> f) { allocs_.push_back(f); }
6148

6249
void impl_add_pre_copy(std::function<void()> f) { preCopies_.push_back(f); }
6350

6451
void impl_add_comm(std::function<void()> f) { comms_.push_back(f); }
6552

66-
void impl_add_post_wait(std::function<void()> f) { postWaits_.push_back(f); }
67-
68-
MPI_Comm &mpi_comm() { return comm_; }
69-
70-
const execution_space &space() const { return space_; }
53+
void impl_track_req(const Req<Mpi> &req) { reqs_.push_back(req); }
7154

7255
void impl_run() {
56+
std::cerr << __FILE__ << ":" << __LINE__ << " impl_run\n";
7357
for (const auto &f : allocs_) f();
7458
for (const auto &f : preCopies_) f();
7559
if (preCommFence_) {
@@ -82,10 +66,9 @@ class Handle<ExecSpace, Mpi> {
8266
comms_.clear();
8367
}
8468

85-
private:
86-
template <KokkosExecutionSpace ES>
87-
friend void wait(Handle<ES, Mpi> &handle);
69+
std::vector<Req<Mpi>> &impl_reqs() { return reqs_; }
8870

71+
private:
8972
execution_space space_;
9073
MPI_Comm comm_;
9174

@@ -96,23 +79,8 @@ class Handle<ExecSpace, Mpi> {
9679
std::vector<std::function<void()>> comms_;
9780

9881
// wait variables
99-
std::vector<MPI_Request> reqs_;
82+
std::vector<Req<Mpi>> reqs_;
10083
std::vector<std::function<void()>> postWaits_;
101-
std::vector<std::shared_ptr<ViewHolderBase>> views_;
10284
};
10385

104-
template <KokkosExecutionSpace ExecSpace>
105-
void wait(Handle<ExecSpace, Mpi> &handle) {
106-
MPI_Waitall(handle.reqs_.size(), handle.reqs_.data(), MPI_STATUSES_IGNORE);
107-
handle.reqs_.clear();
108-
// std::cerr << __FILE__ << ":" << __LINE__ << " MPI_Waitall done\n";
109-
for (const auto &f : handle.postWaits_) {
110-
f();
111-
}
112-
// std::cerr << __FILE__ << ":" << __LINE__ << " postWaits_.clear()...\n";
113-
handle.postWaits_.clear();
114-
// views_.clear();
115-
// std::cerr << __FILE__ << ":" << __LINE__ << " wait() done\n";
116-
}
117-
11886
} // namespace KokkosComm

src/mpi/KokkosComm_mpi_plan.hpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,17 @@
1616

1717
#pragma once
1818

19-
namespace KokkosComm {
19+
namespace KokkosComm::Impl {
2020

21-
template <Dispatch DISPATCH, KokkosExecutionSpace ExecSpace>
22-
class Plan<DISPATCH, ExecSpace, Mpi> {
23-
public:
24-
using execution_space = ExecSpace;
25-
using handle_type = Handle<execution_space, Mpi>;
26-
Plan(const execution_space &space, MPI_Comm comm, DISPATCH d) : handle_(space, comm) {
27-
d(handle_);
28-
handle_.impl_run();
21+
template <Dispatch TDispatch, KokkosExecutionSpace ExecSpace>
22+
struct Plan<TDispatch, ExecSpace, Mpi> {
23+
Plan(Handle<ExecSpace, Mpi> &handle, TDispatch d) {
24+
d(handle);
25+
handle.impl_run();
26+
reqs = handle.impl_reqs();
2927
}
3028

31-
Plan(const execution_space &space, DISPATCH d) : Plan(space, MPI_COMM_WORLD, d) {}
32-
Plan(MPI_Comm comm, DISPATCH d) : Plan(Kokkos::DefaultExecutionSpace(), comm, d) {}
33-
Plan(DISPATCH d) : Plan(Kokkos::DefaultExecutionSpace(), MPI_COMM_WORLD, d) {}
34-
35-
handle_type handle() const { return handle_; }
36-
37-
private:
38-
handle_type handle_;
29+
std::vector<Req<Mpi>> reqs;
3930
};
4031

41-
} // namespace KokkosComm
32+
} // namespace KokkosComm::Impl

src/KokkosComm_request.hpp renamed to src/mpi/KokkosComm_mpi_req.hpp

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616

1717
#pragma once
1818

19-
#include <memory>
2019
#include <vector>
20+
#include <utility>
2121

22-
#include "KokkosComm_include_mpi.hpp"
22+
#include "KokkosComm_fwd.hpp"
23+
24+
#include "KokkosComm_mpi.hpp"
2325

2426
namespace KokkosComm {
2527

26-
class Req {
28+
template <>
29+
class Req<Mpi> {
2730
// a type-erased view. Request uses these to keep temporary views alive for
2831
// the lifetime of "Immediate" MPI operations
2932
struct ViewHolderBase {
@@ -38,27 +41,53 @@ class Req {
3841
struct Record {
3942
Record() : req_(MPI_REQUEST_NULL) {}
4043
MPI_Request req_;
41-
std::vector<std::shared_ptr<ViewHolderBase>> until_waits_;
44+
std::vector<std::function<void()>> postWaits_;
4245
};
4346

4447
public:
4548
Req() : record_(std::make_shared<Record>()) {}
4649

4750
MPI_Request &mpi_req() { return record_->req_; }
4851

49-
void wait() {
50-
MPI_Wait(&(record_->req_), MPI_STATUS_IGNORE);
51-
record_->until_waits_.clear(); // drop any views we're keeping alive until wait()
52-
}
53-
5452
// keep a reference to this view around until wait() is called
5553
template <typename View>
5654
void keep_until_wait(const View &v) {
57-
record_->until_waits_.push_back(std::make_shared<ViewHolder<View>>(v));
55+
record_->postWaits_.push_back([v]() {} /* capture v into a no-op lambda that does nothing*/);
5856
}
5957

6058
private:
6159
std::shared_ptr<Record> record_;
60+
61+
friend void wait(Req<Mpi> &req);
62+
friend void wait_all(std::vector<Req<Mpi>> &reqs);
63+
friend int wait_any(std::vector<Req<Mpi>> &reqs);
6264
};
6365

66+
inline void wait(Req<Mpi> &req) {
67+
std::cerr << __FILE__ << ":" << __LINE__ << " wait on " << req.mpi_req() << "\n";
68+
MPI_Wait(&req.mpi_req(), MPI_STATUS_IGNORE);
69+
for (auto &f : req.record_->postWaits_) {
70+
f();
71+
}
72+
req.record_->postWaits_.clear();
73+
}
74+
75+
inline void wait_all(std::vector<Req<Mpi>> &reqs) {
76+
std::cerr << __FILE__ << ":" << __LINE__ << " wait all\n";
77+
for (Req<Mpi> &req : reqs) {
78+
wait(req);
79+
}
80+
}
81+
82+
inline int wait_any(std::vector<Req<Mpi>> &reqs) {
83+
for (size_t i = 0; i < reqs.size(); ++i) {
84+
int completed;
85+
MPI_Test(&(reqs[i].mpi_req()), &completed, MPI_STATUS_IGNORE);
86+
if (completed) {
87+
return true;
88+
}
89+
}
90+
return false;
91+
}
92+
6493
} // namespace KokkosComm

src/mpi/impl/KokkosComm_mpi_irecv_datatype.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ void irecv_datatype(HandleTy &h, RecvView rv, int src, int tag) {
2626
h.impl_add_pre_comm_fence();
2727

2828
h.impl_add_comm([&h, rv, src, tag]() {
29-
MPI_Request req;
30-
MPI_Irecv(KokkosComm::data_handle(rv), 1, view_mpi_type(rv), src, tag, h.mpi_comm(), &req);
31-
h.impl_track_mpi_request(req);
29+
Req<Mpi> req;
30+
MPI_Irecv(KokkosComm::data_handle(rv), 1, view_mpi_type(rv), src, tag, h.mpi_comm(), &req.mpi_req());
31+
h.impl_track_req(req);
3232
});
3333
}
3434

src/mpi/impl/KokkosComm_mpi_isend_datatype.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ void isend_datatype(HandleTy &h, SendView sv, int dst, int tag) {
2626
h.impl_add_pre_comm_fence();
2727

2828
h.impl_add_comm([&h, sv, dst, tag]() {
29-
MPI_Request req;
30-
MPI_Isend(KokkosComm::data_handle(sv), 1, view_mpi_type(sv), dst, tag, h.mpi_comm(), &req);
31-
h.impl_track_mpi_request(req);
29+
Req<Mpi> req;
30+
MPI_Isend(KokkosComm::data_handle(sv), 1, view_mpi_type(sv), dst, tag, h.mpi_comm(), &req.mpi_req());
31+
h.impl_track_req(req);
3232
});
3333
}
3434

0 commit comments

Comments
 (0)