Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions test/TensorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,5 +333,136 @@ TEST_F(TensorTest, SymNumel) {
EXPECT_EQ(sym_numel, tensor.numel());
}

// 测试 all() - 检查所有元素是否为真(非零)
TEST_F(TensorTest, All) {
FileManerger file(GetTestCaseResultFileName());
file.createFile();

// 测试全1张量 - all() 应返回 true
at::Tensor all_ones = at::ones({2, 2}, at::kInt);
bool result1 = all_ones.all().item<bool>();
file << std::to_string(result1) << " ";

// 测试全0张量 - all() 应返回 false
at::Tensor all_zeros = at::zeros({2, 2}, at::kInt);
bool result2 = all_zeros.all().item<bool>();
file << std::to_string(result2) << " ";

// 测试混合张量(有0有1)- all() 应返回 false
std::vector<int> data3 = {1, 0, 1, 1};
at::Tensor mixed = at::from_blob(data3.data(), {2, 2}, at::kInt).clone();
bool result3 = mixed.all().item<bool>();
file << std::to_string(result3) << " ";

// 测试全为负数张量 - all() 应返回 true(非零)
std::vector<int> data4 = {-1, -2, -3, -4};
at::Tensor all_neg = at::from_blob(data4.data(), {2, 2}, at::kInt).clone();
bool result4 = all_neg.all().item<bool>();
file << std::to_string(result4) << " ";

file.saveFile();
}

// 测试 all(dim, keepdim) - 沿指定维度检查
TEST_F(TensorTest, AllDim) {
FileManerger file(GetTestCaseResultFileName());
file.createFile();

std::vector<int> data = {1, 0, 1, 1, 1, 1};
at::Tensor tensor = at::from_blob(data.data(), {2, 3}, at::kInt).clone();

// 沿 dim=0 检查 - 每列所有行
at::Tensor result_dim0 = tensor.all(0, false);
file << std::to_string(result_dim0.sizes()[0]) << " ";
file << std::to_string(result_dim0.sizes()[1]) << " ";
// 第一列有0,应为false;第二列全为1,应为true;第三列有0,应为false
file << std::to_string(result_dim0[0].item<bool>()) << " ";
file << std::to_string(result_dim0[1].item<bool>()) << " ";
file << std::to_string(result_dim0[2].item<bool>()) << " ";

// 沿 dim=1 检查 - 每行所有列
at::Tensor result_dim1 = tensor.all(1, false);
file << std::to_string(result_dim1.sizes()[0]) << " ";
// 第一行有0,应为false;第二行全为1,应为true
file << std::to_string(result_dim1[0].item<bool>()) << " ";
file << std::to_string(result_dim1[1].item<bool>()) << " ";

// 测试 keepdim=true
at::Tensor result_keepdim = tensor.all(1, true);
file << std::to_string(result_keepdim.sizes()[0]) << " ";
file << std::to_string(result_keepdim.sizes()[1]) << " ";

file.saveFile();
}

// 测试 all(at::OptionalIntArrayRef dim, bool keepdim)
TEST_F(TensorTest, AllOptionalDim) {
FileManerger file(GetTestCaseResultFileName());
file.createFile();

std::vector<int> data = {1, 0, 1, 1, 1, 1};
at::Tensor tensor = at::from_blob(data.data(), {2, 3}, at::kInt).clone();

// 不指定维度 - 检查所有元素
at::Tensor result_no_dim = tensor.all(c10::nullopt, false);
file << std::to_string(result_no_dim.item<bool>()) << " ";

// 指定单个维度
at::Tensor result_single_dim = tensor.all({0}, false);
file << std::to_string(result_single_dim[0].item<bool>()) << " ";
file << std::to_string(result_single_dim[1].item<bool>()) << " ";
file << std::to_string(result_single_dim[2].item<bool>()) << " ";

// 指定多个维度
at::Tensor result_multi_dim = tensor.all({0, 1}, false);
file << std::to_string(result_multi_dim.item<bool>()) << " ";

file.saveFile();
}

// 测试 allclose - 检查两个张量是否接近
TEST_F(TensorTest, Allclose) {
FileManerger file(GetTestCaseResultFileName());
file.createFile();

// 测试1: 完全相同的张量 - 应返回 true
std::vector<float> data1 = {1.0f, 2.0f, 3.0f};
at::Tensor t1 = at::from_blob(data1.data(), {3}, at::kFloat).clone();
at::Tensor t1_copy = at::from_blob(data1.data(), {3}, at::kFloat).clone();
bool result1 = t1.allclose(t1_copy);
file << std::to_string(result1) << " ";

// 测试2: 在默认 rtol/atol 范围内的张量 - 应返回 true
std::vector<float> data2 = {1.0f, 2.0f, 3.0f};
std::vector<float> data2_slight = {1.0f + 1e-6f, 2.0f - 1e-6f, 3.0f};
at::Tensor t2 = at::from_blob(data2.data(), {3}, at::kFloat).clone();
at::Tensor t2_slight =
at::from_blob(data2_slight.data(), {3}, at::kFloat).clone();
bool result2 = t2.allclose(t2_slight);
file << std::to_string(result2) << " ";

// 测试3: 超出默认容差的张量 - 应返回 false
std::vector<float> data3 = {1.0f, 2.0f, 3.0f};
std::vector<float> data3_diff = {1.5f, 2.0f, 3.0f}; // 差异 0.5 > 默认 atol
at::Tensor t3 = at::from_blob(data3.data(), {3}, at::kFloat).clone();
at::Tensor t3_diff =
at::from_blob(data3_diff.data(), {3}, at::kFloat).clone();
bool result3 = t3.allclose(t3_diff);
file << std::to_string(result3) << " ";

// 测试4: 使用较大 rtol 的张量 - 应返回 true
bool result4 = t3.allclose(t3_diff, 0.5, 0.1, false);
file << std::to_string(result4) << " ";

// 测试7: 多维张量
std::vector<float> data7 = {1.0f, 2.0f, 3.0f, 4.0f};
at::Tensor t7 = at::from_blob(data7.data(), {2, 2}, at::kFloat).clone();
at::Tensor t7_copy = at::from_blob(data7.data(), {2, 2}, at::kFloat).clone();
bool result7 = t7.allclose(t7_copy);
file << std::to_string(result7) << " ";

file.saveFile();
}

} // namespace test
} // namespace at