@@ -170,5 +170,108 @@ TEST_F(TensorTest, Transpose) {
170170 EXPECT_EQ (transposed.sizes ()[2 ], 2 );
171171}
172172
173+ // 测试 squeeze
174+ TEST_F (TensorTest, Squeeze) {
175+ // 创建一个包含大小为1的维度的tensor: shape = {2, 1, 3, 1, 4}
176+ at::Tensor tensor_with_ones = at::ones ({2 , 1 , 3 , 1 , 4 }, at::kFloat );
177+
178+ // 移除所有大小为1的维度
179+ at::Tensor squeezed = tensor_with_ones.squeeze ();
180+ EXPECT_EQ (squeezed.dim (), 3 );
181+ EXPECT_EQ (squeezed.sizes ()[0 ], 2 );
182+ EXPECT_EQ (squeezed.sizes ()[1 ], 3 );
183+ EXPECT_EQ (squeezed.sizes ()[2 ], 4 );
184+ EXPECT_EQ (squeezed.numel (), 24 );
185+
186+ // 移除指定维度(维度1,大小为1)
187+ at::Tensor squeezed_dim1 = tensor_with_ones.squeeze (1 );
188+ EXPECT_EQ (squeezed_dim1.dim (), 4 );
189+ EXPECT_EQ (squeezed_dim1.sizes ()[0 ], 2 );
190+ EXPECT_EQ (squeezed_dim1.sizes ()[1 ], 3 );
191+ EXPECT_EQ (squeezed_dim1.sizes ()[2 ], 1 );
192+ EXPECT_EQ (squeezed_dim1.sizes ()[3 ], 4 );
193+ }
194+
195+ // 测试 squeeze_ (原位操作)
196+ TEST_F (TensorTest, SqueezeInplace) {
197+ // 创建一个包含大小为1的维度的tensor: shape = {2, 1, 3, 1, 4}
198+ at::Tensor tensor_with_ones = at::ones ({2 , 1 , 3 , 1 , 4 }, at::kFloat );
199+
200+ // 记录原始数据指针
201+ void * original_ptr = tensor_with_ones.data_ptr ();
202+
203+ // 原位移除所有大小为1的维度
204+ tensor_with_ones.squeeze_ ();
205+ EXPECT_EQ (tensor_with_ones.dim (), 3 );
206+ EXPECT_EQ (tensor_with_ones.sizes ()[0 ], 2 );
207+ EXPECT_EQ (tensor_with_ones.sizes ()[1 ], 3 );
208+ EXPECT_EQ (tensor_with_ones.sizes ()[2 ], 4 );
209+ EXPECT_EQ (tensor_with_ones.numel (), 24 );
210+
211+ // 验证是原位操作(数据指针未改变)
212+ EXPECT_EQ (tensor_with_ones.data_ptr (), original_ptr);
213+
214+ // 测试原位移除指定维度
215+ at::Tensor tensor_with_ones2 = at::ones ({2 , 1 , 3 , 1 , 4 }, at::kFloat );
216+ tensor_with_ones2.squeeze_ (1 );
217+ EXPECT_EQ (tensor_with_ones2.dim (), 4 );
218+ EXPECT_EQ (tensor_with_ones2.sizes ()[1 ], 3 );
219+ }
220+
221+ // 测试 unsqueeze
222+ TEST_F (TensorTest, Unsqueeze) {
223+ // 在维度0之前添加一个大小为1的维度
224+ at::Tensor unsqueezed0 = tensor.unsqueeze (0 );
225+ EXPECT_EQ (unsqueezed0.dim (), 4 );
226+ EXPECT_EQ (unsqueezed0.sizes ()[0 ], 1 );
227+ EXPECT_EQ (unsqueezed0.sizes ()[1 ], 2 );
228+ EXPECT_EQ (unsqueezed0.sizes ()[2 ], 3 );
229+ EXPECT_EQ (unsqueezed0.sizes ()[3 ], 4 );
230+ EXPECT_EQ (unsqueezed0.numel (), 24 );
231+
232+ // 在维度2之前添加一个大小为1的维度
233+ at::Tensor unsqueezed2 = tensor.unsqueeze (2 );
234+ EXPECT_EQ (unsqueezed2.dim (), 4 );
235+ EXPECT_EQ (unsqueezed2.sizes ()[0 ], 2 );
236+ EXPECT_EQ (unsqueezed2.sizes ()[1 ], 3 );
237+ EXPECT_EQ (unsqueezed2.sizes ()[2 ], 1 );
238+ EXPECT_EQ (unsqueezed2.sizes ()[3 ], 4 );
239+
240+ // 在最后添加一个大小为1的维度(使用负索引-1)
241+ at::Tensor unsqueezed_last = tensor.unsqueeze (-1 );
242+ EXPECT_EQ (unsqueezed_last.dim (), 4 );
243+ EXPECT_EQ (unsqueezed_last.sizes ()[0 ], 2 );
244+ EXPECT_EQ (unsqueezed_last.sizes ()[1 ], 3 );
245+ EXPECT_EQ (unsqueezed_last.sizes ()[2 ], 4 );
246+ EXPECT_EQ (unsqueezed_last.sizes ()[3 ], 1 );
247+ }
248+
249+ // 测试 unsqueeze_ (原位操作)
250+ TEST_F (TensorTest, UnsqueezeInplace) {
251+ // 创建一个新的tensor用于原位操作
252+ at::Tensor test_tensor = at::ones ({2 , 3 , 4 }, at::kFloat );
253+
254+ // 记录原始数据指针
255+ void * original_ptr = test_tensor.data_ptr ();
256+
257+ // 原位在维度0之前添加一个大小为1的维度
258+ test_tensor.unsqueeze_ (0 );
259+ EXPECT_EQ (test_tensor.dim (), 4 );
260+ EXPECT_EQ (test_tensor.sizes ()[0 ], 1 );
261+ EXPECT_EQ (test_tensor.sizes ()[1 ], 2 );
262+ EXPECT_EQ (test_tensor.sizes ()[2 ], 3 );
263+ EXPECT_EQ (test_tensor.sizes ()[3 ], 4 );
264+ EXPECT_EQ (test_tensor.numel (), 24 );
265+
266+ // 验证是原位操作(数据指针未改变)
267+ EXPECT_EQ (test_tensor.data_ptr (), original_ptr);
268+
269+ // 测试使用负索引的原位操作
270+ at::Tensor test_tensor2 = at::ones ({2 , 3 , 4 }, at::kFloat );
271+ test_tensor2.unsqueeze_ (-1 );
272+ EXPECT_EQ (test_tensor2.dim (), 4 );
273+ EXPECT_EQ (test_tensor2.sizes ()[3 ], 1 );
274+ }
275+
173276} // namespace test
174277} // namespace at
0 commit comments