Skip to content

Commit 0ce90af

Browse files
committed
Some progress on tf.one_hot().
1 parent b3176c6 commit 0ce90af

File tree

7 files changed

+243
-47
lines changed

7 files changed

+243
-47
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
8686
private static final TensorType TENSOR_3_2_FLOAT32 =
8787
new TensorType(FLOAT_32, asList(new NumericDim(3), new NumericDim(2)));
8888

89+
private static final TensorType TENSOR_3_3_FLOAT32 =
90+
new TensorType(FLOAT_32, asList(new NumericDim(3), new NumericDim(3)));
91+
8992
private static final TensorType TENSOR_2_1_FLOAT32 =
9093
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(1)));
9194

@@ -1304,13 +1307,23 @@ public void testAdd33()
13041307
@Test
13051308
public void testAdd34()
13061309
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1307-
test("tf2_test_add34.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1310+
test(
1311+
"tf2_test_add34.py",
1312+
"add",
1313+
2,
1314+
2,
1315+
Map.of(2, Set.of(TENSOR_3_3_FLOAT32), 3, Set.of(TENSOR_3_3_FLOAT32)));
13081316
}
13091317

13101318
@Test
13111319
public void testAdd35()
13121320
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1313-
test("tf2_test_add35.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1321+
test(
1322+
"tf2_test_add35.py",
1323+
"add",
1324+
2,
1325+
2,
1326+
Map.of(2, Set.of(TENSOR_3_3_FLOAT32), 3, Set.of(TENSOR_3_3_FLOAT32)));
13141327
}
13151328

