Skip to content

Commit 48b7d76

Browse files
committed
Redo number of parameters.
1 parent 462ce2d commit 48b7d76

File tree

2 files changed

+54
-99
lines changed
  • com.ibm.wala.cast.python.ml.test
  • com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client

2 files changed

+54
-99
lines changed

com.ibm.wala.cast.python.ml.test/.classpath

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
<classpath>
33
<classpathentry kind="src" output="target/test-classes" path="source">
44
<attributes>
5+
<attribute name="test" value="true"/>
56
<attribute name="optional" value="true"/>
67
<attribute name="maven.pomderived" value="true"/>
7-
<attribute name="test" value="true"/>
88
</attributes>
99
</classpathentry>
1010
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/JavaSE-25">
1111
<attributes>
12+
<attribute name="module" value="true"/>
1213
<attribute name="maven.pomderived" value="true"/>
1314
</attributes>
1415
</classpathentry>

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java

Lines changed: 52 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
11
package com.ibm.wala.cast.python.ml.client;
22

3-
import static com.ibm.wala.ipa.callgraph.propagation.cfa.CallStringContextSelector.CALL_STRING;
43
import static java.util.function.Function.identity;
54

65
import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType;
76
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
87
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
9-
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
10-
import com.ibm.wala.classLoader.CallSiteReference;
118
import com.ibm.wala.ipa.callgraph.CGNode;
129
import com.ibm.wala.ipa.callgraph.propagation.ConstantKey;
1310
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
1411
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
1512
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
1613
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
1714
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
18-
import com.ibm.wala.ipa.callgraph.propagation.cfa.CallString;
19-
import com.ibm.wala.ssa.SSAAbstractInvokeInstruction;
2015
import com.ibm.wala.util.collections.HashSetFactory;
2116
import com.ibm.wala.util.intset.OrdinalSet;
2217
import java.util.EnumSet;
23-
import java.util.Iterator;
2418
import java.util.List;
2519
import java.util.Set;
2620
import java.util.logging.Logger;
@@ -39,6 +33,7 @@
3933
*/
4034
public class Range extends TensorGenerator {
4135

36+
@SuppressWarnings("unused")
4237
private static final Logger LOGGER = Logger.getLogger(Range.class.getName());
4338

4439
private static final String FUNCTION_NAME = "tf.range()";
@@ -68,117 +63,76 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
6863
// Decide which version of the `range` function is being called based on the number of numeric
6964
// arguments.
7065
// TODO: Handle keyword arguments.
71-
for (Integer numOfPoisitionArguments : getNumberOfPossiblePositionalArguments(builder))
72-
if (numOfPoisitionArguments == 1) {
73-
// it must *just* be `limit`.
74-
PointerKey limitPK =
75-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 2);
76-
OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK);
77-
78-
assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit.";
79-
80-
for (InstanceKey limitIK : limitPointsToSet) {
81-
limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue();
82-
int shape = (int) Math.ceil((limit - start) / delta);
83-
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
84-
}
85-
} else if (numOfPoisitionArguments == 3) {
86-
// it must be `start`, `limit`, and `delta`.
87-
PointerKey startPK =
88-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 2);
89-
PointerKey limitPK =
90-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 3);
91-
PointerKey deltaPK =
92-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 4);
93-
94-
OrdinalSet<InstanceKey> startPointsToSet = pointerAnalysis.getPointsToSet(startPK);
95-
OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK);
96-
OrdinalSet<InstanceKey> deltaPointsToSet = pointerAnalysis.getPointsToSet(deltaPK);
97-
98-
assert !startPointsToSet.isEmpty() : "Expected a non-empty points-to set for start.";
99-
assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit.";
100-
assert !deltaPointsToSet.isEmpty() : "Expected a non-empty points-to set for delta.";
101-
102-
for (InstanceKey startIK : startPointsToSet) {
103-
start = ((Number) ((ConstantKey<?>) startIK).getValue()).doubleValue();
104-
105-
for (InstanceKey limitIK : limitPointsToSet) {
106-
limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue();
107-
108-
for (InstanceKey deltaIK : deltaPointsToSet) {
109-
delta = ((Number) ((ConstantKey<?>) deltaIK).getValue()).doubleValue();
110-
111-
int shape = (int) Math.ceil((limit - start) / delta);
112-
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
113-
}
114-
}
115-
}
116-
} else
117-
throw new IllegalStateException(
118-
"Expected either 1 or 3 positional arguments for range(), but got: "
119-
+ numOfPoisitionArguments
120-
+ ".");
121-
122-
return ret;
123-
}
124-
125-
/**
126-
* Returns the set of possible numbers of positional arguments passed to the range function at the
127-
* call.
128-
*
129-
* @param builder The {@link PropagationCallGraphBuilder} used for the analysis.
130-
* @return A set of integers representing the possible number of positional arguments.
131-
*/
132-
private Set<Integer> getNumberOfPossiblePositionalArguments(PropagationCallGraphBuilder builder) {
133-
Set<Integer> ret = HashSetFactory.make();
134-
135-
CallString cs = (CallString) this.getNode().getContext().get(CALL_STRING);
136-
CallSiteReference siteReference = cs.getCallSiteRefs()[0];
66+
int numberOfParameters =
67+
this.getNode().getMethod().isStatic()
68+
? this.getNode().getIR().getNumberOfParameters()
69+
: this.getNode().getIR().getNumberOfParameters() - 1;
70+
71+
if (numberOfParameters == 1) {
72+
// it must *just* be `limit`.
73+
PointerKey limitPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 2);
74+
OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK);
75+
76+
assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit.";
77+
78+
for (InstanceKey limitIK : limitPointsToSet) {
79+
limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue();
80+
int shape = (int) Math.ceil((limit - start) / delta);
81+
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
82+
}
83+
} else if (numberOfParameters == 3) {
84+
// it must be `start`, `limit`, and `delta`.
85+
PointerKey startPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 2);
86+
PointerKey limitPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 3);
87+
PointerKey deltaPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 4);
13788

