Skip to content

Commit 9aa9bb1

Browse files
committed
add pointer related API tests
1 parent ac33340 commit 9aa9bb1

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

test/TensorUtilTest.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/ones.h>
4+
#include <gtest/gtest.h>
5+
#include <torch/all.h>
6+
7+
#include <string>
8+
#include <vector>
9+
10+
#include "../src/file_manager.h"
11+
12+
extern paddle_api_test::ThreadSafeParam g_custom_param;
13+
14+
namespace at {
15+
namespace test {
16+
17+
using paddle_api_test::FileManerger;
18+
using paddle_api_test::ThreadSafeParam;
19+
20+
class TensorUtilTest : public ::testing::Test {
21+
protected:
22+
void SetUp() override {
23+
std::vector<int64_t> shape = {2, 3, 4};
24+
tensor = at::ones(shape, at::kFloat);
25+
}
26+
27+
at::Tensor tensor;
28+
};
29+
30+
// 测试 toString
31+
TEST_F(TensorUtilTest, ToString) {
32+
auto file_name = g_custom_param.get();
33+
FileManerger file(file_name);
34+
file.createFile();
35+
std::string tensor_str = tensor.toString();
36+
file << tensor_str << " ";
37+
file.saveFile();
38+
}
39+
40+
// 测试 is_contiguous_or_false
41+
TEST_F(TensorUtilTest, IsContiguousOrFalse) {
42+
auto file_name = g_custom_param.get();
43+
FileManerger file(file_name);
44+
file.createFile();
45+
file << std::to_string(tensor.is_contiguous_or_false()) << " ";
46+
47+
// 测试非连续的tensor
48+
at::Tensor transposed = tensor.transpose(0, 2);
49+
file << std::to_string(transposed.is_contiguous_or_false()) << " ";
50+
file.saveFile();
51+
}
52+
53+
// 测试 is_same
54+
TEST_F(TensorUtilTest, IsSame) {
55+
auto file_name = g_custom_param.get();
56+
FileManerger file(file_name);
57+
file.createFile();
58+
59+
// Test that tensor is same as itself
60+
file << std::to_string(tensor.is_same(tensor)) << " ";
61+
62+
// Test that two different tensors are not the same
63+
at::Tensor other_tensor = at::ones({2, 3, 4}, at::kFloat);
64+
file << std::to_string(tensor.is_same(other_tensor)) << " ";
65+
66+
// Test that a shallow copy points to the same tensor
67+
at::Tensor shallow_copy = tensor;
68+
file << std::to_string(tensor.is_same(shallow_copy)) << " ";
69+
70+
// Test that a view of the tensor
71+
at::Tensor view = tensor.view({24});
72+
file << std::to_string(tensor.is_same(view)) << " ";
73+
file.saveFile();
74+
}
75+
76+
// 测试 use_count
77+
TEST_F(TensorUtilTest, UseCount) {
78+
auto file_name = g_custom_param.get();
79+
FileManerger file(file_name);
80+
file.createFile();
81+
82+
// Get initial use count
83+
size_t initial_count = tensor.use_count();
84+
file << std::to_string(initial_count) << " ";
85+
86+
// Create a copy, should increase use count
87+
{
88+
at::Tensor copy = tensor;
89+
size_t new_count = tensor.use_count();
90+
file << std::to_string(new_count) << " ";
91+
file << std::to_string(new_count - initial_count) << " "; // 差值
92+
}
93+
94+
// After copy goes out of scope, use count should decrease
95+
size_t final_count = tensor.use_count();
96+
file << std::to_string(final_count) << " ";
97+
file.saveFile();
98+
}
99+
100+
// 测试 weak_use_count
101+
TEST_F(TensorUtilTest, WeakUseCount) {
102+
auto file_name = g_custom_param.get();
103+
FileManerger file(file_name);
104+
file.createFile();
105+
106+
// Get initial weak use count
107+
size_t initial_weak_count = tensor.weak_use_count();
108+
file << std::to_string(initial_weak_count) << " ";
109+
file.saveFile();
110+
}
111+
112+
// 测试 print
113+
TEST_F(TensorUtilTest, Print) {
114+
auto file_name = g_custom_param.get();
115+
FileManerger file(file_name);
116+
file.createFile();
117+
118+
// 创建一个小的tensor用于print测试
119+
at::Tensor small_tensor = at::ones({2, 2}, at::kFloat);
120+
121+
// 使用 captureStdout 捕获 print() 的输出
122+
file.captureStdout([&]() {
123+
tensor.print();
124+
small_tensor.print();
125+
});
126+
127+
file << std::to_string(1) << " "; // 如果执行到这里说明print()没有崩溃
128+
file.saveFile();
129+
}
130+
131+
} // namespace test
132+
} // namespace at

0 commit comments

Comments
 (0)