11package com .ibm .wala .cast .python .ml .client ;
22
3+ import static com .ibm .wala .ipa .callgraph .propagation .cfa .CallStringContextSelector .CALL_STRING ;
34import static java .util .function .Function .identity ;
45
56import com .ibm .wala .cast .python .ml .types .TensorFlowTypes .DType ;
67import com .ibm .wala .cast .python .ml .types .TensorType .Dimension ;
78import com .ibm .wala .cast .python .ml .types .TensorType .NumericDim ;
9+ import com .ibm .wala .cast .python .ssa .PythonInvokeInstruction ;
10+ import com .ibm .wala .classLoader .CallSiteReference ;
811import com .ibm .wala .ipa .callgraph .CGNode ;
912import com .ibm .wala .ipa .callgraph .propagation .ConstantKey ;
1013import com .ibm .wala .ipa .callgraph .propagation .InstanceKey ;
1114import com .ibm .wala .ipa .callgraph .propagation .PointerAnalysis ;
1215import com .ibm .wala .ipa .callgraph .propagation .PointerKey ;
1316import com .ibm .wala .ipa .callgraph .propagation .PointsToSetVariable ;
1417import com .ibm .wala .ipa .callgraph .propagation .PropagationCallGraphBuilder ;
18+ import com .ibm .wala .ipa .callgraph .propagation .cfa .CallString ;
19+ import com .ibm .wala .ssa .SSAAbstractInvokeInstruction ;
1520import com .ibm .wala .util .collections .HashSetFactory ;
1621import com .ibm .wala .util .intset .OrdinalSet ;
1722import java .util .EnumSet ;
23+ import java .util .Iterator ;
1824import java .util .List ;
1925import java .util .Set ;
2026import java .util .logging .Logger ;
3339 */
3440public class Range extends TensorGenerator {
3541
36- @ SuppressWarnings ("unused" )
3742 private static final Logger LOGGER = Logger .getLogger (Range .class .getName ());
3843
3944 private static final String FUNCTION_NAME = "tf.range()" ;
@@ -48,8 +53,7 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
4853 PointerAnalysis <InstanceKey > pointerAnalysis = builder .getPointerAnalysis ();
4954
5055 // The shape of a range tensor is always a 1D tensor with the length equal to the number of
51- // elements in the range.
52- // For example, `tf.range(5)` produces a tensor with shape (5,).
56+ // elements in the range. For example, `tf.range(5)` produces a tensor with shape (5,).
5357
5458 double start = 0 ; // Default start value.
5559 double limit = start ; // Default limit value.
@@ -63,82 +67,119 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
6367 // Decide which version of the `range` function is being called based on the number of numeric
6468 // arguments.
6569 // TODO: Handle keyword arguments.
66- int numberOfParameters =
67- this .getNode ().getMethod ().isStatic ()
68- ? this .getNode ().getIR ().getNumberOfParameters ()
69- : this .getNode ().getIR ().getNumberOfParameters () - 1 ;
70-
71- if (numberOfParameters == 1 ) {
72- // it must *just* be `limit`.
73- int limitValueNumber =
74- this .getNode ().getMethod ().isStatic ()
75- ? this .getNode ().getIR ().getParameter (0 )
76- : this .getNode ().getIR ().getParameter (1 );
77-
78- PointerKey limitPK =
79- pointerAnalysis .getHeapModel ().getPointerKeyForLocal (this .getNode (), limitValueNumber );
80- OrdinalSet <InstanceKey > limitPointsToSet = pointerAnalysis .getPointsToSet (limitPK );
81-
82- assert !limitPointsToSet .isEmpty () : "Expected a non-empty points-to set for limit." ;
83-
84- for (InstanceKey limitIK : limitPointsToSet ) {
85- limit = ((Number ) ((ConstantKey <?>) limitIK ).getValue ()).doubleValue ();
86- int shape = (int ) Math .ceil ((limit - start ) / delta );
87- ret .add (List .of (new NumericDim (shape ))); // Add the shape as a 1D tensor.
88- }
89- } else if (numberOfParameters == 3 ) {
90- // it must be `start`, `limit`, and `delta`.
91- int startValueNumber =
92- this .getNode ().getMethod ().isStatic ()
93- ? this .getNode ().getIR ().getParameter (0 )
94- : this .getNode ().getIR ().getParameter (1 );
70+ for (Integer numOfPoisitionArguments : getNumberOfPossiblePositionalArguments (builder ))
71+ if (numOfPoisitionArguments == 1 ) {
72+ // it must *just* be `limit`.
73+ int limitValueNumber =
74+ this .getNode ().getMethod ().isStatic ()
75+ ? this .getNode ().getIR ().getParameter (0 )
76+ : this .getNode ().getIR ().getParameter (1 );
9577
96- PointerKey startPK =
97- pointerAnalysis .getHeapModel ().getPointerKeyForLocal (this .getNode (), startValueNumber );
78+ PointerKey limitPK =
79+ pointerAnalysis .getHeapModel ().getPointerKeyForLocal (this .getNode (), limitValueNumber );
80+ OrdinalSet <InstanceKey > limitPointsToSet = pointerAnalysis .getPointsToSet (limitPK );
9881
99- int limitValueNumber =
100- this .getNode ().getMethod ().isStatic ()
101- ? this .getNode ().getIR ().getParameter (1 )
102- : this .getNode ().getIR ().getParameter (2 );
82+ assert !limitPointsToSet .isEmpty () : "Expected a non-empty points-to set for limit." ;
10383
104- PointerKey limitPK =
105- pointerAnalysis .getHeapModel ().getPointerKeyForLocal (this .getNode (), limitValueNumber );
84+ for (InstanceKey limitIK : limitPointsToSet ) {
85+ limit = ((Number ) ((ConstantKey <?>) limitIK ).getValue ()).doubleValue ();
86+ int shape = (int ) Math .ceil ((limit - start ) / delta );
87+ ret .add (List .of (new NumericDim (shape ))); // Add the shape as a 1D tensor.
88+ }
89+ } else if (numOfPoisitionArguments == 3 ) {
90+ // it must be `start`, `limit`, and `delta`.
91+ int startValueNumber =
92+ this .getNode ().getMethod ().isStatic ()
93+ ? this .getNode ().getIR ().getParameter (0 )
94+ : this .getNode ().getIR ().getParameter (1 );
10695
107- int deltaValueNumber =
108- this .getNode ().getMethod ().isStatic ()
109- ? this .getNode ().getIR ().getParameter (2 )
110- : this .getNode ().getIR ().getParameter (3 );
96+ PointerKey startPK =
97+ pointerAnalysis .getHeapModel ().getPointerKeyForLocal (this .getNode (), startValueNumber );
11198
112- PointerKey deltaPK =
113- pointerAnalysis .getHeapModel ().getPointerKeyForLocal (this .getNode (), deltaValueNumber );
99+ int limitValueNumber =
100+ this .getNode ().getMethod ().isStatic ()
101+ ? this .getNode ().getIR ().getParameter (1 )
102+ : this .getNode ().getIR ().getParameter (2 );
114103
115- OrdinalSet <InstanceKey > startPointsToSet = pointerAnalysis .getPointsToSet (startPK );
116- OrdinalSet <InstanceKey > limitPointsToSet = pointerAnalysis .getPointsToSet (limitPK );
117- OrdinalSet <InstanceKey > deltaPointsToSet = pointerAnalysis .getPointsToSet (deltaPK );
104+ PointerKey limitPK =
105+ pointerAnalysis .getHeapModel ().getPointerKeyForLocal (this .getNode (), limitValueNumber );
118106
119- assert !startPointsToSet .isEmpty () : "Expected a non-empty points-to set for start." ;
120- assert !limitPointsToSet .isEmpty () : "Expected a non-empty points-to set for limit." ;
121- assert !deltaPointsToSet .isEmpty () : "Expected a non-empty points-to set for delta." ;
107+ int deltaValueNumber =
108+ this .getNode ().getMethod ().isStatic ()
109+ ? this .getNode ().getIR ().getParameter (2 )
110+ : this .getNode ().getIR ().getParameter (3 );
122111
123- for ( InstanceKey startIK : startPointsToSet ) {
124- start = (( Number ) (( ConstantKey <?>) startIK ). getValue ()). doubleValue ( );
112+ PointerKey deltaPK =
113+ pointerAnalysis . getHeapModel (). getPointerKeyForLocal ( this . getNode (), deltaValueNumber );
125114
126- for (InstanceKey limitIK : limitPointsToSet ) {
127- limit = ((Number ) ((ConstantKey <?>) limitIK ).getValue ()).doubleValue ();
115+ OrdinalSet <InstanceKey > startPointsToSet = pointerAnalysis .getPointsToSet (startPK );
116+ OrdinalSet <InstanceKey > limitPointsToSet = pointerAnalysis .getPointsToSet (limitPK );
117+ OrdinalSet <InstanceKey > deltaPointsToSet = pointerAnalysis .getPointsToSet (deltaPK );
118+
119+ assert !startPointsToSet .isEmpty () : "Expected a non-empty points-to set for start." ;
120+ assert !limitPointsToSet .isEmpty () : "Expected a non-empty points-to set for limit." ;
121+ assert !deltaPointsToSet .isEmpty () : "Expected a non-empty points-to set for delta." ;
122+
123+ for (InstanceKey startIK : startPointsToSet ) {
124+ start = ((Number ) ((ConstantKey <?>) startIK ).getValue ()).doubleValue ();
125+
126+ for (InstanceKey limitIK : limitPointsToSet ) {
127+ limit = ((Number ) ((ConstantKey <?>) limitIK ).getValue ()).doubleValue ();
128+
129+ for (InstanceKey deltaIK : deltaPointsToSet ) {
130+ delta = ((Number ) ((ConstantKey <?>) deltaIK ).getValue ()).doubleValue ();
131+
132+ int shape = (int ) Math .ceil ((limit - start ) / delta );
133+ ret .add (List .of (new NumericDim (shape ))); // Add the shape as a 1D tensor.
134+ }
135+ }
136+ }
137+ } else
138+ throw new IllegalStateException (
139+ "Expected either 1 or 3 positional arguments for range(), but got: "
140+ + numOfPoisitionArguments
141+ + "." );
142+
143+ return ret ;
144+ }
145+
146+ /**
147+ * Returns the set of possible numbers of positional arguments passed to the range function at the
148+ * call.
149+ *
150+ * @param builder The {@link PropagationCallGraphBuilder} used for the analysis.
151+ * @return A set of integers representing the possible number of positional arguments.
152+ */
153+ private Set <Integer > getNumberOfPossiblePositionalArguments (PropagationCallGraphBuilder builder ) {
154+ Set <Integer > ret = HashSetFactory .make ();
155+
156+ CallString cs = (CallString ) this .getNode ().getContext ().get (CALL_STRING );
157+ CallSiteReference siteReference = cs .getCallSiteRefs ()[0 ];
128158
129- for (InstanceKey deltaIK : deltaPointsToSet ) {
130- delta = ((Number ) ((ConstantKey <?>) deltaIK ).getValue ()).doubleValue ();
159+ for (CGNode caller : builder .getCallGraph ())
160+ for (Iterator <CallSiteReference > it = caller .getIR ().iterateCallSites (); it .hasNext (); ) {
161+ CallSiteReference callSite = it .next ();
131162
132- int shape = (int ) Math .ceil ((limit - start ) / delta );
133- ret .add (List .of (new NumericDim (shape ))); // Add the shape as a 1D tensor.
163+ if (callSite .equals (siteReference )) {
164+ // caller is the node that made the call.
165+ LOGGER .finest (() -> "Caller node: " + caller .getMethod ().getSignature () + "." );
166+
167+ SSAAbstractInvokeInstruction [] calls = caller .getIR ().getCalls (callSite );
168+ LOGGER .finest (() -> "Number of calls at this site: " + calls .length + "." );
169+
170+ for (SSAAbstractInvokeInstruction callInstr : calls ) {
171+ LOGGER .finest (() -> "Call instruction: " + callInstr + "." );
172+
173+ PythonInvokeInstruction pyCallInstr = (PythonInvokeInstruction ) callInstr ;
174+ int numberOfPositionalParameters =
175+ pyCallInstr .getNumberOfPositionalParameters () - 1 ; // Exclude the function name.
176+ LOGGER .finer (
177+ () -> "Number of positional parameters: " + numberOfPositionalParameters + "." );
178+
179+ ret .add (numberOfPositionalParameters );
134180 }
135181 }
136182 }
137- } else
138- throw new IllegalStateException (
139- "Expected either 1 or 3 positional arguments for range(), but got: "
140- + numberOfParameters
141- + "." );
142183
143184 return ret ;
144185 }
@@ -147,19 +188,19 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
147188 protected EnumSet <DType > getDefaultDTypes (PropagationCallGraphBuilder builder ) {
148189 // The dtype of the resulting tensor is inferred from the inputs unless it is provided
149190 // explicitly.
150- // TODO: Handle keyword arguments.
151- int numberOfParameters =
152- this .getNode ().getMethod ().isStatic ()
153- ? this .getNode ().getIR ().getNumberOfParameters ()
154- : this .getNode ().getIR ().getNumberOfParameters () - 1 ;
155191
192+ // TODO: Handle keyword arguments.
156193 EnumSet <DType > types =
157- IntStream .range (0 , numberOfParameters )
158- .map (i -> this .getNode ().getIR ().getMethod ().isStatic () ? i : i + 1 )
159- .map (this .getNode ().getIR ()::getParameter )
160- .mapToObj (val -> getDTypes (builder , val ).stream ())
194+ getNumberOfPossiblePositionalArguments (builder ).stream ()
195+ .map (
196+ numArgs ->
197+ IntStream .range (0 , numArgs )
198+ .map (i -> this .getNode ().getIR ().getMethod ().isStatic () ? i : i + 1 )
199+ .map (this .getNode ().getIR ()::getParameter )
200+ .mapToObj (val -> getDTypes (builder , val ).stream ())
201+ .flatMap (identity ())
202+ .distinct ())
161203 .flatMap (identity ())
162- .distinct ()
163204 .collect (Collectors .toCollection (() -> EnumSet .noneOf (DType .class )));
164205
165206 // FIXME: We can't tell the difference here between varying dtypes in a single call and that of
0 commit comments