Skip to content

Commit c80f236

Browse files
committed
add some compat API tests
1 parent d60f9f5 commit c80f236

File tree

8 files changed

+1026
-0
lines changed

8 files changed

+1026
-0
lines changed

test/ops/CoalesceTest.cpp

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/sparse_coo_tensor.h>
4+
#include <ATen/ops/zeros.h>
5+
#include <gtest/gtest.h>
6+
7+
#include <string>
8+
#include <vector>
9+
10+
#include "../../src/file_manager.h"
11+
12+
extern paddle_api_test::ThreadSafeParam g_custom_param;
13+
14+
namespace at {
15+
namespace test {
16+
17+
using paddle_api_test::FileManerger;
18+
using paddle_api_test::ThreadSafeParam;
19+
20+
class CoalesceTest : public ::testing::Test {
21+
protected:
22+
void SetUp() override {
23+
// 构建 3x4 的稀疏 COO tensor(无重复索引)
24+
// 索引:[[0,1,1],[0,1,2]],值:[1,2,3]
25+
at::Tensor idx = at::zeros({2, 3}, at::kLong);
26+
idx[0][0] = 0;
27+
idx[0][1] = 1;
28+
idx[0][2] = 1;
29+
idx[1][0] = 0;
30+
idx[1][1] = 1;
31+
idx[1][2] = 2;
32+
at::Tensor val = at::zeros({3}, at::kFloat);
33+
val[0] = 1.f;
34+
val[1] = 2.f;
35+
val[2] = 3.f;
36+
sparse_unique = at::sparse_coo_tensor(idx, val, {3, 4});
37+
38+
// 构建含重复索引的稀疏 tensor(位置 (1,1) 出现两次)
39+
// 索引:[[0,1,1],[0,1,1]],值:[1,2,10]
40+
// coalesce 后:(0,0)=1, (1,1)=12,nnz=2
41+
at::Tensor idx_dup = at::zeros({2, 3}, at::kLong);
42+
idx_dup[0][0] = 0;
43+
idx_dup[0][1] = 1;
44+
idx_dup[0][2] = 1;
45+
idx_dup[1][0] = 0;
46+
idx_dup[1][1] = 1;
47+
idx_dup[1][2] = 1;
48+
at::Tensor val_dup = at::zeros({3}, at::kFloat);
49+
val_dup[0] = 1.f;
50+
val_dup[1] = 2.f;
51+
val_dup[2] = 10.f;
52+
sparse_dup = at::sparse_coo_tensor(idx_dup, val_dup, {3, 4});
53+
}
54+
55+
at::Tensor sparse_unique;
56+
at::Tensor sparse_dup;
57+
};
58+
59+
// 测试 _nnz():返回存储的非零元素数(含重复计数)
60+
TEST_F(CoalesceTest, NnzBasic) {
61+
auto file_name = g_custom_param.get();
62+
FileManerger file(file_name);
63+
file.createFile();
64+
65+
file << std::to_string(sparse_unique._nnz()) << " "; // 3
66+
file << std::to_string(sparse_dup._nnz()) << " "; // 3(含重复)
67+
file.saveFile();
68+
}
69+
70+
// 测试 _values():返回稀疏 tensor 的 values 子 tensor
71+
TEST_F(CoalesceTest, ValuesBasic) {
72+
auto file_name = g_custom_param.get();
73+
FileManerger file(file_name);
74+
file.createFile();
75+
76+
at::Tensor v = sparse_unique._values();
77+
file << std::to_string(v.dim()) << " ";
78+
file << std::to_string(v.numel()) << " ";
79+
float* data = v.data_ptr<float>();
80+
file << std::to_string(data[0]) << " ";
81+
file << std::to_string(data[1]) << " ";
82+
file << std::to_string(data[2]) << " ";
83+
file.saveFile();
84+
}
85+
86+
// 测试 is_coalesced():未经 coalesce 调用的张量
87+
TEST_F(CoalesceTest, IsCoalescedInitial) {
88+
auto file_name = g_custom_param.get();
89+
FileManerger file(file_name);
90+
file.createFile();
91+
92+
// 含重复索引,未经显式 coalesce,初始状态 is_coalesced 为 false
93+
file << std::to_string(static_cast<int>(sparse_dup.is_coalesced())) << " ";
94+
file.saveFile();
95+
}
96+
97+
// 测试 coalesce():合并重复索引后 nnz 减少
98+
TEST_F(CoalesceTest, CoalesceReducesNnz) {
99+
auto file_name = g_custom_param.get();
100+
FileManerger file(file_name);
101+
file.createFile();
102+
103+
at::Tensor coalesced = sparse_dup.coalesce();
104+
file << std::to_string(coalesced._nnz()) << " "; // 2(重复已合并)
105+
file.saveFile();
106+
}
107+
108+
// 测试 coalesce() 后 is_coalesced() 返回 true
109+
TEST_F(CoalesceTest, CoalesceIsCoalesced) {
110+
auto file_name = g_custom_param.get();
111+
FileManerger file(file_name);
112+
file.createFile();
113+
114+
at::Tensor coalesced = sparse_dup.coalesce();
115+
file << std::to_string(static_cast<int>(coalesced.is_coalesced())) << " ";
116+
file.saveFile();
117+
}
118+
119+
// 测试 coalesce() 后重复索引值被累加
120+
TEST_F(CoalesceTest, CoalesceAccumulatesValues) {
121+
auto file_name = g_custom_param.get();
122+
FileManerger file(file_name);
123+
file.createFile();
124+
125+
at::Tensor coalesced = sparse_dup.coalesce();
126+
at::Tensor v = coalesced._values();
127+
// 按索引排序后:(0,0)=1, (1,1)=12
128+
file << std::to_string(v.numel()) << " ";
129+
// 值应出现 1.0 和 12.0(顺序取决于实现,输出全部值排序后比较)
130+
float* data = v.data_ptr<float>();
131+
float sum = 0.f;
132+
for (int i = 0; i < v.numel(); ++i) sum += data[i];
133+
file << std::to_string(sum) << " "; // 1 + 12 = 13
134+
file.saveFile();
135+
}
136+
137+
// 测试无重复索引的 sparse tensor coalesce 后 nnz 不变
138+
TEST_F(CoalesceTest, CoalesceUniqueIndicesPreservesNnz) {
139+
auto file_name = g_custom_param.get();
140+
FileManerger file(file_name);
141+
file.createFile();
142+
143+
at::Tensor coalesced = sparse_unique.coalesce();
144+
file << std::to_string(coalesced._nnz()) << " "; // 仍为 3
145+
file << std::to_string(static_cast<int>(coalesced.is_coalesced())) << " ";
146+
file.saveFile();
147+
}
148+
149+
} // namespace test
150+
} // namespace at

