Skip to content

Commit 32a8bd5

Browse files
committed
Add test.
1 parent 834cb29 commit 32a8bd5

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4390,6 +4390,16 @@ public void testEye5()
43904390
test("tf2_test_eye5.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_2_2_3_FLOAT32)));
43914391
}
43924392

4393+
/**
4394+
* FIXME: Should not throw an {@link IllegalArgumentException} once
4395+
* https://github.com/wala/ML/issues/340 is fixed.
4396+
*/
4397+
@Test(expected = IllegalArgumentException.class)
4398+
public void testEye6()
4399+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4400+
test("tf2_test_eye6.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_2_2_3_FLOAT32)));
4401+
}
4402+
43934403
private void test(
43944404
String filename,
43954405
String functionName,
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
batch = tf.constant([3, 2])
9+
assert batch.shape == (2,)
10+
assert batch.dtype == tf.int32
11+
12+
arg = tf.eye(2, 3, batch)
13+
14+
assert arg.shape == (3, 2, 2, 3)
15+
assert arg.dtype == tf.float32
16+
17+
f(arg)

0 commit comments

Comments
 (0)