Skip to content

Commit 8c45ae4

Browse files
committed
#3: comm: first step toward parameterized testing
1 parent ba83d00 commit 8c45ae4

File tree

5 files changed

+61
-46
lines changed

5 files changed

+61
-46
lines changed

src/vt-lb/comm/MPI/comm_mpi.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,7 @@ struct CommMPI {
170170
template <typename T>
171171
using HandleType = ClassHandle<T>;
172172

173-
/**
174-
* \brief Construct a new CommMPI instance
175-
* \param comm The MPI communicator to use (defaults to MPI_COMM_WORLD)
176-
*/
177-
explicit CommMPI(MPI_Comm comm = MPI_COMM_WORLD) : comm_(comm) { }
178-
173+
CommMPI() = default;
179174
CommMPI(CommMPI const&) = delete;
180175
CommMPI(CommMPI&&) = delete;
181176
CommMPI& operator=(CommMPI const&) = delete;
@@ -194,8 +189,14 @@ struct CommMPI {
194189
* \param argc Pointer to argument count
195190
* \param argv Pointer to argument array
196191
*/
197-
void init(int& argc, char**& argv) {
198-
MPI_Init(&argc, &argv);
192+
void init(int& argc, char**& argv, MPI_Comm comm = MPI_COMM_NULL) {
193+
if (comm == MPI_COMM_NULL) {
194+
MPI_Init(&argc, &argv);
195+
comm_ = MPI_COMM_WORLD;
196+
} else {
197+
interop_mode_ = true;
198+
comm_ = comm;
199+
}
199200
MPI_Comm_rank(comm_, &cached_rank_);
200201
MPI_Comm_size(comm_, &cached_size_);
201202
initTermination();
@@ -216,8 +217,10 @@ struct CommMPI {
216217
* \brief Finalize MPI
217218
*/
218219
void finalize() {
219-
// printf("%d: Finalizing MPI\n", cached_rank_);
220-
MPI_Finalize();
220+
if (!interop_mode_) {
221+
// printf("%d: Finalizing MPI\n", cached_rank_);
222+
MPI_Finalize();
223+
}
221224
comm_ = MPI_COMM_NULL;
222225
}
223226

@@ -444,6 +447,7 @@ struct CommMPI {
444447
private:
445448
void initTermination();
446449

450+
bool interop_mode_ = false;
447451
MPI_Comm comm_ = MPI_COMM_NULL;
448452
std::list<std::tuple<MPI_Request, std::unique_ptr<char[]>>> pending_;
449453
int next_class_index_ = 0;

src/vt-lb/comm/vt/comm_vt.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,13 @@
4747

4848
namespace vt_lb::comm {
4949

50-
void CommVT::init(int& argc, char**& argv) {
51-
vt::initialize(argc, argv);
50+
void CommVT::init(int& argc, char**& argv, MPI_Comm comm) {
51+
if (comm == MPI_COMM_NULL) {
52+
vt::initialize(argc, argv);
53+
} else {
54+
// interop mode
55+
vt::initialize(argc, argv, &comm);
56+
}
5257
vt::theTerm()->addDefaultAction([this]{ terminated_ = true; });
5358
}
5459

src/vt-lb/comm/vt/comm_vt.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ struct CommVT {
6565
CommVT(vt::EpochType epoch);
6666

6767
public:
68-
void init(int& argc, char**& argv);
68+
void init(int& argc, char**& argv, MPI_Comm comm = MPI_COMM_NULL);
6969
void finalize();
7070
int numRanks() const;
7171
int getRank() const;

tests/unit/test_helpers.h

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
#if !defined INCLUDED_VT_LB_UNIT_TEST_HELPERS_H
4545
#define INCLUDED_VT_LB_UNIT_TEST_HELPERS_H
4646

47-
#include "vt/context/context.h"
4847
#include <mpi.h>
4948
#include <gtest/gtest.h>
5049
#include <sstream>
@@ -58,23 +57,23 @@ extern char** test_argv;
5857
* Maximum number of ranks/nodes detected by CMake on this machine.
5958
* Defaults to number of processors detected on the host system.
6059
*/
61-
constexpr vt::NodeType CMAKE_DETECTED_MAX_NUM_NODES = vt_detected_max_num_nodes;
60+
//constexpr vt::NodeType CMAKE_DETECTED_MAX_NUM_NODES = vt_detected_max_num_nodes;
6261

6362
/**
6463
* Check whether we're oversubscribing on the current execution.
6564
* This is using MPI because it can be used before vt initializes.
6665
*/
67-
inline bool isOversubscribed() {
68-
// be sure to only call this from parallel tests
69-
int init = 0;
70-
MPI_Initialized(&init);
71-
if (!init) {
72-
MPI_Init(&test_argc, &test_argv);
73-
}
74-
int num_ranks = 0;
75-
MPI_Comm_size(MPI_COMM_WORLD, &num_ranks);
76-
return num_ranks > CMAKE_DETECTED_MAX_NUM_NODES;
77-
}
66+
//inline bool isOversubscribed() {
67+
// // be sure to only call this from parallel tests
68+
// int init = 0;
69+
// MPI_Initialized(&init);
70+
// if (!init) {
71+
// MPI_Init(&test_argc, &test_argv);
72+
// }
73+
// int num_ranks = 0;
74+
// MPI_Comm_size(MPI_COMM_WORLD, &num_ranks);
75+
// return num_ranks > CMAKE_DETECTED_MAX_NUM_NODES;
76+
//}
7877

7978
/**
8079
* Get a unique filename based on the unit test name.
@@ -94,12 +93,12 @@ inline std::string getUniqueFilename(const std::string& ext = "") {
9493
* concurrently-running tests will not cause file system race conditions.
9594
* Do not call this from .nompi.cc tests or from addAdditionalArgs().
9695
*/
97-
inline std::string getUniqueFilenameWithRanks(const std::string& ext = "") {
98-
auto ranks = vt::theContext()->getNumNodes();
99-
std::stringstream ss;
100-
ss << getUniqueFilename() << "_" << ranks << ext;
101-
return ss.str();
102-
}
96+
//inline std::string getUniqueFilenameWithRanks(const std::string& ext = "") {
97+
// auto ranks = vt::theContext()->getNumNodes();
98+
// std::stringstream ss;
99+
// ss << getUniqueFilename() << "_" << ranks << ext;
100+
// return ss.str();
101+
//}
103102

104103
/**
105104
* The following helper macros (these have to be macros, because GTEST_SKIP
@@ -116,7 +115,7 @@ inline std::string getUniqueFilenameWithRanks(const std::string& ext = "") {
116115
*/
117116
#define SET_MAX_NUM_NODES_CONSTRAINT(max_req_num_nodes) \
118117
{ \
119-
auto const num_nodes = vt::theContext()->getNumNodes(); \
118+
auto const num_nodes = comm.numRanks(); \
120119
if (num_nodes > max_req_num_nodes) { \
121120
GTEST_SKIP() << fmt::format( \
122121
"Skipping the run on {} nodes. This test should run on at most {} " \
@@ -132,7 +131,7 @@ inline std::string getUniqueFilenameWithRanks(const std::string& ext = "") {
132131
*/
133132
#define SET_MIN_NUM_NODES_CONSTRAINT(min_req_num_nodes) \
134133
{ \
135-
auto const num_nodes = vt::theContext()->getNumNodes(); \
134+
auto const num_nodes = comm.numRanks(); \
136135
if (num_nodes < min_req_num_nodes) { \
137136
GTEST_SKIP() << fmt::format( \
138137
"Skipping the run on {} nodes. This test should run on at least {} "\
@@ -148,7 +147,7 @@ inline std::string getUniqueFilenameWithRanks(const std::string& ext = "") {
148147
*/
149148
#define SET_NUM_NODES_CONSTRAINT(req_num_nodes) \
150149
{ \
151-
auto const num_nodes = vt::theContext()->getNumNodes(); \
150+
auto const num_nodes = comm.numRanks(); \
152151
if (num_nodes != req_num_nodes) { \
153152
GTEST_SKIP() << fmt::format( \
154153
"Skipping the run on {} nodes. This test should run only on {} " \

tests/unit/test_parallel_harness.h

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,17 @@
5050
#include "test_config.h"
5151
#include "test_harness.h"
5252

53-
#include "vt/transport.h"
53+
#include <vt-lb/comm/vt/comm_vt.h>
54+
#include <vt-lb/comm/MPI/comm_mpi.h>
5455

5556
namespace vt_lb { namespace tests { namespace unit {
5657

5758
extern int test_argc;
5859
extern char** test_argv;
5960

61+
using CommType = vt_lb::comm::CommVT;
62+
//using CommType = vt_lb::comm::CommMPI;
63+
6064
template <typename TestBase>
6165
struct TestParallelHarnessAny : TestHarnessAny<TestBase> {
6266
virtual void SetUp() override {
@@ -68,41 +72,44 @@ struct TestParallelHarnessAny : TestHarnessAny<TestBase> {
6872
if (!init) {
6973
MPI_Init(&test_argc, &test_argv);
7074
}
71-
MPI_Comm comm = MPI_COMM_WORLD;
75+
MPI_Comm mpi_comm = MPI_COMM_WORLD;
7276
auto const new_args = injectAdditionalArgs(test_argc, test_argv);
7377
auto custom_argc = new_args.first;
7478
auto custom_argv = new_args.second;
7579
vtAssert(
7680
custom_argv[custom_argc] == nullptr,
7781
"The value of argv[argc] should always be 0"
7882
);
79-
// communicator is duplicated.
80-
vt::initialize(custom_argc, custom_argv, &comm, nullptr, true);
83+
comm.init(custom_argc, custom_argv, mpi_comm);
8184

8285
#if DEBUG_TEST_HARNESS_PRINT
83-
auto const& my_node = vt::theContext()->getNode();
84-
auto const& num_nodes = vt::theContext()->getNumNodes();
85-
fmt::print("my_node={}, num_nodes={}\n", my_node, num_nodes);
86+
auto const& my_rank = comm.getRank();
87+
auto const& num_ranks = comm.numRanks();
88+
fmt::print("my_rank={}, num_ranks={}\n", my_rank, num_ranks);
8689
#endif
8790
}
8891

8992
virtual void TearDown() override {
9093
try {
91-
vt::theSched()->runSchedulerWhile([] { return !vt::rt->isTerminated(); });
94+
while (comm.poll()) {
95+
}
9296
} catch (std::exception& e) {
9397
ADD_FAILURE() << fmt::format("Caught an exception: {}\n", e.what());
9498
}
9599

96100
#if DEBUG_TEST_HARNESS_PRINT
97-
auto const& my_node = vt::theContext()->getNode();
98-
fmt::print("my_node={}, tearing down runtime\n", my_node);
101+
auto const& my_rank = comm.getRank();
102+
fmt::print("my_rank={}, tearing down runtime\n", my_rank);
99103
#endif
100104

101-
vt::finalize();
105+
comm.finalize();
102106

103107
TestHarnessAny<TestBase>::TearDown();
104108
}
105109

110+
public:
111+
CommType comm;
112+
106113
protected:
107114
template <typename Arg>
108115
void addArgs(Arg& arg) {

0 commit comments

Comments
 (0)