Skip to content

Commit 278abab

Browse files
committed
Progress.
1 parent 6423d8d commit 278abab

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,11 +447,11 @@ public String toString() {
447447
};
448448
}
449449

450-
private final Map<PointsToSetVariable, TensorType> init;
450+
private final Map<PointsToSetVariable, Set<TensorType>> init;
451451

452452
public TensorTypeAnalysis(
453453
Graph<PointsToSetVariable> G,
454-
Map<PointsToSetVariable, TensorType> init,
454+
Map<PointsToSetVariable, Set<TensorType>> init,
455455
Map<PointsToSetVariable, TensorType> reshapeTypes,
456456
Map<PointsToSetVariable, TensorType> set_shapes,
457457
Set<PointsToSetVariable> conv2ds,
@@ -480,7 +480,8 @@ protected TensorVariable[] makeStmtRHS(int size) {
480480
protected void initializeVariables() {
481481
super.initializeVariables();
482482
for (PointsToSetVariable src : init.keySet()) {
483-
getOut(src).state.add(init.get(src));
483+
Set<TensorType> tensorTypes = init.get(src);
484+
getOut(src).state.addAll(tensorTypes);
484485
}
485486
}
486487

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -715,9 +715,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
715715

716716
Set<PointsToSetVariable> conv3ds = getKeysDefinedByCall(conv3d, builder);
717717

718-
TensorTypeAnalysis tt = null;
719-
// new TensorTypeAnalysis(dataflow, init, shapeOps, setCalls, conv2ds, conv3ds,
720-
// errorLog);
718+
TensorTypeAnalysis tt =
719+
new TensorTypeAnalysis(dataflow, init, shapeOps, setCalls, conv2ds, conv3ds, errorLog);
721720

722721
tt.solve(new NullProgressMonitor());
723722

0 commit comments

Comments
 (0)