Skip to content

Commit e527317

Browse files
committed
Clean up mpi-driver
1 parent a33e593 commit e527317

File tree

1 file changed

+27
-49
lines changed

1 file changed

+27
-49
lines changed

src/main.cpp

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,22 @@ int main(int argc, char *argv[])
8787
{
8888
#if USE_MPI
8989
int provided;
90+
int localRank;
91+
9092
MPI_Init_thread(&argc, &argv, MPI_THREAD_FUNNELED, &provided);
91-
if (provided < MPI_THREAD_FUNNELED) {
93+
94+
if (provided < MPI_THREAD_FUNNELED)
9295
MPI_Abort(MPI_COMM_WORLD, provided);
93-
}
9496

9597
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
9698
MPI_Comm_size(MPI_COMM_WORLD, &procs);
9799

98-
// Each local rank on a given node will own a single device/GCD
99-
MPI_Comm shmcomm;
100+
// Each rank will run the benchmark on a single device
101+
MPI_Comm shared_comm;
100102
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0,
101-
MPI_INFO_NULL, &shmcomm);
102-
int localRank;
103-
MPI_Comm_rank(shmcomm, &localRank);
103+
MPI_INFO_NULL, &shared_comm);
104+
MPI_Comm_rank(shared_comm, &localRank);
105+
104106
// Set device index to be the local MPI rank
105107
deviceIndex = localRank;
106108
#endif
@@ -110,16 +112,17 @@ int main(int argc, char *argv[])
110112
if (!output_as_csv)
111113
{
112114
#if USE_MPI
113-
if (rank == 0) {
115+
if (rank == 0)
114116
#endif
117+
{
115118
std::cout
116119
<< "BabelStream" << std::endl
117120
<< "Version: " << VERSION_STRING << std::endl
118121
<< "Implementation: " << IMPLEMENTATION_STRING << std::endl;
119122
#if USE_MPI
120123
std::cout << "Number of MPI ranks: " << procs << std::endl;
121-
}
122124
#endif
125+
}
123126
}
124127

125128
if (use_float)
@@ -145,54 +148,48 @@ std::vector<std::vector<double>> run_all(Stream<T> *stream, T& sum)
145148
// Declare timers
146149
std::chrono::high_resolution_clock::time_point t1, t2;
147150

151+
#if USE_MPI
152+
// Set MPI data type for the dot-product reduction
153+
MPI_Datatype MPI_DTYPE = use_float ? MPI_FLOAT : MPI_DOUBLE;
154+
#endif
155+
148156
// Main loop
149157
for (unsigned int k = 0; k < num_times; k++)
150158
{
151-
#if USE_MPI
152-
MPI_Barrier(MPI_COMM_WORLD);
153-
#endif
154159

155160
// Execute Copy
156161
t1 = std::chrono::high_resolution_clock::now();
157162
stream->copy();
158-
#if USE_MPI
159-
MPI_Barrier(MPI_COMM_WORLD);
160-
#endif
161163
t2 = std::chrono::high_resolution_clock::now();
162164
timings[0].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());
163165

164166
// Execute Mul
165167
t1 = std::chrono::high_resolution_clock::now();
166168
stream->mul();
167-
#if USE_MPI
168-
MPI_Barrier(MPI_COMM_WORLD);
169-
#endif
170169
t2 = std::chrono::high_resolution_clock::now();
171170
timings[1].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());
172171

173172
// Execute Add
174173
t1 = std::chrono::high_resolution_clock::now();
175174
stream->add();
176-
#if USE_MPI
177-
MPI_Barrier(MPI_COMM_WORLD);
178-
#endif
179175
t2 = std::chrono::high_resolution_clock::now();
180176
timings[2].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());
181177

182178
// Execute Triad
183179
t1 = std::chrono::high_resolution_clock::now();
184180
stream->triad();
185-
#if USE_MPI
186-
MPI_Barrier(MPI_COMM_WORLD);
187-
#endif
188181
t2 = std::chrono::high_resolution_clock::now();
189182
timings[3].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());
190183

191184
// Execute Dot
185+
#if USE_MPI
186+
// Synchronize ranks before computing dot-product
187+
MPI_Barrier(MPI_COMM_WORLD);
188+
#endif
192189
t1 = std::chrono::high_resolution_clock::now();
193190
sum = stream->dot();
194191
#if USE_MPI
195-
MPI_Allreduce(MPI_IN_PLACE, &sum, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
192+
MPI_Allreduce(MPI_IN_PLACE, &sum, 1, MPI_DTYPE, MPI_SUM, MPI_COMM_WORLD);
196193
#endif
197194
t2 = std::chrono::high_resolution_clock::now();
198195
timings[4].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());
@@ -217,9 +214,6 @@ std::vector<std::vector<double>> run_triad(Stream<T> *stream)
217214
t1 = std::chrono::high_resolution_clock::now();
218215
for (unsigned int k = 0; k < num_times; k++)
219216
{
220-
#if USE_MPI
221-
MPI_Barrier(MPI_COMM_WORLD);
222-
#endif
223217
stream->triad();
224218
}
225219
t2 = std::chrono::high_resolution_clock::now();
@@ -241,14 +235,8 @@ std::vector<std::vector<double>> run_nstream(Stream<T> *stream)
241235

242236
// Run nstream in loop
243237
for (int k = 0; k < num_times; k++) {
244-
#if USE_MPI
245-
MPI_Barrier(MPI_COMM_WORLD);
246-
#endif
247238
t1 = std::chrono::high_resolution_clock::now();
248239
stream->nstream();
249-
#if USE_MPI
250-
MPI_Barrier(MPI_COMM_WORLD);
251-
#endif
252240
t2 = std::chrono::high_resolution_clock::now();
253241
timings[0].push_back(std::chrono::duration_cast<std::chrono::duration<double> >(t2 - t1).count());
254242
}
@@ -416,10 +404,6 @@ void run()
416404

417405

418406
stream->read_arrays(a, b, c);
419-
#if USE_MPI
420-
// Only check solutions on rank 0 in case verificaiton fails
421-
if (rank == 0)
422-
#endif
423407
check_solution<T>(num_times, a, b, c, sum);
424408

425409
// Display timing results
@@ -485,17 +469,11 @@ void run()
485469
double max = *minmax.second;
486470

487471
#if USE_MPI
488-
// Collate timings
489-
if (rank == 0)
490-
{
491-
MPI_Reduce(MPI_IN_PLACE, &min, 1, MPI_DOUBLE, MPI_MIN, 0, MPI_COMM_WORLD);
492-
MPI_Reduce(MPI_IN_PLACE, &max, 1, MPI_DOUBLE, MPI_MAX, 0, MPI_COMM_WORLD);
493-
}
494-
else
495-
{
496-
MPI_Reduce(&min, NULL, 1, MPI_DOUBLE, MPI_MIN, 0, MPI_COMM_WORLD);
497-
MPI_Reduce(&max, NULL, 1, MPI_DOUBLE, MPI_MAX, 0, MPI_COMM_WORLD);
498-
}
472+
MPI_Datatype MPI_DTYPE = use_float ? MPI_FLOAT : MPI_DOUBLE;
473+
474+
// Collect global min/max timings
475+
MPI_Allreduce(MPI_IN_PLACE, &min, 1, MPI_DTYPE, MPI_MIN, MPI_COMM_WORLD);
476+
MPI_Allreduce(MPI_IN_PLACE, &max, 1, MPI_DTYPE, MPI_MAX, MPI_COMM_WORLD);
499477
sizes[i] *= procs;
500478
#endif
501479

0 commit comments

Comments
 (0)