22
33import static com .ibm .wala .cast .python .ml .client .Eye .Parameters .BATCH_SHAPE ;
44import static com .ibm .wala .cast .python .ml .client .Eye .Parameters .DTYPE ;
5- import static com .ibm .wala .cast .python .ml .client .Eye .Parameters .NUM_COLUMNS ;
6- import static com .ibm .wala .cast .python .ml .client .Eye .Parameters .NUM_ROWS ;
75import static java .util .Collections .emptySet ;
86
97import com .ibm .wala .cast .python .ml .types .TensorType .Dimension ;
10- import com .ibm .wala .cast .python .ml .types .TensorType .NumericDim ;
118import com .ibm .wala .ipa .callgraph .CGNode ;
129import com .ibm .wala .ipa .callgraph .propagation .InstanceKey ;
1310import com .ibm .wala .ipa .callgraph .propagation .PointerAnalysis ;
1411import com .ibm .wala .ipa .callgraph .propagation .PointerKey ;
1512import com .ibm .wala .ipa .callgraph .propagation .PointsToSetVariable ;
1613import com .ibm .wala .ipa .callgraph .propagation .PropagationCallGraphBuilder ;
17- import com .ibm .wala .util .collections .HashSetFactory ;
1814import com .ibm .wala .util .intset .OrdinalSet ;
19- import java .util .ArrayList ;
2015import java .util .List ;
21- import java .util .Optional ;
2216import java .util .Set ;
23- import java .util .stream .Collectors ;
24- import java .util .stream .StreamSupport ;
2517
26- public class Eye extends Ones {
18+ public class Eye extends SparseEye {
2719
2820 private static final String FUNCTION_NAME = "tf.eye()" ;
2921
30- private static final int SHAPE_PARAMETER_POSITION = -1 ;
31-
3222 protected enum Parameters {
3323 NUM_ROWS ,
3424 NUM_COLUMNS ,
@@ -46,77 +36,18 @@ protected String getSignature() {
4636 return FUNCTION_NAME ;
4737 }
4838
49- @ Override
50- protected int getShapeParameterPosition () {
51- return SHAPE_PARAMETER_POSITION ;
52- }
53-
54- protected int getNumRowsParameterPosition () {
55- return NUM_ROWS .ordinal ();
56- }
57-
58- protected int getNumRowsArgumentValueNumber () {
59- return this .getArgumentValueNumber (this .getNumRowsParameterPosition ());
60- }
61-
6239 protected int getBatchShapesArgumentValueNumber () {
6340 // TOOD: Handle keyword arguments.
6441 return this .getArgumentValueNumber (this .getBatchShapeParameterPosition ());
6542 }
6643
67- protected int getNumColumnsParameterPosition () {
68- return NUM_COLUMNS .ordinal ();
69- }
70-
71- protected int getNumColumnsArgumentValueNumber (PropagationCallGraphBuilder builder ) {
72- return this .getArgumentValueNumber (this .getNumColumnsParameterPosition ());
73- }
74-
7544 protected int getBatchShapeParameterPosition () {
7645 return BATCH_SHAPE .ordinal ();
7746 }
7847
7948 @ Override
8049 protected Set <List <Dimension <?>>> getShapes (PropagationCallGraphBuilder builder ) {
81- Set <List <Dimension <?>>> ret = HashSetFactory .make ();
82- Set <Optional <Integer >> numRows = this .getNumberOfRows (builder );
83- Set <Optional <Integer >> numColumns = this .getNumberOfColumns (builder );
84-
85- for (Optional <Integer > nRow : numRows ) {
86- if (numColumns .isEmpty ())
87- // If numColumns is not provided, it defaults to numRows.
88- for (Optional <Integer > nCol : numRows )
89- // Build the shape using nRow and nCol.
90- numColumns .add (nCol );
91-
92- for (Optional <Integer > nCol : numColumns )
93- if (nCol .isEmpty ()) {
94- // If numColumns is not provided, it defaults to numRows.
95- for (Optional <Integer > nCol2 : numRows ) {
96- // Build the shape using nRow and nCol.
97- List <Dimension <?>> shape = new ArrayList <>();
98-
99- NumericDim rowDim = new NumericDim (nRow .get ());
100- NumericDim colDim = new NumericDim (nCol2 .get ());
101-
102- shape .add (rowDim );
103- shape .add (colDim );
104-
105- ret .add (shape );
106- }
107- } else {
108- List <Dimension <?>> shape = new ArrayList <>();
109-
110- NumericDim rowDim = new NumericDim (nRow .get ());
111- NumericDim colDim = new NumericDim (nCol .get ());
112-
113- shape .add (rowDim );
114- shape .add (colDim );
115-
116- ret .add (shape );
117- }
118- }
119-
50+ Set <List <Dimension <?>>> ret = super .getShapes (builder );
12051 Set <List <Dimension <?>>> batchShapes = this .getBatchShapes (builder );
12152
12253 // prepend batch dimensions to each shape.
@@ -126,22 +57,6 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
12657 return ret ;
12758 }
12859
129- private Set <Optional <Integer >> getNumberOfRows (PropagationCallGraphBuilder builder ) {
130- // TODO Handle keyword arguments.
131- Set <Optional <Integer >> values =
132- this .getPossiblePositionalArgumentValues (builder , this .getNumRowsParameterPosition ());
133-
134- if (values == null || values .isEmpty ())
135- throw new IllegalStateException ("The num_rows parameter is required for tf.eye()." );
136-
137- return values ;
138- }
139-
140- private Set <Optional <Integer >> getNumberOfColumns (PropagationCallGraphBuilder builder ) {
141- // TODO Handle keyword arguments.
142- return this .getPossiblePositionalArgumentValues (builder , this .getNumColumnsParameterPosition ());
143- }
144-
14560 private Set <List <Dimension <?>>> getBatchShapes (PropagationCallGraphBuilder builder ) {
14661 // TODO Handle keyword arguments.
14762 Set <Integer > possibleNumArgs = this .getNumberOfPossiblePositionalArguments (builder );
@@ -169,27 +84,6 @@ private Set<List<Dimension<?>>> getBatchShapes(PropagationCallGraphBuilder build
16984 return emptySet ();
17085 }
17186
172- private Set <Optional <Integer >> getPossiblePositionalArgumentValues (
173- PropagationCallGraphBuilder builder , int paramPosition ) {
174- PointerAnalysis <InstanceKey > pointerAnalysis = builder .getPointerAnalysis ();
175- Set <Integer > possibleNumArgs = this .getNumberOfPossiblePositionalArguments (builder );
176-
177- return possibleNumArgs .stream ()
178- .filter (numArgs -> numArgs >= paramPosition + 1 )
179- .map (
180- _ -> {
181- PointerKey pointerKey =
182- pointerAnalysis
183- .getHeapModel ()
184- .getPointerKeyForLocal (
185- this .getNode (), this .getArgumentValueNumber (paramPosition ));
186- return pointerAnalysis .getPointsToSet (pointerKey );
187- })
188- .flatMap (pts -> StreamSupport .stream (pts .spliterator (), false ))
189- .map (Eye ::getIntValueFromInstanceKey )
190- .collect (Collectors .toSet ());
191- }
192-
19387 @ Override
19488 protected int getDTypeParameterPosition () {
19589 return DTYPE .ordinal ();
0 commit comments