Skip to content

Commit d5f5be9

Browse files
committed
add select, detach, reciprocal and split related tests
1 parent d19ebe4 commit d5f5be9

File tree

4 files changed

+795
-0
lines changed

4 files changed

+795
-0
lines changed

test/ops/DetachTest.cpp

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/zeros.h>
4+
#include <gtest/gtest.h>
5+
6+
#include <string>
7+
#include <vector>
8+
9+
#include "../../src/file_manager.h"
10+
11+
extern paddle_api_test::ThreadSafeParam g_custom_param;
12+
13+
namespace at {
14+
namespace test {
15+
16+
using paddle_api_test::FileManerger;
17+
using paddle_api_test::ThreadSafeParam;
18+
19+
class DetachTest : public ::testing::Test {
20+
protected:
21+
void SetUp() override {
22+
std::vector<int64_t> shape = {3, 4};
23+
test_tensor = at::zeros(shape, at::kFloat);
24+
float* data = test_tensor.data_ptr<float>();
25+
for (int64_t i = 0; i < 12; ++i) {
26+
data[i] = static_cast<float>(i + 1);
27+
}
28+
}
29+
at::Tensor test_tensor;
30+
};
31+
32+
static void write_detach_result_to_file(FileManerger* file,
33+
const at::Tensor& result,
34+
const at::Tensor& original) {
35+
*file << std::to_string(result.dim()) << " ";
36+
*file << std::to_string(result.numel()) << " ";
37+
38+
// 写入形状信息
39+
for (int64_t i = 0; i < result.dim(); ++i) {
40+
*file << std::to_string(result.sizes()[i]) << " ";
41+
}
42+
43+
// 写入数据内容
44+
float* result_data = result.data_ptr<float>();
45+
for (int64_t i = 0; i < result.numel(); ++i) {
46+
*file << std::to_string(result_data[i]) << " ";
47+
}
48+
49+
// 验证数据指针是否相同(共享存储)
50+
*file << std::to_string(result.data_ptr<float>() ==
51+
original.data_ptr<float>())
52+
<< " ";
53+
}
54+
55+
// 测试 detach() 方法 - 创建新的 tensor,不跟踪梯度
56+
TEST_F(DetachTest, BasicDetach) {
57+
at::Tensor detached = test_tensor.detach();
58+
auto file_name = g_custom_param.get();
59+
FileManerger file(file_name);
60+
file.createFile();
61+
write_detach_result_to_file(&file, detached, test_tensor);
62+
file.saveFile();
63+
}
64+
65+
// 测试 detach_() in-place 方法
66+
TEST_F(DetachTest, InplaceDetach) {
67+
auto file_name = g_custom_param.get();
68+
FileManerger file(file_name);
69+
file.createFile();
70+
71+
// 保存原始指针
72+
float* original_ptr = test_tensor.data_ptr<float>();
73+
74+
// 调用 in-place 版本
75+
at::Tensor& result = test_tensor.detach_();
76+
77+
// 验证返回的是同一个 tensor
78+
file << std::to_string(result.data_ptr<float>() == original_ptr) << " ";
79+
80+
// 写入数据
81+
float* data = result.data_ptr<float>();
82+
for (int64_t i = 0; i < result.numel(); ++i) {
83+
file << std::to_string(data[i]) << " ";
84+
}
85+
file.saveFile();
86+
}
87+
88+
// 测试 detach 后修改数据
89+
TEST_F(DetachTest, DetachAndModify) {
90+
at::Tensor detached = test_tensor.detach();
91+
92+
// 修改 detached tensor 的数据
93+
float* detached_data = detached.data_ptr<float>();
94+
detached_data[0] = 99.0f;
95+
detached_data[1] = 88.0f;
96+
97+
auto file_name = g_custom_param.get();
98+
FileManerger file(file_name);
99+
file.openAppend();
100+
101+
// 验证原始 tensor 的数据也被修改了(因为共享存储)
102+
float* original_data = test_tensor.data_ptr<float>();
103+
file << std::to_string(original_data[0]) << " ";
104+
file << std::to_string(original_data[1]) << " ";
105+
file << std::to_string(detached_data[0]) << " ";
106+
file << std::to_string(detached_data[1]) << " ";
107+
file.saveFile();
108+
}
109+
110+
// 测试不同类型 tensor 的 detach
111+
TEST_F(DetachTest, DetachDifferentTensor) {
112+
at::Tensor different_tensor = at::zeros({2, 2}, at::kFloat);
113+
float* data = different_tensor.data_ptr<float>();
114+
data[0] = 1.0f;
115+
data[1] = 2.0f;
116+
data[2] = 3.0f;
117+
data[3] = 4.0f;
118+
119+
at::Tensor detached = different_tensor.detach();
120+
121+
auto file_name = g_custom_param.get();
122+
FileManerger file(file_name);
123+
file.openAppend();
124+
125+
file << std::to_string(detached.numel()) << " ";
126+
file << std::to_string(detached.dim()) << " ";
127+
128+
float* detached_data = detached.data_ptr<float>();
129+
for (int64_t i = 0; i < detached.numel(); ++i) {
130+
file << std::to_string(detached_data[i]) << " ";
131+
}
132+
file.saveFile();
133+
}
134+
135+
// 测试多维 tensor 的 detach
136+
TEST_F(DetachTest, MultiDimensionalDetach) {
137+
at::Tensor multi_tensor = at::zeros({2, 3, 4}, at::kFloat);
138+
float* data = multi_tensor.data_ptr<float>();
139+
for (int64_t i = 0; i < 24; ++i) {
140+
data[i] = static_cast<float>(i);
141+
}
142+
143+
at::Tensor detached = multi_tensor.detach();
144+
145+
auto file_name = g_custom_param.get();
146+
FileManerger file(file_name);
147+
file.openAppend();
148+
149+
file << std::to_string(detached.dim()) << " ";
150+
file << std::to_string(detached.sizes()[0]) << " ";
151+
file << std::to_string(detached.sizes()[1]) << " ";
152+
file << std::to_string(detached.sizes()[2]) << " ";
153+
file << std::to_string(detached.numel()) << " ";
154+
155+
// 验证数据共享
156+
file << std::to_string(detached.data_ptr<float>() ==
157+
multi_tensor.data_ptr<float>())
158+
<< " ";
159+
file.saveFile();
160+
}
161+
162+
} // namespace test
163+
} // namespace at

