Skip to content

Commit 358a753

Browse files
committed
#1: LB: add deterministic mode
1 parent cdea1f7 commit 358a753

File tree

2 files changed

+69
-14
lines changed

2 files changed

+69
-14
lines changed

examples/test_example.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,15 @@ int main(int argc, char** argv) {
3737
printf("Running runLB\n");
3838
//comm.barrier();
3939

40+
vt_lb::algo::temperedlb::Configuration config{comm.numRanks()};
41+
config.deterministic_ = true;
42+
config.seed_ = 97;
43+
//config.k_max_ = 1;
44+
4045
vt_lb::runLB(
4146
vt_lb::DriverAlgoEnum::TemperedLB,
4247
comm,
43-
vt_lb::algo::temperedlb::Configuration{comm.numRanks()},
48+
config,
4449
nullptr
4550
);
4651

src/vt-lb/algo/temperedlb/temperedlb.h

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ struct Configuration {
7878
int f_ = 2;
7979
/// @brief Number of rounds of information propagation
8080
int k_max_ = 1;
81+
/// @brief Whether to use deterministic selection
82+
bool deterministic_ = true;
83+
/// @brief Seed for random number generation when deterministic_ is true
84+
int seed_ = 29;
8185

8286
bool async_ip_ = true;
8387

@@ -94,14 +98,26 @@ struct InformationPropagation {
9498
using JoinedDataType = std::unordered_map<int, DataT>;
9599
using HandleType = typename CommT::template HandleType<ThisType>;
96100

97-
/// @brief Construct information propagation instance
98-
/// @param comm Communication interface -- n.b., we clone comm to create a new termination scope
99-
/// @param f Fanout parameter
100-
/// @param k_max Maximum number of rounds
101-
InformationPropagation(CommT& comm, int f, int k_max)
102-
: comm_(comm.clone()), f_(f), k_max_(k_max)
101+
/**
102+
* @brief Construct information propagation instance
103+
*
104+
* @param comm Communication interface -- n.b., we clone comm to create a new termination scope
105+
* @param f Fanout parameter
106+
* @param k_max Maximum number of rounds
107+
* @param deterministic Whether to use deterministic selection
108+
*
109+
*/
110+
InformationPropagation(CommT& comm, int f, int k_max, bool deterministic, int seed)
111+
: comm_(comm.clone()), // collective operation
112+
f_(f),
113+
k_max_(k_max),
114+
deterministic_(deterministic)
103115
{
104116
handle_ = comm_.template registerInstanceCollective<ThisType>(this);
117+
118+
if (deterministic_) {
119+
gen_select_.seed(seed + comm_.getRank());
120+
}
105121
}
106122

107123
void run(DataT initial_data) {
@@ -121,7 +137,11 @@ struct InformationPropagation {
121137
}
122138

123139
void sendToFanout(int round, JoinedDataType const& data) {
124-
int num_ranks = comm_.numRanks();
140+
int const rank = comm_.getRank();
141+
int const num_ranks = comm_.numRanks();
142+
143+
sent_count_ = 0;
144+
recv_count_ = 0;
125145

126146
for (int i = 1; i <= f_; ++i) {
127147
if (already_selected_.size() >= static_cast<size_t>(num_ranks)) {
@@ -137,22 +157,48 @@ struct InformationPropagation {
137157
already_selected_.insert(target);
138158

139159
//printf("rank %d sending to rank %d\n", comm_.getRank(), target);
140-
handle_[target].template send<&ThisType::infoPropagateHandler>(round, data);
160+
sent_count_++;
161+
handle_[target].template send<&ThisType::infoPropagateHandler>(rank, round, data);
162+
}
163+
164+
if (deterministic_) {
165+
// In deterministic mode, we expect an ack from each sent message
166+
while (sent_count_ != recv_count_) {
167+
comm_.poll();
168+
}
169+
170+
if (round < k_max_) {
171+
sendToFanout(round + 1, local_data_);
172+
}
141173
}
142174
}
143175

144-
void infoPropagateHandler(int round, JoinedDataType incoming_data) {
176+
void infoAckHandler() {
177+
recv_count_++;
178+
//printf("rank %d received ack %d/%d\n", comm_.getRank(), recv_count_, sent_count_);
179+
}
180+
181+
void infoPropagateHandler(int from_rank, int round, JoinedDataType incoming_data) {
145182
// Process incoming data and add to local data
146183
local_data_.insert(incoming_data.begin(), incoming_data.end());
147-
if (round < k_max_) {
148-
sendToFanout(round + 1, local_data_);
184+
185+
if (deterministic_) {
186+
// Acknowledge receipt of message to sender before we go to the next round
187+
handle_[from_rank].template send<&ThisType::infoAckHandler>();
188+
} else {
189+
if (round < k_max_) {
190+
sendToFanout(round + 1, local_data_);
191+
}
149192
}
150193
}
151194

152195
private:
153196
CommT comm_;
154197
int f_ = 2;
155-
int k_max_ = 0;
198+
int k_max_ = 2;
199+
bool deterministic_ = false;
200+
int sent_count_ = 0;
201+
int recv_count_ = 0;
156202
std::unordered_set<int> already_selected_;
157203
std::unordered_map<int, DataT> local_data_;
158204
std::mt19937 gen_select_{std::random_device{}()};
@@ -190,7 +236,11 @@ struct TemperedLB : baselb::BaseLB {
190236
using LoadType = double;
191237
printf("start InformationPropagation\n");
192238
auto ip = InformationPropagation<CommT, LoadType, TemperedLB<CommT>>(
193-
comm_, config_.f_, config_.k_max_
239+
comm_,
240+
config_.f_,
241+
config_.k_max_,
242+
config_.deterministic_,
243+
config_.seed_
194244
);
195245
ip.run(10.0);
196246
}

0 commit comments

Comments
 (0)