Skip to content

Commit fabb3fe

Browse files
committed
add clamp test
1 parent b7d76d2 commit fabb3fe

File tree

2 files changed

+113
-1
lines changed

2 files changed

+113
-1
lines changed

cmake/external.cmake

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,6 @@ function(ExternalProject repourl tag destination)
4040
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
4141
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
4242
PREFIX "${destination}"
43-
INSTALL_DIR "${destination}")
43+
INSTALL_DIR "${destination}"
44+
INSTALL_COMMAND "")
4445
endfunction()

test/TensorTest.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <gtest/gtest.h>
55
#include <torch/all.h>
66

7+
#include <optional>
78
#include <string>
89
#include <vector>
910

@@ -28,6 +29,7 @@ class TensorTest : public ::testing::Test {
2829
at::Tensor tensor;
2930
};
3031

32+
// 测试 从 Paddle Tensor 构造
3133
TEST_F(TensorTest, ConstructFromPaddleTensor) {
3234
auto file_name = g_custom_param.get();
3335
FileManerger file(file_name);
@@ -213,5 +215,114 @@ TEST_F(TensorTest, Transpose) {
213215
file.saveFile();
214216
}
215217

218+
static void write_tensor_shape_and_data(FileManerger* f,
219+
const at::Tensor& t,
220+
int64_t max_elems = 6) {
221+
*f << std::to_string(t.dim()) << " ";
222+
for (int64_t i = 0; i < t.dim(); ++i) {
223+
*f << std::to_string(t.size(i)) << " ";
224+
}
225+
int64_t n = std::min(t.numel(), max_elems);
226+
float* p = t.data_ptr<float>();
227+
for (int64_t i = 0; i < n; ++i) {
228+
*f << std::to_string(p[i]) << " ";
229+
}
230+
}
231+
232+
// 测试 clamp(scalar, scalar)
233+
TEST_F(TensorTest, ClampScalar) {
234+
auto file_name = g_custom_param.get();
235+
FileManerger file(file_name);
236+
file.createFile();
237+
std::vector<int64_t> shape = {2, 3};
238+
at::Tensor t = at::ones(shape, at::kFloat);
239+
for (int64_t i = 0; i < 6; ++i) {
240+
t.data_ptr<float>()[i] = static_cast<float>(i + 1);
241+
}
242+
at::Tensor out =
243+
t.clamp(std::optional<at::Scalar>(2.0), std::optional<at::Scalar>(5.0));
244+
write_tensor_shape_and_data(&file, out);
245+
file.saveFile();
246+
}
247+
248+
// 测试 clamp_min(scalar)
249+
TEST_F(TensorTest, ClampMinScalar) {
250+
auto file_name = g_custom_param.get();
251+
FileManerger file(file_name);
252+
file.createFile();
253+
std::vector<int64_t> shape = {2, 3};
254+
at::Tensor t = at::ones(shape, at::kFloat);
255+
for (int64_t i = 0; i < 6; ++i) {
256+
t.data_ptr<float>()[i] = static_cast<float>(i + 1);
257+
}
258+
at::Tensor out = t.clamp_min(at::Scalar(2.0));
259+
write_tensor_shape_and_data(&file, out);
260+
file.saveFile();
261+
}
262+
263+
// 测试 clamp_max(scalar)
264+
TEST_F(TensorTest, ClampMaxScalar) {
265+
auto file_name = g_custom_param.get();
266+
FileManerger file(file_name);
267+
file.createFile();
268+
std::vector<int64_t> shape = {2, 3};
269+
at::Tensor t = at::ones(shape, at::kFloat);
270+
for (int64_t i = 0; i < 6; ++i) {
271+
t.data_ptr<float>()[i] = static_cast<float>(i + 1);
272+
}
273+
at::Tensor out = t.clamp_max(at::Scalar(5.0));
274+
write_tensor_shape_and_data(&file, out);
275+
file.saveFile();
276+
}
277+
278+
// 测试 clamp_(scalar)
279+
TEST_F(TensorTest, ClampInplaceScalar) {
280+
auto file_name = g_custom_param.get();
281+
FileManerger file(file_name);
282+
file.createFile();
283+
std::vector<int64_t> shape = {2, 3};
284+
at::Tensor t = at::ones(shape, at::kFloat);
285+
for (int64_t i = 0; i < 6; ++i) {
286+
t.data_ptr<float>()[i] = static_cast<float>(i + 1);
287+
}
288+
t.clamp_(std::optional<at::Scalar>(2.0), std::optional<at::Scalar>(5.0));
289+
write_tensor_shape_and_data(&file, t);
290+
file.saveFile();
291+
}
292+
293+
// 测试 clamp_min(tensor)
294+
TEST_F(TensorTest, ClampMinTensor) {
295+
auto file_name = g_custom_param.get();
296+
FileManerger file(file_name);
297+
file.createFile();
298+
std::vector<int64_t> shape = {2, 3};
299+
at::Tensor t = at::ones(shape, at::kFloat);
300+
for (int64_t i = 0; i < 6; ++i) {
301+
t.data_ptr<float>()[i] = static_cast<float>(i + 1);
302+
}
303+
at::Tensor min_t = at::ones(shape, at::kFloat);
304+
min_t.fill_(2.0);
305+
at::Tensor out = t.clamp_min(min_t);
306+
write_tensor_shape_and_data(&file, out);
307+
file.saveFile();
308+
}
309+
310+
// 测试 clamp_max(tensor)
311+
TEST_F(TensorTest, ClampMaxTensor) {
312+
auto file_name = g_custom_param.get();
313+
FileManerger file(file_name);
314+
file.createFile();
315+
std::vector<int64_t> shape = {2, 3};
316+
at::Tensor t = at::ones(shape, at::kFloat);
317+
for (int64_t i = 0; i < 6; ++i) {
318+
t.data_ptr<float>()[i] = static_cast<float>(i + 1);
319+
}
320+
at::Tensor max_t = at::ones(shape, at::kFloat);
321+
max_t.fill_(5.0);
322+
at::Tensor out = t.clamp_max(max_t);
323+
write_tensor_shape_and_data(&file, out);
324+
file.saveFile();
325+
}
326+
216327
} // namespace test
217328
} // namespace at

0 commit comments

Comments
 (0)