Skip to content

Commit 573f06a

Browse files
committed
Pull up getIntValueFromInstanceKey() to Ones super class.
1 parent e6f49ed commit 573f06a

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/OneHot.java

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
1313
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
1414
import com.ibm.wala.ipa.callgraph.CGNode;
15-
import com.ibm.wala.ipa.callgraph.propagation.ConstantKey;
1615
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
1716
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
1817
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
@@ -23,7 +22,6 @@
2322
import java.util.ArrayList;
2423
import java.util.EnumSet;
2524
import java.util.List;
26-
import java.util.Optional;
2725
import java.util.Set;
2826

2927
public class OneHot extends Ones {
@@ -204,19 +202,6 @@ private Set<Integer> getPossibleAxes(PropagationCallGraphBuilder builder) {
204202
return ret;
205203
}
206204

207-
private static Optional<Integer> getIntValueFromInstanceKey(InstanceKey instanceKey) {
208-
if (instanceKey instanceof ConstantKey) {
209-
ConstantKey<?> constantKey = (ConstantKey<?>) instanceKey;
210-
Object value = constantKey.getValue();
211-
212-
if (value == null) return Optional.empty();
213-
return Optional.of(((Long) value).intValue());
214-
}
215-
216-
throw new IllegalArgumentException(
217-
"Cannot get integer value from non-constant InstanceKey: " + instanceKey + ".");
218-
}
219-
220205
private int getDepthArgumentValueNumber() {
221206
// TODO: Handle keyword arguments.
222207
return this.getArgumentValueNumber(this.getDepthParameterPosition());

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Ones.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType;
77
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
88
import com.ibm.wala.ipa.callgraph.CGNode;
9+
import com.ibm.wala.ipa.callgraph.propagation.ConstantKey;
10+
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
911
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
1012
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
1113
import java.util.EnumSet;
1214
import java.util.List;
15+
import java.util.Optional;
1316
import java.util.Set;
1417

1518
/**
@@ -60,4 +63,17 @@ protected int getShapeParameterPosition() {
6063
protected int getDTypeParameterPosition() {
6164
return DTYPE_PARAMETER_POSITION;
6265
}
66+
67+
protected static Optional<Integer> getIntValueFromInstanceKey(InstanceKey instanceKey) {
68+
if (instanceKey instanceof ConstantKey) {
69+
ConstantKey<?> constantKey = (ConstantKey<?>) instanceKey;
70+
Object value = constantKey.getValue();
71+
72+
if (value == null) return Optional.empty();
73+
return Optional.of(((Long) value).intValue());
74+
}
75+
76+
throw new IllegalArgumentException(
77+
"Cannot get integer value from non-constant InstanceKey: " + instanceKey + ".");
78+
}
6379
}

0 commit comments

Comments
 (0)