@@ -287,108 +287,5 @@ TEST_F(TensorTest, SymNumel) {
287287 EXPECT_EQ (sym_numel, tensor.numel ());
288288}
289289
290- // 测试 squeeze
291- TEST_F (TensorTest, Squeeze) {
292- // 创建一个包含大小为1的维度的tensor: shape = {2, 1, 3, 1, 4}
293- at::Tensor tensor_with_ones = at::ones ({2 , 1 , 3 , 1 , 4 }, at::kFloat );
294-
295- // 移除所有大小为1的维度
296- at::Tensor squeezed = tensor_with_ones.squeeze ();
297- EXPECT_EQ (squeezed.dim (), 3 );
298- EXPECT_EQ (squeezed.sizes ()[0 ], 2 );
299- EXPECT_EQ (squeezed.sizes ()[1 ], 3 );
300- EXPECT_EQ (squeezed.sizes ()[2 ], 4 );
301- EXPECT_EQ (squeezed.numel (), 24 );
302-
303- // 移除指定维度(维度1,大小为1)
304- at::Tensor squeezed_dim1 = tensor_with_ones.squeeze (1 );
305- EXPECT_EQ (squeezed_dim1.dim (), 4 );
306- EXPECT_EQ (squeezed_dim1.sizes ()[0 ], 2 );
307- EXPECT_EQ (squeezed_dim1.sizes ()[1 ], 3 );
308- EXPECT_EQ (squeezed_dim1.sizes ()[2 ], 1 );
309- EXPECT_EQ (squeezed_dim1.sizes ()[3 ], 4 );
310- }
311-
312- // 测试 squeeze_ (原位操作)
313- TEST_F (TensorTest, SqueezeInplace) {
314- // 创建一个包含大小为1的维度的tensor: shape = {2, 1, 3, 1, 4}
315- at::Tensor tensor_with_ones = at::ones ({2 , 1 , 3 , 1 , 4 }, at::kFloat );
316-
317- // 记录原始数据指针
318- void * original_ptr = tensor_with_ones.data_ptr ();
319-
320- // 原位移除所有大小为1的维度
321- tensor_with_ones.squeeze_ ();
322- EXPECT_EQ (tensor_with_ones.dim (), 3 );
323- EXPECT_EQ (tensor_with_ones.sizes ()[0 ], 2 );
324- EXPECT_EQ (tensor_with_ones.sizes ()[1 ], 3 );
325- EXPECT_EQ (tensor_with_ones.sizes ()[2 ], 4 );
326- EXPECT_EQ (tensor_with_ones.numel (), 24 );
327-
328- // 验证是原位操作(数据指针未改变)
329- EXPECT_EQ (tensor_with_ones.data_ptr (), original_ptr);
330-
331- // 测试原位移除指定维度
332- at::Tensor tensor_with_ones2 = at::ones ({2 , 1 , 3 , 1 , 4 }, at::kFloat );
333- tensor_with_ones2.squeeze_ (1 );
334- EXPECT_EQ (tensor_with_ones2.dim (), 4 );
335- EXPECT_EQ (tensor_with_ones2.sizes ()[1 ], 3 );
336- }
337-
338- // 测试 unsqueeze
339- TEST_F (TensorTest, Unsqueeze) {
340- // 在维度0之前添加一个大小为1的维度
341- at::Tensor unsqueezed0 = tensor.unsqueeze (0 );
342- EXPECT_EQ (unsqueezed0.dim (), 4 );
343- EXPECT_EQ (unsqueezed0.sizes ()[0 ], 1 );
344- EXPECT_EQ (unsqueezed0.sizes ()[1 ], 2 );
345- EXPECT_EQ (unsqueezed0.sizes ()[2 ], 3 );
346- EXPECT_EQ (unsqueezed0.sizes ()[3 ], 4 );
347- EXPECT_EQ (unsqueezed0.numel (), 24 );
348-
349- // 在维度2之前添加一个大小为1的维度
350- at::Tensor unsqueezed2 = tensor.unsqueeze (2 );
351- EXPECT_EQ (unsqueezed2.dim (), 4 );
352- EXPECT_EQ (unsqueezed2.sizes ()[0 ], 2 );
353- EXPECT_EQ (unsqueezed2.sizes ()[1 ], 3 );
354- EXPECT_EQ (unsqueezed2.sizes ()[2 ], 1 );
355- EXPECT_EQ (unsqueezed2.sizes ()[3 ], 4 );
356-
357- // 在最后添加一个大小为1的维度(使用负索引-1)
358- at::Tensor unsqueezed_last = tensor.unsqueeze (-1 );
359- EXPECT_EQ (unsqueezed_last.dim (), 4 );
360- EXPECT_EQ (unsqueezed_last.sizes ()[0 ], 2 );
361- EXPECT_EQ (unsqueezed_last.sizes ()[1 ], 3 );
362- EXPECT_EQ (unsqueezed_last.sizes ()[2 ], 4 );
363- EXPECT_EQ (unsqueezed_last.sizes ()[3 ], 1 );
364- }
365-
366- // 测试 unsqueeze_ (原位操作)
367- TEST_F (TensorTest, UnsqueezeInplace) {
368- // 创建一个新的tensor用于原位操作
369- at::Tensor test_tensor = at::ones ({2 , 3 , 4 }, at::kFloat );
370-
371- // 记录原始数据指针
372- void * original_ptr = test_tensor.data_ptr ();
373-
374- // 原位在维度0之前添加一个大小为1的维度
375- test_tensor.unsqueeze_ (0 );
376- EXPECT_EQ (test_tensor.dim (), 4 );
377- EXPECT_EQ (test_tensor.sizes ()[0 ], 1 );
378- EXPECT_EQ (test_tensor.sizes ()[1 ], 2 );
379- EXPECT_EQ (test_tensor.sizes ()[2 ], 3 );
380- EXPECT_EQ (test_tensor.sizes ()[3 ], 4 );
381- EXPECT_EQ (test_tensor.numel (), 24 );
382-
383- // 验证是原位操作(数据指针未改变)
384- EXPECT_EQ (test_tensor.data_ptr (), original_ptr);
385-
386- // 测试使用负索引的原位操作
387- at::Tensor test_tensor2 = at::ones ({2 , 3 , 4 }, at::kFloat );
388- test_tensor2.unsqueeze_ (-1 );
389- EXPECT_EQ (test_tensor2.dim (), 4 );
390- EXPECT_EQ (test_tensor2.sizes ()[3 ], 1 );
391- }
392-
393290} // namespace test
394291} // namespace at
0 commit comments