|
1 | 1 | package com.ibm.wala.cast.python.ml.client; |
2 | 2 |
|
3 | | -import static com.ibm.wala.ipa.callgraph.propagation.cfa.CallStringContextSelector.CALL_STRING; |
4 | 3 | import static java.util.function.Function.identity; |
5 | 4 |
|
6 | 5 | import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType; |
7 | 6 | import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; |
8 | 7 | import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim; |
9 | | -import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction; |
10 | | -import com.ibm.wala.classLoader.CallSiteReference; |
11 | 8 | import com.ibm.wala.ipa.callgraph.CGNode; |
12 | 9 | import com.ibm.wala.ipa.callgraph.propagation.ConstantKey; |
13 | 10 | import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; |
14 | 11 | import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis; |
15 | 12 | import com.ibm.wala.ipa.callgraph.propagation.PointerKey; |
16 | 13 | import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; |
17 | 14 | import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; |
18 | | -import com.ibm.wala.ipa.callgraph.propagation.cfa.CallString; |
19 | | -import com.ibm.wala.ssa.SSAAbstractInvokeInstruction; |
20 | 15 | import com.ibm.wala.util.collections.HashSetFactory; |
21 | 16 | import com.ibm.wala.util.intset.OrdinalSet; |
22 | 17 | import java.util.EnumSet; |
23 | | -import java.util.Iterator; |
24 | 18 | import java.util.List; |
25 | 19 | import java.util.Set; |
26 | 20 | import java.util.logging.Logger; |
|
39 | 33 | */ |
40 | 34 | public class Range extends TensorGenerator { |
41 | 35 |
|
| 36 | + @SuppressWarnings("unused") |
42 | 37 | private static final Logger LOGGER = Logger.getLogger(Range.class.getName()); |
43 | 38 |
|
44 | 39 | private static final String FUNCTION_NAME = "tf.range()"; |
@@ -68,117 +63,76 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder) |
68 | 63 | // Decide which version of the `range` function is being called based on the number of numeric |
69 | 64 | // arguments. |
70 | 65 | // TODO: Handle keyword arguments. |
71 | | - for (Integer numOfPoisitionArguments : getNumberOfPossiblePositionalArguments(builder)) |
72 | | - if (numOfPoisitionArguments == 1) { |
73 | | - // it must *just* be `limit`. |
74 | | - PointerKey limitPK = |
75 | | - pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 2); |
76 | | - OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK); |
77 | | - |
78 | | - assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit."; |
79 | | - |
80 | | - for (InstanceKey limitIK : limitPointsToSet) { |
81 | | - limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue(); |
82 | | - int shape = (int) Math.ceil((limit - start) / delta); |
83 | | - ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor. |
84 | | - } |
85 | | - } else if (numOfPoisitionArguments == 3) { |
86 | | - // it must be `start`, `limit`, and `delta`. |
87 | | - PointerKey startPK = |
88 | | - pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 2); |
89 | | - PointerKey limitPK = |
90 | | - pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 3); |
91 | | - PointerKey deltaPK = |
92 | | - pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 4); |
93 | | - |
94 | | - OrdinalSet<InstanceKey> startPointsToSet = pointerAnalysis.getPointsToSet(startPK); |
95 | | - OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK); |
96 | | - OrdinalSet<InstanceKey> deltaPointsToSet = pointerAnalysis.getPointsToSet(deltaPK); |
97 | | - |
98 | | - assert !startPointsToSet.isEmpty() : "Expected a non-empty points-to set for start."; |
99 | | - assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit."; |
100 | | - assert !deltaPointsToSet.isEmpty() : "Expected a non-empty points-to set for delta."; |
101 | | - |
102 | | - for (InstanceKey startIK : startPointsToSet) { |
103 | | - start = ((Number) ((ConstantKey<?>) startIK).getValue()).doubleValue(); |
104 | | - |
105 | | - for (InstanceKey limitIK : limitPointsToSet) { |
106 | | - limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue(); |
107 | | - |
108 | | - for (InstanceKey deltaIK : deltaPointsToSet) { |
109 | | - delta = ((Number) ((ConstantKey<?>) deltaIK).getValue()).doubleValue(); |
110 | | - |
111 | | - int shape = (int) Math.ceil((limit - start) / delta); |
112 | | - ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor. |
113 | | - } |
114 | | - } |
115 | | - } |
116 | | - } else |
117 | | - throw new IllegalStateException( |
118 | | - "Expected either 1 or 3 positional arguments for range(), but got: " |
119 | | - + numOfPoisitionArguments |
120 | | - + "."); |
121 | | - |
122 | | - return ret; |
123 | | - } |
124 | | - |
125 | | - /** |
126 | | - * Returns the set of possible numbers of positional arguments passed to the range function at the |
127 | | - * call. |
128 | | - * |
129 | | - * @param builder The {@link PropagationCallGraphBuilder} used for the analysis. |
130 | | - * @return A set of integers representing the possible number of positional arguments. |
131 | | - */ |
132 | | - private Set<Integer> getNumberOfPossiblePositionalArguments(PropagationCallGraphBuilder builder) { |
133 | | - Set<Integer> ret = HashSetFactory.make(); |
134 | | - |
135 | | - CallString cs = (CallString) this.getNode().getContext().get(CALL_STRING); |
136 | | - CallSiteReference siteReference = cs.getCallSiteRefs()[0]; |
| 66 | + int numberOfParameters = |
| 67 | + this.getNode().getMethod().isStatic() |
| 68 | + ? this.getNode().getIR().getNumberOfParameters() |
| 69 | + : this.getNode().getIR().getNumberOfParameters() - 1; |
| 70 | + |
| 71 | + if (numberOfParameters == 1) { |
| 72 | + // it must *just* be `limit`. |
| 73 | + PointerKey limitPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 2); |
| 74 | + OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK); |
| 75 | + |
| 76 | + assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit."; |
| 77 | + |
| 78 | + for (InstanceKey limitIK : limitPointsToSet) { |
| 79 | + limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue(); |
| 80 | + int shape = (int) Math.ceil((limit - start) / delta); |
| 81 | + ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor. |
| 82 | + } |
| 83 | + } else if (numberOfParameters == 3) { |
| 84 | + // it must be `start`, `limit`, and `delta`. |
| 85 | + PointerKey startPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 2); |
| 86 | + PointerKey limitPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 3); |
| 87 | + PointerKey deltaPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 4); |
137 | 88 |
|
138 | | - for (CGNode caller : builder.getCallGraph()) { |
139 | | - for (Iterator<CallSiteReference> it = caller.getIR().iterateCallSites(); it.hasNext(); ) { |
140 | | - CallSiteReference callSite = it.next(); |
| 89 | + OrdinalSet<InstanceKey> startPointsToSet = pointerAnalysis.getPointsToSet(startPK); |
| 90 | + OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK); |
| 91 | + OrdinalSet<InstanceKey> deltaPointsToSet = pointerAnalysis.getPointsToSet(deltaPK); |
141 | 92 |
|
142 | | - if (callSite.equals(siteReference)) { |
143 | | - // caller is the node that made the call. |
144 | | - LOGGER.finest(() -> "Caller node: " + caller.getMethod().getSignature() + "."); |
| 93 | + assert !startPointsToSet.isEmpty() : "Expected a non-empty points-to set for start."; |
| 94 | + assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit."; |
| 95 | + assert !deltaPointsToSet.isEmpty() : "Expected a non-empty points-to set for delta."; |
145 | 96 |
|
146 | | - SSAAbstractInvokeInstruction[] calls = caller.getIR().getCalls(callSite); |
147 | | - LOGGER.finest(() -> "Number of calls at this site: " + calls.length + "."); |
| 97 | + for (InstanceKey startIK : startPointsToSet) { |
| 98 | + start = ((Number) ((ConstantKey<?>) startIK).getValue()).doubleValue(); |
148 | 99 |
|
149 | | - for (SSAAbstractInvokeInstruction callInstr : calls) { |
150 | | - LOGGER.finest(() -> "Call instruction: " + callInstr + "."); |
| 100 | + for (InstanceKey limitIK : limitPointsToSet) { |
| 101 | + limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue(); |
151 | 102 |
|
152 | | - PythonInvokeInstruction pyCallInstr = (PythonInvokeInstruction) callInstr; |
153 | | - int numberOfPositionalParameters = |
154 | | - pyCallInstr.getNumberOfPositionalParameters() - 1; // Exclude the function name. |
155 | | - LOGGER.finer( |
156 | | - () -> "Number of positional parameters: " + numberOfPositionalParameters + "."); |
| 103 | + for (InstanceKey deltaIK : deltaPointsToSet) { |
| 104 | + delta = ((Number) ((ConstantKey<?>) deltaIK).getValue()).doubleValue(); |
157 | 105 |
|
158 | | - ret.add(numberOfPositionalParameters); |
| 106 | + int shape = (int) Math.ceil((limit - start) / delta); |
| 107 | + ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor. |
159 | 108 | } |
160 | 109 | } |
161 | 110 | } |
162 | | - } |
| 111 | + } else |
| 112 | + throw new IllegalStateException( |
| 113 | + "Expected either 1 or 3 positional arguments for range(), but got: " |
| 114 | + + numberOfParameters |
| 115 | + + "."); |
| 116 | + |
163 | 117 | return ret; |
164 | 118 | } |
165 | 119 |
|
166 | 120 | @Override |
167 | 121 | protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) { |
168 | 122 | // The dtype of the resulting tensor is inferred from the inputs unless it is provided |
169 | 123 | // explicitly. |
170 | | - |
171 | 124 | // TODO: Handle keyword arguments. |
| 125 | + int numberOfParameters = |
| 126 | + this.getNode().getMethod().isStatic() |
| 127 | + ? this.getNode().getIR().getNumberOfParameters() |
| 128 | + : this.getNode().getIR().getNumberOfParameters() - 1; |
| 129 | + |
172 | 130 | EnumSet<DType> types = |
173 | | - getNumberOfPossiblePositionalArguments(builder).stream() |
174 | | - .map( |
175 | | - numArgs -> |
176 | | - IntStream.range(0, numArgs) |
177 | | - .map(i -> i + 2) // Positional arguments start at index 2. |
178 | | - .mapToObj(val -> getDTypes(builder, val).stream()) |
179 | | - .flatMap(identity()) |
180 | | - .distinct()) |
| 131 | + IntStream.range(0, numberOfParameters) |
| 132 | + .map(i -> i + 2) // Positional arguments start at index 2. |
| 133 | + .mapToObj(val -> getDTypes(builder, val).stream()) |
181 | 134 | .flatMap(identity()) |
| 135 | + .distinct() |
182 | 136 | .collect(Collectors.toCollection(() -> EnumSet.noneOf(DType.class))); |
183 | 137 |
|
184 | 138 | // FIXME: We can't tell the difference here between varying dtypes in a single call and that of |
|
0 commit comments