Skip to content

Commit 6a7a8f3

Browse files
committed
add sparse related API tests
1 parent 6e7d15d commit 6a7a8f3

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

test/TensorTest.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,32 @@ TEST_F(TensorTest, IsCuda) {
151151
EXPECT_FALSE(tensor.is_cuda());
152152
}
153153

154+
// 测试 is_sparse
155+
TEST_F(TensorTest, IsSparse) {
156+
// 密集张量应该返回 false
157+
EXPECT_FALSE(tensor.is_sparse());
158+
159+
// 创建稀疏 COO 张量 - 先创建模板,再使用 zeros_like
160+
at::TensorOptions sparse_options =
161+
at::TensorOptions().dtype(at::kFloat).layout(at::kSparse);
162+
at::Tensor sparse_template = at::empty({2, 3}, sparse_options);
163+
at::Tensor sparse_tensor = at::zeros_like(sparse_template);
164+
EXPECT_TRUE(sparse_tensor.is_sparse());
165+
}
166+
167+
// 测试 is_sparse_csr
168+
TEST_F(TensorTest, IsSparseCsr) {
169+
// 密集张量应该返回 false
170+
EXPECT_FALSE(tensor.is_sparse_csr());
171+
172+
// 创建稀疏 CSR 张量 - 先创建模板,再使用 zeros_like
173+
at::TensorOptions sparse_csr_options =
174+
at::TensorOptions().dtype(at::kFloat).layout(at::kSparseCsr);
175+
at::Tensor sparse_csr_template = at::empty({2, 3}, sparse_csr_options);
176+
at::Tensor sparse_csr_tensor = at::zeros_like(sparse_csr_template);
177+
EXPECT_TRUE(sparse_csr_tensor.is_sparse_csr());
178+
}
179+
154180
// 测试 reshape
155181
TEST_F(TensorTest, Reshape) {
156182
// Tensor tensor(paddle_tensor_);

0 commit comments

Comments
 (0)