Skip to content

Commit ff52a09

Browse files
committed
Initial work on tf.eye().
1 parent 573f06a commit ff52a09

File tree

6 files changed

+222
-3
lines changed

6 files changed

+222
-3
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,19 +1478,34 @@ public void testAdd50()
14781478
@Test
14791479
public void testAdd51()
14801480
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1481-
test("tf2_test_add51.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1481+
test(
1482+
"tf2_test_add51.py",
1483+
"add",
1484+
2,
1485+
2,
1486+
Map.of(2, Set.of(TENSOR_2_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32)));
14821487
}
14831488

14841489
@Test
14851490
public void testAdd52()
14861491
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1487-
test("tf2_test_add52.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1492+
test(
1493+
"tf2_test_add52.py",
1494+
"add",
1495+
2,
1496+
2,
1497+
Map.of(2, Set.of(TENSOR_2_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32)));
14881498
}
14891499

14901500
@Test
14911501
public void testAdd53()
14921502
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1493-
test("tf2_test_add53.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1503+
test(
1504+
"tf2_test_add53.py",
1505+
"add",
1506+
2,
1507+
2,
1508+
Map.of(2, Set.of(TENSOR_2_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32)));
14941509
}
14951510

14961511
@Test
@@ -4337,6 +4352,18 @@ public void testOneHot19()
43374352
test("tf2_test_one_hot19.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_5_3_FLOAT32)));
43384353
}
43394354

