Skip to content

Commit a48bc7d

Browse files
committed
Handle the multiplication of tensors with broadcasting in TensorFlow.
1 parent a2ba663 commit a48bc7d

File tree

12 files changed

+429
-3
lines changed

12 files changed

+429
-3
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.ibm.wala.cast.python.ipa.callgraph.PythonSSAPropagationCallGraphBuilder;
2222
import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis;
2323
import com.ibm.wala.cast.python.ml.analysis.TensorVariable;
24+
import com.ibm.wala.cast.python.ml.client.NonBroadcastableShapesException;
2425
import com.ibm.wala.cast.python.ml.client.PythonTensorAnalysisEngine;
2526
import com.ibm.wala.cast.python.ml.types.TensorType;
2627
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
@@ -192,6 +193,11 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
192193
FLOAT_32,
193194
asList(new NumericDim(3), new NumericDim(2), new NumericDim(2), new NumericDim(3)));
194195

196+
private static final TensorType TENSOR_2_2_2_3_FLOAT32 =
197+
new TensorType(
198+
FLOAT_32,
199+
asList(new NumericDim(2), new NumericDim(2), new NumericDim(2), new NumericDim(3)));
200+
195201
private static final TensorType TENSOR_20_28_28_FLOAT32 =
196202
new TensorType(FLOAT_32, asList(new NumericDim(20), new NumericDim(28), new NumericDim(28)));
197203

@@ -2215,7 +2221,48 @@ public void testMultiply()
22152221
@Test
22162222
public void testMultiply2()
22172223
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2218-
test("tf2_test_multiply2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
2224+
test("tf2_test_multiply2.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32)));
2225+
}
2226+
2227+
@Test
2228+
public void testMultiply3()
2229+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2230+
test("tf2_test_multiply3.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32)));
2231+
}
2232+
2233+
@Test
2234+
public void testMultiply4()
2235+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2236+
test("tf2_test_multiply4.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32)));
2237+
}
2238+
2239+
@Test
2240+
public void testMultiply5()
2241+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2242+
test("tf2_test_multiply5.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_2_2_3_FLOAT32)));
2243+
}
2244+
2245+
/**
2246+
* This is an invalid case since the inputs have different ranks.
2247+
*
2248+
* <p>For now, we are throwing an exception. But, this is invalid code.
2249+
*
2250+
* <p>TODO: We'll need to come up with a suitable way to handle this in the future.
2251+
*/
2252+
@Test(expected = NonBroadcastableShapesException.class)
2253+
public void testMultiply6()
2254+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2255+
test("tf2_test_multiply6.py", "f", 1, 1);
2256+
}
2257+
2258+
/**
2259+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
2260+
* is fixed.
2261+
*/
2262+
@Test(expected = IllegalArgumentException.class)
2263+
public void testMultiply7()
2264+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2265+
test("tf2_test_multiply7.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32)));
22192266
}
22202267

