Skip to content

Commit 5fdec36

Browse files
authored
Add cuda、is_pinned & pin_memory test (#33)
1 parent 1271912 commit 5fdec36

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

test/TensorTest.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,85 @@ TEST_F(TensorTest, Transpose) {
213213
file.saveFile();
214214
}
215215

216+
// 返回当前用例的结果文件名(用于逐个用例对比)
217+
static std::string GetTestCaseResultFileName() {
218+
std::string base = g_custom_param.get();
219+
std::string test_name =
220+
::testing::UnitTest::GetInstance()->current_test_info()->name();
221+
if (base.size() >= 4 && base.substr(base.size() - 4) == ".txt") {
222+
base.resize(base.size() - 4);
223+
}
224+
return base + "_" + test_name + ".txt";
225+
}
226+
227+
// 测试 cuda
228+
TEST_F(TensorTest, CudaResult) {
229+
FileManerger file(GetTestCaseResultFileName());
230+
file.createFile();
231+
try {
232+
at::Tensor cuda_tensor = tensor.cuda();
233+
file << "1 ";
234+
file << std::to_string(static_cast<int>(cuda_tensor.device().type()))
235+
<< " ";
236+
file << std::to_string(cuda_tensor.is_cuda() ? 1 : 0) << " ";
237+
file << std::to_string(cuda_tensor.numel()) << " ";
238+
} catch (const std::exception&) {
239+
file << "0 ";
240+
} catch (...) {
241+
file << "0 ";
242+
}
243+
file.saveFile();
244+
}
245+
246+
// 测试 is_pinned
247+
TEST_F(TensorTest, IsPinnedResult) {
248+
FileManerger file(GetTestCaseResultFileName());
249+
file.createFile();
250+
file << std::to_string(tensor.is_pinned() ? 1 : 0) << " ";
251+
int pinned_after_cuda = 0;
252+
try {
253+
at::Tensor cuda_tensor = tensor.cuda();
254+
at::Tensor pinned_tensor = cuda_tensor.pin_memory();
255+
pinned_after_cuda = pinned_tensor.is_pinned() ? 1 : 0;
256+
} catch (...) {
257+
pinned_after_cuda = 0;
258+
}
259+
file << std::to_string(pinned_after_cuda) << " ";
260+
file.saveFile();
261+
}
262+
263+
// 测试 pin_memory
264+
TEST_F(TensorTest, PinMemoryResult) {
265+
FileManerger file(GetTestCaseResultFileName());
266+
file.createFile();
267+
int gpu_pin_ok = 0;
268+
try {
269+
at::Tensor cuda_tensor = tensor.cuda();
270+
at::Tensor pinned_tensor = cuda_tensor.pin_memory();
271+
gpu_pin_ok = pinned_tensor.is_pinned() ? 1 : 0;
272+
} catch (...) {
273+
gpu_pin_ok = 0;
274+
}
275+
file << std::to_string(gpu_pin_ok) << " ";
276+
file.saveFile();
277+
// 测试 sym_size
278+
TEST_F(TensorTest, SymSize) {
279+
// 获取符号化的单个维度大小
280+
c10::SymInt sym_size_0 = tensor.sym_size(0);
281+
c10::SymInt sym_size_1 = tensor.sym_size(1);
282+
c10::SymInt sym_size_2 = tensor.sym_size(2);
283+
284+
// 验证符号化大小与实际大小一致
285+
EXPECT_EQ(sym_size_0, 2);
286+
EXPECT_EQ(sym_size_1, 3);
287+
EXPECT_EQ(sym_size_2, 4);
288+
289+
// 测试负索引
290+
c10::SymInt sym_size_neg1 = tensor.sym_size(-1);
291+
EXPECT_EQ(sym_size_neg1, 4);
292+
}
293+
}
294+
216295
// 测试 sym_size
217296
TEST_F(TensorTest, SymSize) {
218297
// 获取符号化的单个维度大小

0 commit comments

Comments
 (0)