|
21 | 21 | import java.util.ArrayList; |
22 | 22 | import java.util.EnumSet; |
23 | 23 | import java.util.List; |
| 24 | +import java.util.Optional; |
24 | 25 | import java.util.Set; |
25 | 26 |
|
26 | 27 | public class OneHot extends ZerosLike { |
@@ -137,7 +138,12 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder) |
137 | 138 |
|
138 | 139 | for (int axis : possibleAxes) |
139 | 140 | for (InstanceKey depthIK : depthPTS) { |
140 | | - int depth = getIntValueFromInstanceKey(depthIK); |
| 141 | + int depth = |
| 142 | + getIntValueFromInstanceKey(depthIK) |
| 143 | + .orElseThrow( |
| 144 | + () -> |
| 145 | + new IllegalStateException( |
| 146 | + "Depth argument value for OneHot is not an integer: " + depthIK + ".")); |
141 | 147 |
|
142 | 148 | // For each shape in indices, append the depth as a new dimension. |
143 | 149 | for (List<Dimension<?>> shape : indices) { |
@@ -180,21 +186,21 @@ private Set<Integer> getPossibleAxes(PropagationCallGraphBuilder builder) { |
180 | 186 | // No axis argument value found; default to AXIS_END. |
181 | 187 | ret.add(AXIS_END); |
182 | 188 | else |
183 | | - for (InstanceKey instanceKey : pointsToSet) { |
184 | | - int axis = getIntValueFromInstanceKey(instanceKey); |
185 | | - ret.add(axis); |
186 | | - } |
| 189 | + for (InstanceKey instanceKey : pointsToSet) |
| 190 | + ret.add(getIntValueFromInstanceKey(instanceKey).orElse(AXIS_END)); |
187 | 191 | } |
188 | 192 | } |
189 | 193 |
|
190 | 194 | return ret; |
191 | 195 | } |
192 | 196 |
|
193 | | - private static int getIntValueFromInstanceKey(InstanceKey instanceKey) { |
| 197 | + private static Optional<Integer> getIntValueFromInstanceKey(InstanceKey instanceKey) { |
194 | 198 | if (instanceKey instanceof ConstantKey) { |
195 | 199 | ConstantKey<?> constantKey = (ConstantKey<?>) instanceKey; |
196 | 200 | Object value = constantKey.getValue(); |
197 | | - return ((Long) value).intValue(); |
| 201 | + |
| 202 | + if (value == null) return Optional.empty(); |
| 203 | + return Optional.of(((Long) value).intValue()); |
198 | 204 | } |
199 | 205 |
|
200 | 206 | throw new IllegalStateException( |
|
0 commit comments