22212268
@Test
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package com.ibm.wala.cast.python.ml.client;
2+
3+
import static com.ibm.wala.cast.python.ml.client.Multiply.Parameters.X;
4+
import static com.ibm.wala.cast.python.ml.client.Multiply.Parameters.Y;
5+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.TYPE_REFERENCE_TO_SIGNATURE;
6+
import static com.ibm.wala.cast.python.ml.util.TensorShapeUtil.areBroadcastable;
7+
import static com.ibm.wala.cast.python.ml.util.TensorShapeUtil.getBroadcastedShapes;
8+
import static com.ibm.wala.cast.python.util.Util.getFunction;
9+
import static java.util.logging.Logger.getLogger;
10+
11+
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
12+
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
13+
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
14+
import com.ibm.wala.types.TypeReference;
15+
import com.ibm.wala.util.collections.HashSetFactory;
16+
import java.util.List;
17+
import java.util.Set;
18+
import java.util.logging.Logger;
19+
20+
/**
21+
* A representation of a multiply operation in TensorFlow.
22+
*
23+
* @see <a href="https://www.tensorflow.org/api_docs/python/tf/multiply">tf.multiply</a>.
24+
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
25+
*/
26+
public class Multiply extends ZerosLike {
27+
28+
@SuppressWarnings("unused")
29+
private static final Logger logger = getLogger(Multiply.class.getName());
30+
31+
protected enum Parameters {
32+
X,
33+
Y,
34+
NAME
35+
}
36+
37+
/**
38+
* The dtype argument is not explicitly provided to multiply(); rather, the dtype is inferred from
39+
* the `x` argument.
40+
*
41+
* @see <a
42+
* href="https://www.tensorflow.org/api_docs/python/tf/math/multiply#returns">tf.math.multiply
43+
* - Returns</a>.
44+
*/
45+
protected static final int DTYPE_PARAMETER_POSITION = -1;
46+
47+
@Override
48+
protected int getDTypeParameterPosition() {
49+
return DTYPE_PARAMETER_POSITION;
50+
}
51+
52+
public Multiply(PointsToSetVariable source) {
53+
super(source);
54+
}
55+
56+
protected int getXParameterPosition() {
57+
return X.ordinal();
58+
}
59+
60+
protected int getXArgumentValueNumber(PropagationCallGraphBuilder builder) {
61+
// TODO: Handle keyword arguments.
62+
return this.getArgumentValueNumber(builder, this.getXParameterPosition());
63+
}
64+
65+
protected int getYParameterPosition() {
66+
return Y.ordinal();
67+
}
68+
69+
protected int getYArgumentValueNumber(PropagationCallGraphBuilder builder) {
70+
// TODO: Handle keyword arguments.
71+
return this.getArgumentValueNumber(builder, this.getYParameterPosition());
72+
}
73+
74+
/**
75+
* Returns the TensorFlow function signature represented by this generator.
76+
*
77+
* @return The TensorFlow function signature represented by this generator.
78+
*/
79+
@Override
80+
protected String getSignature() {
81+
TypeReference function = getFunction(this.getSource());
82+
return TYPE_REFERENCE_TO_SIGNATURE.get(function);
83+
}
84+
85+
@Override
86+
protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder builder) {
87+
// The resulting shape is the broadcasted shape of the shapes of x and y.
88+
Set<List<Dimension<?>>> ret = HashSetFactory.make();
89+
90+
Set<List<Dimension<?>>> xShapes =
91+
this.getShapes(builder, this.getXArgumentValueNumber(builder));
92+
Set<List<Dimension<?>>> yShapes =
93+
this.getShapes(builder, this.getYArgumentValueNumber(builder));
94+
95+
for (List<Dimension<?>> xShape : xShapes)
96+
for (List<Dimension<?>> yShape : yShapes)
97+
if (areBroadcastable(xShape, yShape)) ret.add(getBroadcastedShapes(xShape, yShape));
98+
else throw new NonBroadcastableShapesException(this, xShape, yShape);
99+
100+
return ret;
101+
}
102+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package com.ibm.wala.cast.python.ml.client;
2+
3+
import static java.lang.String.format;
4+
5+
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
6+
import java.util.List;
7+
8+
/**
9+
* An exception indicating that two shapes are not broadcastable for a given operation.
10+
*
11+
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
12+
* @see <a href="https://numpy.org/doc/stable/user/basics.broadcasting.html">NumPy Broadcasting</a>.
13+
*/
14+
public class NonBroadcastableShapesException extends RuntimeException {
15+
16+
/** Serial version UID. */
17+
private static final long serialVersionUID = 805036824027449575L;
18+
19+
/** The operation for which the shapes are not broadcastable. */
20+
private final transient Object op;
21+
22+
/** The first shape. */
23+
private final transient List<Dimension<?>> xShape;
24+
25+
/** The second shape. */
26+
private final transient List<Dimension<?>> yShape;
27+
28+
/**
29+
* Constructs a new exception indicating that the given shapes are not broadcastable for the given
30+
* operation.
31+
*
32+
* @param op The operation for which the shapes are not broadcastable.
33+
* @param xShape The first shape.
34+
* @param yShape The second shape.
35+
*/
36+
public NonBroadcastableShapesException(
37+
Object op, List<Dimension<?>> xShape, List<Dimension<?>> yShape) {
38+
this.op = op;
39+
this.xShape = xShape;
40+
this.yShape = yShape;
41+
}
42+
43+
@Override
44+
public String getMessage() {
45+
return format("The shapes %s and %s are not broadcastable for %s.", xShape, yShape, op);
46+
}
47+
}

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.EYE;
66
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.FILL;
77
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.GAMMA;
8+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.MULTIPLY;
89
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.NORMAL;
910
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONES;
1011
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONE_HOT;
@@ -54,6 +55,7 @@ else if (calledFunction.equals(CONVERT_TO_TENSOR.getDeclaringClass()))
5455
else if (calledFunction.equals(POISSON.getDeclaringClass())) return new Poisson(source);
5556
else if (calledFunction.equals(RAGGED_CONSTANT.getDeclaringClass()))
5657
return new RaggedConstant(source);
58+
else if (calledFunction.equals(MULTIPLY.getDeclaringClass())) return new Multiply(source);
5759
else
5860
throw new IllegalArgumentException(
5961
"Unknown call: " + calledFunction + " for source: " + source + ".");

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,21 @@ public boolean canConvertTo(DType other) {
189189

190190
private static final String RAGGED_CONSTANT_SIGNATURE = "tf.ragged.constant()";
191191

192+
public static final MethodReference MULTIPLY =
193+
MethodReference.findOrCreate(
194+
TypeReference.findOrCreate(
195+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/math/multiply")),
196+
AstMethodReference.fnSelector);
197+
198+
private static final String MULTIPLY_SIGNATURE = "tf.multiply()";
199+
192200
/** A mapping from a {@link TypeReference} to its associated TensorFlow signature. */
193201
public static final Map<TypeReference, String> TYPE_REFERENCE_TO_SIGNATURE =
194-
Map.of(RAGGED_CONSTANT.getDeclaringClass(), RAGGED_CONSTANT_SIGNATURE);
202+
Map.of(
203+
RAGGED_CONSTANT.getDeclaringClass(),
204+
RAGGED_CONSTANT_SIGNATURE,
205+
MULTIPLY.getDeclaringClass(),
206+
MULTIPLY_SIGNATURE);
195207

196208
/**
197209
* Represents the TensorFlow float32 data type.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package com.ibm.wala.cast.python.ml.util;
2+
3+
import static java.lang.Math.max;
4+
5+
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
6+
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
7+
import java.util.List;
8+
9+
public class TensorShapeUtil {
10+
11+
public static boolean areBroadcastable(List<Dimension<?>> xShape, List<Dimension<?>> yShape) {
12+
int xRank = xShape.size();
13+
int yRank = yShape.size();
14+
int maxRank = max(xRank, yRank);
15+
16+
for (int i = 0; i < maxRank; i++) {
17+
Dimension<?> xDim = i < (maxRank - xRank) ? null : xShape.get(i - (maxRank - xRank));
18+
Dimension<?> yDim = i < (maxRank - yRank) ? null : yShape.get(i - (maxRank - yRank));
19+
20+
if (xDim == null || yDim == null) {
21+
continue; // One of the dimensions is missing, treat as size 1
22+
}
23+
24+
if (xDim instanceof NumericDim && yDim instanceof NumericDim) {
25+
int xSize = ((NumericDim) xDim).value();
26+
int ySize = ((NumericDim) yDim).value();
27+
28+
if (xSize != ySize && xSize != 1 && ySize != 1) return false; // Incompatible sizes
29+
} else return false; // Non-numeric dimensions are incompatible
30+
}
31+
32+
return true; // All dimensions are compatible
33+
}
34+
35+
public static List<Dimension<?>> getBroadcastedShapes(
36+
List<Dimension<?>> xShape, List<Dimension<?>> yShape) {
37+
List<Dimension<?>> ret = new java.util.ArrayList<>();
38+
39+
int xRank = xShape.size();
40+
int yRank = yShape.size();
41+
int maxRank = max(xRank, yRank);
42+
43+
for (int i = 0; i < maxRank; i++) {
44+
Dimension<?> xDim = i < (maxRank - xRank) ? null : xShape.get(i - (maxRank - xRank));
45+
Dimension<?> yDim = i < (maxRank - yRank) ? null : yShape.get(i - (maxRank - yRank));
46+
47+
if (xDim == null) ret.add(yDim);
48+
else if (yDim == null) ret.add(xDim);
49+
else if (xDim instanceof NumericDim && yDim instanceof NumericDim) {
50+
int xSize = ((NumericDim) xDim).value();
51+
int ySize = ((NumericDim) yDim).value();
52+
53+
if (xSize == ySize) ret.add(xDim); // Both sizes are equal
54+
else if (xSize == 1) ret.add(yDim); // x is broadcasted
55+
else if (ySize == 1) ret.add(xDim); // y is broadcasted
56+
else throw new IllegalArgumentException("Incompatible dimensions for broadcasting.");
57+
} else throw new IllegalArgumentException("Non-numeric dimensions cannot be broadcasted.");
58+
}
59+
60+
return ret;
61+
}
62+
}

com.ibm.wala.cast.python.test/data/tf2_test_multiply2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,8 @@ def f(a):
77
pass
88

99

10-
f(tf.math.multiply(7, 6))
10+
arg = tf.math.multiply(7, 6)
11+
assert arg.shape == ()
12+
assert arg.dtype == tf.int32
13+
14+
f(arg)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/math/multiply#for_example/
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
# Shape: (2, 3)
11+
matrix = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
12+
assert len(matrix) == 2 and len(matrix[0]) == 3 # Confirming shape (2, 3)
13+
assert matrix[0][0].__class__ == float # Confirming dtype float32
14+
15+
# Shape: (1,) -> Broadcasts to (2, 3)
16+
scalar = [10.0]
17+
assert len(scalar) == 1 # Confirming shape (1,)
18+
assert scalar[0].__class__ == float # Confirming dtype float32
19+
20+
# 1. Scalar Multiplication
21+
result_scalar = tf.multiply(matrix, scalar)
22+
assert result_scalar.shape == (2, 3)
23+
assert result_scalar.dtype == tf.float32
24+
25+
f(result_scalar)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/math/multiply#for_example/
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
# Shape: (2, 3)
11+
matrix = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
12+
assert len(matrix) == 2 and len(matrix[0]) == 3 # Confirming shape (2, 3)
13+
assert matrix[0][0].__class__ == float # Confirming dtype float32
14+
15+
# Shape: (2, 1) -> Broadcasts columns to match matrix width (3)
16+
col_vector = [[2.0], [3.0]]
17+
assert len(col_vector) == 2 and len(col_vector[0]) == 1 # Confirming shape (2, 1)
18+
assert col_vector[0][0].__class__ == float # Confirming dtype float32
19+
20+
# 2. Column Vector Multiplication
21+
# [1, 2, 3] * 2 = [2, 4, 6]
22+
# [4, 5, 6] * 3 = [12, 15, 18]
23+
result_col = tf.multiply(matrix, col_vector)
24+
assert result_col.shape == (2, 3)
25+
assert result_col.dtype == tf.float32
26+
27+
f(result_col)

0 commit comments

Comments
 (0)