11package com .ibm .wala .cast .python .ml .client ;
22
33import static com .ibm .wala .cast .python .ml .types .TensorFlowTypes .TYPE_REFERENCE_TO_SIGNATURE ;
4+ import static com .ibm .wala .cast .python .types .PythonTypes .Root ;
5+ import static com .ibm .wala .cast .python .types .PythonTypes .list ;
6+ import static com .ibm .wala .cast .python .types .PythonTypes .tuple ;
7+ import static com .ibm .wala .cast .python .util .Util .getAllocationSiteInNode ;
48import static com .ibm .wala .cast .python .util .Util .getFunction ;
9+ import static com .ibm .wala .core .util .strings .Atom .findOrCreateAsciiAtom ;
10+ import static java .util .logging .Logger .getLogger ;
511
12+ import com .ibm .wala .cast .ipa .callgraph .AstPointerKeyFactory ;
613import com .ibm .wala .cast .python .ml .types .TensorType .Dimension ;
14+ import com .ibm .wala .cast .python .ml .types .TensorType .NumericDim ;
15+ import com .ibm .wala .classLoader .IField ;
16+ import com .ibm .wala .ipa .callgraph .propagation .AllocationSiteInNode ;
17+ import com .ibm .wala .ipa .callgraph .propagation .ConstantKey ;
18+ import com .ibm .wala .ipa .callgraph .propagation .InstanceKey ;
19+ import com .ibm .wala .ipa .callgraph .propagation .PointerAnalysis ;
20+ import com .ibm .wala .ipa .callgraph .propagation .PointerKey ;
721import com .ibm .wala .ipa .callgraph .propagation .PointsToSetVariable ;
822import com .ibm .wala .ipa .callgraph .propagation .PropagationCallGraphBuilder ;
23+ import com .ibm .wala .types .FieldReference ;
924import com .ibm .wala .types .TypeReference ;
25+ import com .ibm .wala .util .collections .HashSetFactory ;
26+ import com .ibm .wala .util .intset .OrdinalSet ;
27+ import java .util .ArrayList ;
1028import java .util .List ;
29+ import java .util .Optional ;
1130import java .util .Set ;
31+ import java .util .logging .Logger ;
1232
1333/**
1434 * A representation of the `tf.ragged.constant()` API in TensorFlow.
1939 */
2040public class RaggedConstant extends ZerosLike {
2141
42+ private static final Logger LOGGER = getLogger (RaggedConstant .class .getName ());
43+
2244 protected enum Parameters {
2345 PYLIST ,
2446 DTYPE ,
@@ -38,9 +60,138 @@ protected String getSignature() {
3860 return TYPE_REFERENCE_TO_SIGNATURE .get (function );
3961 }
4062
63+ private static Set <Integer > getPossibleListLengths (
64+ PropagationCallGraphBuilder builder , OrdinalSet <InstanceKey > valuePointsToSet ) {
65+ Set <Integer > ret = HashSetFactory .make ();
66+ PointerAnalysis <InstanceKey > pointerAnalysis = builder .getPointerAnalysis ();
67+
68+ for (InstanceKey valueIK : valuePointsToSet ) {
69+ AllocationSiteInNode asin = getAllocationSiteInNode (valueIK );
70+ TypeReference reference = asin .getConcreteType ().getReference ();
71+
72+ // A `list` or `tuple`.
73+ if (reference .equals (list ) || reference .equals (tuple )) {
74+ OrdinalSet <InstanceKey > objectCatalogPointsToSet =
75+ pointerAnalysis .getPointsToSet (
76+ ((AstPointerKeyFactory ) builder .getPointerKeyFactory ())
77+ .getPointerKeyForObjectCatalog (asin ));
78+
79+ ret .add (objectCatalogPointsToSet .size ());
80+ } else
81+ throw new IllegalArgumentException (
82+ "Expected a list or tuple, but found: " + reference + "." );
83+ }
84+
85+ return ret ;
86+ }
87+
88+ private static Set <Integer > getMaximumDepthOfScalars (
89+ PropagationCallGraphBuilder builder , OrdinalSet <InstanceKey > valuePointsToSet ) {
90+ Set <Integer > ret = HashSetFactory .make ();
91+ PointerAnalysis <InstanceKey > pointerAnalysis = builder .getPointerAnalysis ();
92+
93+ for (InstanceKey valueIK : valuePointsToSet ) {
94+ int maxDepth = -1 ;
95+
96+ if (valueIK instanceof ConstantKey ) maxDepth = Math .max (maxDepth , 0 ); // Scalar value.
97+ else {
98+ AllocationSiteInNode asin = getAllocationSiteInNode (valueIK );
99+ TypeReference reference = asin .getConcreteType ().getReference ();
100+
101+ // A nested `list`, `tuple`, or `np.ndarray`.
102+ if (reference .equals (list ) || reference .equals (tuple )) {
103+ OrdinalSet <InstanceKey > objectCatalogPointsToSet =
104+ pointerAnalysis .getPointsToSet (
105+ ((AstPointerKeyFactory ) builder .getPointerKeyFactory ())
106+ .getPointerKeyForObjectCatalog (asin ));
107+
108+ for (InstanceKey catalogIK : objectCatalogPointsToSet ) {
109+ ConstantKey <?> constantKey = (ConstantKey <?>) catalogIK ;
110+ Object constantKeyValue = constantKey .getValue ();
111+
112+ Integer fieldIndex = (Integer ) constantKeyValue ;
113+
114+ FieldReference subscript =
115+ FieldReference .findOrCreate (
116+ Root , findOrCreateAsciiAtom (fieldIndex .toString ()), Root );
117+
118+ IField f = builder .getClassHierarchy ().resolveField (subscript );
119+
120+ PointerKey pointerKeyForInstanceField = builder .getPointerKeyForInstanceField (asin , f );
121+
122+ OrdinalSet <InstanceKey > instanceFieldPointsToSet =
123+ pointerAnalysis .getPointsToSet (pointerKeyForInstanceField );
124+
125+ Set <Integer > possibleDepthsOfField =
126+ getMaximumDepthOfScalars (builder , instanceFieldPointsToSet );
127+
128+ for (int depthOfField : possibleDepthsOfField )
129+ maxDepth = Math .max (maxDepth , 1 + depthOfField );
130+ }
131+ }
132+ }
133+
134+ ret .add (maxDepth );
135+ }
136+
137+ return ret ;
138+ }
139+
41140 @ Override
42- protected Set <List <Dimension <?>>> getDefaultShapes (PropagationCallGraphBuilder builder ) {
141+ protected Set <List <Dimension <?>>> getShapesOfValue (
142+ PropagationCallGraphBuilder builder , OrdinalSet <InstanceKey > valuePointsToSet ) {
143+ // Returns a potentially ragged tensor with rank K and the specified `ragged_rank`, containing
144+ // the values from `pylist`.
145+
146+ // All scalar values in `pylist` must have the same nesting depth K, and the returned
147+ // `RaggedTensor` will have rank K. If `pylist` contains no scalar values, then K is one greater
148+ // than the maximum depth of empty lists in `pylist`.
149+
150+ // Step 1: Calculate K, the maximum depth of scalar values in `pylist`.
151+
152+ if (valuePointsToSet == null || valuePointsToSet .isEmpty ())
153+ throw new IllegalArgumentException (
154+ "Empty points-to set for value in source: " + this .getSource () + "." );
155+
156+ Set <List <Dimension <?>>> ret = HashSetFactory .make ();
157+
158+ Set <Integer > maxDepthOfScalars = getMaximumDepthOfScalars (builder , valuePointsToSet );
159+ LOGGER .fine ("Maximum depth of scalars in pylist: " + maxDepthOfScalars );
160+
161+ // Step 2: Determine Ragged Rank (R).
162+ for (int K : maxDepthOfScalars ) {
163+ Optional <Integer > raggedRank = this .getRaggedRankArgumentValue (builder );
164+ int R = raggedRank .orElse (K - 1 );
165+ LOGGER .fine ("Ragged rank: " + R );
166+
167+ // Step 3: Construct shape with rank K and ragged rank R.
168+
169+ // Get the length of the outer list.
170+ Set <Integer > possibleOuterListLengths = getPossibleListLengths (builder , valuePointsToSet );
171+
172+ for (int outerListLength : possibleOuterListLengths ) {
173+ List <Dimension <?>> shape = new ArrayList <>();
174+ shape .add (new NumericDim (outerListLength ));
175+
176+ // The first R dimensions are ragged.
177+ for (int i = 0 ; i < R ; i ++) shape .add (null ); // Unknown size for ragged dimensions.
178+
179+ /*
180+ // The remaining K - R dimensions are dense.
181+ for (int i = R; i < K; i++) {
182+ shape.add(new NumericDim(-1)); // Unknown size for dense dimensions.
183+ }
184+ */
185+
186+ ret .add (shape );
187+ }
188+ }
189+
190+ return ret ;
191+ }
192+
193+ private Optional <Integer > getRaggedRankArgumentValue (PropagationCallGraphBuilder builder ) {
43194 // TODO Auto-generated method stub
44- return super . getDefaultShapes ( builder );
195+ return Optional . empty ( );
45196 }
46197}
0 commit comments