|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <ATen/core/Tensor.h> |
| 3 | +#include <ATen/ops/zeros.h> |
| 4 | +#include <gtest/gtest.h> |
| 5 | + |
| 6 | +#include <string> |
| 7 | +#include <vector> |
| 8 | + |
| 9 | +#include "../../src/file_manager.h" |
| 10 | + |
| 11 | +extern paddle_api_test::ThreadSafeParam g_custom_param; |
| 12 | + |
| 13 | +namespace at { |
| 14 | +namespace test { |
| 15 | + |
| 16 | +using paddle_api_test::FileManerger; |
| 17 | +using paddle_api_test::ThreadSafeParam; |
| 18 | + |
| 19 | +class DetachTest : public ::testing::Test { |
| 20 | + protected: |
| 21 | + void SetUp() override { |
| 22 | + std::vector<int64_t> shape = {3, 4}; |
| 23 | + test_tensor = at::zeros(shape, at::kFloat); |
| 24 | + float* data = test_tensor.data_ptr<float>(); |
| 25 | + for (int64_t i = 0; i < 12; ++i) { |
| 26 | + data[i] = static_cast<float>(i + 1); |
| 27 | + } |
| 28 | + } |
| 29 | + at::Tensor test_tensor; |
| 30 | +}; |
| 31 | + |
| 32 | +static void write_detach_result_to_file(FileManerger* file, |
| 33 | + const at::Tensor& result, |
| 34 | + const at::Tensor& original) { |
| 35 | + *file << std::to_string(result.dim()) << " "; |
| 36 | + *file << std::to_string(result.numel()) << " "; |
| 37 | + |
| 38 | + // 写入形状信息 |
| 39 | + for (int64_t i = 0; i < result.dim(); ++i) { |
| 40 | + *file << std::to_string(result.sizes()[i]) << " "; |
| 41 | + } |
| 42 | + |
| 43 | + // 写入数据内容 |
| 44 | + float* result_data = result.data_ptr<float>(); |
| 45 | + for (int64_t i = 0; i < result.numel(); ++i) { |
| 46 | + *file << std::to_string(result_data[i]) << " "; |
| 47 | + } |
| 48 | + |
| 49 | + // 验证数据指针是否相同(共享存储) |
| 50 | + *file << std::to_string(result.data_ptr<float>() == |
| 51 | + original.data_ptr<float>()) |
| 52 | + << " "; |
| 53 | +} |
| 54 | + |
| 55 | +// 测试 detach() 方法 - 创建新的 tensor,不跟踪梯度 |
| 56 | +TEST_F(DetachTest, BasicDetach) { |
| 57 | + at::Tensor detached = test_tensor.detach(); |
| 58 | + auto file_name = g_custom_param.get(); |
| 59 | + FileManerger file(file_name); |
| 60 | + file.createFile(); |
| 61 | + write_detach_result_to_file(&file, detached, test_tensor); |
| 62 | + file.saveFile(); |
| 63 | +} |
| 64 | + |
| 65 | +// 测试 detach_() in-place 方法 |
| 66 | +TEST_F(DetachTest, InplaceDetach) { |
| 67 | + auto file_name = g_custom_param.get(); |
| 68 | + FileManerger file(file_name); |
| 69 | + file.createFile(); |
| 70 | + |
| 71 | + // 保存原始指针 |
| 72 | + float* original_ptr = test_tensor.data_ptr<float>(); |
| 73 | + |
| 74 | + // 调用 in-place 版本 |
| 75 | + at::Tensor& result = test_tensor.detach_(); |
| 76 | + |
| 77 | + // 验证返回的是同一个 tensor |
| 78 | + file << std::to_string(result.data_ptr<float>() == original_ptr) << " "; |
| 79 | + |
| 80 | + // 写入数据 |
| 81 | + float* data = result.data_ptr<float>(); |
| 82 | + for (int64_t i = 0; i < result.numel(); ++i) { |
| 83 | + file << std::to_string(data[i]) << " "; |
| 84 | + } |
| 85 | + file.saveFile(); |
| 86 | +} |
| 87 | + |
| 88 | +// 测试 detach 后修改数据 |
| 89 | +TEST_F(DetachTest, DetachAndModify) { |
| 90 | + at::Tensor detached = test_tensor.detach(); |
| 91 | + |
| 92 | + // 修改 detached tensor 的数据 |
| 93 | + float* detached_data = detached.data_ptr<float>(); |
| 94 | + detached_data[0] = 99.0f; |
| 95 | + detached_data[1] = 88.0f; |
| 96 | + |
| 97 | + auto file_name = g_custom_param.get(); |
| 98 | + FileManerger file(file_name); |
| 99 | + file.openAppend(); |
| 100 | + |
| 101 | + // 验证原始 tensor 的数据也被修改了(因为共享存储) |
| 102 | + float* original_data = test_tensor.data_ptr<float>(); |
| 103 | + file << std::to_string(original_data[0]) << " "; |
| 104 | + file << std::to_string(original_data[1]) << " "; |
| 105 | + file << std::to_string(detached_data[0]) << " "; |
| 106 | + file << std::to_string(detached_data[1]) << " "; |
| 107 | + file.saveFile(); |
| 108 | +} |
| 109 | + |
| 110 | +// 测试不同类型 tensor 的 detach |
| 111 | +TEST_F(DetachTest, DetachDifferentTensor) { |
| 112 | + at::Tensor different_tensor = at::zeros({2, 2}, at::kFloat); |
| 113 | + float* data = different_tensor.data_ptr<float>(); |
| 114 | + data[0] = 1.0f; |
| 115 | + data[1] = 2.0f; |
| 116 | + data[2] = 3.0f; |
| 117 | + data[3] = 4.0f; |
| 118 | + |
| 119 | + at::Tensor detached = different_tensor.detach(); |
| 120 | + |
| 121 | + auto file_name = g_custom_param.get(); |
| 122 | + FileManerger file(file_name); |
| 123 | + file.openAppend(); |
| 124 | + |
| 125 | + file << std::to_string(detached.numel()) << " "; |
| 126 | + file << std::to_string(detached.dim()) << " "; |
| 127 | + |
| 128 | + float* detached_data = detached.data_ptr<float>(); |
| 129 | + for (int64_t i = 0; i < detached.numel(); ++i) { |
| 130 | + file << std::to_string(detached_data[i]) << " "; |
| 131 | + } |
| 132 | + file.saveFile(); |
| 133 | +} |
| 134 | + |
| 135 | +// 测试多维 tensor 的 detach |
| 136 | +TEST_F(DetachTest, MultiDimensionalDetach) { |
| 137 | + at::Tensor multi_tensor = at::zeros({2, 3, 4}, at::kFloat); |
| 138 | + float* data = multi_tensor.data_ptr<float>(); |
| 139 | + for (int64_t i = 0; i < 24; ++i) { |
| 140 | + data[i] = static_cast<float>(i); |
| 141 | + } |
| 142 | + |
| 143 | + at::Tensor detached = multi_tensor.detach(); |
| 144 | + |
| 145 | + auto file_name = g_custom_param.get(); |
| 146 | + FileManerger file(file_name); |
| 147 | + file.openAppend(); |
| 148 | + |
| 149 | + file << std::to_string(detached.dim()) << " "; |
| 150 | + file << std::to_string(detached.sizes()[0]) << " "; |
| 151 | + file << std::to_string(detached.sizes()[1]) << " "; |
| 152 | + file << std::to_string(detached.sizes()[2]) << " "; |
| 153 | + file << std::to_string(detached.numel()) << " "; |
| 154 | + |
| 155 | + // 验证数据共享 |
| 156 | + file << std::to_string(detached.data_ptr<float>() == |
| 157 | + multi_tensor.data_ptr<float>()) |
| 158 | + << " "; |
| 159 | + file.saveFile(); |
| 160 | +} |
| 161 | + |
| 162 | +} // namespace test |
| 163 | +} // namespace at |
0 commit comments