13161329
@Test
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package com.ibm.wala.cast.python.ml.client;
2+
3+
import static com.ibm.wala.cast.python.ml.client.OneHot.Parameters.DEPTH;
4+
import static com.ibm.wala.cast.python.ml.client.OneHot.Parameters.DTYPE;
5+
import static com.ibm.wala.cast.python.ml.client.OneHot.Parameters.OFF_VALUE;
6+
import static com.ibm.wala.cast.python.ml.client.OneHot.Parameters.ON_VALUE;
7+
8+
import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType;
9+
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
10+
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
11+
import com.ibm.wala.ipa.callgraph.CGNode;
12+
import com.ibm.wala.ipa.callgraph.propagation.ConstantKey;
13+
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
14+
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
15+
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
16+
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
17+
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
18+
import com.ibm.wala.util.collections.HashSetFactory;
19+
import com.ibm.wala.util.intset.OrdinalSet;
20+
import java.util.ArrayList;
21+
import java.util.EnumSet;
22+
import java.util.List;
23+
import java.util.Set;
24+
25+
public class OneHot extends ZerosLike {
26+
27+
private static final String FUNCTION_NAME = "tf.one_hot()";
28+
29+
enum Parameters {
30+
INDICES,
31+
DEPTH,
32+
ON_VALUE,
33+
OFF_VALUE,
34+
AXIS,
35+
DTYPE
36+
}
37+
38+
public OneHot(PointsToSetVariable source, CGNode node) {
39+
super(source, node);
40+
}
41+
42+
@Override
43+
protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder builder) {
44+
throw new UnsupportedOperationException(
45+
"Shapes are derived from mandatory numeric arguments and must be provided explicitly.");
46+
}
47+
48+
@Override
49+
protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
50+
// If dtype is not provided, it will attempt to assume the data type of on_value or off_value,
51+
// if one or both are passed in. If none of on_value, off_value, or dtype are provided, dtype
52+
// will default to the value tf.float32.
53+
// TODO: Handle keyword arguments.
54+
EnumSet<DType> ret = EnumSet.noneOf(DType.class);
55+
Set<Integer> possiblePositionalArguments = this.getNumberOfPossiblePositionalArguments(builder);
56+
57+
for (int numArgs : possiblePositionalArguments)
58+
if (numArgs == Parameters.DEPTH.ordinal() + 1)
59+
// Neither on_value nor off_value is provided.
60+
ret.add(DType.FLOAT32);
61+
else if (numArgs <= Parameters.OFF_VALUE.ordinal() + 1) {
62+
// Either on_value and off_value are provided.
63+
EnumSet<DType> onValueDTypes =
64+
this.getDTypes(builder, this.getOnValueArgumentValueNumber());
65+
66+
if (!onValueDTypes.isEmpty()) ret.addAll(onValueDTypes);
67+
else {
68+
EnumSet<DType> offValueDTypes =
69+
this.getDTypes(builder, this.getOffValueArgumentValueNumber());
70+
ret.addAll(offValueDTypes);
71+
}
72+
}
73+
74+
return ret;
75+
}
76+
77+
@Override
78+
protected int getDTypeParameterPosition() {
79+
return DTYPE.ordinal();
80+
}
81+
82+
protected int getDepthParameterPosition() {
83+
return DEPTH.ordinal();
84+
}
85+
86+
protected int getOnValueParameterPosition() {
87+
return ON_VALUE.ordinal();
88+
}
89+
90+
protected int getOffValueParameterPosition() {
91+
return OFF_VALUE.ordinal();
92+
}
93+
94+
protected int getOnValueArgumentValueNumber() {
95+
return this.getArgumentValueNumber(this.getOnValueParameterPosition());
96+
}
97+
98+
protected int getOffValueArgumentValueNumber() {
99+
return this.getArgumentValueNumber(this.getOffValueParameterPosition());
100+
}
101+
102+
@Override
103+
protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder) {
104+
Set<List<Dimension<?>>> ret = HashSetFactory.make();
105+
Set<List<Dimension<?>>> indices = this.getShapes(builder, this.getValueArgumentValueNumber());
106+
int depthArgumentValueNumber = this.getDepthArgumentValueNumber();
107+
108+
if (depthArgumentValueNumber <= 0)
109+
throw new IllegalStateException(
110+
"No depth argument value found for OneHot tensor generation.");
111+
112+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
113+
114+
PointerKey pointerKey =
115+
pointerAnalysis
116+
.getHeapModel()
117+
.getPointerKeyForLocal(this.getNode(), depthArgumentValueNumber);
118+
OrdinalSet<InstanceKey> pointsToSet = pointerAnalysis.getPointsToSet(pointerKey);
119+
120+
if (pointsToSet == null || pointsToSet.isEmpty())
121+
throw new IllegalStateException(
122+
"No depth argument value found for OneHot tensor generation.");
123+
124+
for (InstanceKey instanceKey : pointsToSet) {
125+
int depth = getIntValueFromInstanceKey(instanceKey);
126+
127+
// For each shape in indices, append the depth as a new dimension.
128+
for (List<Dimension<?>> shape : indices) {
129+
NumericDim dim = new NumericDim(depth);
130+
131+
List<Dimension<?>> newShape = new ArrayList<>(shape);
132+
newShape.add(dim);
133+
ret.add(newShape);
134+
}
135+
}
136+
137+
assert ret.size() >= indices.size()
138+
: "Number of OneHot shapes should be at least the number of indices shapes.";
139+
140+
return ret;
141+
}
142+
143+
private static int getIntValueFromInstanceKey(InstanceKey instanceKey) {
144+
if (instanceKey instanceof ConstantKey) {
145+
ConstantKey<?> constantKey = (ConstantKey<?>) instanceKey;
146+
Object value = constantKey.getValue();
147+
return ((Long) value).intValue();
148+
}
149+
150+
throw new IllegalStateException(
151+
"Cannot get integer value from non-constant InstanceKey: " + instanceKey);
152+
}
153+
154+
private int getDepthArgumentValueNumber() {
155+
return this.getArgumentValueNumber(this.getDepthParameterPosition());
156+
}
157+
158+
@Override
159+
protected String getSignature() {
160+
return FUNCTION_NAME;
161+
}
162+
}

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
11
package com.ibm.wala.cast.python.ml.client;
22

3-
import static com.ibm.wala.ipa.callgraph.propagation.cfa.CallStringContextSelector.CALL_STRING;
43
import static java.util.function.Function.identity;
54

