1111import static com .ibm .wala .ipa .callgraph .propagation .cfa .CallStringContextSelector .CALL_STRING ;
1212import static java .util .Arrays .asList ;
1313import static java .util .Collections .emptyList ;
14+ import static java .util .stream .Collectors .toSet ;
1415
1516import com .ibm .wala .cast .ipa .callgraph .AstPointerKeyFactory ;
1617import com .ibm .wala .cast .python .ml .types .TensorFlowTypes ;
4344import java .util .EnumSet ;
4445import java .util .List ;
4546import java .util .Map .Entry ;
46- import java .util .Optional ;
4747import java .util .Set ;
4848import java .util .logging .Logger ;
4949
@@ -353,11 +353,11 @@ protected EnumSet<DType> getDTypesFromDTypeArgument(
353353
354354 if (typeReference .equals (TensorFlowTypes .D_TYPE )) {
355355 // we have a dtype.
356- // let's see if it's float32 .
356+ // let's see if it's a dtype .
357357 Set <CGNode > importNodes = builder .getCallGraph ().getNodes (IMPORT );
358358
359- // find the import node from this file.
360- Optional <CGNode > importNode =
359+ // find the import nodes from this file.
360+ Set <CGNode > importNodesOfInterest =
361361 importNodes .stream ()
362362 .filter (
363363 in -> {
@@ -384,38 +384,44 @@ protected EnumSet<DType> getDTypesFromDTypeArgument(
384384
385385 return method .equals (nodeCS .getMethods ()[0 ]);
386386 })
387- .findFirst ( );
387+ .collect ( toSet () );
388388
389- InstanceKey tensorFlowIK =
390- pointerAnalysis
391- .getHeapModel ()
392- .getInstanceKeyForAllocation (
393- importNode .get (), NewSiteReference .make (0 , TENSORFLOW ));
389+ if (importNodesOfInterest .isEmpty ())
390+ throw new IllegalStateException ("No import nodes found for source: " + source + "." );
394391
395- // Check dtype literals.
396392 boolean found = false ;
397393
398- for (Entry <FieldReference , DType > entry : FIELD_REFERENCE_TO_DTYPE .entrySet ()) {
399- FieldReference fieldRef = entry .getKey ();
400- DType dtype = entry .getValue ();
401- IField field = builder .getClassHierarchy ().resolveField (fieldRef );
402-
403- PointerKey pk =
404- pointerAnalysis .getHeapModel ().getPointerKeyForInstanceField (tensorFlowIK , field );
405-
406- for (InstanceKey ik : pointerAnalysis .getPointsToSet (pk ))
407- if (ik .equals (instanceKey )) {
408- ret .add (dtype );
409- LOGGER .info (
410- "Found dtype: "
411- + dtype
412- + " for source: "
413- + source
414- + " from dType: "
415- + instanceKey
416- + "." );
417- found = true ;
418- }
394+ for (CGNode importNode : importNodesOfInterest ) {
395+ LOGGER .fine ("Found import node of interest: " + importNode + "." );
396+
397+ InstanceKey tensorFlowIK =
398+ pointerAnalysis
399+ .getHeapModel ()
400+ .getInstanceKeyForAllocation (importNode , NewSiteReference .make (0 , TENSORFLOW ));
401+
402+ // Check dtype literals.
403+ for (Entry <FieldReference , DType > entry : FIELD_REFERENCE_TO_DTYPE .entrySet ()) {
404+ FieldReference fieldRef = entry .getKey ();
405+ DType dtype = entry .getValue ();
406+ IField field = builder .getClassHierarchy ().resolveField (fieldRef );
407+
408+ PointerKey pk =
409+ pointerAnalysis .getHeapModel ().getPointerKeyForInstanceField (tensorFlowIK , field );
410+
411+ for (InstanceKey ik : pointerAnalysis .getPointsToSet (pk ))
412+ if (ik .equals (instanceKey )) {
413+ ret .add (dtype );
414+ LOGGER .info (
415+ "Found dtype: "
416+ + dtype
417+ + " for source: "
418+ + source
419+ + " from dType: "
420+ + instanceKey
421+ + "." );
422+ found = true ;
423+ }
424+ }
419425 }
420426
421427 if (!found ) throw new IllegalStateException ("Unknown dtype: " + instanceKey + "." );
0 commit comments