Skip to content

Commit d1e2ec5

Browse files
committed
Can't use declared number of positional parameters.
We need to consider the calling contexts.
1 parent b606645 commit d1e2ec5

File tree

2 files changed

+116
-77
lines changed

2 files changed

+116
-77
lines changed
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
eclipse.preferences.version=1
2-
encoding//src/main/resources=UTF-8
3-
encoding//src/test/resources=UTF-8
42
encoding/<project>=UTF-8
53
encoding/source=UTF-8

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java

Lines changed: 116 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
package com.ibm.wala.cast.python.ml.client;
22

3+
import static com.ibm.wala.ipa.callgraph.propagation.cfa.CallStringContextSelector.CALL_STRING;
34
import static java.util.function.Function.identity;
45

56
import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType;
67
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
78
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
9+
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
10+
import com.ibm.wala.classLoader.CallSiteReference;
811
import com.ibm.wala.ipa.callgraph.CGNode;
912
import com.ibm.wala.ipa.callgraph.propagation.ConstantKey;
1013
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
1114
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
1215
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
1316
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
1417
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
18+
import com.ibm.wala.ipa.callgraph.propagation.cfa.CallString;
19+
import com.ibm.wala.ssa.SSAAbstractInvokeInstruction;
1520
import com.ibm.wala.util.collections.HashSetFactory;
1621
import com.ibm.wala.util.intset.OrdinalSet;
1722
import java.util.EnumSet;
23+
import java.util.Iterator;
1824
import java.util.List;
1925
import java.util.Set;
2026
import java.util.logging.Logger;
@@ -33,7 +39,6 @@
3339
*/
3440
public class Range extends TensorGenerator {
3541

36-
@SuppressWarnings("unused")
3742
private static final Logger LOGGER = Logger.getLogger(Range.class.getName());
3843

3944
private static final String FUNCTION_NAME = "tf.range()";
@@ -48,8 +53,7 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
4853
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
4954

5055
// The shape of a range tensor is always a 1D tensor with the length equal to the number of
51-
// elements in the range.
52-
// For example, `tf.range(5)` produces a tensor with shape (5,).
56+
// elements in the range. For example, `tf.range(5)` produces a tensor with shape (5,).
5357

5458
double start = 0; // Default start value.
5559
double limit = start; // Default limit value.
@@ -63,82 +67,119 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
6367
// Decide which version of the `range` function is being called based on the number of numeric
6468
// arguments.
6569
// TODO: Handle keyword arguments.
66-
int numberOfParameters =
67-
this.getNode().getMethod().isStatic()
68-
? this.getNode().getIR().getNumberOfParameters()
69-
: this.getNode().getIR().getNumberOfParameters() - 1;
70-
71-
if (numberOfParameters == 1) {
72-
// it must *just* be `limit`.
73-
int limitValueNumber =
74-
this.getNode().getMethod().isStatic()
75-
? this.getNode().getIR().getParameter(0)
76-
: this.getNode().getIR().getParameter(1);
77-
78-
PointerKey limitPK =
79-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), limitValueNumber);
80-
OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK);
81-
82-
assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit.";
83-
84-
for (InstanceKey limitIK : limitPointsToSet) {
85-
limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue();
86-
int shape = (int) Math.ceil((limit - start) / delta);
87-
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
88-
}
89-
} else if (numberOfParameters == 3) {
90-
// it must be `start`, `limit`, and `delta`.
91-
int startValueNumber =
92-
this.getNode().getMethod().isStatic()
93-
? this.getNode().getIR().getParameter(0)
94-
: this.getNode().getIR().getParameter(1);
70+
for (Integer numOfPoisitionArguments : getNumberOfPossiblePositionalArguments(builder))
71+
if (numOfPoisitionArguments == 1) {
72+
// it must *just* be `limit`.
73+
int limitValueNumber =
74+
this.getNode().getMethod().isStatic()
75+
? this.getNode().getIR().getParameter(0)
76+
: this.getNode().getIR().getParameter(1);
9577

96-
PointerKey startPK =
97-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), startValueNumber);
78+
PointerKey limitPK =
79+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), limitValueNumber);
80+
OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK);
9881

