Skip to content

Commit cd37563

Browse files
committed
add squeeze and unsqueeze test
1 parent 6e7d15d commit cd37563

File tree

1 file changed

+103
-0
lines changed

1 file changed

+103
-0
lines changed

test/TensorTest.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,5 +170,108 @@ TEST_F(TensorTest, Transpose) {
170170
EXPECT_EQ(transposed.sizes()[2], 2);
171171
}
172172

173+
// 测试 squeeze
174+
TEST_F(TensorTest, Squeeze) {
175+
// 创建一个包含大小为1的维度的tensor: shape = {2, 1, 3, 1, 4}
176+
at::Tensor tensor_with_ones = at::ones({2, 1, 3, 1, 4}, at::kFloat);
177+
178+
// 移除所有大小为1的维度
179+
at::Tensor squeezed = tensor_with_ones.squeeze();
180+
EXPECT_EQ(squeezed.dim(), 3);
181+
EXPECT_EQ(squeezed.sizes()[0], 2);
182+
EXPECT_EQ(squeezed.sizes()[1], 3);
183+
EXPECT_EQ(squeezed.sizes()[2], 4);
184+
EXPECT_EQ(squeezed.numel(), 24);
185+
186+
// 移除指定维度(维度1,大小为1)
187+
at::Tensor squeezed_dim1 = tensor_with_ones.squeeze(1);
188+
EXPECT_EQ(squeezed_dim1.dim(), 4);
189+
EXPECT_EQ(squeezed_dim1.sizes()[0], 2);
190+
EXPECT_EQ(squeezed_dim1.sizes()[1], 3);
191+
EXPECT_EQ(squeezed_dim1.sizes()[2], 1);
192+
EXPECT_EQ(squeezed_dim1.sizes()[3], 4);
193+
}
194+
195+
// 测试 squeeze_ (原位操作)
196+
TEST_F(TensorTest, SqueezeInplace) {
197+
// 创建一个包含大小为1的维度的tensor: shape = {2, 1, 3, 1, 4}
198+
at::Tensor tensor_with_ones = at::ones({2, 1, 3, 1, 4}, at::kFloat);
199+
200+
// 记录原始数据指针
201+
void* original_ptr = tensor_with_ones.data_ptr();
202+
203+
// 原位移除所有大小为1的维度
204+
tensor_with_ones.squeeze_();
205+
EXPECT_EQ(tensor_with_ones.dim(), 3);
206+
EXPECT_EQ(tensor_with_ones.sizes()[0], 2);
207+
EXPECT_EQ(tensor_with_ones.sizes()[1], 3);
208+
EXPECT_EQ(tensor_with_ones.sizes()[2], 4);
209+
EXPECT_EQ(tensor_with_ones.numel(), 24);
210+
211+
// 验证是原位操作(数据指针未改变)
212+
EXPECT_EQ(tensor_with_ones.data_ptr(), original_ptr);
213+
214+
// 测试原位移除指定维度
215+
at::Tensor tensor_with_ones2 = at::ones({2, 1, 3, 1, 4}, at::kFloat);
216+
tensor_with_ones2.squeeze_(1);
217+
EXPECT_EQ(tensor_with_ones2.dim(), 4);
218+
EXPECT_EQ(tensor_with_ones2.sizes()[1], 3);
219+
}
220+
221+
// 测试 unsqueeze
222+
TEST_F(TensorTest, Unsqueeze) {
223+
// 在维度0之前添加一个大小为1的维度
224+
at::Tensor unsqueezed0 = tensor.unsqueeze(0);
225+
EXPECT_EQ(unsqueezed0.dim(), 4);
226+
EXPECT_EQ(unsqueezed0.sizes()[0], 1);
227+
EXPECT_EQ(unsqueezed0.sizes()[1], 2);
228+
EXPECT_EQ(unsqueezed0.sizes()[2], 3);
229+
EXPECT_EQ(unsqueezed0.sizes()[3], 4);
230+
EXPECT_EQ(unsqueezed0.numel(), 24);
231+
232+
// 在维度2之前添加一个大小为1的维度
233+
at::Tensor unsqueezed2 = tensor.unsqueeze(2);
234+
EXPECT_EQ(unsqueezed2.dim(), 4);
235+
EXPECT_EQ(unsqueezed2.sizes()[0], 2);
236+
EXPECT_EQ(unsqueezed2.sizes()[1], 3);
237+
EXPECT_EQ(unsqueezed2.sizes()[2], 1);
238+
EXPECT_EQ(unsqueezed2.sizes()[3], 4);
239+
240+
// 在最后添加一个大小为1的维度(使用负索引-1)
241+
at::Tensor unsqueezed_last = tensor.unsqueeze(-1);
242+
EXPECT_EQ(unsqueezed_last.dim(), 4);
243+
EXPECT_EQ(unsqueezed_last.sizes()[0], 2);
244+
EXPECT_EQ(unsqueezed_last.sizes()[1], 3);
245+
EXPECT_EQ(unsqueezed_last.sizes()[2], 4);
246+
EXPECT_EQ(unsqueezed_last.sizes()[3], 1);
247+
}
248+
249+
// 测试 unsqueeze_ (原位操作)
250+
TEST_F(TensorTest, UnsqueezeInplace) {
251+
// 创建一个新的tensor用于原位操作
252+
at::Tensor test_tensor = at::ones({2, 3, 4}, at::kFloat);
253+
254+
// 记录原始数据指针
255+
void* original_ptr = test_tensor.data_ptr();
256+
257+
// 原位在维度0之前添加一个大小为1的维度
258+
test_tensor.unsqueeze_(0);
259+
EXPECT_EQ(test_tensor.dim(), 4);
260+
EXPECT_EQ(test_tensor.sizes()[0], 1);
261+
EXPECT_EQ(test_tensor.sizes()[1], 2);
262+
EXPECT_EQ(test_tensor.sizes()[2], 3);
263+
EXPECT_EQ(test_tensor.sizes()[3], 4);
264+
EXPECT_EQ(test_tensor.numel(), 24);
265+
266+
// 验证是原位操作(数据指针未改变)
267+
EXPECT_EQ(test_tensor.data_ptr(), original_ptr);
268+
269+
// 测试使用负索引的原位操作
270+
at::Tensor test_tensor2 = at::ones({2, 3, 4}, at::kFloat);
271+
test_tensor2.unsqueeze_(-1);
272+
EXPECT_EQ(test_tensor2.dim(), 4);
273+
EXPECT_EQ(test_tensor2.sizes()[3], 1);
274+
}
275+
173276
} // namespace test
174277
} // namespace at

0 commit comments

Comments
 (0)