44import static com .ibm .wala .cast .python .ml .types .TensorFlowTypes .DATASET ;
55import static com .ibm .wala .cast .python .util .Util .getAllocationSiteInNode ;
66import static com .ibm .wala .cast .types .AstMethodReference .fnReference ;
7+ import static java .util .Arrays .asList ;
78
89import com .ibm .wala .cast .ipa .callgraph .AstPointerKeyFactory ;
910import com .ibm .wala .cast .ir .ssa .EachElementGetInstruction ;
5152import com .ibm .wala .util .graph .impl .SlowSparseNumberedGraph ;
5253import com .ibm .wala .util .intset .OrdinalSet ;
5354import java .io .File ;
54- import java .util .ArrayList ;
5555import java .util .Iterator ;
5656import java .util .List ;
5757import java .util .Map ;
5858import java .util .Set ;
59- import java .util .TreeMap ;
6059import java .util .logging .Logger ;
6160
6261public class PythonTensorAnalysisEngine extends PythonAnalysisEngine <TensorTypeAnalysis > {
@@ -735,8 +734,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
735734 */
736735 private Set <TensorType > getTensorType (
737736 PointsToSetVariable source , PropagationCallGraphBuilder builder ) {
738-
739737 logger .info ("Getting tensor types for source: " + source + "." );
738+
740739 Set <TensorType > ret = HashSetFactory .make ();
741740
742741 // Get the pointer key for the source.
@@ -769,10 +768,10 @@ private Set<TensorType> getTensorType(
769768 OrdinalSet <InstanceKey > objectCatalogPointsToSet =
770769 pointerAnalysis .getPointsToSet (pointerKeyForObjectCatalog );
771770
772- // We expect the object catalog to contain a list of integers. Each element in the map
771+ // We expect the object catalog to contain a list of integers. Each element in the array
773772 // corresponds to the set of possible dimensions for that index.
774- Map < Integer , Set < Dimension < Integer >>> indexToPossibleDimensions =
775- new TreeMap < Integer , Set <Dimension <Integer >>>() ;
773+ @ SuppressWarnings ( "unchecked" )
774+ Set <Dimension <Integer >>[] possibleDimensions = new Set [ objectCatalogPointsToSet . size ()] ;
776775
777776 for (InstanceKey catalogIK : objectCatalogPointsToSet ) {
778777 if (catalogIK instanceof ConstantKey ) {
@@ -847,14 +846,14 @@ private Set<TensorType> getTensorType(
847846 + "." );
848847
849848 // Add the shape dimensions.
850- assert ! indexToPossibleDimensions . containsKey ( fieldIndex )
849+ assert possibleDimensions [ fieldIndex ] == null
851850 : "Duplicate field index: "
852851 + fieldIndex
853852 + " in object catalog: "
854853 + objectCatalogPointsToSet
855854 + "." ;
856855
857- indexToPossibleDimensions . put ( fieldIndex , tensorDimensions ) ;
856+ possibleDimensions [ fieldIndex ] = tensorDimensions ;
858857 logger .fine (
859858 "Added shape dimensions: "
860859 + tensorDimensions
@@ -877,23 +876,26 @@ private Set<TensorType> getTensorType(
877876 + "." );
878877 }
879878
880- for (Integer i : indexToPossibleDimensions .keySet ()) {
881- Set <Dimension <Integer >> iDims = indexToPossibleDimensions .get (i );
882-
883- for (Dimension <Integer > iDim : iDims ) {
884- List <Dimension <Integer >> dimensionList = new ArrayList <>();
885- dimensionList .add (iDim );
879+ for (int i = 0 ; i < possibleDimensions .length ; i ++) {
880+ for (Dimension <Integer > iDim : possibleDimensions [i ]) {
881+ @ SuppressWarnings ("unchecked" )
882+ Dimension <Integer >[] dimensions = new Dimension [possibleDimensions .length ];
886883
887- for (int j = i + 1 ; j < indexToPossibleDimensions .keySet ().size (); j ++) {
888- Set <Dimension <Integer >> jDims = indexToPossibleDimensions .get (j );
884+ dimensions [i ] = iDim ;
889885
890- for (Dimension <Integer > jDim : jDims ) dimensionList .add (jDim );
886+ for (int j = 0 ; j < possibleDimensions .length ; j ++) {
887+ if (i != j ) {
888+ for (Dimension <Integer > jDim : possibleDimensions [j ]) {
889+ dimensions [j ] = jDim ;
890+ }
891+ }
891892 }
892893
893- System .out .println (dimensionList );
894+ List <Dimension <?>> dimensionList = asList (dimensions );
895+ TensorType tensorType = new TensorType ("pixel" , dimensionList );
896+ ret .add (tensorType );
894897 }
895898 }
896-
897899 } else
898900 throw new IllegalStateException (
899901 "Expected a " + PythonTypes .list + " for the shape, but got: " + reference + "." );
0 commit comments