test/ops/EyeTest.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/eye.h>
4+
#include <gtest/gtest.h>
5+
6+
#include <string>
7+
#include <vector>
8+
9+
#include "../../src/file_manager.h"
10+
11+
extern paddle_api_test::ThreadSafeParam g_custom_param;
12+
13+
namespace at {
14+
namespace test {
15+
16+
using paddle_api_test::FileManerger;
17+
using paddle_api_test::ThreadSafeParam;
18+
19+
class EyeTest : public ::testing::Test {
20+
protected:
21+
void SetUp() override {}
22+
};
23+
24+
static void write_eye_result_to_file(FileManerger* file,
25+
const at::Tensor& result) {
26+
*file << std::to_string(result.dim()) << " ";
27+
*file << std::to_string(result.numel()) << " ";
28+
for (int64_t i = 0; i < result.dim(); ++i) {
29+
*file << std::to_string(result.sizes()[i]) << " ";
30+
}
31+
*file << std::to_string(static_cast<int>(result.scalar_type())) << " ";
32+
float* data = result.data_ptr<float>();
33+
for (int64_t i = 0; i < result.numel(); ++i) {
34+
*file << std::to_string(data[i]) << " ";
35+
}
36+
}
37+
38+
// 基本 3×3 单位矩阵(默认 float dtype)
39+
TEST_F(EyeTest, BasicEyeSquare) {
40+
at::Tensor result = at::eye(3);
41+
auto file_name = g_custom_param.get();
42+
FileManerger file(file_name);
43+
file.createFile();
44+
write_eye_result_to_file(&file, result);
45+
file.saveFile();
46+
}
47+
48+
// 1×1 单位矩阵
49+
TEST_F(EyeTest, EyeSingleElement) {
50+
at::Tensor result = at::eye(1);
51+
auto file_name = g_custom_param.get();
52+
FileManerger file(file_name);
53+
file.openAppend();
54+
write_eye_result_to_file(&file, result);
55+
file.saveFile();
56+
}
57+
58+
// 指定 double dtype 的 4×4 单位矩阵
59+
TEST_F(EyeTest, EyeWithDoubleDtype) {
60+
at::Tensor result = at::eye(4, at::TensorOptions().dtype(at::kDouble));
61+
auto file_name = g_custom_param.get();
62+
FileManerger file(file_name);
63+
file.openAppend();
64+
file << std::to_string(result.dim()) << " ";
65+
file << std::to_string(result.numel()) << " ";
66+
for (int64_t i = 0; i < result.dim(); ++i) {
67+
file << std::to_string(result.sizes()[i]) << " ";
68+
}
69+
file << std::to_string(static_cast<int>(result.scalar_type())) << " ";
70+
double* data = result.data_ptr<double>();
71+
for (int64_t i = 0; i < result.numel(); ++i) {
72+
file << std::to_string(data[i]) << " ";
73+
}
74+
file.saveFile();
75+
}
76+
77+
// 行数 < 列数的矩形单位矩阵(3×5)
78+
TEST_F(EyeTest, EyeRectangularMoreCols) {
79+
at::Tensor result = at::eye(3, 5);
80+
auto file_name = g_custom_param.get();
81+
FileManerger file(file_name);
82+
file.openAppend();
83+
write_eye_result_to_file(&file, result);
84+
file.saveFile();
85+
}
86+
87+
// 行数 > 列数的矩形单位矩阵(5×3)
88+
TEST_F(EyeTest, EyeRectangularMoreRows) {
89+
at::Tensor result = at::eye(5, 3);
90+
auto file_name = g_custom_param.get();
91+
FileManerger file(file_name);
92+
file.openAppend();
93+
write_eye_result_to_file(&file, result);
94+
file.saveFile();
95+
}
96+
97+
// 指定 int dtype 的矩形单位矩阵(2×4)
98+
TEST_F(EyeTest, EyeRectangularWithIntDtype) {
99+
at::Tensor result = at::eye(2, 4, at::TensorOptions().dtype(at::kInt));
100+
auto file_name = g_custom_param.get();
101+
FileManerger file(file_name);
102+
file.openAppend();
103+
file << std::to_string(result.dim()) << " ";
104+
file << std::to_string(result.numel()) << " ";
105+
for (int64_t i = 0; i < result.dim(); ++i) {
106+
file << std::to_string(result.sizes()[i]) << " ";
107+
}
108+
file << std::to_string(static_cast<int>(result.scalar_type())) << " ";
109+
int* data = result.data_ptr<int>();
110+
for (int64_t i = 0; i < result.numel(); ++i) {
111+
file << std::to_string(data[i]) << " ";
112+
}
113+
file.saveFile();
114+
}
115+
116+
} // namespace test
117+
} // namespace at

