Skip to content

Commit cb07edc

Browse files
committed
Handle tf.sparse.eye() in addition to tf.eye().
1 parent 550c997 commit cb07edc

File tree

11 files changed

+296
-111
lines changed

11 files changed

+296
-111
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,18 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
107107
private static final TensorType TENSOR_10_2_FLOAT64 =
108108
new TensorType(FLOAT_64, asList(new NumericDim(10), new NumericDim(2)));
109109

110+
private static final TensorType TENSOR_5_2_FLOAT32 =
111+
new TensorType(FLOAT_32, asList(new NumericDim(5), new NumericDim(2)));
112+
113+
private static final TensorType TENSOR_5_2_INT32 =
114+
new TensorType(INT_32, asList(new NumericDim(5), new NumericDim(2)));
115+
116+
private static final TensorType TENSOR_5_5_FLOAT32 =
117+
new TensorType(FLOAT_32, asList(new NumericDim(5), new NumericDim(5)));
118+
119+
private static final TensorType TENSOR_5_5_INT32 =
120+
new TensorType(INT_32, asList(new NumericDim(5), new NumericDim(5)));
121+
110122
private static final TensorType TENSOR_2_3_3_INT32 =
111123
new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(3)));
112124

@@ -1928,19 +1940,34 @@ public void testAdd110()
19281940
@Test
19291941
public void testAdd111()
19301942
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1931-
test("tf2_test_add111.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1943+
test(
1944+
"tf2_test_add111.py",
1945+
"add",
1946+
2,
1947+
2,
1948+
Map.of(2, Set.of(TENSOR_2_3_FLOAT32), 3, Set.of(TENSOR_2_3_FLOAT32)));
19321949
}
19331950

19341951
@Test
19351952
public void testAdd112()
19361953
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1937-
test("tf2_test_add112.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1954+
test(
1955+
"tf2_test_add112.py",
1956+
"add",
1957+
2,
1958+
2,
1959+
Map.of(2, Set.of(TENSOR_2_3_FLOAT32), 3, Set.of(TENSOR_2_3_FLOAT32)));
19381960
}
19391961

19401962
@Test
19411963
public void testAdd113()
19421964
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1943-
test("tf2_test_add113.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1965+
test(
1966+
"tf2_test_add113.py",
1967+
"add",
1968+
2,
1969+
2,
1970+
Map.of(2, Set.of(TENSOR_2_3_FLOAT32), 3, Set.of(TENSOR_2_3_FLOAT32)));
19441971
}
19451972

19461973
@Test
@@ -4505,6 +4532,36 @@ public void testPoisson4()
45054532
test("tf2_test_poisson4.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_7_5_2_FLOAT32)));
45064533
}
45074534

