44#include < gtest/gtest.h>
55#include < torch/all.h>
66
7+ #include < optional>
78#include < string>
89#include < vector>
910
@@ -28,6 +29,7 @@ class TensorTest : public ::testing::Test {
2829 at::Tensor tensor;
2930};
3031
32+ // 测试 从 Paddle Tensor 构造
3133TEST_F (TensorTest, ConstructFromPaddleTensor) {
3234 auto file_name = g_custom_param.get ();
3335 FileManerger file (file_name);
@@ -213,5 +215,114 @@ TEST_F(TensorTest, Transpose) {
213215 file.saveFile ();
214216}
215217
218+ static void write_tensor_shape_and_data (FileManerger* f,
219+ const at::Tensor& t,
220+ int64_t max_elems = 6 ) {
221+ *f << std::to_string (t.dim ()) << " " ;
222+ for (int64_t i = 0 ; i < t.dim (); ++i) {
223+ *f << std::to_string (t.size (i)) << " " ;
224+ }
225+ int64_t n = std::min (t.numel (), max_elems);
226+ float * p = t.data_ptr <float >();
227+ for (int64_t i = 0 ; i < n; ++i) {
228+ *f << std::to_string (p[i]) << " " ;
229+ }
230+ }
231+
232+ // 测试 clamp(scalar, scalar)
233+ TEST_F (TensorTest, ClampScalar) {
234+ auto file_name = g_custom_param.get ();
235+ FileManerger file (file_name);
236+ file.createFile ();
237+ std::vector<int64_t > shape = {2 , 3 };
238+ at::Tensor t = at::ones (shape, at::kFloat );
239+ for (int64_t i = 0 ; i < 6 ; ++i) {
240+ t.data_ptr <float >()[i] = static_cast <float >(i + 1 );
241+ }
242+ at::Tensor out =
243+ t.clamp (std::optional<at::Scalar>(2.0 ), std::optional<at::Scalar>(5.0 ));
244+ write_tensor_shape_and_data (&file, out);
245+ file.saveFile ();
246+ }
247+
248+ // 测试 clamp_min(scalar)
249+ TEST_F (TensorTest, ClampMinScalar) {
250+ auto file_name = g_custom_param.get ();
251+ FileManerger file (file_name);
252+ file.createFile ();
253+ std::vector<int64_t > shape = {2 , 3 };
254+ at::Tensor t = at::ones (shape, at::kFloat );
255+ for (int64_t i = 0 ; i < 6 ; ++i) {
256+ t.data_ptr <float >()[i] = static_cast <float >(i + 1 );
257+ }
258+ at::Tensor out = t.clamp_min (at::Scalar (2.0 ));
259+ write_tensor_shape_and_data (&file, out);
260+ file.saveFile ();
261+ }
262+
263+ // 测试 clamp_max(scalar)
264+ TEST_F (TensorTest, ClampMaxScalar) {
265+ auto file_name = g_custom_param.get ();
266+ FileManerger file (file_name);
267+ file.createFile ();
268+ std::vector<int64_t > shape = {2 , 3 };
269+ at::Tensor t = at::ones (shape, at::kFloat );
270+ for (int64_t i = 0 ; i < 6 ; ++i) {
271+ t.data_ptr <float >()[i] = static_cast <float >(i + 1 );
272+ }
273+ at::Tensor out = t.clamp_max (at::Scalar (5.0 ));
274+ write_tensor_shape_and_data (&file, out);
275+ file.saveFile ();
276+ }
277+
278+ // 测试 clamp_(scalar)
279+ TEST_F (TensorTest, ClampInplaceScalar) {
280+ auto file_name = g_custom_param.get ();
281+ FileManerger file (file_name);
282+ file.createFile ();
283+ std::vector<int64_t > shape = {2 , 3 };
284+ at::Tensor t = at::ones (shape, at::kFloat );
285+ for (int64_t i = 0 ; i < 6 ; ++i) {
286+ t.data_ptr <float >()[i] = static_cast <float >(i + 1 );
287+ }
288+ t.clamp_ (std::optional<at::Scalar>(2.0 ), std::optional<at::Scalar>(5.0 ));
289+ write_tensor_shape_and_data (&file, t);
290+ file.saveFile ();
291+ }
292+
293+ // 测试 clamp_min(tensor)
294+ TEST_F (TensorTest, ClampMinTensor) {
295+ auto file_name = g_custom_param.get ();
296+ FileManerger file (file_name);
297+ file.createFile ();
298+ std::vector<int64_t > shape = {2 , 3 };
299+ at::Tensor t = at::ones (shape, at::kFloat );
300+ for (int64_t i = 0 ; i < 6 ; ++i) {
301+ t.data_ptr <float >()[i] = static_cast <float >(i + 1 );
302+ }
303+ at::Tensor min_t = at::ones (shape, at::kFloat );
304+ min_t .fill_ (2.0 );
305+ at::Tensor out = t.clamp_min (min_t );
306+ write_tensor_shape_and_data (&file, out);
307+ file.saveFile ();
308+ }
309+
310+ // 测试 clamp_max(tensor)
311+ TEST_F (TensorTest, ClampMaxTensor) {
312+ auto file_name = g_custom_param.get ();
313+ FileManerger file (file_name);
314+ file.createFile ();
315+ std::vector<int64_t > shape = {2 , 3 };
316+ at::Tensor t = at::ones (shape, at::kFloat );
317+ for (int64_t i = 0 ; i < 6 ; ++i) {
318+ t.data_ptr <float >()[i] = static_cast <float >(i + 1 );
319+ }
320+ at::Tensor max_t = at::ones (shape, at::kFloat );
321+ max_t .fill_ (5.0 );
322+ at::Tensor out = t.clamp_max (max_t );
323+ write_tensor_shape_and_data (&file, out);
324+ file.saveFile ();
325+ }
326+
216327} // namespace test
217328} // namespace at
0 commit comments