Skip to content

Commit d0da2c0

Browse files
authored
add sparse related API tests (#23)
1 parent 3429fa1 commit d0da2c0

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

test/TensorTest.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,32 @@ TEST_F(TensorTest, IsCuda) {
190190
file.saveFile();
191191
}
192192

193+
// 测试 is_sparse
194+
TEST_F(TensorTest, IsSparse) {
195+
// 密集张量应该返回 false
196+
EXPECT_FALSE(tensor.is_sparse());
197+
198+
// 创建稀疏 COO 张量 - 先创建模板,再使用 zeros_like
199+
at::TensorOptions sparse_options =
200+
at::TensorOptions().dtype(at::kFloat).layout(at::kSparse);
201+
at::Tensor sparse_template = at::empty({2, 3}, sparse_options);
202+
at::Tensor sparse_tensor = at::zeros_like(sparse_template);
203+
EXPECT_TRUE(sparse_tensor.is_sparse());
204+
}
205+
206+
// 测试 is_sparse_csr
207+
TEST_F(TensorTest, IsSparseCsr) {
208+
// 密集张量应该返回 false
209+
EXPECT_FALSE(tensor.is_sparse_csr());
210+
211+
// 创建稀疏 CSR 张量 - 先创建模板,再使用 zeros_like
212+
at::TensorOptions sparse_csr_options =
213+
at::TensorOptions().dtype(at::kFloat).layout(at::kSparseCsr);
214+
at::Tensor sparse_csr_template = at::empty({2, 3}, sparse_csr_options);
215+
at::Tensor sparse_csr_tensor = at::zeros_like(sparse_csr_template);
216+
EXPECT_TRUE(sparse_csr_tensor.is_sparse_csr());
217+
}
218+
193219
// 测试 reshape
194220
TEST_F(TensorTest, Reshape) {
195221
auto file_name = g_custom_param.get();

0 commit comments

Comments
 (0)