Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions test/TensorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,106 @@ TEST_F(TensorTest, SymNumel) {
EXPECT_EQ(sym_numel, tensor.numel());
}

// 测试 any
TEST_F(TensorTest, Any) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
at::Tensor test_tensor = at::ones({2, 2}, at::kFloat);
test_tensor.fill_(0.0);
test_tensor.data_ptr<float>()[0] = 1.0;
bool any_result = test_tensor.any().item<bool>();
file << std::to_string(any_result) << " ";
auto any_dim_result = test_tensor.any(0);
file << std::to_string(any_dim_result.sizes()[0]) << " ";
file.saveFile();
}

// 测试 chunk
TEST_F(TensorTest, Chunk) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
at::Tensor test_tensor = at::ones({4, 4}, at::kFloat);
std::vector<at::Tensor> chunks = test_tensor.chunk(2, 0);
file << std::to_string(chunks.size()) << " ";
file << std::to_string(chunks[0].sizes()[0]) << " ";
file << std::to_string(chunks[1].sizes()[0]) << " ";
file.saveFile();
}

// 测试 rename - Paddle不支持Dimname,返回原tensor
TEST_F(TensorTest, Rename) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
at::Tensor renamed = tensor.rename(std::nullopt);
file << std::to_string(renamed.sizes().size()) << " ";
file.saveFile();
}

// 测试 new_empty
TEST_F(TensorTest, NewEmpty) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
at::Tensor empty_tensor = tensor.new_empty({3, 4});
file << std::to_string(empty_tensor.sizes()[0]) << " ";
file << std::to_string(empty_tensor.sizes()[1]) << " ";
file << std::to_string(empty_tensor.dtype() == tensor.dtype()) << " ";
file.saveFile();
}

// 测试 new_full
TEST_F(TensorTest, NewFull) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
at::Tensor full_tensor = tensor.new_full({2, 3}, 7.5);
file << std::to_string(full_tensor.sizes()[0]) << " ";
file << std::to_string(full_tensor.sizes()[1]) << " ";
file << std::to_string(full_tensor.data_ptr<float>()[0]) << " ";
file.saveFile();
}

// 测试 new_zeros
TEST_F(TensorTest, NewZeros) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
at::Tensor zeros_tensor = tensor.new_zeros({2, 3});
file << std::to_string(zeros_tensor.sizes()[0]) << " ";
file << std::to_string(zeros_tensor.sizes()[1]) << " ";
file << std::to_string(zeros_tensor.data_ptr<float>()[0]) << " ";
file.saveFile();
}

// 测试 new_ones
TEST_F(TensorTest, NewOnes) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
at::Tensor ones_tensor = tensor.new_ones({2, 3});
file << std::to_string(ones_tensor.sizes()[0]) << " ";
file << std::to_string(ones_tensor.sizes()[1]) << " ";
file << std::to_string(ones_tensor.data_ptr<float>()[0]) << " ";
file.saveFile();
}

// 测试 resize_ - Paddle不支持,会抛出异常
TEST_F(TensorTest, Resize) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
try {
tensor.resize_({4, 5});
file << "0 ";
} catch (const std::exception& e) {
file << "1 ";
}
file.saveFile();
}

// 测试 cpu()
TEST_F(TensorTest, CpuMethod) {
FileManerger file(GetTestCaseResultFileName());
Expand Down Expand Up @@ -463,6 +563,32 @@ TEST_F(TensorTest, ItemTemplate) {
file.saveFile();
}

// 测试 expand
TEST_F(TensorTest, Expand) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
at::Tensor small = at::ones({1, 3}, at::kFloat);
at::Tensor expanded = small.expand({4, 3});
file << std::to_string(expanded.sizes()[0]) << " ";
file << std::to_string(expanded.sizes()[1]) << " ";
file.saveFile();
}

// 测试 expand_as - 只测试维度等于1的情况(libtorch和paddle都支持)
TEST_F(TensorTest, ExpandAs) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
// 只有维度等于1时才能扩展,这是libtorch的语义
at::Tensor small = at::ones({1, 3}, at::kFloat);
at::Tensor target = at::ones({4, 3}, at::kFloat);
at::Tensor expanded = small.expand_as(target);
file << std::to_string(expanded.sizes()[0]) << " ";
file << std::to_string(expanded.sizes()[1]) << " ";
file.saveFile();
}

// 测试 clamp(min, max) with Scalar
TEST_F(TensorTest, ClampScalarMinMax) {
FileManerger file(GetTestCaseResultFileName());
Expand Down