Skip to content

Commit 1b88e96

Browse files
committed
More tests.
1 parent 86770a3 commit 1b88e96

File tree

4 files changed

+67
-0
lines changed

4 files changed

+67
-0
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
7777
private static final TensorType TENSOR_2_1_FLOAT32 =
7878
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(1)));
7979

80+
private static final TensorType TENSOR_2_3_3_FLOAT32 =
81+
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(3)));
82+
83+
private static final TensorType TENSOR_2_3_4_FLOAT32 =
84+
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(4)));
85+
8086
private static final TensorType TENSOR_2_FLOAT32 =
8187
new TensorType(FLOAT_32, asList(new NumericDim(2)));
8288

@@ -183,6 +189,24 @@ public void testFunction8()
183189
Map.of(2, Set.of(TENSOR_2_1_FLOAT32, TENSOR_2_FLOAT32)));
184190
}
185191

192+
@Test
193+
public void testFunction9()
194+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
195+
test("tf2_test_function9.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32)));
196+
}
197+
198+
@Test
199+
public void testFunction10()
200+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
201+
test("tf2_test_function10.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_3_4_FLOAT32)));
202+
}
203+
204+
@Test
205+
public void testFunction11()
206+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
207+
test("tf2_test_function11.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_3_3_FLOAT32)));
208+
}
209+
186210
@Test
187211
public void testDecorator()
188212
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import tensorflow as tf
2+
3+
4+
def func(t):
5+
pass
6+
7+
8+
a = tf.constant(
9+
[
10+
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
11+
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]],
12+
]
13+
)
14+
assert a.shape == (2, 3, 4)
15+
16+
func(a)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import tensorflow as tf
2+
3+
4+
def func(t):
5+
pass
6+
7+
8+
a = tf.constant(
9+
[
10+
[[1, 2, 3], [5, 6, 7], [9, 10, 11]],
11+
[[13, 14, 15], [17, 18, 19], [21, 22, 23]],
12+
]
13+
)
14+
assert a.shape == (2, 3, 3)
15+
16+
func(a)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import tensorflow as tf
2+
3+
4+
def func(t):
5+
pass
6+
7+
8+
a = tf.constant([[1.0, 3.0]])
9+
assert a.shape == (1, 2)
10+
11+
func(a)

0 commit comments

Comments
 (0)