65
import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType;
76
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
87
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
9-
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
10-
import com.ibm.wala.classLoader.CallSiteReference;
118
import com.ibm.wala.ipa.callgraph.CGNode;
129
import com.ibm.wala.ipa.callgraph.propagation.ConstantKey;
1310
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
1411
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
1512
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
1613
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
1714
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
18-
import com.ibm.wala.ipa.callgraph.propagation.cfa.CallString;
19-
import com.ibm.wala.ssa.SSAAbstractInvokeInstruction;
2015
import com.ibm.wala.util.collections.HashSetFactory;
2116
import com.ibm.wala.util.intset.OrdinalSet;
2217
import java.util.EnumSet;
23-
import java.util.Iterator;
2418
import java.util.List;
2519
import java.util.Set;
2620
import java.util.logging.Logger;
@@ -39,6 +33,7 @@
3933
*/
4034
public class Range extends TensorGenerator {
4135

36+
@SuppressWarnings("unused")
4237
private static final Logger LOGGER = Logger.getLogger(Range.class.getName());
4338

4439
private static final String FUNCTION_NAME = "tf.range()";
@@ -143,44 +138,6 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
143138
return ret;
144139
}
145140

146-
/**
147-
* Returns the set of possible numbers of positional arguments passed to the range function at the
148-
* call.
149-
*
150-
* @param builder The {@link PropagationCallGraphBuilder} used for the analysis.
151-
* @return A set of integers representing the possible number of positional arguments.
152-
*/
153-
private Set<Integer> getNumberOfPossiblePositionalArguments(PropagationCallGraphBuilder builder) {
154-
Set<Integer> ret = HashSetFactory.make();
155-
156-
CallString cs = (CallString) this.getNode().getContext().get(CALL_STRING);
157-
CallSiteReference siteReference = cs.getCallSiteRefs()[0];
158-
LOGGER.fine(() -> "Analyzing call site: " + siteReference + ".");
159-
160-
for (Iterator<CGNode> it = builder.getCallGraph().getPredNodes(this.getNode());
161-
it.hasNext(); ) {
162-
CGNode caller = it.next();
163-
LOGGER.fine(() -> "Analyzing caller node: " + caller.getMethod().getSignature() + ".");
164-
165-
SSAAbstractInvokeInstruction[] calls = caller.getIR().getCalls(siteReference);
166-
LOGGER.finest(() -> "Number of calls at this site: " + calls.length + ".");
167-
168-
for (SSAAbstractInvokeInstruction callInstr : calls) {
169-
LOGGER.finest(() -> "Call instruction: " + callInstr + ".");
170-
171-
PythonInvokeInstruction pyCallInstr = (PythonInvokeInstruction) callInstr;
172-
int numberOfPositionalParameters =
173-
pyCallInstr.getNumberOfPositionalParameters() - 1; // Exclude the function name.
174-
LOGGER.finer(
175-
() -> "Number of positional parameters: " + numberOfPositionalParameters + ".");
176-
177-
ret.add(numberOfPositionalParameters);
178-
}
179-
}
180-
181-
return ret;
182-
}
183-
184141
@Override
185142
protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
186143
// The dtype of the resulting tensor is inferred from the inputs unless it is provided

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import com.ibm.wala.cast.python.ml.types.TensorType;
2121
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
2222
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
23+
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
2324
import com.ibm.wala.cast.python.types.PythonTypes;
25+
import com.ibm.wala.classLoader.CallSiteReference;
2426
import com.ibm.wala.classLoader.IClass;
2527
import com.ibm.wala.classLoader.IField;
2628
import com.ibm.wala.classLoader.IMethod;
@@ -34,6 +36,7 @@
3436
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
3537
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
3638
import com.ibm.wala.ipa.callgraph.propagation.cfa.CallString;
39+
import com.ibm.wala.ssa.SSAAbstractInvokeInstruction;
3740
import com.ibm.wala.types.Descriptor;
3841
import com.ibm.wala.types.FieldReference;
3942
import com.ibm.wala.types.MethodReference;
@@ -42,6 +45,7 @@
4245
import com.ibm.wala.util.intset.OrdinalSet;
4346
import java.util.ArrayList;
4447
import java.util.EnumSet;
48+
import java.util.Iterator;
4549
import java.util.List;
4650
import java.util.Map.Entry;
4751
import java.util.Set;
@@ -656,4 +660,43 @@ protected int getArgumentValueNumber(int parameterPosition) {
656660
? this.getNode().getIR().getParameter(parameterPosition)
657661
: this.getNode().getIR().getParameter(parameterPosition + 1);
658662
}
663+
664+
/**
665+
* Returns the set of possible numbers of positional arguments passed to the range function at the
666+
* call.
667+
*
668+
* @param builder The {@link PropagationCallGraphBuilder} used for the analysis.
669+
* @return A set of integers representing the possible number of positional arguments.
670+
*/
671+
protected Set<Integer> getNumberOfPossiblePositionalArguments(
672+
PropagationCallGraphBuilder builder) {
673+
Set<Integer> ret = HashSetFactory.make();
674+
675+
CallString cs = (CallString) this.getNode().getContext().get(CALL_STRING);
676+
CallSiteReference siteReference = cs.getCallSiteRefs()[0];
677+
LOGGER.fine(() -> "Analyzing call site: " + siteReference + ".");
678+
679+
for (Iterator<CGNode> it = builder.getCallGraph().getPredNodes(this.getNode());
680+
it.hasNext(); ) {
681+
CGNode caller = it.next();
682+
LOGGER.fine(() -> "Analyzing caller node: " + caller.getMethod().getSignature() + ".");
683+
684+
SSAAbstractInvokeInstruction[] calls = caller.getIR().getCalls(siteReference);
685+
LOGGER.finest(() -> "Number of calls at this site: " + calls.length + ".");
686+
687+
for (SSAAbstractInvokeInstruction callInstr : calls) {
688+
LOGGER.finest(() -> "Call instruction: " + callInstr + ".");
689+
690+
PythonInvokeInstruction pyCallInstr = (PythonInvokeInstruction) callInstr;
691+
int numberOfPositionalParameters =
692+
pyCallInstr.getNumberOfPositionalParameters() - 1; // Exclude the function name.
693+
LOGGER.finer(
694+
() -> "Number of positional parameters: " + numberOfPositionalParameters + ".");
695+
696+
ret.add(numberOfPositionalParameters);
697+
}
698+
}
699+
700+
return ret;
701+
}
659702
}

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.FILL;
66
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.NORMAL;
77
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONES;
8+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONE_HOT;
89
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.RANGE;
910
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.TRUNCATED_NORMAL;
1011
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.UNIFORM;
@@ -50,6 +51,7 @@ else if (calledFunction.equals(ZEROS_LIKE.getDeclaringClass()))
5051
else if (calledFunction.equals(FILL.getDeclaringClass())) return new Fill(source, node);
5152
else if (calledFunction.equals(CONVERT_TO_TENSOR.getDeclaringClass()))
5253
return new ConvertToTensor(source, node);
54+
else if (calledFunction.equals(ONE_HOT.getDeclaringClass())) return new OneHot(source, node);
5355
else
5456
throw new IllegalArgumentException(
5557
"Unknown call: " + calledFunction + " for source: " + source + ".");

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ public boolean canConvertTo(DType other) {
143143
TypeName.string2TypeName("Ltensorflow/functions/convert_to_tensor")),
144144
AstMethodReference.fnSelector);
145145

