Skip to content

Commit 8f2d046

Browse files
committed
rewrite test files
1 parent 6f26562 commit 8f2d046

File tree

3 files changed

+205
-103
lines changed

3 files changed

+205
-103
lines changed

test/TensorTest.cpp

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -287,108 +287,5 @@ TEST_F(TensorTest, SymNumel) {
287287
EXPECT_EQ(sym_numel, tensor.numel());
288288
}
289289

290-
// 测试 squeeze
291-
TEST_F(TensorTest, Squeeze) {
292-
// 创建一个包含大小为1的维度的tensor: shape = {2, 1, 3, 1, 4}
293-
at::Tensor tensor_with_ones = at::ones({2, 1, 3, 1, 4}, at::kFloat);
294-
295-
// 移除所有大小为1的维度
296-
at::Tensor squeezed = tensor_with_ones.squeeze();
297-
EXPECT_EQ(squeezed.dim(), 3);
298-
EXPECT_EQ(squeezed.sizes()[0], 2);
299-
EXPECT_EQ(squeezed.sizes()[1], 3);
300-
EXPECT_EQ(squeezed.sizes()[2], 4);
301-
EXPECT_EQ(squeezed.numel(), 24);
302-
303-
// 移除指定维度(维度1,大小为1)
304-
at::Tensor squeezed_dim1 = tensor_with_ones.squeeze(1);
305-
EXPECT_EQ(squeezed_dim1.dim(), 4);
306-
EXPECT_EQ(squeezed_dim1.sizes()[0], 2);
307-
EXPECT_EQ(squeezed_dim1.sizes()[1], 3);
308-
EXPECT_EQ(squeezed_dim1.sizes()[2], 1);
309-
EXPECT_EQ(squeezed_dim1.sizes()[3], 4);
310-
}
311-
312-
// 测试 squeeze_ (原位操作)
313-
TEST_F(TensorTest, SqueezeInplace) {
314-
// 创建一个包含大小为1的维度的tensor: shape = {2, 1, 3, 1, 4}
315-
at::Tensor tensor_with_ones = at::ones({2, 1, 3, 1, 4}, at::kFloat);
316-
317-
// 记录原始数据指针
318-
void* original_ptr = tensor_with_ones.data_ptr();
319-
320-
// 原位移除所有大小为1的维度
321-
tensor_with_ones.squeeze_();
322-
EXPECT_EQ(tensor_with_ones.dim(), 3);
323-
EXPECT_EQ(tensor_with_ones.sizes()[0], 2);
324-
EXPECT_EQ(tensor_with_ones.sizes()[1], 3);
325-
EXPECT_EQ(tensor_with_ones.sizes()[2], 4);
326-
EXPECT_EQ(tensor_with_ones.numel(), 24);
327-
328-
// 验证是原位操作(数据指针未改变)
329-
EXPECT_EQ(tensor_with_ones.data_ptr(), original_ptr);
330-
331-
// 测试原位移除指定维度
332-
at::Tensor tensor_with_ones2 = at::ones({2, 1, 3, 1, 4}, at::kFloat);
333-
tensor_with_ones2.squeeze_(1);
334-
EXPECT_EQ(tensor_with_ones2.dim(), 4);
335-
EXPECT_EQ(tensor_with_ones2.sizes()[1], 3);
336-
}
337-
338-
// 测试 unsqueeze
339-
TEST_F(TensorTest, Unsqueeze) {
340-
// 在维度0之前添加一个大小为1的维度
341-
at::Tensor unsqueezed0 = tensor.unsqueeze(0);
342-
EXPECT_EQ(unsqueezed0.dim(), 4);
343-
EXPECT_EQ(unsqueezed0.sizes()[0], 1);
344-
EXPECT_EQ(unsqueezed0.sizes()[1], 2);
345-
EXPECT_EQ(unsqueezed0.sizes()[2], 3);
346-
EXPECT_EQ(unsqueezed0.sizes()[3], 4);
347-
EXPECT_EQ(unsqueezed0.numel(), 24);
348-
349-
// 在维度2之前添加一个大小为1的维度
350-
at::Tensor unsqueezed2 = tensor.unsqueeze(2);
351-
EXPECT_EQ(unsqueezed2.dim(), 4);
352-
EXPECT_EQ(unsqueezed2.sizes()[0], 2);
353-
EXPECT_EQ(unsqueezed2.sizes()[1], 3);
354-
EXPECT_EQ(unsqueezed2.sizes()[2], 1);
355-
EXPECT_EQ(unsqueezed2.sizes()[3], 4);
356-
357-
// 在最后添加一个大小为1的维度(使用负索引-1)
358-
at::Tensor unsqueezed_last = tensor.unsqueeze(-1);
359-
EXPECT_EQ(unsqueezed_last.dim(), 4);
360-
EXPECT_EQ(unsqueezed_last.sizes()[0], 2);
361-
EXPECT_EQ(unsqueezed_last.sizes()[1], 3);
362-
EXPECT_EQ(unsqueezed_last.sizes()[2], 4);
363-
EXPECT_EQ(unsqueezed_last.sizes()[3], 1);
364-
}
365-
366-
// 测试 unsqueeze_ (原位操作)
367-
TEST_F(TensorTest, UnsqueezeInplace) {
368-
// 创建一个新的tensor用于原位操作
369-
at::Tensor test_tensor = at::ones({2, 3, 4}, at::kFloat);
370-
371-
// 记录原始数据指针
372-
void* original_ptr = test_tensor.data_ptr();
373-
374-
// 原位在维度0之前添加一个大小为1的维度
375-
test_tensor.unsqueeze_(0);
376-
EXPECT_EQ(test_tensor.dim(), 4);
377-
EXPECT_EQ(test_tensor.sizes()[0], 1);
378-
EXPECT_EQ(test_tensor.sizes()[1], 2);
379-
EXPECT_EQ(test_tensor.sizes()[2], 3);
380-
EXPECT_EQ(test_tensor.sizes()[3], 4);
381-
EXPECT_EQ(test_tensor.numel(), 24);
382-
383-
// 验证是原位操作(数据指针未改变)
384-
EXPECT_EQ(test_tensor.data_ptr(), original_ptr);
385-
386-
// 测试使用负索引的原位操作
387-
at::Tensor test_tensor2 = at::ones({2, 3, 4}, at::kFloat);
388-
test_tensor2.unsqueeze_(-1);
389-
EXPECT_EQ(test_tensor2.dim(), 4);
390-
EXPECT_EQ(test_tensor2.sizes()[3], 1);
391-
}
392-
393290
} // namespace test
394291
} // namespace at

