|
59 | 59 | #include <sys/stat.h> |
60 | 60 | #include <zlib.h> |
61 | 61 | #include <map> |
| 62 | +#include <numeric> |
62 | 63 |
|
63 | 64 | namespace vt { |
64 | 65 | #if vt_check_enabled(trace_only) |
@@ -542,15 +543,42 @@ void TraceLite::flushTracesFile(bool useGlobalSync) { |
542 | 543 |
|
543 | 544 | void TraceLite::writeTracesFile(int flush, bool is_incremental_flush) { |
544 | 545 | auto const node = theContext()->getNode(); |
| 546 | + auto const comm = theContext()->getComm(); |
| 547 | + auto const comm_size = theContext()->getNumNodes(); |
545 | 548 |
|
546 | | - // Allreduce the hashed events to rank 0 before writing sts file |
| 549 | + // Gather all hashed events to rank 0 before writing sts file |
| 550 | + using events_t = std::vector<UserEventIDType>; |
547 | 551 | auto const root = 0; |
548 | | - std::vector<UserEventIDType> all_hashed_events; |
549 | | - auto msg = makeMessage<ReduceVecMsg<UserEventIDType>>( |
550 | | - theTrace()->user_hashed_events_); |
551 | | - theCollective()->global()->reduce< |
552 | | - PlusOp<std::vector<UserEventIDType>>, Verify<ReduceOP::Plus> |
553 | | - >(root, msg.get()); |
| 552 | + events_t local_hashed_events = theTrace()->user_hashed_events_; |
| 553 | + int local_size = local_hashed_events.size(); |
| 554 | + std::vector<int> all_sizes(comm_size); |
| 555 | + MPI_Gather(&local_size, 1, MPI_INT, all_sizes.data(), 1, MPI_INT, 0, comm); |
| 556 | + |
| 557 | + // Compute displacements |
| 558 | + std::vector<int> displs(comm_size, 0); |
| 559 | + if (node == 0) { |
| 560 | + std::partial_sum(all_sizes.begin(), all_sizes.end() - 1, displs.begin() + 1); |
| 561 | + } |
| 562 | + |
| 563 | + // Create vector in which to store all events |
| 564 | + events_t all_hashed_events; |
| 565 | + if (node == 0) { |
| 566 | + int total_size = std::accumulate(all_sizes.begin(), all_sizes.end(), 0); |
| 567 | + all_hashed_events.resize(total_size); |
| 568 | + } |
| 569 | + |
| 570 | + // Gather events |
| 571 | + MPI_Gatherv( |
| 572 | + local_hashed_events.data(), // Send buffer |
| 573 | + local_size, // Number of elements to send |
| 574 | + MPI_UINT32_T, // Data type (adjust to match UserEventIDType) |
| 575 | + all_hashed_events.data(), // Receive buffer (on root) |
| 576 | + all_sizes.data(), // Number of elements to receive from each rank |
| 577 | + displs.data(), // Displacements for each rank |
| 578 | + MPI_UINT32_T, // Data type (adjust to match UserEventIDType) |
| 579 | + root, // Root node |
| 580 | + comm // Communicator |
| 581 | + ); |
554 | 582 |
|
555 | 583 | size_t to_write = traces_.size(); |
556 | 584 |
|
|
0 commit comments