|
12 | 12 |
|
13 | 13 | namespace klimov_m_torus { |
14 | 14 |
|
| 15 | +namespace { |
| 16 | + |
| 17 | +// Вспомогательные функции, вынесенные в анонимный namespace для снижения сложности RelayMessage |
| 18 | + |
| 19 | +void HandleSameNode(int current_rank, int src, const std::vector<int> &buffer, std::vector<int> &output) { |
| 20 | + if (current_rank == src) { |
| 21 | + output = buffer; |
| 22 | + } |
| 23 | +} |
| 24 | + |
| 25 | +void HandleSourceNode(int current_rank, int src, const std::vector<int> &route, const std::vector<int> &buffer, |
| 26 | + std::vector<int> &output) { |
| 27 | + output = buffer; |
| 28 | + if (current_rank == src && route.size() > 1) { |
| 29 | + int next_hop = route[1]; |
| 30 | + int send_len = static_cast<int>(buffer.size()); |
| 31 | + MPI_Send(&send_len, 1, MPI_INT, next_hop, 0, MPI_COMM_WORLD); |
| 32 | + if (send_len > 0) { |
| 33 | + MPI_Send(output.data(), send_len, MPI_INT, next_hop, 1, MPI_COMM_WORLD); |
| 34 | + } |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +void HandleIntermediateNode(int current_rank, int dst, const std::vector<int> &route, int my_pos, |
| 39 | + std::vector<int> &output) { |
| 40 | + int prev_hop = route[my_pos - 1]; |
| 41 | + int recv_len = 0; |
| 42 | + MPI_Recv(&recv_len, 1, MPI_INT, prev_hop, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); |
| 43 | + output.resize(recv_len); |
| 44 | + if (recv_len > 0) { |
| 45 | + MPI_Recv(output.data(), recv_len, MPI_INT, prev_hop, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE); |
| 46 | + } |
| 47 | + |
| 48 | + if (current_rank != dst && my_pos + 1 < static_cast<int>(route.size())) { |
| 49 | + int next_hop = route[my_pos + 1]; |
| 50 | + MPI_Send(&recv_len, 1, MPI_INT, next_hop, 0, MPI_COMM_WORLD); |
| 51 | + if (recv_len > 0) { |
| 52 | + MPI_Send(output.data(), recv_len, MPI_INT, next_hop, 1, MPI_COMM_WORLD); |
| 53 | + } |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +} // namespace |
| 58 | + |
15 | 59 | TorusMeshCommunicator::TorusMeshCommunicator(const InType &in) { |
16 | 60 | SetTypeOfTask(GetStaticTypeOfTask()); |
17 | 61 | GetInput() = in; |
18 | 62 | GetOutput() = {}; |
19 | 63 | } |
20 | 64 |
|
21 | | -std::pair<int, int> TorusMeshCommunicator::CalculateGridSize(int totalProcesses) { |
22 | | - int rows = static_cast<int>(std::sqrt(static_cast<double>(totalProcesses))); |
23 | | - while (rows > 1 && (totalProcesses % rows != 0)) { |
| 65 | +std::pair<int, int> TorusMeshCommunicator::CalculateGridSize(int total_processes) { |
| 66 | + int rows = static_cast<int>(std::sqrt(static_cast<double>(total_processes))); |
| 67 | + while (rows > 1 && (total_processes % rows != 0)) { |
24 | 68 | --rows; |
25 | 69 | } |
26 | 70 | if (rows <= 0) { |
27 | 71 | rows = 1; |
28 | 72 | } |
29 | | - int cols = totalProcesses / rows; |
| 73 | + int cols = total_processes / rows; |
30 | 74 | if (cols <= 0) { |
31 | 75 | cols = 1; |
32 | 76 | } |
@@ -119,7 +163,8 @@ bool TorusMeshCommunicator::PreProcessingImpl() { |
119 | 163 | } |
120 | 164 |
|
121 | 165 | bool TorusMeshCommunicator::RunImpl() { |
122 | | - int sender = 0, receiver = 0; |
| 166 | + int sender = 0; |
| 167 | + int receiver = 0; |
123 | 168 | DistributeSenderReceiver(sender, receiver); |
124 | 169 |
|
125 | 170 | int data_len = 0; |
@@ -155,48 +200,23 @@ void TorusMeshCommunicator::DistributeDataLength(int src, int &len) const { |
155 | 200 | std::vector<int> TorusMeshCommunicator::AssembleSendBuffer(int src, int len) const { |
156 | 201 | std::vector<int> buffer(len); |
157 | 202 | if (current_rank_ == src && len > 0) { |
158 | | - std::copy(local_request_.data.begin(), local_request_.data.end(), buffer.begin()); |
| 203 | + std::ranges::copy(local_request_.data, buffer.begin()); |
159 | 204 | } |
160 | 205 | return buffer; |
161 | 206 | } |
162 | 207 |
|
163 | 208 | void TorusMeshCommunicator::RelayMessage(int src, int dst, const std::vector<int> &route, |
164 | 209 | const std::vector<int> &buffer, std::vector<int> &output) const { |
165 | | - const int route_len = static_cast<int>(route.size()); |
166 | | - auto it = std::find(route.begin(), route.end(), current_rank_); |
| 210 | + auto it = std::ranges::find(route, current_rank_); |
167 | 211 | bool on_route = (it != route.end()); |
168 | 212 | int my_pos = on_route ? static_cast<int>(std::distance(route.begin(), it)) : -1; |
169 | 213 |
|
170 | 214 | if (src == dst) { |
171 | | - if (current_rank_ == src) { |
172 | | - output = buffer; |
173 | | - } |
| 215 | + HandleSameNode(current_rank_, src, buffer, output); |
174 | 216 | } else if (current_rank_ == src) { |
175 | | - output = buffer; |
176 | | - if (route_len > 1) { |
177 | | - int next_hop = route[1]; |
178 | | - int send_len = static_cast<int>(buffer.size()); |
179 | | - MPI_Send(&send_len, 1, MPI_INT, next_hop, 0, MPI_COMM_WORLD); |
180 | | - if (send_len > 0) { |
181 | | - MPI_Send(output.data(), send_len, MPI_INT, next_hop, 1, MPI_COMM_WORLD); |
182 | | - } |
183 | | - } |
| 217 | + HandleSourceNode(current_rank_, src, route, buffer, output); |
184 | 218 | } else if (on_route) { |
185 | | - int prev_hop = route[my_pos - 1]; |
186 | | - int recv_len = 0; |
187 | | - MPI_Recv(&recv_len, 1, MPI_INT, prev_hop, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); |
188 | | - output.resize(recv_len); |
189 | | - if (recv_len > 0) { |
190 | | - MPI_Recv(output.data(), recv_len, MPI_INT, prev_hop, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE); |
191 | | - } |
192 | | - |
193 | | - if (current_rank_ != dst && my_pos + 1 < route_len) { |
194 | | - int next_hop = route[my_pos + 1]; |
195 | | - MPI_Send(&recv_len, 1, MPI_INT, next_hop, 0, MPI_COMM_WORLD); |
196 | | - if (recv_len > 0) { |
197 | | - MPI_Send(output.data(), recv_len, MPI_INT, next_hop, 1, MPI_COMM_WORLD); |
198 | | - } |
199 | | - } |
| 219 | + HandleIntermediateNode(current_rank_, dst, route, my_pos, output); |
200 | 220 | } |
201 | 221 | } |
202 | 222 |
|
|
0 commit comments