4355+
@Test
4356+
public void testEye()
4357+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4358+
test("tf2_test_eye.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32)));
4359+
}
4360+
4361+
@Test
4362+
public void testEye2()
4363+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4364+
test("tf2_test_eye2.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32)));
4365+
}
4366+
43404367
private void test(
43414368
String filename,
43424369
String functionName,
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package com.ibm.wala.cast.python.ml.client;
2+
3+
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.BATCH_SHAPE;
4+
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.DTYPE;
5+
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.NUM_COLUMNS;
6+
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.NUM_ROWS;
7+
8+
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
9+
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
10+
import com.ibm.wala.ipa.callgraph.CGNode;
11+
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
12+
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
13+
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
14+
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
15+
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
16+
import com.ibm.wala.util.collections.HashSetFactory;
17+
import java.util.ArrayList;
18+
import java.util.List;
19+
import java.util.Optional;
20+
import java.util.Set;
21+
import java.util.stream.Collectors;
22+
import java.util.stream.StreamSupport;
23+
24+
public class Eye extends Ones {
25+
26+
private static final String FUNCTION_NAME = "tf.eye()";
27+
28+
private static final int SHAPE_PARAMETER_POSITION = -1;
29+
30+
enum Parameters {
31+
NUM_ROWS,
32+
NUM_COLUMNS,
33+
BATCH_SHAPE,
34+
DTYPE,
35+
NAME
36+
}
37+
38+
public Eye(PointsToSetVariable source, CGNode node) {
39+
super(source, node);
40+
}
41+
42+
@Override
43+
protected String getSignature() {
44+
return FUNCTION_NAME;
45+
}
46+
47+
@Override
48+
protected int getShapeParameterPosition() {
49+
return SHAPE_PARAMETER_POSITION;
50+
}
51+
52+
protected int getNumRowsParameterPosition() {
53+
return NUM_ROWS.ordinal();
54+
}
55+
56+
protected int getNumRowsValueNumber(PropagationCallGraphBuilder builder) {
57+
return this.getArgumentValueNumber(this.getNumRowsParameterPosition());
58+
}
59+
60+
protected int getNumColumnsParameterPosition() {
61+
return NUM_COLUMNS.ordinal();
62+
}
63+
64+
protected int getNumColumnsValueNumber(PropagationCallGraphBuilder builder) {
65+
return this.getArgumentValueNumber(this.getNumColumnsParameterPosition());
66+
}
67+
68+
protected int getBatchShapeParameterPosition() {
69+
return BATCH_SHAPE.ordinal();
70+
}
71+
72+
protected int getBatchShapeValueNumber(PropagationCallGraphBuilder builder) {
73+
return this.getArgumentValueNumber(this.getBatchShapeParameterPosition());
74+
}
75+
76+
@Override
77+
protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder) {
78+
Set<List<Dimension<?>>> ret = HashSetFactory.make();
79+
Set<Optional<Integer>> numRows = this.getNumberOfRows(builder);
80+
Set<Optional<Integer>> numColumns = this.getNumberOfColumns(builder);
81+
82+
for (Optional<Integer> nRow : numRows) {
83+
if (numColumns.isEmpty())
84+
// If numColumns is not provided, it defaults to numRows.
85+
for (Optional<Integer> nCol : numRows)
86+
// Build the shape using nRow and nCol.
87+
numColumns.add(nCol);
88+
89+
for (Optional<Integer> nCol : numColumns)
90+
if (nCol.isEmpty()) {
91+
// If numColumns is not provided, it defaults to numRows.
92+
for (Optional<Integer> nCol2 : numRows) {
93+
// Build the shape using nRow and nCol.
94+
List<Dimension<?>> shape = new ArrayList<>();
95+
96+
NumericDim rowDim = new NumericDim(nRow.get());
97+
NumericDim colDim = new NumericDim(nCol2.get());
98+
99+
shape.add(rowDim);
100+
shape.add(colDim);
101+
102+
ret.add(shape);
103+
}
104+
} else {
105+
List<Dimension<?>> shape = new ArrayList<>();
106+
107+
NumericDim rowDim = new NumericDim(nRow.get());
108+
NumericDim colDim = new NumericDim(nCol.get());
109+
110+
shape.add(rowDim);
111+
shape.add(colDim);
112+
113+
ret.add(shape);
114+
}
115+
}
116+
117+
return ret;
118+
}
119+
120+
private Set<Optional<Integer>> getNumberOfRows(PropagationCallGraphBuilder builder) {
121+
// TODO Handle keyword arguments.
122+
return this.getPossiblePositionalArgumentValues(builder, this.getNumRowsParameterPosition());
123+
}
124+
125+
private Set<Optional<Integer>> getNumberOfColumns(PropagationCallGraphBuilder builder) {
126+
// TODO Handle keyword arguments.
127+
return this.getPossiblePositionalArgumentValues(builder, this.getNumColumnsParameterPosition());
128+
}
129+
130+
private Set<Optional<Integer>> getPossiblePositionalArgumentValues(
131+
PropagationCallGraphBuilder builder, int paramPosition) {
132+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
133+
Set<Integer> possibleNumArgs = this.getNumberOfPossiblePositionalArguments(builder);
134+
135+
return possibleNumArgs.stream()
136+
.filter(numArgs -> numArgs >= paramPosition + 1)
137+
.map(
138+
_ -> {
139+
PointerKey pointerKey =
140+
pointerAnalysis
141+
.getHeapModel()
142+
.getPointerKeyForLocal(
143+
this.getNode(), this.getArgumentValueNumber(paramPosition));
144+
return pointerAnalysis.getPointsToSet(pointerKey);
145+
})
146+
.flatMap(pts -> StreamSupport.stream(pts.spliterator(), false))
147+
.map(Eye::getIntValueFromInstanceKey)
148+
.collect(Collectors.toSet());
149+
}
150+
151+
@Override
152+
protected int getDTypeParameterPosition() {
153+
return DTYPE.ordinal();
154+
}
155+
}

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.CONSTANT;
44
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.CONVERT_TO_TENSOR;
5+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.EYE;
56
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.FILL;
67
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.NORMAL;
78
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONES;
@@ -52,6 +53,7 @@ else if (calledFunction.equals(ZEROS_LIKE.getDeclaringClass()))
5253
else if (calledFunction.equals(CONVERT_TO_TENSOR.getDeclaringClass()))
5354
return new ConvertToTensor(source, node);
5455
else if (calledFunction.equals(ONE_HOT.getDeclaringClass())) return new OneHot(source, node);
56+
else if (calledFunction.equals(EYE.getDeclaringClass())) return new Eye(source, node);
5557
else
5658
throw new IllegalArgumentException(
5759
"Unknown call: " + calledFunction + " for source: " + source + ".");

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ public boolean canConvertTo(DType other) {
150150
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/one_hot")),
151151
AstMethodReference.fnSelector);
152152

153+
/** https://www.tensorflow.org/api_docs/python/tf/eye. */
154+
public static final MethodReference EYE =
155+
MethodReference.findOrCreate(
156+
TypeReference.findOrCreate(
157+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/eye")),
158+
AstMethodReference.fnSelector);
159+
153160
/**
154161
* Represents the TensorFlow float32 data type.
155162
*
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import tensorflow as tf
2+
3+
4+
def f(ab):
5+
pass
6+
7+
8+
# Construct one identity matrix.
9+
arg = tf.eye(2)
10+
assert isinstance(arg, tf.Tensor)
11+
assert arg.dtype == tf.float32
12+
assert arg.shape == (2, 2)
13+
14+
f(arg)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import tensorflow as tf
2+
3+
4+
def f(ab):
5+
pass
6+
7+
8+
# Construct one identity matrix.
9+
arg = tf.eye(2, None)
10+
assert isinstance(arg, tf.Tensor)
11+
assert arg.dtype == tf.float32
12+
assert arg.shape == (2, 2)
13+
14+
f(arg)

0 commit comments

Comments
 (0)