Skip to content

Commit 9a54455

Browse files
committed
Merge branch '267-initial-tensor-dimensions-arent-always-accurate' of https://github.com/ponder-lab/ML into 267-initial-tensor-dimensions-arent-always-accurate
2 parents 7cd16ef + e494084 commit 9a54455

File tree

17 files changed

+597
-19
lines changed

17 files changed

+597
-19
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 125 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.ibm.wala.cast.python.ipa.callgraph.PythonSSAPropagationCallGraphBuilder;
2222
import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis;
2323
import com.ibm.wala.cast.python.ml.analysis.TensorVariable;
24+
import com.ibm.wala.cast.python.ml.client.NonBroadcastableShapesException;
2425
import com.ibm.wala.cast.python.ml.client.PythonTensorAnalysisEngine;
2526
import com.ibm.wala.cast.python.ml.types.TensorType;
2627
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
@@ -120,6 +121,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
120121
private static final TensorType TENSOR_2_NONE_2_INT32 =
121122
new TensorType(INT_32, asList(new NumericDim(2), null, new NumericDim(2)));
122123

124+
private static final TensorType TENSOR_2_NONE_2_3_INT32 =
125+
new TensorType(INT_32, asList(new NumericDim(2), null, new NumericDim(2), new NumericDim(3)));
126+
127+
private static final TensorType TENSOR_2_NONE_2_2_INT32 =
128+
new TensorType(INT_32, asList(new NumericDim(2), null, new NumericDim(2), new NumericDim(2)));
129+
123130
@SuppressWarnings("unused")
124131
private static final TensorType TENSOR_2_NONE_NONE_NONE_INT32 =
125132
new TensorType(INT_32, asList(new NumericDim(2), null));
@@ -186,6 +193,11 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
186193
FLOAT_32,
187194
asList(new NumericDim(3), new NumericDim(2), new NumericDim(2), new NumericDim(3)));
188195

196+
private static final TensorType TENSOR_2_2_2_3_FLOAT32 =
197+
new TensorType(
198+
FLOAT_32,
199+
asList(new NumericDim(2), new NumericDim(2), new NumericDim(2), new NumericDim(3)));
200+
189201
private static final TensorType TENSOR_20_28_28_FLOAT32 =
190202
new TensorType(FLOAT_32, asList(new NumericDim(20), new NumericDim(28), new NumericDim(28)));
191203

@@ -227,7 +239,11 @@ public void testValueIndex()
227239
Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
228240
}
229241

230-
@Test
242+
/**
243+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
244+
* is fixed.
245+
*/
246+
@Test(expected = IllegalArgumentException.class)
231247
public void testValueIndex2()
232248
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
233249
test(
@@ -238,7 +254,11 @@ public void testValueIndex2()
238254
Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
239255
}
240256

241-
@Test
257+
/**
258+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
259+
* is fixed.
260+
*/
261+
@Test(expected = IllegalArgumentException.class)
242262
public void testValueIndex3()
243263
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
244264
test(
@@ -249,7 +269,11 @@ public void testValueIndex3()
249269
Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
250270
}
251271

252-
@Test
272+
/**
273+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
274+
* is fixed.
275+
*/
276+
@Test(expected = IllegalArgumentException.class)
253277
public void testValueIndex4()
254278
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
255279
test(
@@ -260,7 +284,11 @@ public void testValueIndex4()
260284
Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
261285
}
262286

