Skip to content

Commit bf63246

Browse files
committed
Add future test.
1 parent acef928 commit bf63246

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
9595
private static final TensorType TENSOR_3_3_INT32 =
9696
new TensorType(INT_32, asList(new NumericDim(3), new NumericDim(3)));
9797

98+
private static final TensorType TENSOR_2_3_INT32 =
99+
new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3)));
100+
98101
private static final TensorType TENSOR_2_1_FLOAT32 =
99102
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(1)));
100103

@@ -104,6 +107,9 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
104107
private static final TensorType TENSOR_2_3_4_INT32 =
105108
new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(4)));
106109

110+
private static final TensorType TENSOR_2_5_3_FLOAT32 =
111+
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(5), new NumericDim(3)));
112+
107113
private static final TensorType TENSOR_20_28_28_FLOAT32 =
108114
new TensorType(FLOAT_32, asList(new NumericDim(20), new NumericDim(28), new NumericDim(28)));
109115

@@ -4314,6 +4320,23 @@ public void testOneHot17()
43144320
test("tf2_test_one_hot17.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32)));
43154321
}
43164322

4323+
@Test
4324+
public void testOneHot18()
4325+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4326+
test("tf2_test_one_hot18.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_5_3_FLOAT32)));
4327+
}
4328+
4329+
/**
4330+
* FIXME: Should not throw an {@link IllegalArgumentException} once
4331+
* https://github.com/wala/ML/issues/340 is fixed.
4332+
*/
4333+
@Test(expected = IllegalArgumentException.class)
4334+
public void testOneHot19()
4335+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4336+
test("tf2_test_one_hot19.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_2_3_INT32)));
4337+
test("tf2_test_one_hot19.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_5_3_FLOAT32)));
4338+
}
4339+
43174340
private void test(
43184341
String filename,
43194342
String functionName,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
arg1 = [[10, 20, 30], [40, 50, 60]] # Row 1 # Row 2
9+
10+
assert isinstance(arg1, list)
11+
assert all(isinstance(row, list) for row in arg1)
12+
assert all(isinstance(elem, int) for row in arg1 for elem in row)
13+
assert len(arg1) == 2
14+
assert all(len(row) == 3 for row in arg1)
15+
16+
arg2 = tf.one_hot(arg1, 5, None, None, 1)
17+
assert isinstance(arg2, tf.Tensor)
18+
assert arg2.dtype == tf.float32
19+
assert arg2.shape == (2, 5, 3)
20+
21+
f(arg2)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
def g(a):
9+
pass
10+
11+
12+
my_tensor = tf.constant([[10, 20, 30], [40, 50, 60]]) # Row 1 # Row 2
13+
assert isinstance(my_tensor, tf.Tensor)
14+
assert my_tensor.dtype == tf.int32
15+
assert my_tensor.shape == (2, 3)
16+
17+
g(my_tensor)
18+
19+
arg = tf.one_hot(my_tensor, 5, None, None, 1)
20+
assert isinstance(arg, tf.Tensor)
21+
assert arg.dtype == tf.float32
22+
assert arg.shape == (2, 5, 3)
23+
24+
f(arg)

0 commit comments

Comments
 (0)