138-
for (CGNode caller : builder.getCallGraph()) {
139-
for (Iterator<CallSiteReference> it = caller.getIR().iterateCallSites(); it.hasNext(); ) {
140-
CallSiteReference callSite = it.next();
89+
OrdinalSet<InstanceKey> startPointsToSet = pointerAnalysis.getPointsToSet(startPK);
90+
OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK);
91+
OrdinalSet<InstanceKey> deltaPointsToSet = pointerAnalysis.getPointsToSet(deltaPK);
14192

142-
if (callSite.equals(siteReference)) {
143-
// caller is the node that made the call.
144-
LOGGER.finest(() -> "Caller node: " + caller.getMethod().getSignature() + ".");
93+
assert !startPointsToSet.isEmpty() : "Expected a non-empty points-to set for start.";
94+
assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit.";
95+
assert !deltaPointsToSet.isEmpty() : "Expected a non-empty points-to set for delta.";
14596

146-
SSAAbstractInvokeInstruction[] calls = caller.getIR().getCalls(callSite);
147-
LOGGER.finest(() -> "Number of calls at this site: " + calls.length + ".");
97+
for (InstanceKey startIK : startPointsToSet) {
98+
start = ((Number) ((ConstantKey<?>) startIK).getValue()).doubleValue();
14899

149-
for (SSAAbstractInvokeInstruction callInstr : calls) {
150-
LOGGER.finest(() -> "Call instruction: " + callInstr + ".");
100+
for (InstanceKey limitIK : limitPointsToSet) {
101+
limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue();
151102

152-
PythonInvokeInstruction pyCallInstr = (PythonInvokeInstruction) callInstr;
153-
int numberOfPositionalParameters =
154-
pyCallInstr.getNumberOfPositionalParameters() - 1; // Exclude the function name.
155-
LOGGER.finer(
156-
() -> "Number of positional parameters: " + numberOfPositionalParameters + ".");
103+
for (InstanceKey deltaIK : deltaPointsToSet) {
104+
delta = ((Number) ((ConstantKey<?>) deltaIK).getValue()).doubleValue();
157105

158-
ret.add(numberOfPositionalParameters);
106+
int shape = (int) Math.ceil((limit - start) / delta);
107+
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
159108
}
160109
}
161110
}
162-
}
111+
} else
112+
throw new IllegalStateException(
113+
"Expected either 1 or 3 positional arguments for range(), but got: "
114+
+ numberOfParameters
115+
+ ".");
116+
163117
return ret;
164118
}
165119

166120
@Override
167121
protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
168122
// The dtype of the resulting tensor is inferred from the inputs unless it is provided
169123
// explicitly.
170-
171124
// TODO: Handle keyword arguments.
125+
int numberOfParameters =
126+
this.getNode().getMethod().isStatic()
127+
? this.getNode().getIR().getNumberOfParameters()
128+
: this.getNode().getIR().getNumberOfParameters() - 1;
129+
172130
EnumSet<DType> types =
173-
getNumberOfPossiblePositionalArguments(builder).stream()
174-
.map(
175-
numArgs ->
176-
IntStream.range(0, numArgs)
177-
.map(i -> i + 2) // Positional arguments start at index 2.
178-
.mapToObj(val -> getDTypes(builder, val).stream())
179-
.flatMap(identity())
180-
.distinct())
131+
IntStream.range(0, numberOfParameters)
132+
.map(i -> i + 2) // Positional arguments start at index 2.
133+
.mapToObj(val -> getDTypes(builder, val).stream())
181134
.flatMap(identity())
135+
.distinct()
182136
.collect(Collectors.toCollection(() -> EnumSet.noneOf(DType.class)));
183137

184138
// FIXME: We can't tell the difference here between varying dtypes in a single call and that of

0 commit comments

Comments
 (0)