Skip to content

Commit e227fdc

Browse files
Copilotchhwang
andcommitted
Convert mp_unit tests from gtest to framework.hpp
- Modified test/mp_unit/mp_unit_tests.hpp to use ../framework.hpp instead of gtest/gtest.h - Enhanced test/framework.hpp with GTest-compatible APIs: - Added Environment base class for global test setup/teardown - Added TestInfo and UnitTest classes for test metadata access - Added GTEST_SKIP macro support via SkipHelper class - Added namespace alias 'testing' for compatibility - Added InitGoogleTest and AddGlobalTestEnvironment helper functions - Updated test/framework.cc with implementations for new classes - All mp_unit test files now use framework.hpp through mp_unit_tests.hpp - Formatting applied via lint.sh Co-authored-by: chhwang <8018170+chhwang@users.noreply.github.com>
1 parent c881bc5 commit e227fdc

15 files changed

Lines changed: 310 additions & 194 deletions

test/executor_test.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,8 @@ double benchTime(int rank, std::shared_ptr<mscclpp::Bootstrap> bootstrap, std::s
9393

9494
int main(int argc, char* argv[]) {
9595
if (argc != 5 && argc != 6) {
96-
std::cerr << "Usage: " << argv[0] << " <buffer size>"
97-
<< " <execution plan path>"
98-
<< " <number of iterations>"
99-
<< " <number of graph iterations>"
100-
<< " (optional) <packet type>" << std::endl;
96+
std::cerr << "Usage: " << argv[0] << " <buffer size>" << " <execution plan path>" << " <number of iterations>"
97+
<< " <number of graph iterations>" << " (optional) <packet type>" << std::endl;
10198
return 1;
10299
}
103100

test/framework.cc

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,26 +161,51 @@ int runMultipleTests(
161161

162162
} // namespace utils
163163

164+
// UnitTest implementation
165+
UnitTest* UnitTest::GetInstance() {
166+
static UnitTest instance;
167+
return &instance;
168+
}
169+
164170
// TestRegistry implementation
165171
TestRegistry& TestRegistry::instance() {
166172
static TestRegistry registry;
167173
return registry;
168174
}
169175

170176
void TestRegistry::registerTest(const std::string& test_suite, const std::string& test_name, TestFactory factory) {
171-
TestInfo info;
177+
TestInfoInternal info;
172178
info.suite_name = test_suite;
173179
info.test_name = test_name;
174180
info.factory = factory;
175181
tests_.push_back(info);
176182
}
177183

184+
void TestRegistry::addGlobalTestEnvironment(Environment* env) { environments_.push_back(env); }
185+
186+
void TestRegistry::initGoogleTest(int* argc, char** argv) {
187+
// Parse command-line arguments if needed
188+
// For now, this is a no-op placeholder for compatibility
189+
}
190+
178191
int TestRegistry::runAllTests(int argc, char* argv[]) {
179192
// Initialize MPI if not already initialized
180193
if (!g_mpi_initialized) {
181194
utils::initializeMPI(argc, argv);
182195
}
183196

197+
// Set up global test environments
198+
for (auto* env : environments_) {
199+
try {
200+
env->SetUp();
201+
} catch (const std::exception& e) {
202+
if (g_mpi_rank == 0) {
203+
std::cerr << "Failed to set up test environment: " << e.what() << std::endl;
204+
}
205+
return 1;
206+
}
207+
}
208+
184209
int passed = 0;
185210
int failed = 0;
186211

@@ -196,6 +221,10 @@ int TestRegistry::runAllTests(int argc, char* argv[]) {
196221
std::cout << "[ RUN ] " << test_info.suite_name << "." << test_info.test_name << std::endl;
197222
}
198223

224+
// Set current test info for UnitTest::GetInstance()->current_test_info()
225+
TestInfo current_info(test_info.suite_name, test_info.test_name);
226+
UnitTest::GetInstance()->set_current_test_info(&current_info);
227+
199228
TestCase* test_case = nullptr;
200229
try {
201230
test_case = test_info.factory();
@@ -216,6 +245,9 @@ int TestRegistry::runAllTests(int argc, char* argv[]) {
216245

217246
delete test_case;
218247

248+
// Clear current test info
249+
UnitTest::GetInstance()->set_current_test_info(nullptr);
250+
219251
// Synchronize test status across all MPI processes
220252
int local_passed = g_current_test_passed ? 1 : 0;
221253
int global_passed = 1;
@@ -246,6 +278,17 @@ int TestRegistry::runAllTests(int argc, char* argv[]) {
246278
}
247279
}
248280

281+
// Tear down global test environments (in reverse order)
282+
for (auto it = environments_.rbegin(); it != environments_.rend(); ++it) {
283+
try {
284+
(*it)->TearDown();
285+
} catch (const std::exception& e) {
286+
if (g_mpi_rank == 0) {
287+
std::cerr << "Failed to tear down test environment: " << e.what() << std::endl;
288+
}
289+
}
290+
}
291+
249292
return failed > 0 ? 1 : 0;
250293
}
251294

test/framework.hpp

Lines changed: 244 additions & 165 deletions
Large diffs are not rendered by default.

test/mp_unit/mp_unit_tests.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
#ifndef MSCCLPP_MP_UNIT_TESTS_HPP_
55
#define MSCCLPP_MP_UNIT_TESTS_HPP_
66

7-
#include <gtest/gtest.h>
8-
97
#include <mscclpp/core.hpp>
108
#include <mscclpp/executor.hpp>
119
#include <mscclpp/memory_channel.hpp>
1210
#include <mscclpp/packet_device.hpp>
1311
#include <mscclpp/port_channel.hpp>
1412
#include <mscclpp/utils.hpp>
1513

14+
#include "../framework.hpp"
1615
#include "ib.hpp"
1716
#include "utils_internal.hpp"
1817

test/perf/framework.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ namespace mscclpp {
1212
namespace test {
1313

1414
// Global state for performance test results
15-
static std::vector<struct PerfTestResult {
15+
static std::vector < struct PerfTestResult {
1616
std::string test_name;
1717
std::string test_category;
1818
std::map<std::string, std::string> test_params;
1919
nlohmann::ordered_json metrics;
2020
int num_processes;
2121
int process_rank;
2222
std::string timestamp;
23-
}> g_perf_results;
23+
} > g_perf_results;
2424

2525
static std::string getCurrentTimestamp() {
2626
auto now = std::chrono::system_clock::now();

test/perf/framework.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
// This file is kept for backwards compatibility with perf tests
88
// The actual framework is now in test/framework.hpp
99

10-
#include "../framework.hpp"
11-
1210
#include <nlohmann/json.hpp>
1311

12+
#include "../framework.hpp"
13+
1414
namespace mscclpp {
1515
namespace test {
1616

test/unit/core_tests.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
// Licensed under the MIT license.
33

44
#include <gmock/gmock.h>
5-
#include "../framework.hpp"
65

76
#include <mscclpp/core.hpp>
87

8+
#include "../framework.hpp"
9+
910
class LocalCommunicatorTest : public ::testing::Test {
1011
protected:
1112
void SetUp() override {

test/unit/errors_tests.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT license.
33

4-
#include "../framework.hpp"
5-
64
#include <mscclpp/errors.hpp>
75

6+
#include "../framework.hpp"
7+
88
TEST(ErrorsTest, SystemError) {
99
mscclpp::Error error("test", mscclpp::ErrorCode::SystemError);
1010
EXPECT_EQ(error.getErrorCode(), mscclpp::ErrorCode::SystemError);

test/unit/fifo_tests.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT license.
33

4-
#include "../framework.hpp"
5-
64
#include <mscclpp/fifo.hpp>
75
#include <mscclpp/gpu_utils.hpp>
86
#include <mscclpp/numa.hpp>
97
#include <mscclpp/utils.hpp>
108

9+
#include "../framework.hpp"
1110
#include "utils_internal.hpp"
1211

1312
#define ITER 10000 // should be larger than the FIFO size for proper testing

test/unit/gpu_utils_tests.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT license.
33

4-
#include "../framework.hpp"
5-
64
#include <mscclpp/gpu_utils.hpp>
75

6+
#include "../framework.hpp"
7+
88
TEST(GpuUtilsTest, StreamPool) {
99
auto streamPool = mscclpp::gpuStreamPool();
1010
cudaStream_t s;

0 commit comments

Comments
 (0)