Skip to content

Commit c609ccb

Browse files
committed
Guard against None.
1 parent f0c3e99 commit c609ccb

File tree

1 file changed

+13
-7
lines changed
  • com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client

1 file changed

+13
-7
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/OneHot.java

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.ArrayList;
2222
import java.util.EnumSet;
2323
import java.util.List;
24+
import java.util.Optional;
2425
import java.util.Set;
2526

2627
public class OneHot extends ZerosLike {
@@ -137,7 +138,12 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
137138

138139
for (int axis : possibleAxes)
139140
for (InstanceKey depthIK : depthPTS) {
140-
int depth = getIntValueFromInstanceKey(depthIK);
141+
int depth =
142+
getIntValueFromInstanceKey(depthIK)
143+
.orElseThrow(
144+
() ->
145+
new IllegalStateException(
146+
"Depth argument value for OneHot is not an integer: " + depthIK + "."));
141147

142148
// For each shape in indices, append the depth as a new dimension.
143149
for (List<Dimension<?>> shape : indices) {
@@ -180,21 +186,21 @@ private Set<Integer> getPossibleAxes(PropagationCallGraphBuilder builder) {
180186
// No axis argument value found; default to AXIS_END.
181187
ret.add(AXIS_END);
182188
else
183-
for (InstanceKey instanceKey : pointsToSet) {
184-
int axis = getIntValueFromInstanceKey(instanceKey);
185-
ret.add(axis);
186-
}
189+
for (InstanceKey instanceKey : pointsToSet)
190+
ret.add(getIntValueFromInstanceKey(instanceKey).orElse(AXIS_END));
187191
}
188192
}
189193

190194
return ret;
191195
}
192196

193-
private static int getIntValueFromInstanceKey(InstanceKey instanceKey) {
197+
private static Optional<Integer> getIntValueFromInstanceKey(InstanceKey instanceKey) {
194198
if (instanceKey instanceof ConstantKey) {
195199
ConstantKey<?> constantKey = (ConstantKey<?>) instanceKey;
196200
Object value = constantKey.getValue();
197-
return ((Long) value).intValue();
201+
202+
if (value == null) return Optional.empty();
203+
return Optional.of(((Long) value).intValue());
198204
}
199205

200206
throw new IllegalStateException(

0 commit comments

Comments
 (0)