Skip to content

Commit a477e87

Browse files
committed
rewrite test files
1 parent 6d8d6f6 commit a477e87

File tree

1 file changed

+78
-40
lines changed

1 file changed

+78
-40
lines changed

test/TensorTest.cpp

Lines changed: 78 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -215,91 +215,129 @@ TEST_F(TensorTest, Transpose) {
215215

216216
// 测试 sym_size
217217
TEST_F(TensorTest, SymSize) {
218+
auto file_name = g_custom_param.get();
219+
FileManerger file(file_name);
220+
file.createFile();
218221
// 获取符号化的单个维度大小
219222
c10::SymInt sym_size_0 = tensor.sym_size(0);
220223
c10::SymInt sym_size_1 = tensor.sym_size(1);
221224
c10::SymInt sym_size_2 = tensor.sym_size(2);
222-
223-
// 验证符号化大小与实际大小一致
224-
EXPECT_EQ(sym_size_0, 2);
225-
EXPECT_EQ(sym_size_1, 3);
226-
EXPECT_EQ(sym_size_2, 4);
227-
225+
#if USE_PADDLE_API
226+
file << std::to_string(sym_size_0) << " ";
227+
file << std::to_string(sym_size_1) << " ";
228+
file << std::to_string(sym_size_2) << " ";
229+
#else
230+
file << std::to_string(sym_size_0.guard_int(__FILE__, __LINE__)) << " ";
231+
file << std::to_string(sym_size_1.guard_int(__FILE__, __LINE__)) << " ";
232+
file << std::to_string(sym_size_2.guard_int(__FILE__, __LINE__)) << " ";
233+
#endif
228234
// 测试负索引
229235
c10::SymInt sym_size_neg1 = tensor.sym_size(-1);
230-
EXPECT_EQ(sym_size_neg1, 4);
236+
#if USE_PADDLE_API
237+
file << std::to_string(sym_size_neg1) << " ";
238+
#else
239+
file << std::to_string(sym_size_neg1.guard_int(__FILE__, __LINE__)) << " ";
240+
#endif
241+
file.saveFile();
231242
}
232243

233244
// 测试 sym_stride
234245
TEST_F(TensorTest, SymStride) {
246+
auto file_name = g_custom_param.get();
247+
FileManerger file(file_name);
248+
file.createFile();
235249
// 获取符号化的单个维度步长
236250
c10::SymInt sym_stride_0 = tensor.sym_stride(0);
237251
c10::SymInt sym_stride_1 = tensor.sym_stride(1);
238252
c10::SymInt sym_stride_2 = tensor.sym_stride(2);
239-
240-
// 验证符号化步长
241-
EXPECT_GT(sym_stride_0, 0);
242-
EXPECT_GT(sym_stride_1, 0);
243-
EXPECT_GT(sym_stride_2, 0);
244-
253+
#if USE_PADDLE_API
254+
file << std::to_string(sym_stride_0) << " ";
255+
file << std::to_string(sym_stride_1) << " ";
256+
file << std::to_string(sym_stride_2) << " ";
257+
#else
258+
file << std::to_string(sym_stride_0.guard_int(__FILE__, __LINE__)) << " ";
259+
file << std::to_string(sym_stride_1.guard_int(__FILE__, __LINE__)) << " ";
260+
file << std::to_string(sym_stride_2.guard_int(__FILE__, __LINE__)) << " ";
261+
#endif
245262
// 测试负索引
246263
c10::SymInt sym_stride_neg1 = tensor.sym_stride(-1);
247-
EXPECT_EQ(sym_stride_neg1, 1); // 最后一维步长通常为1
264+
#if USE_PADDLE_API
265+
file << std::to_string(sym_stride_neg1) << " ";
266+
#else
267+
file << std::to_string(sym_stride_neg1.guard_int(__FILE__, __LINE__)) << " ";
268+
#endif
269+
file.saveFile();
248270
}
249271

250272
// 测试 sym_sizes
251273
TEST_F(TensorTest, SymSizes) {
274+
auto file_name = g_custom_param.get();
275+
FileManerger file(file_name);
276+
file.createFile();
252277
// 获取符号化的所有维度大小
253278
c10::SymIntArrayRef sym_sizes = tensor.sym_sizes();
254-
255-
// 验证维度数量
256-
EXPECT_EQ(sym_sizes.size(), 3U);
257-
258-
// 验证每个维度的大小
259-
EXPECT_EQ(sym_sizes[0], 2);
260-
EXPECT_EQ(sym_sizes[1], 3);
261-
EXPECT_EQ(sym_sizes[2], 4);
279+
file << std::to_string(sym_sizes.size()) << " ";
280+
for (size_t i = 0; i < sym_sizes.size(); ++i) {
281+
#if USE_PADDLE_API
282+
file << std::to_string(sym_sizes[i]) << " ";
283+
#else
284+
file << std::to_string(sym_sizes[i].guard_int(__FILE__, __LINE__)) << " ";
285+
#endif
286+
}
287+
file.saveFile();
262288
}
263289

264290
// 测试 sym_strides
265291
TEST_F(TensorTest, SymStrides) {
292+
auto file_name = g_custom_param.get();
293+
FileManerger file(file_name);
294+
file.createFile();
266295
// 获取符号化的所有维度步长
267296
c10::SymIntArrayRef sym_strides = tensor.sym_strides();
268-
269-
// 验证维度数量
270-
EXPECT_EQ(sym_strides.size(), 3U);
271-
272-
// 验证步长值都大于0
297+
file << std::to_string(sym_strides.size()) << " ";
273298
for (size_t i = 0; i < sym_strides.size(); ++i) {
274-
EXPECT_GT(sym_strides[i], 0);
299+
#if USE_PADDLE_API
300+
file << std::to_string(sym_strides[i]) << " ";
301+
#else
302+
file << std::to_string(sym_strides[i].guard_int(__FILE__, __LINE__)) << " ";
303+
#endif
275304
}
305+
file.saveFile();
276306
}
277307

278308
// 测试 sym_numel
279309
TEST_F(TensorTest, SymNumel) {
310+
auto file_name = g_custom_param.get();
311+
FileManerger file(file_name);
312+
file.createFile();
280313
// 获取符号化的元素总数
281314
c10::SymInt sym_numel = tensor.sym_numel();
282-
283-
// 验证符号化元素数与实际元素数一致
284-
EXPECT_EQ(sym_numel, 24); // 2*3*4
285-
286-
// 验证与 numel() 结果一致
287-
EXPECT_EQ(sym_numel, tensor.numel());
315+
#if USE_PADDLE_API
316+
file << std::to_string(sym_numel) << " ";
317+
#else
318+
file << std::to_string(sym_numel.guard_int(__FILE__, __LINE__)) << " ";
319+
#endif
320+
file << std::to_string(tensor.numel()) << " ";
321+
file.saveFile();
288322
}
289323

290324
// 测试 defined
291325
TEST_F(TensorTest, Defined) {
292-
// Tensor tensor(paddle_tensor_);
293-
294-
EXPECT_TRUE(tensor.defined());
326+
auto file_name = g_custom_param.get();
327+
FileManerger file(file_name);
328+
file.createFile();
329+
file << std::to_string(tensor.defined()) << " ";
330+
file.saveFile();
295331
}
296332

297333
// 测试 reset
298334
TEST_F(TensorTest, Reset) {
299-
// Tensor tensor(paddle_tensor_);
300-
335+
auto file_name = g_custom_param.get();
336+
FileManerger file(file_name);
337+
file.createFile();
301338
tensor.reset();
302-
EXPECT_FALSE(tensor.defined());
339+
file << std::to_string(tensor.defined()) << " ";
340+
file.saveFile();
303341
}
304342

305343
} // namespace test

0 commit comments

Comments
 (0)