Skip to content

Commit 317bd3b

Browse files
committed
Initial support for tf.fill operation in TensorFlow models.
Also move tensor types.
1 parent 2cc65b9 commit 317bd3b

File tree

5 files changed

+128
-49
lines changed

5 files changed

+128
-49
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,12 @@ public void testAdd29()
12541254
@Test
12551255
public void testAdd30()
12561256
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1257-
test("tf2_test_add30.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1257+
test(
1258+
"tf2_test_add30.py",
1259+
"add",
1260+
2,
1261+
2,
1262+
Map.of(2, Set.of(TENSOR_1_2_INT32), 3, Set.of(TENSOR_2_2_INT32)));
12581263
}
12591264

12601265
@Test
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package com.ibm.wala.cast.python.ml.client;
2+
3+
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
4+
import com.ibm.wala.ipa.callgraph.CGNode;
5+
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
6+
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
7+
import java.util.List;
8+
import java.util.Set;
9+
10+
/**
11+
* A representation of the TensorFlow <code>fill()</code> function.
12+
*
13+
* @see <a href="https://www.tensorflow.org/api_docs/python/tf/fill">TensorFlow fill() API</a>.
14+
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
15+
*/
16+
public class Fill extends Constant {
17+
18+
private static final int VALUE_NUMBER_FOR_SHAPE_ARGUMENT = 2;
19+
20+
private static final int VALUE_NUMBER_FOR_VALUE_ARGUMENT = 3;
21+
22+
/**
23+
* The dtype argument is not explicitly provided to fill(); rather, the dtype is inferred from the
24+
* `value` argument.
25+
*/
26+
private static final int VALUE_NUMBER_FOR_DTYPE_ARGUMENT = -1;
27+
28+
public Fill(PointsToSetVariable source, CGNode node) {
29+
super(source, node);
30+
}
31+
32+
@Override
33+
protected int getValueNumberForDTypeArgument() {
34+
return VALUE_NUMBER_FOR_DTYPE_ARGUMENT;
35+
}
36+
37+
@Override
38+
protected int getValueNumberForValueArgument() {
39+
return VALUE_NUMBER_FOR_VALUE_ARGUMENT;
40+
}
41+
42+
@Override
43+
protected int getValueNumberForShapeArgument() {
44+
return VALUE_NUMBER_FOR_SHAPE_ARGUMENT;
45+
}
46+
47+
@Override
48+
protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder builder) {
49+
throw new UnsupportedOperationException("Shape is mandatory and must be provided explicitly.");
50+
}
51+
}

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java

Lines changed: 14 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,29 @@
11
package com.ibm.wala.cast.python.ml.client;
22

3-
import com.ibm.wala.cast.python.types.PythonTypes;
4-
import com.ibm.wala.cast.types.AstMethodReference;
3+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.CONSTANT;
4+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.FILL;
5+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONES;
6+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.RANGE;
7+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.UNIFORM;
8+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ZEROS;
9+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ZEROS_LIKE;
10+
511
import com.ibm.wala.ipa.callgraph.CGNode;
612
import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey;
713
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
814
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
9-
import com.ibm.wala.types.MethodReference;
10-
import com.ibm.wala.types.TypeName;
1115
import com.ibm.wala.types.TypeReference;
1216
import java.util.logging.Logger;
1317

