Skip to content

Commit 13c7ec0

Browse files
committed
Update tests.
1 parent 6b098bd commit 13c7ec0

File tree

3 files changed

+50
-13
lines changed

3 files changed

+50
-13
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,30 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
8080
private static final TensorType TENSOR_2_3_3_FLOAT32 =
8181
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(3)));
8282

83+
private static final TensorType TENSOR_2_3_3_INT32 =
84+
new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(3)));
85+
8386
private static final TensorType TENSOR_2_3_4_FLOAT32 =
8487
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(4)));
8588

89+
private static final TensorType TENSOR_2_3_4_INT32 =
90+
new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(4)));
91+
8692
private static final TensorType TENSOR_2_FLOAT32 =
8793
new TensorType(FLOAT_32, asList(new NumericDim(2)));
8894

95+
private static final TensorType TENSOR_2_INT32 =
96+
new TensorType(INT_32, asList(new NumericDim(2)));
97+
98+
private static final TensorType TENSOR_3_INT32 =
99+
new TensorType(INT_32, asList(new NumericDim(3)));
100+
101+
private static final TensorType TENSOR_3_FLOAT32 =
102+
new TensorType(FLOAT_32, asList(new NumericDim(3)));
103+
104+
private static final TensorType TENSOR_4_FLOAT32 =
105+
new TensorType(FLOAT_32, asList(new NumericDim(4)));
106+
89107
private static final TensorType TENSOR_5_FLOAT32 =
90108
new TensorType(FLOAT_32, asList(new NumericDim(5)));
91109

@@ -198,13 +216,13 @@ public void testFunction9()
198216
@Test
199217
public void testFunction10()
200218
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
201-
test("tf2_test_function10.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_3_4_FLOAT32)));
219+
test("tf2_test_function10.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_3_4_INT32)));
202220
}
203221

204222
@Test
205223
public void testFunction11()
206224
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
207-
test("tf2_test_function11.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_3_3_FLOAT32)));
225+
test("tf2_test_function11.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_3_3_INT32)));
208226
}
209227

210228
@Test
@@ -222,7 +240,7 @@ public void testDecorator2()
222240
@Test
223241
public void testDecorator3()
224242
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
225-
test("tf2_test_decorator3.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32)));
243+
test("tf2_test_decorator3.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_2_FLOAT32)));
226244
}
227245

228246
@Test
@@ -907,13 +925,13 @@ public void testAutoencoder4()
907925
@Test
908926
public void testSigmoid()
909927
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
910-
test("tf2_test_sigmoid.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
928+
test("tf2_test_sigmoid.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_4_FLOAT32)));
911929
}
912930

913931
@Test
914932
public void testSigmoid2()
915933
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
916-
test("tf2_test_sigmoid2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
934+
test("tf2_test_sigmoid2.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_4_FLOAT32)));
917935
}
918936

919937
@Test
@@ -1066,19 +1084,34 @@ public void testAdd22()
10661084
@Test
10671085
public void testAdd23()
10681086
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1069-
test("tf2_test_add23.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1087+
test(
1088+
"tf2_test_add23.py",
1089+
"add",
1090+
2,
1091+
2,
1092+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
10701093
}
10711094

10721095
@Test
10731096
public void testAdd24()
10741097
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1075-
test("tf2_test_add24.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1098+
test(
1099+
"tf2_test_add24.py",
1100+
"add",
1101+
2,
1102+
2,
1103+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
10761104
}
10771105

10781106
@Test
10791107
public void testAdd25()
10801108
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1081-
test("tf2_test_add25.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1109+
test(
1110+
"tf2_test_add25.py",
1111+
"add",
1112+
2,
1113+
2,
1114+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
10821115
}
10831116

10841117
@Test
@@ -1695,19 +1728,19 @@ public void testMultiGPUTraining2()
16951728
@Test
16961729
public void testReduceMean()
16971730
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1698-
test("tf2_test_reduce_mean.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
1731+
test("tf2_test_reduce_mean.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32)));
16991732
}
17001733

17011734
@Test
17021735
public void testReduceMean2()
17031736
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1704-
test("tf2_test_reduce_mean.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
1737+
test("tf2_test_reduce_mean.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32)));
17051738
}
17061739

17071740
@Test
17081741
public void testReduceMean3()
17091742
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1710-
test("tf2_test_reduce_mean.py", "h", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
1743+
test("tf2_test_reduce_mean.py", "h", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32)));
17111744
}
17121745

17131746
@Test
@@ -1742,13 +1775,13 @@ public void testSparseSoftmaxCrossEntropyWithLogits()
17421775
"f",
17431776
1,
17441777
1,
1745-
Map.of(2, Set.of(MNIST_INPUT)));
1778+
Map.of(2, Set.of(TENSOR_3_INT32)));
17461779
}
17471780

17481781
@Test
17491782
public void testRelu()
17501783
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1751-
test("tf2_test_relu.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
1784+
test("tf2_test_relu.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_FLOAT32)));
17521785
}
17531786

17541787
@Test

com.ibm.wala.cast.python.test/data/tf2_test_decorator3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ def returned(a):
99

1010
a = tf.constant([1.0, 1.0])
1111
b = returned(a)
12+
13+
assert a.shape == (2,)
14+
assert a.dtype == tf.float32

com.ibm.wala.cast.python.test/data/tf2_test_function11.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ def func(t):
1212
]
1313
)
1414
assert a.shape == (2, 3, 3)
15+
assert a.dtype == tf.int32
1516

1617
func(a)

0 commit comments

Comments
 (0)