Skip to content

Commit 6f53b16

Browse files
committed
Update tests to use more accurate tensor types.
1 parent cb0d921 commit 6f53b16

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,25 +1623,45 @@ public void testAdd58()
16231623
@Test
16241624
public void testAdd59()
16251625
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1626-
test("tf2_test_add59.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1626+
test(
1627+
"tf2_test_add59.py",
1628+
"add",
1629+
2,
1630+
2,
1631+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
16271632
}
16281633

16291634
@Test
16301635
public void testAdd60()
16311636
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1632-
test("tf2_test_add60.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1637+
test(
1638+
"tf2_test_add60.py",
1639+
"add",
1640+
2,
1641+
2,
1642+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
16331643
}
16341644

16351645
@Test
16361646
public void testAdd61()
16371647
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1638-
test("tf2_test_add61.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1648+
test(
1649+
"tf2_test_add61.py",
1650+
"add",
1651+
2,
1652+
2,
1653+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
16391654
}
16401655

16411656
@Test
16421657
public void testAdd62()
16431658
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1644-
test("tf2_test_add62.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1659+
test(
1660+
"tf2_test_add62.py",
1661+
"add",
1662+
2,
1663+
2,
1664+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
16451665
}
16461666

16471667
@Test
@@ -2157,7 +2177,7 @@ public void testReduceMean3()
21572177
@Test
21582178
public void testGradient()
21592179
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2160-
test("tf2_test_gradient.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
2180+
test("tf2_test_gradient.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_FLOAT32)));
21612181
}
21622182

21632183
@Test

com.ibm.wala.cast.python.test/data/tf2_test_add59.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,12 @@ def add(a, b):
55
return a + b
66

77

8-
c = add(tf.ragged.constant([1, 2]), tf.ragged.constant([2, 2]))
8+
arg1 = tf.ragged.constant([1, 2])
9+
assert arg1.shape == (2,)
10+
assert arg1.dtype == tf.int32
11+
12+
arg2 = tf.ragged.constant([2, 2])
13+
assert arg2.shape == (2,)
14+
assert arg2.dtype == tf.int32
15+
16+
c = add(arg1, arg2)

0 commit comments

Comments
 (0)