Skip to content

Commit 5567058

Browse files
committed
#1: inform: fix some bugs, incl. termination, get inform working
1 parent 9afea2e commit 5567058

File tree

4 files changed

+168
-21
lines changed

4 files changed

+168
-21
lines changed

examples/test_example.cc

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,38 @@ int main(int argc, char** argv) {
1717
auto comm = vt_lb::comm::CommMPI();
1818
comm.init(argc, argv);
1919

20-
auto cls = std::make_unique<MyClass>();
21-
auto handle = comm.registerInstanceCollective(cls.get());
22-
auto rank = comm.getRank();
20+
//auto cls = std::make_unique<MyClass>();
21+
//auto handle = comm.registerInstanceCollective(cls.get());
22+
//auto rank = comm.getRank();
2323

24-
int value = 10;
25-
int recv_value = 0;
26-
handle.reduce(1, MPI_INT, MPI_SUM, &value, &recv_value, 1);
24+
// int value = 10;
25+
// int recv_value = 0;
26+
// handle.reduce(1, MPI_INT, MPI_SUM, &value, &recv_value, 1);
2727

28-
fmt::print("Rank {}: reduced value is {}\n", rank, recv_value);
28+
// fmt::print("Rank {}: reduced value is {}\n", rank, recv_value);
2929

30-
if (rank == 0) {
31-
handle[1].send<&MyClass::myHandler2>(std::string{"hello from rank 0"});
32-
}
33-
if (rank == 1) {
34-
handle[0].send<&MyClass::myHandler>(2, 10.3);
35-
}
36-
37-
while (comm.poll()) {
38-
}
30+
// if (rank == 0) {
31+
// handle[1].send<&MyClass::myHandler2>(std::string{"hello from rank 0"});
32+
// }
33+
// if (rank == 1) {
34+
// handle[0].send<&MyClass::myHandler>(2, 10.3);
35+
// }
3936

40-
printf("out of poll\n");
37+
printf("Running runLB\n");
38+
//comm.barrier();
4139

4240
vt_lb::runLB(
4341
vt_lb::DriverAlgoEnum::TemperedLB,
4442
comm,
45-
vt_lb::algo::temperedlb::Configuration{},
43+
vt_lb::algo::temperedlb::Configuration{comm.numRanks()},
4644
nullptr
4745
);
4846

47+
while (comm.poll()) {
48+
}
49+
50+
printf("out of poll\n");
51+
4952
comm.finalize();
5053
return 0;
5154
}

src/vt-lb/algo/temperedlb/temperedlb.h

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include <vt-lb/algo/baselb/baselb.h>
4949

5050
#include <limits>
51+
#include <random>
5152

