Skip to content

Commit 2ce5856

Browse files
committed
Progress.
1 parent 2344618 commit 2ce5856

File tree

5 files changed

+48
-3
lines changed

5 files changed

+48
-3
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,13 +1791,23 @@ public void testAdd108()
17911791
@Test
17921792
public void testAdd109()
17931793
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1794-
test("tf2_test_add109.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1794+
test(
1795+
"tf2_test_add109.py",
1796+
"add",
1797+
2,
1798+
2,
1799+
Map.of(2, Set.of(TENSOR_2_FLOAT32), 3, Set.of(TENSOR_2_FLOAT32)));
17951800
}
17961801

17971802
@Test
17981803
public void testAdd110()
17991804
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1800-
test("tf2_test_add110.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1805+
test(
1806+
"tf2_test_add110.py",
1807+
"add",
1808+
2,
1809+
2,
1810+
Map.of(2, Set.of(TENSOR_2_FLOAT32), 3, Set.of(TENSOR_2_FLOAT32)));
18011811
}
18021812

18031813
@Test

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.NORMAL;
66
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONES;
77
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.RANGE;
8+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.TRUNCATED_NORMAL;
89
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.UNIFORM;
910
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ZEROS;
1011
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ZEROS_LIKE;
@@ -40,6 +41,8 @@ public static TensorGenerator getGenerator(PointsToSetVariable source) {
4041
else if (calledFunction.equals(RANGE.getDeclaringClass())) return new Range(source, node);
4142
else if (calledFunction.equals(UNIFORM.getDeclaringClass())) return new Uniform(source, node);
4243
else if (calledFunction.equals(NORMAL.getDeclaringClass())) return new Normal(source, node);
44+
else if (calledFunction.equals(TRUNCATED_NORMAL.getDeclaringClass()))
45+
return new TruncatedNormal(source, node);
4346
else if (calledFunction.equals(ZEROS.getDeclaringClass())) return new Zeros(source, node);
4447
else if (calledFunction.equals(ZEROS_LIKE.getDeclaringClass()))
4548
return new ZerosLike(source, node);
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.ibm.wala.cast.python.ml.client;
2+
3+
import com.ibm.wala.ipa.callgraph.CGNode;
4+
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
5+
6+
/**
7+
* A representation of the `tf.random.truncated_normal' API in TensorFlow.
8+
*
9+
* @see <a
10+
* href="https://www.tensorflow.org/api_docs/python/tf/random/truncated_normal">tf.random.truncated_normal
11+
* API</a>.
12+
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
13+
*/
14+
public class TruncatedNormal extends Normal {
15+
16+
public TruncatedNormal(PointsToSetVariable source, CGNode node) {
17+
super(source, node);
18+
// TODO Auto-generated constructor stub
19+
}
20+
}

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ public enum DType {
8383
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/normal")),
8484
AstMethodReference.fnSelector);
8585

86+
/** https://www.tensorflow.org/api_docs/python/tf/random/truncated_normal. */
87+
public static final MethodReference TRUNCATED_NORMAL =
88+
MethodReference.findOrCreate(
89+
TypeReference.findOrCreate(
90+
PythonTypes.pythonLoader,
91+
TypeName.string2TypeName("Ltensorflow/functions/truncated_normal")),
92+
AstMethodReference.fnSelector);
93+
8694
/** https://www.tensorflow.org/api_docs/python/tf/zeros. */
8795
public static final MethodReference ZEROS =
8896
MethodReference.findOrCreate(

com.ibm.wala.cast.python.test/data/tf2_test_add109.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,9 @@
44
def add(a, b):
55
return a + b
66

7+
arg1 = tf.random.truncated_normal([2])
8+
assert isinstance(arg1, tf.Tensor)
9+
assert arg1.dtype == tf.float32
10+
assert arg1.shape == (2,)
711

8-
c = add(tf.random.truncated_normal([2]), tf.random.truncated_normal([2], 3, 1))
12+
c = add(arg1, tf.random.truncated_normal([2], 3, 1))

0 commit comments

Comments
 (0)