@@ -78,6 +78,10 @@ struct Configuration {
7878 int f_ = 2 ;
7979 // / @brief Number of rounds of information propagation
8080 int k_max_ = 1 ;
81+ // / @brief Whether to use deterministic selection
82+ bool deterministic_ = true ;
83+ // / @brief Seed for random number generation when deterministic_ is true
84+ int seed_ = 29 ;
8185
8286 bool async_ip_ = true ;
8387
@@ -94,14 +98,26 @@ struct InformationPropagation {
9498 using JoinedDataType = std::unordered_map<int , DataT>;
9599 using HandleType = typename CommT::template HandleType<ThisType>;
96100
97- // / @brief Construct information propagation instance
98- // / @param comm Communication interface -- n.b., we clone comm to create a new termination scope
99- // / @param f Fanout parameter
100- // / @param k_max Maximum number of rounds
101- InformationPropagation (CommT& comm, int f, int k_max)
102- : comm_(comm.clone()), f_(f), k_max_(k_max)
101+ /* *
102+ * @brief Construct information propagation instance
103+ *
104+ * @param comm Communication interface -- n.b., we clone comm to create a new termination scope
105+ * @param f Fanout parameter
106+ * @param k_max Maximum number of rounds
107+ * @param deterministic Whether to use deterministic selection
108+ *
109+ */
110+ InformationPropagation (CommT& comm, int f, int k_max, bool deterministic, int seed)
111+ : comm_(comm.clone()), // collective operation
112+ f_ (f),
113+ k_max_(k_max),
114+ deterministic_(deterministic)
103115 {
104116 handle_ = comm_.template registerInstanceCollective <ThisType>(this );
117+
118+ if (deterministic_) {
119+ gen_select_.seed (seed + comm_.getRank ());
120+ }
105121 }
106122
107123 void run (DataT initial_data) {
@@ -121,7 +137,11 @@ struct InformationPropagation {
121137 }
122138
123139 void sendToFanout (int round, JoinedDataType const & data) {
124- int num_ranks = comm_.numRanks ();
140+ int const rank = comm_.getRank ();
141+ int const num_ranks = comm_.numRanks ();
142+
143+ sent_count_ = 0 ;
144+ recv_count_ = 0 ;
125145
126146 for (int i = 1 ; i <= f_; ++i) {
127147 if (already_selected_.size () >= static_cast <size_t >(num_ranks)) {
@@ -137,22 +157,48 @@ struct InformationPropagation {
137157 already_selected_.insert (target);
138158
139159 // printf("rank %d sending to rank %d\n", comm_.getRank(), target);
140- handle_[target].template send <&ThisType::infoPropagateHandler>(round, data);
160+ sent_count_++;
161+ handle_[target].template send <&ThisType::infoPropagateHandler>(rank, round, data);
162+ }
163+
164+ if (deterministic_) {
165+ // In deterministic mode, we expect an ack from each sent message
166+ while (sent_count_ != recv_count_) {
167+ comm_.poll ();
168+ }
169+
170+ if (round < k_max_) {
171+ sendToFanout (round + 1 , local_data_);
172+ }
141173 }
142174 }
143175
144- void infoPropagateHandler (int round, JoinedDataType incoming_data) {
176+ void infoAckHandler () {
177+ recv_count_++;
178+ // printf("rank %d received ack %d/%d\n", comm_.getRank(), recv_count_, sent_count_);
179+ }
180+
181+ void infoPropagateHandler (int from_rank, int round, JoinedDataType incoming_data) {
145182 // Process incoming data and add to local data
146183 local_data_.insert (incoming_data.begin (), incoming_data.end ());
147- if (round < k_max_) {
148- sendToFanout (round + 1 , local_data_);
184+
185+ if (deterministic_) {
186+ // Acknowledge receipt of message to sender before we go to the next round
187+ handle_[from_rank].template send <&ThisType::infoAckHandler>();
188+ } else {
189+ if (round < k_max_) {
190+ sendToFanout (round + 1 , local_data_);
191+ }
149192 }
150193 }
151194
152195private:
153196 CommT comm_;
154197 int f_ = 2 ;
155- int k_max_ = 0 ;
198+ int k_max_ = 2 ;
199+ bool deterministic_ = false ;
200+ int sent_count_ = 0 ;
201+ int recv_count_ = 0 ;
156202 std::unordered_set<int > already_selected_;
157203 std::unordered_map<int , DataT> local_data_;
158204 std::mt19937 gen_select_{std::random_device{}()};
@@ -190,7 +236,11 @@ struct TemperedLB : baselb::BaseLB {
190236 using LoadType = double ;
191237 printf (" start InformationPropagation\n " );
192238 auto ip = InformationPropagation<CommT, LoadType, TemperedLB<CommT>>(
193- comm_, config_.f_ , config_.k_max_
239+ comm_,
240+ config_.f_ ,
241+ config_.k_max_ ,
242+ config_.deterministic_ ,
243+ config_.seed_
194244 );
195245 ip.run (10.0 );
196246 }
0 commit comments