Skip to content

Commit addd7d0

Browse files
committed
Merge branch 'master' into 267-initial-tensor-dimensions-arent-always-accurate
2 parents 1bbccc8 + 0c3ed3d commit addd7d0

File tree

2 files changed

+37
-28
lines changed

2 files changed

+37
-28
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph;
6464
import com.ibm.wala.util.intset.OrdinalSet;
6565
import java.io.File;
66+
import java.io.IOException;
6667
import java.util.EnumSet;
6768
import java.util.Iterator;
6869
import java.util.List;
@@ -653,16 +654,19 @@ private Map<PointsToSetVariable, TensorType> getShapeSourceCalls(
653654
op,
654655
builder,
655656
(CGNode src, SSAAbstractInvokeInstruction call) -> {
656-
if (call.getNumberOfUses() > param) {
657-
targets.put(
658-
builder
659-
.getPropagationSystem()
660-
.findOrCreatePointsToSet(
661-
builder
662-
.getPointerAnalysis()
663-
.getHeapModel()
664-
.getPointerKeyForLocal(src, call.getDef())),
665-
TensorType.shapeArg(src, call.getUse(param)));
657+
try {
658+
if (call.getNumberOfUses() > param)
659+
targets.put(
660+
builder
661+
.getPropagationSystem()
662+
.findOrCreatePointsToSet(
663+
builder
664+
.getPointerAnalysis()
665+
.getHeapModel()
666+
.getPointerKeyForLocal(src, call.getDef())),
667+
TensorType.shapeArg(src, call.getUse(param)));
668+
} catch (IOException e) {
669+
throw new RuntimeException("Error while processing shape source call: " + call, e);
666670
}
667671
});
668672
return targets;
@@ -701,9 +705,13 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
701705

702706
for (PointsToSetVariable v : sources) init.put(v, getTensorType(v, builder));
703707

704-
Map<PointsToSetVariable, TensorType> placeholders =
705-
handleShapeSourceOp(builder, dataflow, placeholder, 2);
706-
logger.fine(() -> "Placeholders: " + placeholders);
708+
Map<PointsToSetVariable, TensorType> placeholders = null;
709+
try {
710+
placeholders = handleShapeSourceOp(builder, dataflow, placeholder, 2);
711+
} catch (IOException e) {
712+
throw new RuntimeException("Error while processing placeholder calls.", e);
713+
}
714+
logger.fine("Placeholders: " + placeholders);
707715

708716
for (Map.Entry<PointsToSetVariable, TensorType> e : placeholders.entrySet())
709717
init.put(e.getKey(), Set.of(e.getValue()));
@@ -728,7 +736,12 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
728736
}
729737

730738
Map<PointsToSetVariable, TensorType> shapeOps = HashMapFactory.make();
731-
shapeOps.putAll(handleShapeSourceOp(builder, dataflow, reshape, 2));
739+
740+
try {
741+
shapeOps.putAll(handleShapeSourceOp(builder, dataflow, reshape, 2));
742+
} catch (IOException e) {
743+
throw new RuntimeException("Error while processing reshape calls.", e);
744+
}
732745

733746
Set<PointsToSetVariable> conv2ds = getKeysDefinedByCall(conv2d, builder);
734747

@@ -983,7 +996,8 @@ private Map<PointsToSetVariable, TensorType> handleShapeSourceOp(
983996
PropagationCallGraphBuilder builder,
984997
Graph<PointsToSetVariable> dataflow,
985998
MethodReference op,
986-
int shapeSrcOperand) {
999+
int shapeSrcOperand)
1000+
throws IOException {
9871001
Map<PointsToSetVariable, TensorType> reshapeTypes =
9881002
getShapeSourceCalls(op, builder, shapeSrcOperand);
9891003
for (PointsToSetVariable to : reshapeTypes.keySet()) {

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorType.java

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ public static TensorType mnistInput() {
329329
return new TensorType("pixel", Arrays.asList(batch, vec));
330330
}
331331

332-
public static TensorType shapeArg(CGNode node, int literalVn) {
332+
public static TensorType shapeArg(CGNode node, int literalVn) throws IOException {
333333
logger.fine(() -> node.getIR().toString());
334334
ArrayList<Dimension<?>> r = new ArrayList<>();
335335
DefUse du = node.getDU();
@@ -360,18 +360,13 @@ public static TensorType shapeArg(CGNode node, int literalVn) {
360360
.debugInfo()
361361
.getInstructionPosition(du.getDef(val).iIndex());
362362
System.err.println(p);
363-
try {
364-
SourceBuffer b = new SourceBuffer(p);
365-
String expr = b.toString();
366-
System.err.println(expr);
367-
Integer ival = PythonInterpreter.interpretAsInt(expr);
368-
if (ival != null) {
369-
r.add(new NumericDim(ival));
370-
continue;
371-
}
372-
} catch (IOException e) {
373-
// TODO Auto-generated catch block
374-
e.printStackTrace();
363+
SourceBuffer b = new SourceBuffer(p);
364+
String expr = b.toString();
365+
System.err.println(expr);
366+
Integer ival = PythonInterpreter.interpretAsInt(expr);
367+
if (ival != null) {
368+
r.add(new NumericDim(ival));
369+
continue;
375370
}
376371
}
377372
r.add(new SymbolicDim("?"));

0 commit comments

Comments
 (0)