test/ops/SqueezeTest.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/ones.h>
4+
#include <gtest/gtest.h>
5+
6+
#include <string>
7+
#include <vector>
8+
9+
#include "../../src/file_manager.h"
10+
11+
extern paddle_api_test::ThreadSafeParam g_custom_param;
12+
13+
namespace at {
14+
namespace test {
15+
16+
using paddle_api_test::FileManerger;
17+
using paddle_api_test::ThreadSafeParam;
18+
19+
class SqueezeTest : public ::testing::Test {
20+
protected:
21+
void SetUp() override {
22+
// 创建一个包含大小为1的维度的tensor: shape = {2, 1, 3, 1, 4}
23+
tensor_with_ones = at::ones({2, 1, 3, 1, 4}, at::kFloat);
24+
}
25+
at::Tensor tensor_with_ones;
26+
};
27+
28+
// 测试 squeeze - 移除所有大小为1的维度
29+
TEST_F(SqueezeTest, SqueezeAll) {
30+
auto file_name = g_custom_param.get();
31+
FileManerger file(file_name);
32+
file.createFile();
33+
at::Tensor squeezed = tensor_with_ones.squeeze();
34+
file << std::to_string(squeezed.dim()) << " ";
35+
file << std::to_string(squeezed.numel()) << " ";
36+
for (int64_t i = 0; i < squeezed.dim(); ++i) {
37+
file << std::to_string(squeezed.sizes()[i]) << " ";
38+
}
39+
file.saveFile();
40+
}
41+
42+
// 测试 squeeze - 移除指定维度
43+
TEST_F(SqueezeTest, SqueezeDim) {
44+
auto file_name = g_custom_param.get();
45+
FileManerger file(file_name);
46+
file.createFile();
47+
// 移除维度1(大小为1)
48+
at::Tensor squeezed_dim1 = tensor_with_ones.squeeze(1);
49+
file << std::to_string(squeezed_dim1.dim()) << " ";
50+
file << std::to_string(squeezed_dim1.numel()) << " ";
51+
for (int64_t i = 0; i < squeezed_dim1.dim(); ++i) {
52+
file << std::to_string(squeezed_dim1.sizes()[i]) << " ";
53+
}
54+
file.saveFile();
55+
}
56+
57+
// 测试 squeeze_ - 原位移除所有大小为1的维度
58+
TEST_F(SqueezeTest, SqueezeInplaceAll) {
59+
auto file_name = g_custom_param.get();
60+
FileManerger file(file_name);
61+
file.createFile();
62+
// 记录原始数据指针
63+
void* original_ptr = tensor_with_ones.data_ptr();
64+
// 原位移除所有大小为1的维度
65+
tensor_with_ones.squeeze_();
66+
file << std::to_string(tensor_with_ones.dim()) << " ";
67+
file << std::to_string(tensor_with_ones.numel()) << " ";
68+
for (int64_t i = 0; i < tensor_with_ones.dim(); ++i) {
69+
file << std::to_string(tensor_with_ones.sizes()[i]) << " ";
70+
}
71+
// 验证是原位操作(数据指针未改变)
72+
file << std::to_string(tensor_with_ones.data_ptr() == original_ptr) << " ";
73+
file.saveFile();
74+
}
75+
76+
// 测试 squeeze_ - 原位移除指定维度
77+
TEST_F(SqueezeTest, SqueezeInplaceDim) {
78+
auto file_name = g_custom_param.get();
79+
FileManerger file(file_name);
80+
file.createFile();
81+
// 记录原始数据指针
82+
void* original_ptr = tensor_with_ones.data_ptr();
83+
// 原位移除维度1
84+
tensor_with_ones.squeeze_(1);
85+
file << std::to_string(tensor_with_ones.dim()) << " ";
86+
file << std::to_string(tensor_with_ones.numel()) << " ";
87+
for (int64_t i = 0; i < tensor_with_ones.dim(); ++i) {
88+
file << std::to_string(tensor_with_ones.sizes()[i]) << " ";
89+
}
90+
// 验证是原位操作(数据指针未改变)
91+
file << std::to_string(tensor_with_ones.data_ptr() == original_ptr) << " ";
92+
file.saveFile();
93+
}
94+
95+
} // namespace test
96+
} // namespace at

