Skip to content

Commit ac33340

Browse files
authored
Add the function of comparing results between Paddle and Torch (#22)
1 parent 6e7d15d commit ac33340

File tree

8 files changed

+217
-3
lines changed

8 files changed

+217
-3
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ set(COMMAND_TO_RUN
5858
include_directories(${Python3_INCLUDE_DIRS})
5959
link_directories("${Python3_LIBRARY_DIRS}")
6060

61-
set(COMMON_INCLUDES ${PROJECT_SOURCE_DIR}/include)
61+
set(COMMON_INCLUDES ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src)
6262

6363
enable_testing()
6464
include_directories(${COMMON_INCLUDES})

cmake/build.cmake

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@ function(
1010
foreach(_test_file ${TEST_SRC_FILES})
1111
get_filename_component(_file_name ${_test_file} NAME_WE)
1212
set(_test_name ${BIN_PREFIX}${_file_name})
13-
add_executable(${_test_name} ${_test_file} ${TEST_BASE_FILES})
13+
add_executable(${_test_name} ${_test_file} ${TEST_BASE_FILES}
14+
${PROJECT_SOURCE_DIR}/src/file_manager.cpp)
1415
add_dependencies(${_test_name} "googletest.git")
1516
target_link_libraries(
1617
${_test_name} gtest gtest_main ${CMAKE_THREAD_LIBS_INIT}
1718
${DEPS_LIBRARIES} ${Python3_LIBRARIES})
1819
target_include_directories(${_test_name} PRIVATE ${Python3_INCLUDE_DIRS})
19-
target_include_directories(${_test_name} PRIVATE ${INCLUDE_DIR})
20+
target_include_directories(${_test_name} PRIVATE ${INCLUDE_DIR}
21+
${PROJECT_SOURCE_DIR}/src)
22+
target_include_directories(${_test_name} PRIVATE ${PROJECT_SOURCE_DIR}/src)
2023
message(STATUS "include dir: ${INCLUDE_DIR}")
2124
target_compile_definitions(${_test_name}
2225
PRIVATE USE_PADDLE_API=${USE_PADDLE_API})

src/file_manager.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include "src/file_manager.h"
2+
3+
#include <filesystem>
4+
#include <iostream>
5+
6+
namespace paddle_api_test {
7+
8+
void FileManerger::createFile() {
9+
std::unique_lock<std::shared_mutex> lock(mutex_);
10+
11+
std::error_code ec;
12+
if (!std::filesystem::create_directories(basic_path_, ec) && ec) {
13+
throw std::runtime_error("Failed to create directory: " + basic_path_ +
14+
", error: " + ec.message());
15+
}
16+
17+
std::string full_path = basic_path_ + file_name_;
18+
19+
if (std::filesystem::exists(full_path)) {
20+
std::filesystem::remove(full_path);
21+
}
22+
23+
file_stream_.open(full_path, std::ios::out | std::ios::trunc);
24+
if (!file_stream_.is_open()) {
25+
throw std::runtime_error("Failed to create file: " + full_path);
26+
}
27+
}
28+
29+
void FileManerger::writeString(const std::string& str) {
30+
std::shared_lock<std::shared_mutex> lock(mutex_);
31+
if (file_stream_.is_open()) {
32+
file_stream_ << str;
33+
} else {
34+
throw std::runtime_error(
35+
"File stream is not open. Call createFile() first.");
36+
}
37+
}
38+
39+
FileManerger& FileManerger::operator<<(const std::string& str) {
40+
writeString(str);
41+
return *this;
42+
}
43+
44+
void FileManerger::saveFile() {
45+
std::unique_lock<std::shared_mutex> lock(mutex_);
46+
if (file_stream_.is_open()) {
47+
file_stream_.flush();
48+
file_stream_.close();
49+
}
50+
}
51+
52+
void FileManerger::setFileName(const std::string& value) {
53+
std::unique_lock<std::shared_mutex> lock(mutex_);
54+
file_name_ = value;
55+
}
56+
} // namespace paddle_api_test

src/file_manager.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#pragma once
2+
#include <fstream>
3+
#include <mutex>
4+
#include <shared_mutex>
5+
#include <string>
6+
7+
namespace paddle_api_test {
8+
class FileManerger {
9+
public:
10+
FileManerger() = default;
11+
explicit FileManerger(const std::string& first) : file_name_(first) {}
12+
13+
void setFileName(const std::string& value);
14+
void createFile();
15+
void writeString(const std::string& str);
16+
FileManerger& operator<<(const std::string& str);
17+
void saveFile();
18+
19+
private:
20+
mutable std::shared_mutex mutex_;
21+
std::string basic_path_ = "/tmp/paddle_cpp_api_test/";
22+
std::string file_name_ = "";
23+
std::ofstream file_stream_;
24+
};
25+
26+
class ThreadSafeParam {
27+
private:
28+
std::string param_;
29+
mutable std::mutex mutex_;
30+
31+
public:
32+
void set(const std::string& value) {
33+
std::lock_guard<std::mutex> lock(mutex_);
34+
param_ = value;
35+
}
36+
37+
std::string get() const {
38+
std::lock_guard<std::mutex> lock(mutex_);
39+
return param_;
40+
}
41+
};
42+
} // namespace paddle_api_test

src/main.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,30 @@
1+
#include <cstdlib>
2+
#include <iostream>
3+
#include <mutex>
4+
15
#include "gtest/gtest.h"
26
#if USE_PADDLE_API
37
#include "paddle/extension.h"
48
#endif
59

10+
#include "src/file_manager.h"
11+
12+
paddle_api_test::ThreadSafeParam g_custom_param;
13+
14+
std::string extract_filename(const std::string& path) {
15+
size_t last_slash = path.find_last_of('/');
16+
if (last_slash != std::string::npos) {
17+
return path.substr(last_slash + 1);
18+
}
19+
return path;
20+
}
21+
622
int main(int argc, char** argv) { // NOLINT
723
testing::InitGoogleTest(&argc, argv);
824

25+
auto exe_cmd = std::string(argv[0]);
26+
g_custom_param.set(extract_filename(exe_cmd) + ".txt");
27+
928
int ret = RUN_ALL_TESTS();
1029

1130
return ret;

test/TensorTest.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
#include <vector>
88

9+
#include "../src/file_manager.h"
910
namespace at {
1011
namespace test {
1112

13+
using paddle_api_test::FileManerger;
1214
class TensorTest : public ::testing::Test {
1315
protected:
1416
void SetUp() override {

test/TensorTest_compare.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/ones.h>
4+
#include <gtest/gtest.h>
5+
#include <torch/all.h>
6+
7+
#include <string>
8+
#include <vector>
9+
10+
#include "../src/file_manager.h"
11+
12+
extern paddle_api_test::ThreadSafeParam g_custom_param;
13+
14+
namespace at {
15+
namespace test {
16+
17+
using paddle_api_test::FileManerger;
18+
using paddle_api_test::ThreadSafeParam;
19+
class TensorTest : public ::testing::Test {
20+
protected:
21+
void SetUp() override {
22+
std::vector<int64_t> shape = {2, 3, 4};
23+
tensor = at::ones(shape, at::kFloat);
24+
}
25+
26+
at::Tensor tensor;
27+
};
28+
29+
TEST_F(TensorTest, test) {
30+
auto file_name = g_custom_param.get();
31+
FileManerger file(file_name);
32+
file.createFile();
33+
file << std::to_string(tensor.dim()) << " ";
34+
file << std::to_string(tensor.numel()) << " ";
35+
file.saveFile();
36+
}
37+
38+
} // namespace test
39+
} // namespace at

test/result_cmp.sh

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/bin/bash
2+
set -e # 出错时退出
3+
4+
# using guide: ./result_cmp.sh <BUILD_PATH>
5+
BUILD_PATH=$1
6+
7+
PADDLE_PATH=${BUILD_PATH}/paddle/
8+
TORCH_PATH=${BUILD_PATH}/torch/
9+
RESULT_FILE_PATH="/tmp/paddle_cpp_api_test/"
10+
11+
# 记录PADDLE_PATH下所有可执行文件到列表
12+
echo "Collecting and executing Paddle executables..."
13+
PADDLE_EXECUTABLES=()
14+
for test_file in ${PADDLE_PATH}/*; do
15+
if [[ -x "$test_file" && -f "$test_file" ]]; then
16+
filename=$(basename $test_file)
17+
${PADDLE_PATH}${filename}
18+
PADDLE_EXECUTABLES+=("$filename")
19+
echo "Executing Paddle test: $filename"
20+
$test_file
21+
fi
22+
done
23+
24+
# 记录并执行TORCH_PATH下所有可执行文件
25+
echo "Collecting and executing Torch executables..."
26+
TORCH_EXECUTABLES=()
27+
for test_file in ${TORCH_PATH}/*; do
28+
if [[ -x "$test_file" && -f "$test_file" ]]; then
29+
filename=$(basename $test_file)
30+
${TORCH_PATH}${filename}
31+
TORCH_EXECUTABLES+=("$filename")
32+
echo "Executing Torch test: $filename"
33+
$test_file
34+
fi
35+
done
36+
37+
# 比较结果文件
38+
echo "Comparing result files..."
39+
for ((i=0; i<${#PADDLE_EXECUTABLES[@]}; i++)); do
40+
paddle_file="${RESULT_FILE_PATH}/${PADDLE_EXECUTABLES[i]}.txt"
41+
torch_file="${RESULT_FILE_PATH}/${TORCH_EXECUTABLES[i]}.txt"
42+
43+
if [[ -f "$paddle_file" && -f "$torch_file" ]]; then
44+
if diff -q "$paddle_file" "$torch_file" >/dev/null; then
45+
echo "MATCH: ${PADDLE_EXECUTABLES[i]} and ${TORCH_EXECUTABLES[i]}"
46+
else
47+
echo "DIFFER: ${PADDLE_EXECUTABLES[i]} and ${TORCH_EXECUTABLES[i]}"
48+
diff "$paddle_file" "$torch_file"
49+
fi
50+
else
51+
echo "MISSING: ${paddle_file} or ${torch_file}"
52+
fi
53+
done

0 commit comments

Comments
 (0)