99-
int limitValueNumber =
100-
this.getNode().getMethod().isStatic()
101-
? this.getNode().getIR().getParameter(1)
102-
: this.getNode().getIR().getParameter(2);
82+
assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit.";
10383

104-
PointerKey limitPK =
105-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), limitValueNumber);
84+
for (InstanceKey limitIK : limitPointsToSet) {
85+
limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue();
86+
int shape = (int) Math.ceil((limit - start) / delta);
87+
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
88+
}
89+
} else if (numOfPoisitionArguments == 3) {
90+
// it must be `start`, `limit`, and `delta`.
91+
int startValueNumber =
92+
this.getNode().getMethod().isStatic()
93+
? this.getNode().getIR().getParameter(0)
94+
: this.getNode().getIR().getParameter(1);
10695

107-
int deltaValueNumber =
108-
this.getNode().getMethod().isStatic()
109-
? this.getNode().getIR().getParameter(2)
110-
: this.getNode().getIR().getParameter(3);
96+
PointerKey startPK =
97+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), startValueNumber);
11198

112-
PointerKey deltaPK =
113-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), deltaValueNumber);
99+
int limitValueNumber =
100+
this.getNode().getMethod().isStatic()
101+
? this.getNode().getIR().getParameter(1)
102+
: this.getNode().getIR().getParameter(2);
114103

115-
OrdinalSet<InstanceKey> startPointsToSet = pointerAnalysis.getPointsToSet(startPK);
116-
OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK);
117-
OrdinalSet<InstanceKey> deltaPointsToSet = pointerAnalysis.getPointsToSet(deltaPK);
104+
PointerKey limitPK =
105+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), limitValueNumber);
118106

119-
assert !startPointsToSet.isEmpty() : "Expected a non-empty points-to set for start.";
120-
assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit.";
121-
assert !deltaPointsToSet.isEmpty() : "Expected a non-empty points-to set for delta.";
107+
int deltaValueNumber =
108+
this.getNode().getMethod().isStatic()
109+
? this.getNode().getIR().getParameter(2)
110+
: this.getNode().getIR().getParameter(3);
122111

123-
for (InstanceKey startIK : startPointsToSet) {
124-
start = ((Number) ((ConstantKey<?>) startIK).getValue()).doubleValue();
112+
PointerKey deltaPK =
113+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), deltaValueNumber);
125114

