33import static com .ibm .wala .cast .python .ml .client .OneHot .Parameters .AXIS ;
44import static com .ibm .wala .cast .python .ml .client .OneHot .Parameters .DEPTH ;
55import static com .ibm .wala .cast .python .ml .client .OneHot .Parameters .DTYPE ;
6+ import static com .ibm .wala .cast .python .ml .client .OneHot .Parameters .INDICES ;
67import static com .ibm .wala .cast .python .ml .client .OneHot .Parameters .OFF_VALUE ;
78import static com .ibm .wala .cast .python .ml .client .OneHot .Parameters .ON_VALUE ;
89
2425import java .util .Optional ;
2526import java .util .Set ;
2627
27- public class OneHot extends ZerosLike {
28+ public class OneHot extends Ones {
2829
2930 private static final String FUNCTION_NAME = "tf.one_hot()" ;
3031
@@ -87,6 +88,10 @@ protected int getDTypeParameterPosition() {
8788 return DTYPE .ordinal ();
8889 }
8990
91+ protected int getIndicesParameterPosition () {
92+ return INDICES .ordinal ();
93+ }
94+
9095 protected int getDepthParameterPosition () {
9196 return DEPTH .ordinal ();
9297 }
@@ -111,10 +116,14 @@ protected int getOffValueArgumentValueNumber() {
111116 return this .getArgumentValueNumber (this .getOffValueParameterPosition ());
112117 }
113118
119+ protected int getIndicesArgumentValueNumber () {
120+ return this .getArgumentValueNumber (this .getIndicesParameterPosition ());
121+ }
122+
114123 @ Override
115124 protected Set <List <Dimension <?>>> getShapes (PropagationCallGraphBuilder builder ) {
116125 Set <List <Dimension <?>>> ret = HashSetFactory .make ();
117- Set <List <Dimension <?>>> indices = this .getShapes (builder , this .getValueArgumentValueNumber ());
126+ Set <List <Dimension <?>>> indices = this .getShapes (builder , this .getIndicesArgumentValueNumber ());
118127 int depthArgumentValueNumber = this .getDepthArgumentValueNumber ();
119128
120129 if (depthArgumentValueNumber <= 0 )
0 commit comments