263-
@Test
287+
/**
288+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
289+
* is fixed.
290+
*/
291+
@Test(expected = IllegalArgumentException.class)
264292
public void testFunction()
265293
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
266294
test("tf2_test_function.py", "func2", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
@@ -1617,25 +1645,45 @@ public void testAdd58()
16171645
@Test
16181646
public void testAdd59()
16191647
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1620-
test("tf2_test_add59.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1648+
test(
1649+
"tf2_test_add59.py",
1650+
"add",
1651+
2,
1652+
2,
1653+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
16211654
}
16221655

16231656
@Test
16241657
public void testAdd60()
16251658
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1626-
test("tf2_test_add60.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1659+
test(
1660+
"tf2_test_add60.py",
1661+
"add",
1662+
2,
1663+
2,
1664+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
16271665
}
16281666

16291667
@Test
16301668
public void testAdd61()
16311669
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1632-
test("tf2_test_add61.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1670+
test(
1671+
"tf2_test_add61.py",
1672+
"add",
1673+
2,
1674+
2,
1675+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
16331676
}
16341677

16351678
@Test
16361679
public void testAdd62()
16371680
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1638-
test("tf2_test_add62.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1681+
test(
1682+
"tf2_test_add62.py",
1683+
"add",
1684+
2,
1685+
2,
1686+
Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32)));
16391687
}
16401688

16411689
@Test
@@ -2151,16 +2199,24 @@ public void testReduceMean3()
21512199
@Test
21522200
public void testGradient()
21532201
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2154-
test("tf2_test_gradient.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
2202+
test("tf2_test_gradient.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_FLOAT32)));
21552203
}
21562204

2157-
@Test
2205+
/**
2206+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
2207+
* is fixed.
2208+
*/
2209+
@Test(expected = IllegalArgumentException.class)
21582210
public void testGradient2()
21592211
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
21602212
test("tf2_test_gradient2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
21612213
}
21622214

2163-
@Test
2215+
/**
2216+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
2217+
* is fixed.
2218+
*/
2219+
@Test(expected = IllegalArgumentException.class)
21642220
public void testMultiply()
21652221
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
21662222
test("tf2_test_multiply.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
@@ -2169,7 +2225,48 @@ public void testMultiply()
21692225
@Test
21702226
public void testMultiply2()
21712227
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2172-
test("tf2_test_multiply2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
2228+
test("tf2_test_multiply2.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32)));
2229+
}
2230+
2231+
@Test
2232+
public void testMultiply3()
2233+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2234+
test("tf2_test_multiply3.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32)));
2235+
}
2236+
2237+
@Test
2238+
public void testMultiply4()
2239+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2240+
test("tf2_test_multiply4.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32)));
2241+
}
2242+
2243+
@Test
2244+
public void testMultiply5()
2245+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2246+
test("tf2_test_multiply5.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_2_2_3_FLOAT32)));
2247+
}
2248+
2249+
/**
2250+
* This is an invalid case since the inputs have different ranks.
2251+
*
2252+
* <p>For now, we are throwing an exception. But, this is invalid code.
2253+
*
2254+
* <p>TODO: We'll need to come up with a suitable way to handle this in the future.
2255+
*/
2256+
@Test(expected = NonBroadcastableShapesException.class)
2257+
public void testMultiply6()
2258+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2259+
test("tf2_test_multiply6.py", "f", 1, 1);
2260+
}
2261+
2262+
/**
2263+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
2264+
* is fixed.
2265+
*/
2266+
@Test(expected = IllegalArgumentException.class)
2267+
public void testMultiply7()
2268+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2269+
test("tf2_test_multiply7.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32)));
21732270
}
21742271

21752272
@Test
@@ -4697,6 +4794,22 @@ public void testRaggedConstant16() throws ClassHierarchyException, CancelExcepti
46974794
test("tf2_test_ragged_constant16.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_INT32)));
46984795
}
46994796

4797+
/**
4798+
* Test non-uniform inner dimensions.
4799+
*
4800+
* <p>TODO: Remove expected assertion error once https://github.com/wala/ML/issues/350 is fixed.
4801+
*/
4802+
@Test(expected = AssertionError.class)
4803+
public void testRaggedConstant17() throws ClassHierarchyException, CancelException, IOException {
4804+
test("tf2_test_ragged_constant17.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_3_INT32)));
4805+
}
4806+
4807+
/** This one works because the inner dimensions are uniform. */
4808+
@Test
4809+
public void testRaggedConstant18() throws ClassHierarchyException, CancelException, IOException {
4810+
test("tf2_test_ragged_constant18.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_2_INT32)));
4811+
}
4812+
47004813
private void test(
47014814
String filename,
47024815
String functionName,

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Constant.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public Constant(PointsToSetVariable source) {
3232
protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder builder) {
3333
// If the shape argument is not specified, then the shape is inferred from the shape of value.
3434
// TODO: Handle keyword arguments.
35-
return getShapes(builder, this.getValueArgumentValueNumber());
35+
return this.getShapes(builder, this.getValueArgumentValueNumber());
3636
}
3737

3838
/**
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package com.ibm.wala.cast.python.ml.client;
2+
3+
import static com.ibm.wala.cast.python.ml.client.Multiply.Parameters.X;
4+
import static com.ibm.wala.cast.python.ml.client.Multiply.Parameters.Y;
5+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.TYPE_REFERENCE_TO_SIGNATURE;
6+
import static com.ibm.wala.cast.python.ml.util.TensorShapeUtil.areBroadcastable;
7+
import static com.ibm.wala.cast.python.ml.util.TensorShapeUtil.getBroadcastedShapes;
8+
import static com.ibm.wala.cast.python.util.Util.getFunction;
9+
import static java.util.logging.Logger.getLogger;
10+
11+
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
12+
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
13+
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
14+
import com.ibm.wala.types.TypeReference;
15+
import com.ibm.wala.util.collections.HashSetFactory;
16+
import java.util.List;
17+
import java.util.Set;
18+
import java.util.logging.Logger;
19+
20+
/**
21+
* A representation of a multiply operation in TensorFlow.
22+
*
23+
* @see <a href="https://www.tensorflow.org/api_docs/python/tf/multiply">tf.multiply</a>.
24+
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
25+
*/
26+
public class Multiply extends ZerosLike {
27+
28+
@SuppressWarnings("unused")
29+
private static final Logger logger = getLogger(Multiply.class.getName());
30+
31+
protected enum Parameters {
32+
X,
33+
Y,
34+
NAME
35+
}
36+
37+
/**
38+
* The dtype argument is not explicitly provided to multiply(); rather, the dtype is inferred from
39+
* the `x` argument.
40+
*
41+
* @see <a
42+
* href="https://www.tensorflow.org/api_docs/python/tf/math/multiply#returns">tf.math.multiply
43+
* - Returns</a>.
44+
*/
45+
protected static final int DTYPE_PARAMETER_POSITION = -1;
46+
47+
@Override
48+
protected int getDTypeParameterPosition() {
49+
return DTYPE_PARAMETER_POSITION;
50+
}
51+
52+
public Multiply(PointsToSetVariable source) {
53+
super(source);
54+
}
55+
56+
protected int getXParameterPosition() {
57+
return X.ordinal();
58+
}
59+
60+
protected int getXArgumentValueNumber(PropagationCallGraphBuilder builder) {
61+
// TODO: Handle keyword arguments.
62+
return this.getArgumentValueNumber(builder, this.getXParameterPosition());
63+
}
64+
65+
protected int getYParameterPosition() {
66+
return Y.ordinal();
67+
}
68+
69+
protected int getYArgumentValueNumber(PropagationCallGraphBuilder builder) {
70+
// TODO: Handle keyword arguments.
71+
return this.getArgumentValueNumber(builder, this.getYParameterPosition());
72+
}
73+
74+
/**
75+
* Returns the TensorFlow function signature represented by this generator.
76+
*
77+
* @return The TensorFlow function signature represented by this generator.
78+
*/
79+
@Override
80+
protected String getSignature() {
81+
TypeReference function = getFunction(this.getSource());
82+
return TYPE_REFERENCE_TO_SIGNATURE.get(function);
83+
}
84+
85+
@Override
86+
protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder builder) {
87+
// The resulting shape is the broadcasted shape of the shapes of x and y.
88+
Set<List<Dimension<?>>> ret = HashSetFactory.make();
89+
90+
Set<List<Dimension<?>>> xShapes =
91+
this.getShapes(builder, this.getXArgumentValueNumber(builder));
92+
Set<List<Dimension<?>>> yShapes =
93+
this.getShapes(builder, this.getYArgumentValueNumber(builder));
94+
95+
for (List<Dimension<?>> xShape : xShapes)
96+
for (List<Dimension<?>> yShape : yShapes)
97+
if (areBroadcastable(xShape, yShape)) ret.add(getBroadcastedShapes(xShape, yShape));
98+
else throw new NonBroadcastableShapesException(this, xShape, yShape);
99+
100+
return ret;
101+
}
102+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package com.ibm.wala.cast.python.ml.client;
2+
3+
import static java.lang.String.format;
4+
5+
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
6+
import java.util.List;
7+
8+
/**
9+
* An exception indicating that two shapes are not broadcastable for a given operation.
10+
*
11+
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
12+
* @see <a href="https://numpy.org/doc/stable/user/basics.broadcasting.html">NumPy Broadcasting</a>.
13+
*/
14+
public class NonBroadcastableShapesException extends RuntimeException {
15+
16+
/** Serial version UID. */
17+
private static final long serialVersionUID = 805036824027449575L;
18+
19+
/** The operation for which the shapes are not broadcastable. */
20+
private final transient Object op;
21+
22+
/** The first shape. */
23+
private final transient List<Dimension<?>> xShape;
24+
25+
/** The second shape. */
26+
private final transient List<Dimension<?>> yShape;
27+
28+
/**
29+
* Constructs a new exception indicating that the given shapes are not broadcastable for the given
30+
* operation.
31+
*
32+
* @param op The operation for which the shapes are not broadcastable.
33+
* @param xShape The first shape.
34+
* @param yShape The second shape.
35+
*/
36+
public NonBroadcastableShapesException(
37+
Object op, List<Dimension<?>> xShape, List<Dimension<?>> yShape) {
38+
this.op = op;
39+
this.xShape = xShape;
40+
this.yShape = yShape;
41+
}
42+
43+
@Override
44+
public String getMessage() {
45+
return format("The shapes %s and %s are not broadcastable for %s.", xShape, yShape, op);
46+
}
47+
}

0 commit comments

Comments
 (0)