Skip to content

Commit 5aa75be

Browse files
committed
Add test for tf.convert_to_tensor with 2D list input.
1 parent fc60ced commit 5aa75be

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4311,6 +4311,12 @@ public void testConvertToTensor11()
43114311
test("tf2_test_convert_to_tensor11.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_FLOAT32)));
43124312
}
43134313

4314+
@Test
4315+
public void testConvertToTensor12()
4316+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4317+
test("tf2_test_convert_to_tensor12.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_3_INT32)));
4318+
}
4319+
43144320
@Test
43154321
public void testOneHot()
43164322
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
# A 2D list (Matrix)
9+
matrix_list = [
10+
[1, 2, 3],
11+
[4, 5, 6]
12+
]
13+
14+
assert isinstance(matrix_list, list)
15+
assert len(matrix_list) == 2
16+
assert all(isinstance(row, list) for row in matrix_list)
17+
assert all(isinstance(x, int) for row in matrix_list for x in row)
18+
assert len(matrix_list[0]) == 3
19+
assert len(matrix_list[1]) == 3
20+
21+
# Convert the 2D list to a TensorFlow Tensor
22+
matrix_tensor = tf.convert_to_tensor(matrix_list)
23+
24+
# Output: shape=(2, 3), dtype=int32
25+
26+
assert isinstance(matrix_tensor, tf.Tensor)
27+
assert matrix_tensor.dtype == tf.int32
28+
assert matrix_tensor.shape == (2, 3)
29+
30+
f(matrix_tensor)

0 commit comments

Comments
 (0)