11package com .ibm .wala .cast .python .ml .client ;
22
3+ import static com .ibm .wala .cast .python .ml .types .TensorFlowTypes .DType .FLOAT32 ;
34import static com .ibm .wala .cast .python .types .PythonTypes .Root ;
45import static com .ibm .wala .cast .python .types .PythonTypes .list ;
56import static com .ibm .wala .cast .python .types .PythonTypes .tuple ;
67import static com .ibm .wala .cast .python .util .Util .getAllocationSiteInNode ;
78import static com .ibm .wala .core .util .strings .Atom .findOrCreateAsciiAtom ;
89import static java .util .logging .Logger .getLogger ;
10+ import static java .util .stream .Collectors .toSet ;
911
1012import com .ibm .wala .cast .ipa .callgraph .AstPointerKeyFactory ;
1113import com .ibm .wala .cast .python .ml .types .TensorFlowTypes .DType ;
2931import java .util .Optional ;
3032import java .util .Set ;
3133import java .util .logging .Logger ;
34+ import java .util .stream .StreamSupport ;
3235
3336/**
3437 * A representation of the `tf.ragged.constant()` API in TensorFlow.
@@ -79,13 +82,6 @@ private static Set<Integer> getPossibleOuterListLengths(
7982 return ret ;
8083 }
8184
82- private static Set <InstanceKey > containsScalars (
83- PropagationCallGraphBuilder builder , OrdinalSet <InstanceKey > pts ) {
84- Set <InstanceKey > ret = HashSetFactory .make ();
85- for (InstanceKey ik : pts ) if (containsScalars (builder , ik )) ret .add (ik );
86- return ret ;
87- }
88-
8985 private static boolean containsScalars (PropagationCallGraphBuilder builder , InstanceKey ik ) {
9086 PointerAnalysis <InstanceKey > pointerAnalysis = builder .getPointerAnalysis ();
9187
@@ -231,17 +227,19 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
231227 // than the maximum depth of empty lists in `pylist`.
232228
233229 // Step 1: Calculate K, the maximum depth of scalar values in `pylist`.
234-
235230 if (valuePointsToSet == null || valuePointsToSet .isEmpty ())
236231 throw new IllegalArgumentException (
237232 "Empty points-to set for value in source: " + this .getSource () + "." );
238233
239234 Set <List <Dimension <?>>> ret = HashSetFactory .make ();
240235
241- Set <InstanceKey > scalars = containsScalars (builder , valuePointsToSet );
236+ Set <InstanceKey > valuesWithScalars =
237+ StreamSupport .stream (valuePointsToSet .spliterator (), false )
238+ .filter (ik -> containsScalars (builder , ik ))
239+ .collect (toSet ());
242240
243241 for (InstanceKey valueIK : valuePointsToSet ) {
244- int maxDepth = getMaxDepth (builder , scalars , valueIK );
242+ int maxDepth = getMaximumDepthOfInstance (builder , valuesWithScalars , valueIK );
245243 LOGGER .fine ("Maximum depth of `pylist`: " + maxDepth );
246244
247245 // Step 2: Determine Ragged Rank (R).
@@ -279,16 +277,15 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
279277 return ret ;
280278 }
281279
282- private static int getMaxDepth (
283- PropagationCallGraphBuilder builder , Set < InstanceKey > scalars , InstanceKey valueIK ) {
284- int maxDepth ;
285-
286- if (scalars .contains (valueIK )) maxDepth = getMaximumDepthOfScalars (builder , valueIK );
280+ private static int getMaximumDepthOfInstance (
281+ PropagationCallGraphBuilder builder ,
282+ Set < InstanceKey > instancesWithScalars ,
283+ InstanceKey instance ) {
284+ if (instancesWithScalars .contains (instance )) return getMaximumDepthOfScalars (builder , instance );
287285 else
288286 // If `pylist` contains no scalar values, then K is one greater than the maximum depth of
289287 // empty lists in `pylist`.
290- maxDepth = 1 + getMaximumDepthOfEmptyList (builder , valueIK );
291- return maxDepth ;
288+ return 1 + getMaximumDepthOfEmptyList (builder , instance );
292289 }
293290
294291 private Optional <Integer > getRaggedRankArgumentValue (PropagationCallGraphBuilder builder ) {
@@ -314,9 +311,16 @@ protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
314311 pointerAnalysis .getHeapModel ().getPointerKeyForLocal (this .getNode (), valueNumber );
315312 OrdinalSet <InstanceKey > valuePointsToSet = pointerAnalysis .getPointsToSet (valuePK );
316313
317- if (containsScalars (builder , valuePointsToSet ).isEmpty ()) {
314+ if (valuePointsToSet == null || valuePointsToSet .isEmpty ())
315+ throw new IllegalArgumentException (
316+ "Empty points-to set for value in source: " + this .getSource () + "." );
317+
318+ if (StreamSupport .stream (valuePointsToSet .spliterator (), false )
319+ .filter (ik -> containsScalars (builder , ik ))
320+ .count ()
321+ == 0 ) {
318322 LOGGER .fine ("No scalars found in `pylist`; defaulting to `tf.float32` dtype." );
319- return EnumSet .of (DType . FLOAT32 );
323+ return EnumSet .of (FLOAT32 );
320324 }
321325
322326 return super .getDefaultDTypes (builder );
0 commit comments