5253
namespace vt_lb::algo::temperedlb {
5354

@@ -63,14 +64,22 @@ struct WorkModel {
6364
};
6465

6566
struct Configuration {
67+
Configuration() = default;
68+
Configuration(int num_ranks) {
69+
f_ = 2;
70+
k_max_ = std::ceil(std::sqrt(std::log(num_ranks)/std::log(2.0)));
71+
}
72+
6673
/// @brief Number of trials to perform
6774
int num_trials_ = 1;
6875
/// @brief Number of iterations per trial
6976
int num_iters_ = 10;
7077
/// @brief Fanout for information propagation
7178
int f_ = 2;
7279
/// @brief Number of rounds of information propagation
73-
int k_max_ = 0;
80+
int k_max_ = 1;
81+
82+
bool async_ip_ = true;
7483

7584
/// @brief Work model parameters (rank-alpha, beta, gamma, delta)
7685
WorkModel work_model_;
@@ -79,6 +88,71 @@ struct Configuration {
7988
double converge_tolerance_ = 0.01;
8089
};
8190

91+
template <typename CommT, typename DataT, typename JoinT>
92+
struct InformationPropagation {
93+
using ThisType = InformationPropagation<CommT, DataT, JoinT>;
94+
using JoinedDataType = std::unordered_map<int, DataT>;
95+
96+
InformationPropagation(CommT& comm, int f, int k_max)
97+
: comm_(comm.clone()), f_(f), k_max_(k_max)
98+
{
99+
handle_ = comm_.template registerInstanceCollective<ThisType>(this);
100+
}
101+
102+
void run(DataT initial_data) {
103+
// Insert my own rank to avoid self-selection
104+
already_selected_.insert(comm_.getRank());
105+
106+
local_data_[comm_.getRank()] = initial_data;
107+
108+
sendToFanout(1, local_data_);
109+
110+
while (comm_.poll()) {
111+
// do nothing
112+
}
113+
114+
printf("%d: done with poll: local_data size=%zu\n", comm_.getRank(), local_data_.size());
115+
}
116+
117+
void sendToFanout(int round, JoinedDataType const& data) {
118+
int num_ranks = comm_.numRanks();
119+
120+
for (int i = 1; i <= f_; ++i) {
121+
if (already_selected_.size() >= static_cast<size_t>(num_ranks)) {
122+
return;
123+
}
124+
125+
std::uniform_int_distribution<int> dist(0, num_ranks - 1);
126+
int target = -1;
127+
do {
128+
target = dist(gen_select_);
129+
} while (already_selected_.find(target) != already_selected_.end());
130+
131+
already_selected_.insert(target);
132+
133+
//printf("rank %d sending to rank %d\n", comm_.getRank(), target);
134+
handle_[target].template send<&ThisType::infoPropagateHandler>(round, data);
135+
}
136+
}
137+
138+
void infoPropagateHandler(int round, JoinedDataType incoming_data) {
139+
// Process incoming data and add to local data
140+
local_data_.insert(incoming_data.begin(), incoming_data.end());
141+
if (round < k_max_) {
142+
sendToFanout(round + 1, local_data_);
143+
}
144+
}
145+
146+
private:
147+
CommT comm_;
148+
int f_ = 2;
149+
int k_max_ = 0;
150+
std::unordered_set<int> already_selected_;
151+
std::unordered_map<int, DataT> local_data_;
152+
std::mt19937 gen_select_{std::random_device{}()};
153+
typename CommT::template HandleType<ThisType> handle_;
154+
};
155+
82156
template <typename CommT>
83157
struct TemperedLB : baselb::BaseLB {
84158

@@ -97,11 +171,22 @@ struct TemperedLB : baselb::BaseLB {
97171
{ }
98172

99173
void makeHandle() {
174+
// printf("makeHandle\n");
100175
handle_ = comm_.template registerInstanceCollective<TemperedLB<CommT>>(this);
101176
}
102177

103178
void run() {
104179
// Implementation of the TemperedLB algorithm would go here
180+
auto& wm = config_.work_model_;
181+
if (wm.beta == 0.0 && wm.gamma == 0.0 && wm.delta == 0.0) {
182+
using LoadType = double;
183+
printf("start InformationPropagation\n");
184+
auto ip = InformationPropagation<CommT, LoadType, TemperedLB<CommT>>(
185+
comm_, config_.f_, config_.k_max_
186+
);
187+
ip.run(10.0);
188+
}
189+
105190
}
106191

107192
private:

src/vt-lb/comm/comm_mpi.h

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,23 +174,51 @@ struct CommMPI {
174174
* \brief Construct a new CommMPI instance
175175
* \param comm The MPI communicator to use (defaults to MPI_COMM_WORLD)
176176
*/
177-
CommMPI(MPI_Comm comm = MPI_COMM_WORLD) : comm_(comm) {}
177+
explicit CommMPI(MPI_Comm comm = MPI_COMM_WORLD) : comm_(comm) { }
178178

179+
CommMPI(CommMPI const&) = delete;
180+
CommMPI(CommMPI&&) = delete;
181+
CommMPI& operator=(CommMPI const&) = delete;
182+
CommMPI& operator=(CommMPI&&) = delete;
183+
184+
private:
185+
CommMPI(MPI_Comm comm, int rank, int size)
186+
: comm_(comm), cached_rank_(rank), cached_size_(size)
187+
{
188+
initTermination();
189+
}
190+
191+
public:
179192
/**
180193
* \brief Initialize MPI
181194
* \param argc Pointer to argument count
182195
* \param argv Pointer to argument array
183196
*/
184197
void init(int& argc, char**& argv) {
185198
MPI_Init(&argc, &argv);
199+
MPI_Comm_rank(comm_, &cached_rank_);
200+
MPI_Comm_size(comm_, &cached_size_);
186201
initTermination();
202+
// printf("%d: Initialized MPI with %d ranks\n", cached_rank_, cached_size_);
203+
}
204+
205+
CommMPI clone() const {
206+
MPI_Comm new_comm;
207+
MPI_Comm_dup(comm_, &new_comm);
208+
return CommMPI{new_comm, cached_rank_, cached_size_};
187209
}
188210

189211
/**
190212
* \brief Finalize MPI
191213
*/
192214
void finalize() {
215+
// printf("%d: Finalizing MPI\n", cached_rank_);
193216
MPI_Finalize();
217+
comm_ = MPI_COMM_NULL;
218+
}
219+
220+
void barrier() {
221+
MPI_Barrier(comm_);
194222
}
195223

196224
/**
@@ -287,6 +315,13 @@ struct CommMPI {
287315
buf_interpreter.classIndex() = idx;
288316
buf_interpreter.isTermination() = is_termination_msg ? 1 : 0;
289317

318+
// printf("%d: MPI_Send to %d handler_index=%d class_index=%d is_termination=%d\n",
319+
// cached_rank_,
320+
// dest,
321+
// buf_interpreter.handlerIndex(),
322+
// buf_interpreter.classIndex(),
323+
// buf_interpreter.isTermination()
324+
// );
290325
MPI_Request req;
291326
MPI_Isend(
292327
send_ptr, send_ptr_size, MPI_BYTE, dest, 0, comm_,
@@ -310,6 +345,7 @@ struct CommMPI {
310345
{
311346
int flag = 0;
312347
MPI_Status status;
348+
// printf("%d: MPI_Iprobe\n", cached_rank_);
313349
MPI_Iprobe(MPI_ANY_SOURCE, 0, comm_, &flag, &status);
314350
if (flag) {
315351
int count = 0;
@@ -326,12 +362,19 @@ struct CommMPI {
326362
throw std::runtime_error("Buffer alignment error");
327363
}
328364

365+
// printf("%d: MPI_Recv from %d of size %d\n", cached_rank_, status.MPI_SOURCE, count);
329366
MPI_Recv(buf.data(), count, MPI_BYTE, status.MPI_SOURCE, status.MPI_TAG, comm_, MPI_STATUS_IGNORE);
330367
BufferIntInterpreter buf_interpreter(buf.data());
331368
int handler_index = buf_interpreter.handlerIndex();
332369
int class_index = buf_interpreter.classIndex();
333370
bool is_termination = buf_interpreter.isTermination() != 0;
334371

372+
// printf("%d: Received message: handler_index=%d class_index=%d is_termination=%d\n",
373+
// cached_rank_,
374+
// handler_index,
375+
// class_index,
376+
// is_termination ? 1 : 0
377+
// );
335378
auto mem_fn = detail::getMember(handler_index);
336379
assert(class_map_.find(class_index) != class_map_.end() && "Class index not found");
337380
mem_fn->dispatch(
@@ -367,6 +410,10 @@ struct CommMPI {
367410
*/
368411
int numRanks() const {
369412
int size = 0;
413+
// printf("%d: numRanks\n", cached_rank_);
414+
if (comm_ == MPI_COMM_NULL) {
415+
throw std::runtime_error("Communicator is not initialized");
416+
}
370417
MPI_Comm_size(comm_, &size);
371418
return size;
372419
}
@@ -377,18 +424,25 @@ struct CommMPI {
377424
*/
378425
int getRank() const {
379426
int rank = 0;
427+
// printf("%d: getRank\n", cached_rank_);
428+
if (comm_ == MPI_COMM_NULL) {
429+
throw std::runtime_error("Communicator is not initialized");
430+
}
380431
MPI_Comm_rank(comm_, &rank);
381432
return rank;
382433
}
383434

384435
private:
385436
void initTermination();
386437

387-
MPI_Comm comm_;
438+
MPI_Comm comm_ = MPI_COMM_NULL;
388439
std::list<std::tuple<MPI_Request, std::unique_ptr<char[]>>> pending_;
389440
int next_class_index_ = 0;
390441
std::unordered_map<int, void*> class_map_;
391442
std::unique_ptr<detail::TerminationDetector> termination_detector_;
443+
444+
int cached_rank_ = -1;
445+
int cached_size_ = -1;
392446
};
393447

394448
inline void CommMPI::initTermination() {

src/vt-lb/comm/termination.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ void TerminationDetector::onResponse(uint64_t in_sent, uint64_t in_recv) {
119119
if (rank_ == 0) {
120120
// Root checks for termination
121121

122+
global_sent1_ += sent_;
123+
global_recv1_ += recv_;
124+
122125
#if DEBUG_TERMINATION
123126
printf("Root total: s1=%lld, r1=%lld, s2=%lld, r2=%lld\n",
124127
global_sent1_, global_recv1_, global_sent2_, global_recv2_);
@@ -167,7 +170,9 @@ void TerminationDetector::notifyMessageReceive() {
167170
}
168171

169172
void TerminationDetector::terminated() {
173+
#if DEBUG_TERMINATION
170174
printf("%d: Terminated!\n", rank_);
175+
#endif
171176
terminated_ = true;
172177
for (int i = 0; i < num_children_; i++) {
173178
handle_[first_child_ + i].sendTerm<&TerminationDetector::terminated>();

0 commit comments

Comments
 (0)