126-
for (InstanceKey limitIK : limitPointsToSet) {
127-
limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue();
115+
OrdinalSet<InstanceKey> startPointsToSet = pointerAnalysis.getPointsToSet(startPK);
116+
OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK);
117+
OrdinalSet<InstanceKey> deltaPointsToSet = pointerAnalysis.getPointsToSet(deltaPK);
118+
119+
assert !startPointsToSet.isEmpty() : "Expected a non-empty points-to set for start.";
120+
assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit.";
121+
assert !deltaPointsToSet.isEmpty() : "Expected a non-empty points-to set for delta.";
122+
123+
for (InstanceKey startIK : startPointsToSet) {
124+
start = ((Number) ((ConstantKey<?>) startIK).getValue()).doubleValue();
125+
126+
for (InstanceKey limitIK : limitPointsToSet) {
127+
limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue();
128+
129+
for (InstanceKey deltaIK : deltaPointsToSet) {
130+
delta = ((Number) ((ConstantKey<?>) deltaIK).getValue()).doubleValue();
131+
132+
int shape = (int) Math.ceil((limit - start) / delta);
133+
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
134+
}
135+
}
136+
}
137+
} else
138+
throw new IllegalStateException(
139+
"Expected either 1 or 3 positional arguments for range(), but got: "
140+
+ numOfPoisitionArguments
141+
+ ".");
142+
143+
return ret;
144+
}
145+
146+
/**
147+
* Returns the set of possible numbers of positional arguments passed to the range function at the
148+
* call.
149+
*
150+
* @param builder The {@link PropagationCallGraphBuilder} used for the analysis.
151+
* @return A set of integers representing the possible number of positional arguments.
152+
*/
153+
private Set<Integer> getNumberOfPossiblePositionalArguments(PropagationCallGraphBuilder builder) {
154+
Set<Integer> ret = HashSetFactory.make();
155+
156+
CallString cs = (CallString) this.getNode().getContext().get(CALL_STRING);
157+
CallSiteReference siteReference = cs.getCallSiteRefs()[0];
128158

129-
for (InstanceKey deltaIK : deltaPointsToSet) {
130-
delta = ((Number) ((ConstantKey<?>) deltaIK).getValue()).doubleValue();
159+
for (CGNode caller : builder.getCallGraph())
160+
for (Iterator<CallSiteReference> it = caller.getIR().iterateCallSites(); it.hasNext(); ) {
161+
CallSiteReference callSite = it.next();
131162

132-
int shape = (int) Math.ceil((limit - start) / delta);
133-
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
163+
if (callSite.equals(siteReference)) {
164+
// caller is the node that made the call.
165+
LOGGER.finest(() -> "Caller node: " + caller.getMethod().getSignature() + ".");
166+
167+
SSAAbstractInvokeInstruction[] calls = caller.getIR().getCalls(callSite);
168+
LOGGER.finest(() -> "Number of calls at this site: " + calls.length + ".");
169+
170+
for (SSAAbstractInvokeInstruction callInstr : calls) {
171+
LOGGER.finest(() -> "Call instruction: " + callInstr + ".");
172+
173+
PythonInvokeInstruction pyCallInstr = (PythonInvokeInstruction) callInstr;
174+
int numberOfPositionalParameters =
175+
pyCallInstr.getNumberOfPositionalParameters() - 1; // Exclude the function name.
176+
LOGGER.finer(
177+
() -> "Number of positional parameters: " + numberOfPositionalParameters + ".");
178+
179+
ret.add(numberOfPositionalParameters);
134180
}
135181
}
136182
}
137-
} else
138-
throw new IllegalStateException(
139-
"Expected either 1 or 3 positional arguments for range(), but got: "
140-
+ numberOfParameters
141-
+ ".");
142183

143184
return ret;
144185
}
@@ -147,19 +188,19 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
147188
protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
148189
// The dtype of the resulting tensor is inferred from the inputs unless it is provided
149190
// explicitly.
150-
// TODO: Handle keyword arguments.
151-
int numberOfParameters =
152-
this.getNode().getMethod().isStatic()
153-
? this.getNode().getIR().getNumberOfParameters()
154-
: this.getNode().getIR().getNumberOfParameters() - 1;
155191

192+
// TODO: Handle keyword arguments.
156193
EnumSet<DType> types =
157-
IntStream.range(0, numberOfParameters)
158-
.map(i -> this.getNode().getIR().getMethod().isStatic() ? i : i + 1)
159-
.map(this.getNode().getIR()::getParameter)
160-
.mapToObj(val -> getDTypes(builder, val).stream())
194+
getNumberOfPossiblePositionalArguments(builder).stream()
195+
.map(
196+
numArgs ->
197+
IntStream.range(0, numArgs)
198+
.map(i -> this.getNode().getIR().getMethod().isStatic() ? i : i + 1)
199+
.map(this.getNode().getIR()::getParameter)
200+
.mapToObj(val -> getDTypes(builder, val).stream())
201+
.flatMap(identity())
202+
.distinct())
161203
.flatMap(identity())
162-
.distinct()
163204
.collect(Collectors.toCollection(() -> EnumSet.noneOf(DType.class)));
164205

165206
// FIXME: We can't tell the difference here between varying dtypes in a single call and that of

0 commit comments

Comments
 (0)