Skip to content

Commit d25f3f2

Browse files
committed
#2387: use gather instead of allreduce
1 parent 779e46d commit d25f3f2

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

src/vt/trace/trace_lite.cc

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
#include <sys/stat.h>
6060
#include <zlib.h>
6161
#include <map>
62+
#include <numeric>
6263

6364
namespace vt {
6465
#if vt_check_enabled(trace_only)
@@ -542,15 +543,42 @@ void TraceLite::flushTracesFile(bool useGlobalSync) {
542543

543544
void TraceLite::writeTracesFile(int flush, bool is_incremental_flush) {
544545
auto const node = theContext()->getNode();
546+
auto const comm = theContext()->getComm();
547+
auto const comm_size = theContext()->getNumNodes();
545548

546-
// Allreduce the hashed events to rank 0 before writing sts file
549+
// Gather all hashed events to rank 0 before writing sts file
550+
using events_t = std::vector<UserEventIDType>;
547551
auto const root = 0;
548-
std::vector<UserEventIDType> all_hashed_events;
549-
auto msg = makeMessage<ReduceVecMsg<UserEventIDType>>(
550-
theTrace()->user_hashed_events_);
551-
theCollective()->global()->reduce<
552-
PlusOp<std::vector<UserEventIDType>>, Verify<ReduceOP::Plus>
553-
>(root, msg.get());
552+
events_t local_hashed_events = theTrace()->user_hashed_events_;
553+
int local_size = local_hashed_events.size();
554+
std::vector<int> all_sizes(comm_size);
555+
MPI_Gather(&local_size, 1, MPI_INT, all_sizes.data(), 1, MPI_INT, 0, comm);
556+
557+
// Compute displacements
558+
std::vector<int> displs(comm_size, 0);
559+
if (node == 0) {
560+
std::partial_sum(all_sizes.begin(), all_sizes.end() - 1, displs.begin() + 1);
561+
}
562+
563+
// Create vector in which to store all events
564+
events_t all_hashed_events;
565+
if (node == 0) {
566+
int total_size = std::accumulate(all_sizes.begin(), all_sizes.end(), 0);
567+
all_hashed_events.resize(total_size);
568+
}
569+
570+
// Gather events
571+
MPI_Gatherv(
572+
local_hashed_events.data(), // Send buffer
573+
local_size, // Number of elements to send
574+
MPI_UINT32_T, // Data type (adjust to match UserEventIDType)
575+
all_hashed_events.data(), // Receive buffer (on root)
576+
all_sizes.data(), // Number of elements to receive from each rank
577+
displs.data(), // Displacements for each rank
578+
MPI_UINT32_T, // Data type (adjust to match UserEventIDType)
579+
root, // Root node
580+
comm // Communicator
581+
);
554582

555583
size_t to_write = traces_.size();
556584

src/vt/trace/trace_user_event.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ UserEventIDType UserEventRegistry::hash(std::string const& in_event_name) {
8484
auto id = std::get<0>(ret);
8585
auto inserted = std::get<1>(ret);
8686
if (inserted) {
87-
theTrace->addHashedEvent(id);
87+
vt::theTrace()->addHashedEvent(id);
8888
}
8989
return id;
9090
}

0 commit comments

Comments
 (0)