146+
/** https://www.tensorflow.org/api_docs/python/tf/one_hot. */
147+
public static final MethodReference ONE_HOT =
148+
MethodReference.findOrCreate(
149+
TypeReference.findOrCreate(
150+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/one_hot")),
151+
AstMethodReference.fnSelector);
152+
146153
/**
147154
* Represents the TensorFlow float32 data type.
148155
*

com.ibm.wala.cast.python.test/data/tf2_test_add34.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,16 @@ def add(a, b):
55
return a + b
66

77

8-
c = add(tf.one_hot([0, 1, 2], 3), tf.one_hot([2, 4, 3], 3))
8+
arg1 = [0, 1, 2]
9+
assert isinstance(arg1, list)
10+
assert all(isinstance(x, int) for x in arg1)
11+
assert len(arg1) == 3
12+
assert tf.convert_to_tensor(arg1).dtype == tf.int32
13+
assert tf.convert_to_tensor(arg1).shape == (3,)
14+
15+
arg2 = tf.one_hot(arg1, 3)
16+
assert isinstance(arg2, tf.Tensor)
17+
assert arg2.dtype == tf.float32
18+
assert arg2.shape == (3, 3)
19+
20+
c = add(arg2, tf.one_hot([2, 4, 3], 3))

0 commit comments

Comments
 (0)