18+
/**
19+
* A factory for creating TensorGenerator instances based on the called TensorFlow function.
20+
*
21+
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
22+
*/
1423
public class TensorGeneratorFactory {
1524

1625
private static final Logger LOGGER = Logger.getLogger(TensorGeneratorFactory.class.getName());
1726

18-
/** https://www.tensorflow.org/api_docs/python/tf/ones. */
19-
private static final MethodReference ONES =
20-
MethodReference.findOrCreate(
21-
TypeReference.findOrCreate(
22-
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/ones")),
23-
AstMethodReference.fnSelector);
24-
25-
/** https://www.tensorflow.org/api_docs/python/tf/constant. */
26-
private static final MethodReference CONSTANT =
27-
MethodReference.findOrCreate(
28-
TypeReference.findOrCreate(
29-
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/constant")),
30-
AstMethodReference.fnSelector);
31-
32-
/** https://www.tensorflow.org/api_docs/python/tf/range. */
33-
private static final MethodReference RANGE =
34-
MethodReference.findOrCreate(
35-
TypeReference.findOrCreate(
36-
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/range")),
37-
AstMethodReference.fnSelector);
38-
39-
/** https://www.tensorflow.org/api_docs/python/tf/random/uniform. */
40-
private static final MethodReference UNIFORM =
41-
MethodReference.findOrCreate(
42-
TypeReference.findOrCreate(
43-
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/uniform")),
44-
AstMethodReference.fnSelector);
45-
46-
/** https://www.tensorflow.org/api_docs/python/tf/zeros. */
47-
private static final MethodReference ZEROS =
48-
MethodReference.findOrCreate(
49-
TypeReference.findOrCreate(
50-
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/zeros")),
51-
AstMethodReference.fnSelector);
52-
53-
/** https://www.tensorflow.org/api_docs/python/tf/zeros_like. */
54-
private static final MethodReference ZEROS_LIKE =
55-
MethodReference.findOrCreate(
56-
TypeReference.findOrCreate(
57-
PythonTypes.pythonLoader,
58-
TypeName.string2TypeName("Ltensorflow/functions/zeros_like")),
59-
AstMethodReference.fnSelector);
60-
6127
public static TensorGenerator getGenerator(PointsToSetVariable source) {
6228
// Get the pointer key for the source.
6329
PointerKey pointerKey = source.getPointerKey();
@@ -75,6 +41,7 @@ public static TensorGenerator getGenerator(PointsToSetVariable source) {
7541
else if (calledFunction.equals(ZEROS.getDeclaringClass())) return new Zeros(source, node);
7642
else if (calledFunction.equals(ZEROS_LIKE.getDeclaringClass()))
7743
return new ZerosLike(source, node);
44+
else if (calledFunction.equals(FILL.getDeclaringClass())) return new Fill(source, node);
7845
else
7946
throw new IllegalArgumentException(
8047
"Unknown call: " + calledFunction + " for source: " + source + ".");

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import static com.ibm.wala.core.util.strings.Atom.findOrCreateAsciiAtom;
66

77
import com.ibm.wala.cast.python.types.PythonTypes;
8+
import com.ibm.wala.cast.types.AstMethodReference;
89
import com.ibm.wala.types.FieldReference;
10+
import com.ibm.wala.types.MethodReference;
911
import com.ibm.wala.types.TypeName;
1012
import com.ibm.wala.types.TypeReference;
1113
import java.util.Map;
@@ -45,6 +47,56 @@ public enum DType {
4547
public static final TypeReference D_TYPE =
4648
TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Ltensorflow/dtypes/DType"));
4749

50+
/** https://www.tensorflow.org/api_docs/python/tf/ones. */
51+
public static final MethodReference ONES =
52+
MethodReference.findOrCreate(
53+
TypeReference.findOrCreate(
54+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/ones")),
55+
AstMethodReference.fnSelector);
56+
57+
/** https://www.tensorflow.org/api_docs/python/tf/constant. */
58+
public static final MethodReference CONSTANT =
59+
MethodReference.findOrCreate(
60+
TypeReference.findOrCreate(
61+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/constant")),
62+
AstMethodReference.fnSelector);
63+
64+
/** https://www.tensorflow.org/api_docs/python/tf/range. */
65+
public static final MethodReference RANGE =
66+
MethodReference.findOrCreate(
67+
TypeReference.findOrCreate(
68+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/range")),
69+
AstMethodReference.fnSelector);
70+
71+
/** https://www.tensorflow.org/api_docs/python/tf/random/uniform. */
72+
public static final MethodReference UNIFORM =
73+
MethodReference.findOrCreate(
74+
TypeReference.findOrCreate(
75+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/uniform")),
76+
AstMethodReference.fnSelector);
77+
78+
/** https://www.tensorflow.org/api_docs/python/tf/zeros. */
79+
public static final MethodReference ZEROS =
80+
MethodReference.findOrCreate(
81+
TypeReference.findOrCreate(
82+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/zeros")),
83+
AstMethodReference.fnSelector);
84+
85+
/** https://www.tensorflow.org/api_docs/python/tf/zeros_like. */
86+
public static final MethodReference ZEROS_LIKE =
87+
MethodReference.findOrCreate(
88+
TypeReference.findOrCreate(
89+
PythonTypes.pythonLoader,
90+
TypeName.string2TypeName("Ltensorflow/functions/zeros_like")),
91+
AstMethodReference.fnSelector);
92+
93+
/** https://www.tensorflow.org/api_docs/python/tf/fill. */
94+
public static final MethodReference FILL =
95+
MethodReference.findOrCreate(
96+
TypeReference.findOrCreate(
97+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/fill")),
98+
AstMethodReference.fnSelector);
99+
48100
/**
49101
* Represents the TensorFlow float32 data type.
50102
*

com.ibm.wala.cast.python.test/data/tf2_test_add30.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,8 @@ def add(a, b):
55
return a + b
66

77

8-
c = add(tf.fill([1, 2], 2), tf.fill([2, 2], 1))
8+
arg1 = tf.fill([1, 2], 2)
9+
assert arg1.shape == (1, 2)
10+
assert arg1.dtype == tf.int32
11+
12+
c = add(arg1, tf.fill([2, 2], 1))

0 commit comments

Comments
 (0)