Skip to content

Commit cc07365

Browse files
committed
Progress.
1 parent 8e10b2c commit cc07365

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DATASET;
55
import static com.ibm.wala.cast.python.util.Util.getAllocationSiteInNode;
66
import static com.ibm.wala.cast.types.AstMethodReference.fnReference;
7+
import static java.util.Arrays.asList;
78

89
import com.ibm.wala.cast.ipa.callgraph.AstPointerKeyFactory;
910
import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction;
@@ -51,12 +52,10 @@
5152
import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph;
5253
import com.ibm.wala.util.intset.OrdinalSet;
5354
import java.io.File;
54-
import java.util.ArrayList;
5555
import java.util.Iterator;
5656
import java.util.List;
5757
import java.util.Map;
5858
import java.util.Set;
59-
import java.util.TreeMap;
6059
import java.util.logging.Logger;
6160

6261
public class PythonTensorAnalysisEngine extends PythonAnalysisEngine<TensorTypeAnalysis> {
@@ -735,8 +734,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
735734
*/
736735
private Set<TensorType> getTensorType(
737736
PointsToSetVariable source, PropagationCallGraphBuilder builder) {
738-
739737
logger.info("Getting tensor types for source: " + source + ".");
738+
740739
Set<TensorType> ret = HashSetFactory.make();
741740

742741
// Get the pointer key for the source.
@@ -769,10 +768,10 @@ private Set<TensorType> getTensorType(
769768
OrdinalSet<InstanceKey> objectCatalogPointsToSet =
770769
pointerAnalysis.getPointsToSet(pointerKeyForObjectCatalog);
771770

772-
// We expect the object catalog to contain a list of integers. Each element in the map
771+
// We expect the object catalog to contain a list of integers. Each element in the array
773772
// corresponds to the set of possible dimensions for that index.
774-
Map<Integer, Set<Dimension<Integer>>> indexToPossibleDimensions =
775-
new TreeMap<Integer, Set<Dimension<Integer>>>();
773+
@SuppressWarnings("unchecked")
774+
Set<Dimension<Integer>>[] possibleDimensions = new Set[objectCatalogPointsToSet.size()];
776775

777776
for (InstanceKey catalogIK : objectCatalogPointsToSet) {
778777
if (catalogIK instanceof ConstantKey) {
@@ -847,14 +846,14 @@ private Set<TensorType> getTensorType(
847846
+ ".");
848847

849848
// Add the shape dimensions.
850-
assert !indexToPossibleDimensions.containsKey(fieldIndex)
849+
assert possibleDimensions[fieldIndex] == null
851850
: "Duplicate field index: "
852851
+ fieldIndex
853852
+ " in object catalog: "
854853
+ objectCatalogPointsToSet
855854
+ ".";
856855

857-
indexToPossibleDimensions.put(fieldIndex, tensorDimensions);
856+
possibleDimensions[fieldIndex] = tensorDimensions;
858857
logger.fine(
859858
"Added shape dimensions: "
860859
+ tensorDimensions
@@ -877,23 +876,26 @@ private Set<TensorType> getTensorType(
877876
+ ".");
878877
}
879878

880-
for (Integer i : indexToPossibleDimensions.keySet()) {
881-
Set<Dimension<Integer>> iDims = indexToPossibleDimensions.get(i);
882-
883-
for (Dimension<Integer> iDim : iDims) {
884-
List<Dimension<Integer>> dimensionList = new ArrayList<>();
885-
dimensionList.add(iDim);
879+
for (int i = 0; i < possibleDimensions.length; i++) {
880+
for (Dimension<Integer> iDim : possibleDimensions[i]) {
881+
@SuppressWarnings("unchecked")
882+
Dimension<Integer>[] dimensions = new Dimension[possibleDimensions.length];
886883

887-
for (int j = i + 1; j < indexToPossibleDimensions.keySet().size(); j++) {
888-
Set<Dimension<Integer>> jDims = indexToPossibleDimensions.get(j);
884+
dimensions[i] = iDim;
889885

890-
for (Dimension<Integer> jDim : jDims) dimensionList.add(jDim);
886+
for (int j = 0; j < possibleDimensions.length; j++) {
887+
if (i != j) {
888+
for (Dimension<Integer> jDim : possibleDimensions[j]) {
889+
dimensions[j] = jDim;
890+
}
891+
}
891892
}
892893

893-
System.out.println(dimensionList);
894+
List<Dimension<?>> dimensionList = asList(dimensions);
895+
TensorType tensorType = new TensorType("pixel", dimensionList);
896+
ret.add(tensorType);
894897
}
895898
}
896-
897899
} else
898900
throw new IllegalStateException(
899901
"Expected a " + PythonTypes.list + " for the shape, but got: " + reference + ".");

0 commit comments

Comments
 (0)