4535+
@Test
4536+
public void testSparseEye() throws ClassHierarchyException, CancelException, IOException {
4537+
test("tf2_test_sparse_eye.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_5_FLOAT32)));
4538+
}
4539+
4540+
@Test
4541+
public void testSparseEye2() throws ClassHierarchyException, CancelException, IOException {
4542+
test("tf2_test_sparse_eye2.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_5_FLOAT32)));
4543+
}
4544+
4545+
@Test
4546+
public void testSparseEye3() throws ClassHierarchyException, CancelException, IOException {
4547+
test("tf2_test_sparse_eye3.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_5_INT32)));
4548+
}
4549+
4550+
@Test
4551+
public void testSparseEye4() throws ClassHierarchyException, CancelException, IOException {
4552+
test("tf2_test_sparse_eye4.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_2_FLOAT32)));
4553+
}
4554+
4555+
@Test
4556+
public void testSparseEye5() throws ClassHierarchyException, CancelException, IOException {
4557+
test("tf2_test_sparse_eye5.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_2_FLOAT32)));
4558+
}
4559+
4560+
@Test
4561+
public void testSparseEye6() throws ClassHierarchyException, CancelException, IOException {
4562+
test("tf2_test_sparse_eye6.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_2_INT32)));
4563+
}
4564+
45084565
private void test(
45094566
String filename,
45104567
String functionName,

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Eye.java

Lines changed: 2 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,23 @@
22

33
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.BATCH_SHAPE;
44
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.DTYPE;
5-
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.NUM_COLUMNS;
6-
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.NUM_ROWS;
75
import static java.util.Collections.emptySet;
86

97
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
10-
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
118
import com.ibm.wala.ipa.callgraph.CGNode;
129
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
1310
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
1411
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
1512
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
1613
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
17-
import com.ibm.wala.util.collections.HashSetFactory;
1814
import com.ibm.wala.util.intset.OrdinalSet;
19-
import java.util.ArrayList;
2015
import java.util.List;
21-
import java.util.Optional;
2216
import java.util.Set;
23-
import java.util.stream.Collectors;
24-
import java.util.stream.StreamSupport;
2517

26-
public class Eye extends Ones {
18+
public class Eye extends SparseEye {
2719

2820
private static final String FUNCTION_NAME = "tf.eye()";
2921

30-
private static final int SHAPE_PARAMETER_POSITION = -1;
31-
3222
protected enum Parameters {
3323
NUM_ROWS,
3424
NUM_COLUMNS,
@@ -46,77 +36,18 @@ protected String getSignature() {
4636
return FUNCTION_NAME;
4737
}
4838

49-
@Override
50-
protected int getShapeParameterPosition() {
51-
return SHAPE_PARAMETER_POSITION;
52-
}
53-
54-
protected int getNumRowsParameterPosition() {
55-
return NUM_ROWS.ordinal();
56-
}
57-
58-
protected int getNumRowsArgumentValueNumber() {
59-
return this.getArgumentValueNumber(this.getNumRowsParameterPosition());
60-
}
61-
6239
protected int getBatchShapesArgumentValueNumber() {
6340
// TOOD: Handle keyword arguments.
6441
return this.getArgumentValueNumber(this.getBatchShapeParameterPosition());
6542
}
6643

67-
protected int getNumColumnsParameterPosition() {
68-
return NUM_COLUMNS.ordinal();
69-
}
70-
71-
protected int getNumColumnsArgumentValueNumber(PropagationCallGraphBuilder builder) {
72-
return this.getArgumentValueNumber(this.getNumColumnsParameterPosition());
73-
}
74-
7544
protected int getBatchShapeParameterPosition() {
7645
return BATCH_SHAPE.ordinal();
7746
}
7847

7948
@Override
8049
protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder) {
81-
Set<List<Dimension<?>>> ret = HashSetFactory.make();
82-
Set<Optional<Integer>> numRows = this.getNumberOfRows(builder);
83-
Set<Optional<Integer>> numColumns = this.getNumberOfColumns(builder);
84-
85-
for (Optional<Integer> nRow : numRows) {
86-
if (numColumns.isEmpty())
87-
// If numColumns is not provided, it defaults to numRows.
88-
for (Optional<Integer> nCol : numRows)
89-
// Build the shape using nRow and nCol.
90-
numColumns.add(nCol);
91-
92-
for (Optional<Integer> nCol : numColumns)
93-
if (nCol.isEmpty()) {
94-
// If numColumns is not provided, it defaults to numRows.
95-
for (Optional<Integer> nCol2 : numRows) {
96-
// Build the shape using nRow and nCol.
97-
List<Dimension<?>> shape = new ArrayList<>();
98-
99-
NumericDim rowDim = new NumericDim(nRow.get());
100-
NumericDim colDim = new NumericDim(nCol2.get());
101-
102-
shape.add(rowDim);
103-
shape.add(colDim);
104-
105-
ret.add(shape);
106-
}
107-
} else {
108-
List<Dimension<?>> shape = new ArrayList<>();
109-
110-
NumericDim rowDim = new NumericDim(nRow.get());
111-
NumericDim colDim = new NumericDim(nCol.get());
112-
113-
shape.add(rowDim);
114-
shape.add(colDim);
115-
116-
ret.add(shape);
117-
}
118-
}
119-
50+
Set<List<Dimension<?>>> ret = super.getShapes(builder);
12051
Set<List<Dimension<?>>> batchShapes = this.getBatchShapes(builder);
12152

12253
// prepend batch dimensions to each shape.
@@ -126,22 +57,6 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
12657
return ret;
12758
}
12859

129-
private Set<Optional<Integer>> getNumberOfRows(PropagationCallGraphBuilder builder) {
130-
// TODO Handle keyword arguments.
131-
Set<Optional<Integer>> values =
132-
this.getPossiblePositionalArgumentValues(builder, this.getNumRowsParameterPosition());
133-
134-
if (values == null || values.isEmpty())
135-
throw new IllegalStateException("The num_rows parameter is required for tf.eye().");
136-
137-
return values;
138-
}
139-
140-
private Set<Optional<Integer>> getNumberOfColumns(PropagationCallGraphBuilder builder) {
141-
// TODO Handle keyword arguments.
142-
return this.getPossiblePositionalArgumentValues(builder, this.getNumColumnsParameterPosition());
143-
}
144-
14560
private Set<List<Dimension<?>>> getBatchShapes(PropagationCallGraphBuilder builder) {
14661
// TODO Handle keyword arguments.
14762
Set<Integer> possibleNumArgs = this.getNumberOfPossiblePositionalArguments(builder);
@@ -169,27 +84,6 @@ private Set<List<Dimension<?>>> getBatchShapes(PropagationCallGraphBuilder build
16984
return emptySet();
17085
}
17186

172-
private Set<Optional<Integer>> getPossiblePositionalArgumentValues(
173-
PropagationCallGraphBuilder builder, int paramPosition) {
174-
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
175-
Set<Integer> possibleNumArgs = this.getNumberOfPossiblePositionalArguments(builder);
176-
177-
return possibleNumArgs.stream()
178-
.filter(numArgs -> numArgs >= paramPosition + 1)
179-
.map(
180-
_ -> {
181-
PointerKey pointerKey =
182-
pointerAnalysis
183-
.getHeapModel()
184-
.getPointerKeyForLocal(
185-
this.getNode(), this.getArgumentValueNumber(paramPosition));
186-
return pointerAnalysis.getPointsToSet(pointerKey);
187-
})
188-
.flatMap(pts -> StreamSupport.stream(pts.spliterator(), false))
189-
.map(Eye::getIntValueFromInstanceKey)
190-
.collect(Collectors.toSet());
191-
}
192-
19387
@Override
19488
protected int getDTypeParameterPosition() {
19589
return DTYPE.ordinal();

0 commit comments

Comments
 (0)