Skip to content

Commit fc81337

Browse files
committed
Add test.
1 parent 6d9eac5 commit fc81337

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4364,6 +4364,12 @@ public void testEye2()
43644364
test("tf2_test_eye2.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32)));
43654365
}
43664366

4367+
@Test
4368+
public void testEye3()
4369+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4370+
test("tf2_test_eye3.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32)));
4371+
}
4372+
43674373
private void test(
43684374
String filename,
43694375
String functionName,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
# Construct one 2 x 3 "identity" matrix
9+
arg = tf.eye(2, 3)
10+
assert arg.shape == (2, 3)
11+
assert arg.dtype == tf.float32
12+
13+
f(arg)

0 commit comments

Comments
 (0)