Skip to content

Commit 3612e02

Browse files
committed
Handle batch dimensions for tf.eye().
1 parent a98f6a7 commit 3612e02

File tree

5 files changed

+97
-8
lines changed

5 files changed

+97
-8
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
110110
private static final TensorType TENSOR_2_5_3_FLOAT32 =
111111
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(5), new NumericDim(3)));
112112

113+
private static final TensorType TENSOR_3_2_2_FLOAT32 =
114+
new TensorType(FLOAT_32, asList(new NumericDim(3), new NumericDim(2), new NumericDim(2)));
115+
116+
private static final TensorType TENSOR_3_2_2_3_FLOAT32 =
117+
new TensorType(
118+
FLOAT_32,
119+
asList(new NumericDim(3), new NumericDim(2), new NumericDim(2), new NumericDim(3)));
120+
113121
private static final TensorType TENSOR_20_28_28_FLOAT32 =
114122
new TensorType(FLOAT_32, asList(new NumericDim(20), new NumericDim(28), new NumericDim(28)));
115123

@@ -4370,6 +4378,18 @@ public void testEye3()
43704378
test("tf2_test_eye3.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32)));
43714379
}
43724380

4381+
@Test
4382+
public void testEye4()
4383+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4384+
test("tf2_test_eye4.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_2_2_FLOAT32)));
4385+
}
4386+
4387+
@Test
4388+
public void testEye5()
4389+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4390+
test("tf2_test_eye5.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_2_2_3_FLOAT32)));
4391+
}
4392+
43734393
private void test(
43744394
String filename,
43754395
String functionName,

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Eye.java

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.DTYPE;
55
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.NUM_COLUMNS;
66
import static com.ibm.wala.cast.python.ml.client.Eye.Parameters.NUM_ROWS;
7+
import static java.util.Collections.emptySet;
78

89
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
910
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
@@ -14,6 +15,7 @@
1415
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
1516
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
1617
import com.ibm.wala.util.collections.HashSetFactory;
18+
import com.ibm.wala.util.intset.OrdinalSet;
1719
import java.util.ArrayList;
1820
import java.util.List;
1921
import java.util.Optional;
@@ -53,26 +55,27 @@ protected int getNumRowsParameterPosition() {
5355
return NUM_ROWS.ordinal();
5456
}
5557

56-
protected int getNumRowsValueNumber(PropagationCallGraphBuilder builder) {
58+
protected int getNumRowsArgumentValueNumber() {
5759
return this.getArgumentValueNumber(this.getNumRowsParameterPosition());
5860
}
5961

62+
protected int getBatchShapesArgumentValueNumber() {
63+
// TOOD: Handle keyword arguments.
64+
return this.getArgumentValueNumber(this.getBatchShapeParameterPosition());
65+
}
66+
6067
protected int getNumColumnsParameterPosition() {
6168
return NUM_COLUMNS.ordinal();
6269
}
6370

64-
protected int getNumColumnsValueNumber(PropagationCallGraphBuilder builder) {
71+
protected int getNumColumnsArgumentValueNumber(PropagationCallGraphBuilder builder) {
6572
return this.getArgumentValueNumber(this.getNumColumnsParameterPosition());
6673
}
6774

6875
protected int getBatchShapeParameterPosition() {
6976
return BATCH_SHAPE.ordinal();
7077
}
7178

72-
protected int getBatchShapeValueNumber(PropagationCallGraphBuilder builder) {
73-
return this.getArgumentValueNumber(this.getBatchShapeParameterPosition());
74-
}
75-
7679
@Override
7780
protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder) {
7881
Set<List<Dimension<?>>> ret = HashSetFactory.make();
@@ -114,6 +117,12 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
114117
}
115118
}
116119

120+
Set<List<Dimension<?>>> batchShapes = this.getBatchShapes(builder);
121+
122+
// prepend batch dimensions to each shape.
123+
for (List<Dimension<?>> batchDim : batchShapes)
124+
for (List<Dimension<?>> retDim : ret) retDim.addAll(0, batchDim);
125+
117126
return ret;
118127
}
119128

@@ -133,6 +142,33 @@ private Set<Optional<Integer>> getNumberOfColumns(PropagationCallGraphBuilder bu
133142
return this.getPossiblePositionalArgumentValues(builder, this.getNumColumnsParameterPosition());
134143
}
135144

145+
private Set<List<Dimension<?>>> getBatchShapes(PropagationCallGraphBuilder builder) {
146+
// TODO Handle keyword arguments.
147+
Set<Integer> possibleNumArgs = this.getNumberOfPossiblePositionalArguments(builder);
148+
149+
if (possibleNumArgs.contains(this.getBatchShapeParameterPosition() + 1)) {
150+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
151+
152+
PointerKey pointerKey =
153+
pointerAnalysis
154+
.getHeapModel()
155+
.getPointerKeyForLocal(this.getNode(), this.getBatchShapesArgumentValueNumber());
156+
157+
OrdinalSet<InstanceKey> pts = pointerAnalysis.getPointsToSet(pointerKey);
158+
159+
Set<List<Dimension<?>>> shapesFromShapeArgument =
160+
this.getShapesFromShapeArgument(builder, pts);
161+
162+
if (shapesFromShapeArgument == null || shapesFromShapeArgument.isEmpty())
163+
throw new IllegalStateException(
164+
"Batch shape argument for tf.eye() should be a list of dimensions.");
165+
166+
return shapesFromShapeArgument;
167+
}
168+
169+
return emptySet();
170+
}
171+
136172
private Set<Optional<Integer>> getPossiblePositionalArgumentValues(
137173
PropagationCallGraphBuilder builder, int paramPosition) {
138174
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ protected Set<List<Dimension<?>>> getShapesFromShapeArgument(
103103
AllocationSiteInNode asin = getAllocationSiteInNode(instanceKey);
104104
TypeReference reference = asin.getConcreteType().getReference();
105105

106-
if (reference.equals(list)) { // TODO: This can also be a tuple of tensors.
106+
if (reference.equals(list) || reference.equals(tuple)) {
107107
// We have a list of integers that represent the shape.
108108
OrdinalSet<InstanceKey> objectCatalogPointsToSet =
109109
pointerAnalysis.getPointsToSet(
@@ -206,7 +206,13 @@ protected Set<List<Dimension<?>>> getShapesFromShapeArgument(
206206
}
207207
} else
208208
throw new IllegalStateException(
209-
"Expected a " + PythonTypes.list + " for the shape, but got: " + reference + ".");
209+
"Expected a "
210+
+ PythonTypes.list
211+
+ " or "
212+
+ PythonTypes.tuple
213+
+ " for the shape, but got: "
214+
+ reference
215+
+ ".");
210216
}
211217

212218
return ret;
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
# Construct a batch of 3 identity matrices, each 2 x 2.
9+
# batch_identity[i, :, :] is a 2 x 2 identity matrix, i = 0, 1, 2.
10+
arg = tf.eye(2, None, [3])
11+
assert arg.shape == (3, 2, 2)
12+
assert arg.dtype == tf.float32
13+
14+
f(arg)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
arg = tf.eye(2, 3, [3, 2])
9+
10+
assert arg.shape == (3, 2, 2, 3)
11+
assert arg.dtype == tf.float32
12+
13+
f(arg)

0 commit comments

Comments
 (0)