test/ops/ReciprocalTest.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/reciprocal.h>
4+
#include <ATen/ops/zeros.h>
5+
#include <gtest/gtest.h>
6+
7+
#include <string>
8+
#include <vector>
9+
10+
#include "../../src/file_manager.h"
11+
12+
extern paddle_api_test::ThreadSafeParam g_custom_param;
13+
14+
namespace at {
15+
namespace test {
16+
17+
using paddle_api_test::FileManerger;
18+
using paddle_api_test::ThreadSafeParam;
19+
20+
class ReciprocalTest : public ::testing::Test {
21+
protected:
22+
void SetUp() override {
23+
std::vector<int64_t> shape = {4};
24+
test_tensor = at::zeros(shape, at::kFloat);
25+
float* data = test_tensor.data_ptr<float>();
26+
data[0] = 1.0f;
27+
data[1] = 2.0f;
28+
data[2] = 0.5f;
29+
data[3] = 4.0f;
30+
}
31+
at::Tensor test_tensor;
32+
};
33+
34+
static void write_reciprocal_result_to_file(FileManerger* file,
35+
const at::Tensor& result) {
36+
*file << std::to_string(result.dim()) << " ";
37+
*file << std::to_string(result.numel()) << " ";
38+
float* result_data = result.data_ptr<float>();
39+
for (int64_t i = 0; i < result.numel(); ++i) {
40+
*file << std::to_string(result_data[i]) << " ";
41+
}
42+
}
43+
44+
// 测试 reciprocal() 方法
45+
TEST_F(ReciprocalTest, BasicReciprocal) {
46+
at::Tensor result = test_tensor.reciprocal();
47+
auto file_name = g_custom_param.get();
48+
FileManerger file(file_name);
49+
file.createFile();
50+
write_reciprocal_result_to_file(&file, result);
51+
52+
// 验证原始 tensor 未被修改
53+
float* original_data = test_tensor.data_ptr<float>();
54+
file << std::to_string(original_data[0]) << " ";
55+
file << std::to_string(original_data[1]) << " ";
56+
file.saveFile();
57+
}
58+
59+
// 测试 reciprocal_() in-place 方法
60+
TEST_F(ReciprocalTest, InplaceReciprocal) {
61+
auto file_name = g_custom_param.get();
62+
FileManerger file(file_name);
63+
file.createFile();
64+
65+
// 保存原始数据指针
66+
float* original_ptr = test_tensor.data_ptr<float>();
67+
68+
// 调用 in-place 版本
69+
at::Tensor& result = test_tensor.reciprocal_();
70+
71+
// 验证返回的是同一个 tensor
72+
file << std::to_string(result.data_ptr<float>() == original_ptr) << " ";
73+
74+
write_reciprocal_result_to_file(&file, result);
75+
file.saveFile();
76+
}
77+
78+
// 测试不同值的 reciprocal
79+
TEST_F(ReciprocalTest, VariousValues) {
80+
at::Tensor various_tensor = at::zeros({5}, at::kFloat);
81+
float* data = various_tensor.data_ptr<float>();
82+
data[0] = 10.0f;
83+
data[1] = 0.1f;
84+
data[2] = -2.0f;
85+
data[3] = -0.5f;
86+
data[4] = 100.0f;
87+
88+
at::Tensor result = various_tensor.reciprocal();
89+
auto file_name = g_custom_param.get();
90+
FileManerger file(file_name);
91+
file.openAppend();
92+
write_reciprocal_result_to_file(&file, result);
93+
file.saveFile();
94+
}
95+
96+
// 测试多维 tensor 的 reciprocal
97+
TEST_F(ReciprocalTest, MultiDimensionalTensor) {
98+
at::Tensor multi_dim_tensor = at::zeros({2, 3}, at::kFloat);
99+
float* data = multi_dim_tensor.data_ptr<float>();
100+
data[0] = 1.0f;
101+
data[1] = 2.0f;
102+
data[2] = 4.0f;
103+
data[3] = 0.25f;
104+
data[4] = 0.5f;
105+
data[5] = 8.0f;
106+
107+
at::Tensor result = multi_dim_tensor.reciprocal();
108+
auto file_name = g_custom_param.get();
109+
FileManerger file(file_name);
110+
file.openAppend();
111+
112+
file << std::to_string(result.dim()) << " ";
113+
file << std::to_string(result.sizes()[0]) << " ";
114+
file << std::to_string(result.sizes()[1]) << " ";
115+
write_reciprocal_result_to_file(&file, result);
116+
file.saveFile();
117+
}
118+
119+
// 测试使用 at::reciprocal 全局函数
120+
TEST_F(ReciprocalTest, GlobalReciprocal) {
121+
at::Tensor result = at::reciprocal(test_tensor);
122+
auto file_name = g_custom_param.get();
123+
FileManerger file(file_name);
124+
file.openAppend();
125+
write_reciprocal_result_to_file(&file, result);
126+
file.saveFile();
127+
}
128+
129+
} // namespace test
130+
} // namespace at

0 commit comments

Comments
 (0)