@@ -215,91 +215,129 @@ TEST_F(TensorTest, Transpose) {
215215
216216// 测试 sym_size
217217TEST_F (TensorTest, SymSize) {
218+ auto file_name = g_custom_param.get ();
219+ FileManerger file (file_name);
220+ file.createFile ();
218221 // 获取符号化的单个维度大小
219222 c10::SymInt sym_size_0 = tensor.sym_size (0 );
220223 c10::SymInt sym_size_1 = tensor.sym_size (1 );
221224 c10::SymInt sym_size_2 = tensor.sym_size (2 );
222-
223- // 验证符号化大小与实际大小一致
224- EXPECT_EQ (sym_size_0, 2 );
225- EXPECT_EQ (sym_size_1, 3 );
226- EXPECT_EQ (sym_size_2, 4 );
227-
225+ #if USE_PADDLE_API
226+ file << std::to_string (sym_size_0) << " " ;
227+ file << std::to_string (sym_size_1) << " " ;
228+ file << std::to_string (sym_size_2) << " " ;
229+ #else
230+ file << std::to_string (sym_size_0.guard_int (__FILE__, __LINE__)) << " " ;
231+ file << std::to_string (sym_size_1.guard_int (__FILE__, __LINE__)) << " " ;
232+ file << std::to_string (sym_size_2.guard_int (__FILE__, __LINE__)) << " " ;
233+ #endif
228234 // 测试负索引
229235 c10::SymInt sym_size_neg1 = tensor.sym_size (-1 );
230- EXPECT_EQ (sym_size_neg1, 4 );
236+ #if USE_PADDLE_API
237+ file << std::to_string (sym_size_neg1) << " " ;
238+ #else
239+ file << std::to_string (sym_size_neg1.guard_int (__FILE__, __LINE__)) << " " ;
240+ #endif
241+ file.saveFile ();
231242}
232243
233244// 测试 sym_stride
234245TEST_F (TensorTest, SymStride) {
246+ auto file_name = g_custom_param.get ();
247+ FileManerger file (file_name);
248+ file.createFile ();
235249 // 获取符号化的单个维度步长
236250 c10::SymInt sym_stride_0 = tensor.sym_stride (0 );
237251 c10::SymInt sym_stride_1 = tensor.sym_stride (1 );
238252 c10::SymInt sym_stride_2 = tensor.sym_stride (2 );
239-
240- // 验证符号化步长
241- EXPECT_GT (sym_stride_0, 0 );
242- EXPECT_GT (sym_stride_1, 0 );
243- EXPECT_GT (sym_stride_2, 0 );
244-
253+ #if USE_PADDLE_API
254+ file << std::to_string (sym_stride_0) << " " ;
255+ file << std::to_string (sym_stride_1) << " " ;
256+ file << std::to_string (sym_stride_2) << " " ;
257+ #else
258+ file << std::to_string (sym_stride_0.guard_int (__FILE__, __LINE__)) << " " ;
259+ file << std::to_string (sym_stride_1.guard_int (__FILE__, __LINE__)) << " " ;
260+ file << std::to_string (sym_stride_2.guard_int (__FILE__, __LINE__)) << " " ;
261+ #endif
245262 // 测试负索引
246263 c10::SymInt sym_stride_neg1 = tensor.sym_stride (-1 );
247- EXPECT_EQ (sym_stride_neg1, 1 ); // 最后一维步长通常为1
264+ #if USE_PADDLE_API
265+ file << std::to_string (sym_stride_neg1) << " " ;
266+ #else
267+ file << std::to_string (sym_stride_neg1.guard_int (__FILE__, __LINE__)) << " " ;
268+ #endif
269+ file.saveFile ();
248270}
249271
250272// 测试 sym_sizes
251273TEST_F (TensorTest, SymSizes) {
274+ auto file_name = g_custom_param.get ();
275+ FileManerger file (file_name);
276+ file.createFile ();
252277 // 获取符号化的所有维度大小
253278 c10::SymIntArrayRef sym_sizes = tensor.sym_sizes ();
254-
255- // 验证维度数量
256- EXPECT_EQ (sym_sizes.size (), 3U );
257-
258- // 验证每个维度的大小
259- EXPECT_EQ (sym_sizes[0 ], 2 );
260- EXPECT_EQ (sym_sizes[1 ], 3 );
261- EXPECT_EQ (sym_sizes[2 ], 4 );
279+ file << std::to_string (sym_sizes.size ()) << " " ;
280+ for (size_t i = 0 ; i < sym_sizes.size (); ++i) {
281+ #if USE_PADDLE_API
282+ file << std::to_string (sym_sizes[i]) << " " ;
283+ #else
284+ file << std::to_string (sym_sizes[i].guard_int (__FILE__, __LINE__)) << " " ;
285+ #endif
286+ }
287+ file.saveFile ();
262288}
263289
264290// 测试 sym_strides
265291TEST_F (TensorTest, SymStrides) {
292+ auto file_name = g_custom_param.get ();
293+ FileManerger file (file_name);
294+ file.createFile ();
266295 // 获取符号化的所有维度步长
267296 c10::SymIntArrayRef sym_strides = tensor.sym_strides ();
268-
269- // 验证维度数量
270- EXPECT_EQ (sym_strides.size (), 3U );
271-
272- // 验证步长值都大于0
297+ file << std::to_string (sym_strides.size ()) << " " ;
273298 for (size_t i = 0 ; i < sym_strides.size (); ++i) {
274- EXPECT_GT (sym_strides[i], 0 );
299+ #if USE_PADDLE_API
300+ file << std::to_string (sym_strides[i]) << " " ;
301+ #else
302+ file << std::to_string (sym_strides[i].guard_int (__FILE__, __LINE__)) << " " ;
303+ #endif
275304 }
305+ file.saveFile ();
276306}
277307
278308// 测试 sym_numel
279309TEST_F (TensorTest, SymNumel) {
310+ auto file_name = g_custom_param.get ();
311+ FileManerger file (file_name);
312+ file.createFile ();
280313 // 获取符号化的元素总数
281314 c10::SymInt sym_numel = tensor.sym_numel ();
282-
283- // 验证符号化元素数与实际元素数一致
284- EXPECT_EQ (sym_numel, 24 ); // 2*3*4
285-
286- // 验证与 numel() 结果一致
287- EXPECT_EQ (sym_numel, tensor.numel ());
315+ #if USE_PADDLE_API
316+ file << std::to_string (sym_numel) << " " ;
317+ #else
318+ file << std::to_string (sym_numel.guard_int (__FILE__, __LINE__)) << " " ;
319+ #endif
320+ file << std::to_string (tensor.numel ()) << " " ;
321+ file.saveFile ();
288322}
289323
290324// 测试 defined
291325TEST_F (TensorTest, Defined) {
292- // Tensor tensor(paddle_tensor_);
293-
294- EXPECT_TRUE (tensor.defined ());
326+ auto file_name = g_custom_param.get ();
327+ FileManerger file (file_name);
328+ file.createFile ();
329+ file << std::to_string (tensor.defined ()) << " " ;
330+ file.saveFile ();
295331}
296332
297333// 测试 reset
298334TEST_F (TensorTest, Reset) {
299- // Tensor tensor(paddle_tensor_);
300-
335+ auto file_name = g_custom_param.get ();
336+ FileManerger file (file_name);
337+ file.createFile ();
301338 tensor.reset ();
302- EXPECT_FALSE (tensor.defined ());
339+ file << std::to_string (tensor.defined ()) << " " ;
340+ file.saveFile ();
303341}
304342
305343} // namespace test
0 commit comments