@@ -77,56 +77,145 @@ private static Set<Integer> getPossibleOuterListLengths(
7777 return ret ;
7878 }
7979
80- private static Set <Integer > getMaximumDepthOfScalars (
81- PropagationCallGraphBuilder builder , OrdinalSet <InstanceKey > valuePointsToSet ) {
82- Set <Integer > ret = HashSetFactory .make ();
80+ private static Set <InstanceKey > containsScalars (
81+ PropagationCallGraphBuilder builder , OrdinalSet <InstanceKey > pts ) {
82+ Set <InstanceKey > ret = HashSetFactory .make ();
83+ for (InstanceKey ik : pts ) if (containsScalars (builder , ik )) ret .add (ik );
84+ return ret ;
85+ }
86+
87+ private static boolean containsScalars (PropagationCallGraphBuilder builder , InstanceKey ik ) {
8388 PointerAnalysis <InstanceKey > pointerAnalysis = builder .getPointerAnalysis ();
8489
85- for (InstanceKey valueIK : valuePointsToSet ) {
86- int maxDepth = -1 ;
90+ if (ik instanceof ConstantKey ) return true ; // Scalar value.
91+ else {
92+ AllocationSiteInNode asin = getAllocationSiteInNode (ik );
93+ TypeReference reference = asin .getConcreteType ().getReference ();
8794
88- if (valueIK instanceof ConstantKey ) maxDepth = Math .max (maxDepth , 0 ); // Scalar value.
89- else {
90- AllocationSiteInNode asin = getAllocationSiteInNode (valueIK );
91- TypeReference reference = asin .getConcreteType ().getReference ();
95+ // A nested `list`, `tuple`, or `np.ndarray`.
96+ if (reference .equals (list ) || reference .equals (tuple )) {
97+ OrdinalSet <InstanceKey > objectCatalogPointsToSet =
98+ pointerAnalysis .getPointsToSet (
99+ ((AstPointerKeyFactory ) builder .getPointerKeyFactory ())
100+ .getPointerKeyForObjectCatalog (asin ));
92101
93- // A nested `list`, `tuple`, or `np.ndarray`.
94- if (reference .equals (list ) || reference .equals (tuple )) {
95- OrdinalSet <InstanceKey > objectCatalogPointsToSet =
96- pointerAnalysis .getPointsToSet (
97- ((AstPointerKeyFactory ) builder .getPointerKeyFactory ())
98- .getPointerKeyForObjectCatalog (asin ));
102+ for (InstanceKey catalogIK : objectCatalogPointsToSet ) {
103+ ConstantKey <?> constantKey = (ConstantKey <?>) catalogIK ;
104+ Object constantKeyValue = constantKey .getValue ();
99105
100- for (InstanceKey catalogIK : objectCatalogPointsToSet ) {
101- ConstantKey <?> constantKey = (ConstantKey <?>) catalogIK ;
102- Object constantKeyValue = constantKey .getValue ();
106+ Integer fieldIndex = (Integer ) constantKeyValue ;
103107
104- Integer fieldIndex = (Integer ) constantKeyValue ;
108+ FieldReference subscript =
109+ FieldReference .findOrCreate (Root , findOrCreateAsciiAtom (fieldIndex .toString ()), Root );
105110
106- FieldReference subscript =
107- FieldReference .findOrCreate (
108- Root , findOrCreateAsciiAtom (fieldIndex .toString ()), Root );
111+ IField f = builder .getClassHierarchy ().resolveField (subscript );
109112
110- IField f = builder .getClassHierarchy (). resolveField ( subscript );
113+ PointerKey pointerKeyForInstanceField = builder .getPointerKeyForInstanceField ( asin , f );
111114
112- PointerKey pointerKeyForInstanceField = builder .getPointerKeyForInstanceField (asin , f );
115+ OrdinalSet <InstanceKey > instanceFieldPointsToSet =
116+ pointerAnalysis .getPointsToSet (pointerKeyForInstanceField );
113117
114- OrdinalSet <InstanceKey > instanceFieldPointsToSet =
115- pointerAnalysis .getPointsToSet (pointerKeyForInstanceField );
118+ for (InstanceKey fieldIK : instanceFieldPointsToSet )
119+ if (containsScalars (builder , fieldIK )) return true ;
120+ }
121+ } else
122+ throw new IllegalArgumentException (
123+ "Expected a list or tuple, but found: " + reference + "." );
124+ }
116125
117- Set < Integer > possibleDepthsOfField =
118- getMaximumDepthOfScalars ( builder , instanceFieldPointsToSet );
126+ return false ;
127+ }
119128
120- for (int depthOfField : possibleDepthsOfField )
121- maxDepth = Math .max (maxDepth , 1 + depthOfField );
122- }
129+ private static int getMaximumDepthOfEmptyList (
130+ PropagationCallGraphBuilder builder , InstanceKey valueIK ) {
131+ PointerAnalysis <InstanceKey > pointerAnalysis = builder .getPointerAnalysis ();
132+ int maxDepth = 0 ;
133+
134+ AllocationSiteInNode asin = getAllocationSiteInNode (valueIK );
135+ TypeReference reference = asin .getConcreteType ().getReference ();
136+
137+ // A nested `list` or `tuple`.
138+ if (reference .equals (list ) || reference .equals (tuple )) {
139+ OrdinalSet <InstanceKey > objectCatalogPointsToSet =
140+ pointerAnalysis .getPointsToSet (
141+ ((AstPointerKeyFactory ) builder .getPointerKeyFactory ())
142+ .getPointerKeyForObjectCatalog (asin ));
143+
144+ for (InstanceKey catalogIK : objectCatalogPointsToSet ) {
145+ ConstantKey <?> constantKey = (ConstantKey <?>) catalogIK ;
146+ Object constantKeyValue = constantKey .getValue ();
147+
148+ Integer fieldIndex = (Integer ) constantKeyValue ;
149+
150+ FieldReference subscript =
151+ FieldReference .findOrCreate (Root , findOrCreateAsciiAtom (fieldIndex .toString ()), Root );
152+
153+ IField f = builder .getClassHierarchy ().resolveField (subscript );
154+
155+ PointerKey pointerKeyForInstanceField = builder .getPointerKeyForInstanceField (asin , f );
156+
157+ OrdinalSet <InstanceKey > instanceFieldPointsToSet =
158+ pointerAnalysis .getPointsToSet (pointerKeyForInstanceField );
159+
160+ if (instanceFieldPointsToSet .isEmpty ())
161+ // An empty list at this field.
162+ maxDepth = Math .max (maxDepth , 0 );
163+
164+ for (InstanceKey fieldIK : instanceFieldPointsToSet ) {
165+ int depthOfField = getMaximumDepthOfEmptyList (builder , fieldIK );
166+ maxDepth = Math .max (maxDepth , 1 + depthOfField );
123167 }
124168 }
169+ } else
170+ throw new IllegalArgumentException ("Expected a list or tuple, but found: " + reference + "." );
125171
126- ret .add (maxDepth );
172+ return maxDepth ;
173+ }
174+
175+ private static int getMaximumDepthOfScalars (
176+ PropagationCallGraphBuilder builder , InstanceKey valueIK ) {
177+ PointerAnalysis <InstanceKey > pointerAnalysis = builder .getPointerAnalysis ();
178+ int maxDepth = 0 ;
179+
180+ if (valueIK instanceof ConstantKey ) maxDepth = Math .max (maxDepth , 0 ); // Scalar value.
181+ else {
182+ AllocationSiteInNode asin = getAllocationSiteInNode (valueIK );
183+ TypeReference reference = asin .getConcreteType ().getReference ();
184+
185+ // A nested `list`, `tuple`, or `np.ndarray`.
186+ if (reference .equals (list ) || reference .equals (tuple )) {
187+ OrdinalSet <InstanceKey > objectCatalogPointsToSet =
188+ pointerAnalysis .getPointsToSet (
189+ ((AstPointerKeyFactory ) builder .getPointerKeyFactory ())
190+ .getPointerKeyForObjectCatalog (asin ));
191+
192+ for (InstanceKey catalogIK : objectCatalogPointsToSet ) {
193+ ConstantKey <?> constantKey = (ConstantKey <?>) catalogIK ;
194+ Object constantKeyValue = constantKey .getValue ();
195+
196+ Integer fieldIndex = (Integer ) constantKeyValue ;
197+
198+ FieldReference subscript =
199+ FieldReference .findOrCreate (Root , findOrCreateAsciiAtom (fieldIndex .toString ()), Root );
200+
201+ IField f = builder .getClassHierarchy ().resolveField (subscript );
202+
203+ PointerKey pointerKeyForInstanceField = builder .getPointerKeyForInstanceField (asin , f );
204+
205+ OrdinalSet <InstanceKey > instanceFieldPointsToSet =
206+ pointerAnalysis .getPointsToSet (pointerKeyForInstanceField );
207+
208+ for (InstanceKey fieldIK : instanceFieldPointsToSet ) {
209+ int depthOfField = getMaximumDepthOfScalars (builder , fieldIK );
210+ maxDepth = Math .max (maxDepth , 1 + depthOfField );
211+ }
212+ }
213+ } else
214+ throw new IllegalArgumentException (
215+ "Expected a list or tuple, but found: " + reference + "." );
127216 }
128217
129- return ret ;
218+ return maxDepth ;
130219 }
131220
132221 @ Override
@@ -147,11 +236,16 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
147236
148237 Set <List <Dimension <?>>> ret = HashSetFactory .make ();
149238
150- Set <Integer > maxDepthOfScalars = getMaximumDepthOfScalars (builder , valuePointsToSet );
151- LOGGER .fine ("Maximum depth of scalars in `pylist`: " + maxDepthOfScalars );
239+ Set <InstanceKey > scalars = containsScalars (builder , valuePointsToSet );
240+
241+ for (InstanceKey valueIK : valuePointsToSet ) {
242+ int maxDepth = getMaxDepth (builder , scalars , valueIK );
243+ LOGGER .fine ("Maximum depth of `pylist`: " + maxDepth );
244+
245+ // Step 2: Determine Ragged Rank (R).
246+ int K = maxDepth ;
247+ LOGGER .fine ("Tensor rank: " + K );
152248
153- // Step 2: Determine Ragged Rank (R).
154- for (int K : maxDepthOfScalars ) {
155249 Optional <Integer > raggedRank = this .getRaggedRankArgumentValue (builder );
156250 int R = raggedRank .orElse (K - 1 );
157251 LOGGER .fine ("Ragged rank: " + R );
@@ -183,6 +277,18 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
183277 return ret ;
184278 }
185279
280+ private static int getMaxDepth (
281+ PropagationCallGraphBuilder builder , Set <InstanceKey > scalars , InstanceKey valueIK ) {
282+ int maxDepth ;
283+
284+ if (scalars .contains (valueIK )) maxDepth = getMaximumDepthOfScalars (builder , valueIK );
285+ else
286+ // If `pylist` contains no scalar values, then K is one greater than the maximum depth of
287+ // empty lists in `pylist`.
288+ maxDepth = 1 + getMaximumDepthOfEmptyList (builder , valueIK );
289+ return maxDepth ;
290+ }
291+
186292 private Optional <Integer > getRaggedRankArgumentValue (PropagationCallGraphBuilder builder ) {
187293 // TODO Auto-generated method stub
188294 return Optional .empty ();
0 commit comments