Skip to content

Commit a8acc8d

Browse files
committed
Switch hierarchy.
1 parent bf63246 commit a8acc8d

File tree

1 file changed

+11
-2
lines changed
  • com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client

1 file changed

+11
-2
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/OneHot.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import static com.ibm.wala.cast.python.ml.client.OneHot.Parameters.AXIS;
44
import static com.ibm.wala.cast.python.ml.client.OneHot.Parameters.DEPTH;
55
import static com.ibm.wala.cast.python.ml.client.OneHot.Parameters.DTYPE;
6+
import static com.ibm.wala.cast.python.ml.client.OneHot.Parameters.INDICES;
67
import static com.ibm.wala.cast.python.ml.client.OneHot.Parameters.OFF_VALUE;
78
import static com.ibm.wala.cast.python.ml.client.OneHot.Parameters.ON_VALUE;
89

@@ -24,7 +25,7 @@
2425
import java.util.Optional;
2526
import java.util.Set;
2627

27-
public class OneHot extends ZerosLike {
28+
public class OneHot extends Ones {
2829

2930
private static final String FUNCTION_NAME = "tf.one_hot()";
3031

@@ -87,6 +88,10 @@ protected int getDTypeParameterPosition() {
8788
return DTYPE.ordinal();
8889
}
8990

91+
protected int getIndicesParameterPosition() {
92+
return INDICES.ordinal();
93+
}
94+
9095
protected int getDepthParameterPosition() {
9196
return DEPTH.ordinal();
9297
}
@@ -111,10 +116,14 @@ protected int getOffValueArgumentValueNumber() {
111116
return this.getArgumentValueNumber(this.getOffValueParameterPosition());
112117
}
113118

119+
protected int getIndicesArgumentValueNumber() {
120+
return this.getArgumentValueNumber(this.getIndicesParameterPosition());
121+
}
122+
114123
@Override
115124
protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder) {
116125
Set<List<Dimension<?>>> ret = HashSetFactory.make();
117-
Set<List<Dimension<?>>> indices = this.getShapes(builder, this.getValueArgumentValueNumber());
126+
Set<List<Dimension<?>>> indices = this.getShapes(builder, this.getIndicesArgumentValueNumber());
118127
int depthArgumentValueNumber = this.getDepthArgumentValueNumber();
119128

120129
if (depthArgumentValueNumber <= 0)

0 commit comments

Comments
 (0)