test/ops/ItemTest.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/ops/zeros.h>
4+
#include <gtest/gtest.h>
5+
6+
#include <string>
7+
#include <vector>
8+
9+
#include "../../src/file_manager.h"
10+
11+
extern paddle_api_test::ThreadSafeParam g_custom_param;
12+
13+
namespace at {
14+
namespace test {
15+
16+
using paddle_api_test::FileManerger;
17+
using paddle_api_test::ThreadSafeParam;
18+
19+
class ItemTest : public ::testing::Test {
20+
protected:
21+
void SetUp() override {
22+
scalar_float = at::zeros({}, at::kFloat);
23+
scalar_float.fill_(3.14f);
24+
25+
scalar_int = at::zeros({}, at::kInt);
26+
scalar_int.fill_(42);
27+
28+
scalar_double = at::zeros({}, at::kDouble);
29+
scalar_double.fill_(2.718281828);
30+
}
31+
32+
at::Tensor scalar_float;
33+
at::Tensor scalar_int;
34+
at::Tensor scalar_double;
35+
};
36+
37+
// 测试 item() 从 float 0-dim tensor 获取标量(返回 at::Scalar)
38+
TEST_F(ItemTest, ItemFloatScalar) {
39+
auto file_name = g_custom_param.get();
40+
FileManerger file(file_name);
41+
file.createFile();
42+
43+
at::Scalar s = scalar_float.item();
44+
file << std::to_string(s.to<float>()) << " ";
45+
file.saveFile();
46+
}
47+
48+
// 测试 item<float>() 模板形式
49+
TEST_F(ItemTest, ItemTemplateFloat) {
50+
auto file_name = g_custom_param.get();
51+
FileManerger file(file_name);
52+
file.createFile();
53+
54+
float val = scalar_float.item<float>();
55+
file << std::to_string(val) << " ";
56+
file.saveFile();
57+
}
58+
59+
// 测试 item<int>() 从 int tensor
60+
TEST_F(ItemTest, ItemTemplateInt) {
61+
auto file_name = g_custom_param.get();
62+
FileManerger file(file_name);
63+
file.createFile();
64+
65+
int val = scalar_int.item<int>();
66+
file << std::to_string(val) << " ";
67+
file.saveFile();
68+
}
69+
70+
// 测试 item<double>() 获取 double 精度值
71+
TEST_F(ItemTest, ItemTemplateDouble) {
72+
auto file_name = g_custom_param.get();
73+
FileManerger file(file_name);
74+
file.createFile();
75+
76+
double val = scalar_double.item<double>();
77+
// 保留 9 位有效数字
78+
file << std::to_string(val) << " ";
79+
file.saveFile();
80+
}
81+
82+
// 测试 item<int64_t>()
83+
TEST_F(ItemTest, ItemTemplateInt64) {
84+
auto file_name = g_custom_param.get();
85+
FileManerger file(file_name);
86+
file.createFile();
87+
88+
at::Tensor t = at::zeros({}, at::kLong);
89+
t.fill_(static_cast<int64_t>(1234567890));
90+
int64_t val = t.item<int64_t>();
91+
file << std::to_string(val) << " ";
92+
file.saveFile();
93+
}
94+
95+
// 测试 item() 对单元素 1-dim tensor(squeeze 后语义)
96+
TEST_F(ItemTest, ItemFromSingleElementTensor) {
97+
auto file_name = g_custom_param.get();
98+
FileManerger file(file_name);
99+
file.createFile();
100+
101+
at::Tensor t = at::zeros({1}, at::kFloat);
102+
t.fill_(7.5f);
103+
float val = t.item<float>();
104+
file << std::to_string(val) << " ";
105+
file.saveFile();
106+
}
107+
108+
// 测试 item() 跨类型转换:double tensor 通过 item<float>()
109+
TEST_F(ItemTest, ItemCrossTypeCast) {
110+
auto file_name = g_custom_param.get();
111+
FileManerger file(file_name);
112+
file.createFile();
113+
114+
float val = scalar_double.item<float>();
115+
file << std::to_string(val) << " ";
116+
file.saveFile();
117+
}
118+
119+
} // namespace test
120+
} // namespace at

0 commit comments

Comments
 (0)