test/ops/UnsqueezeTest.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/ones.h>
4+
#include <gtest/gtest.h>
5+
6+
#include <string>
7+
#include <vector>
8+
9+
#include "../../src/file_manager.h"
10+
11+
extern paddle_api_test::ThreadSafeParam g_custom_param;
12+
13+
namespace at {
14+
namespace test {
15+
16+
using paddle_api_test::FileManerger;
17+
using paddle_api_test::ThreadSafeParam;
18+
19+
class UnsqueezeTest : public ::testing::Test {
20+
protected:
21+
void SetUp() override {
22+
// 创建一个基础tensor: shape = {2, 3, 4}
23+
tensor = at::ones({2, 3, 4}, at::kFloat);
24+
}
25+
at::Tensor tensor;
26+
};
27+
28+
// 测试 unsqueeze - 在维度0之前添加维度
29+
TEST_F(UnsqueezeTest, UnsqueezeDim0) {
30+
auto file_name = g_custom_param.get();
31+
FileManerger file(file_name);
32+
file.createFile();
33+
at::Tensor unsqueezed0 = tensor.unsqueeze(0);
34+
file << std::to_string(unsqueezed0.dim()) << " ";
35+
file << std::to_string(unsqueezed0.numel()) << " ";
36+
for (int64_t i = 0; i < unsqueezed0.dim(); ++i) {
37+
file << std::to_string(unsqueezed0.sizes()[i]) << " ";
38+
}
39+
file.saveFile();
40+
}
41+
42+
// 测试 unsqueeze - 在维度2之前添加维度
43+
TEST_F(UnsqueezeTest, UnsqueezeDim2) {
44+
auto file_name = g_custom_param.get();
45+
FileManerger file(file_name);
46+
file.createFile();
47+
at::Tensor unsqueezed2 = tensor.unsqueeze(2);
48+
file << std::to_string(unsqueezed2.dim()) << " ";
49+
file << std::to_string(unsqueezed2.numel()) << " ";
50+
for (int64_t i = 0; i < unsqueezed2.dim(); ++i) {
51+
file << std::to_string(unsqueezed2.sizes()[i]) << " ";
52+
}
53+
file.saveFile();
54+
}
55+
56+
// 测试 unsqueeze - 使用负索引在最后添加维度
57+
TEST_F(UnsqueezeTest, UnsqueezeNegativeDim) {
58+
auto file_name = g_custom_param.get();
59+
FileManerger file(file_name);
60+
file.createFile();
61+
at::Tensor unsqueezed_last = tensor.unsqueeze(-1);
62+
file << std::to_string(unsqueezed_last.dim()) << " ";
63+
file << std::to_string(unsqueezed_last.numel()) << " ";
64+
for (int64_t i = 0; i < unsqueezed_last.dim(); ++i) {
65+
file << std::to_string(unsqueezed_last.sizes()[i]) << " ";
66+
}
67+
file.saveFile();
68+
}
69+
70+
// 测试 unsqueeze_ - 原位在维度0之前添加维度
71+
TEST_F(UnsqueezeTest, UnsqueezeInplaceDim0) {
72+
auto file_name = g_custom_param.get();
73+
FileManerger file(file_name);
74+
file.createFile();
75+
// 记录原始数据指针
76+
void* original_ptr = tensor.data_ptr();
77+
// 原位在维度0之前添加维度
78+
tensor.unsqueeze_(0);
79+
file << std::to_string(tensor.dim()) << " ";
80+
file << std::to_string(tensor.numel()) << " ";
81+
for (int64_t i = 0; i < tensor.dim(); ++i) {
82+
file << std::to_string(tensor.sizes()[i]) << " ";
83+
}
84+
// 验证是原位操作(数据指针未改变)
85+
file << std::to_string(tensor.data_ptr() == original_ptr) << " ";
86+
file.saveFile();
87+
}
88+
89+
// 测试 unsqueeze_ - 原位使用负索引添加维度
90+
TEST_F(UnsqueezeTest, UnsqueezeInplaceNegativeDim) {
91+
auto file_name = g_custom_param.get();
92+
FileManerger file(file_name);
93+
file.createFile();
94+
// 记录原始数据指针
95+
void* original_ptr = tensor.data_ptr();
96+
// 原位在最后添加维度
97+
tensor.unsqueeze_(-1);
98+
file << std::to_string(tensor.dim()) << " ";
99+
file << std::to_string(tensor.numel()) << " ";
100+
for (int64_t i = 0; i < tensor.dim(); ++i) {
101+
file << std::to_string(tensor.sizes()[i]) << " ";
102+
}
103+
// 验证是原位操作(数据指针未改变)
104+
file << std::to_string(tensor.data_ptr() == original_ptr) << " ";
105+
file.saveFile();
106+
}
107+
108+
} // namespace test
109+
} // namespace at

0 commit comments

Comments
 (0)