Skip to content

Commit 3715148

Browse files
committed
Generalize Multiply to ElementWiseOperation to cover add, subtract, and divide operations as well.
1 parent c4d2be2 commit 3715148

File tree

4 files changed

+79
-18
lines changed

4 files changed

+79
-18
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,16 +1140,24 @@ public void testSigmoid2()
11401140
@Test
11411141
public void testAdd()
11421142
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1143-
test("tf2_test_add.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
1143+
test("tf2_test_add.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32)));
11441144
}
11451145

1146-
@Test
1146+
/**
1147+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
1148+
* is fixed.
1149+
*/
1150+
@Test(expected = IllegalArgumentException.class)
11471151
public void testAdd2()
11481152
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
11491153
test("tf2_test_add2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
11501154
}
11511155

1152-
@Test
1156+
/**
1157+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
1158+
* is fixed.
1159+
*/
1160+
@Test(expected = IllegalArgumentException.class)
11531161
public void testAdd3()
11541162
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
11551163
test("tf2_test_add3.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
@@ -1158,16 +1166,24 @@ public void testAdd3()
11581166
@Test
11591167
public void testAdd4()
11601168
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1161-
test("tf2_test_add4.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
1169+
test("tf2_test_add4.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32)));
11621170
}
11631171

1164-
@Test
1172+
/**
1173+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
1174+
* is fixed.
1175+
*/
1176+
@Test(expected = IllegalArgumentException.class)
11651177
public void testAdd5()
11661178
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
11671179
test("tf2_test_add5.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
11681180
}
11691181

1170-
@Test
1182+
/**
1183+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
1184+
* is fixed.
1185+
*/
1186+
@Test(expected = IllegalArgumentException.class)
11711187
public void testAdd6()
11721188
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
11731189
test("tf2_test_add6.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
@@ -1504,13 +1520,21 @@ public void testAdd42()
15041520
test("tf2_test_add42.py", "func2", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
15051521
}
15061522

1507-
@Test
1523+
/**
1524+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
1525+
* is fixed.
1526+
*/
1527+
@Test(expected = IllegalArgumentException.class)
15081528
public void testAdd43()
15091529
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
15101530
test("tf2_test_add43.py", "add", 2, 3, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
15111531
}
15121532

1513-
@Test
1533+
/**
1534+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
1535+
* is fixed.
1536+
*/
1537+
@Test(expected = IllegalArgumentException.class)
15141538
public void testAdd44()
15151539
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
15161540
test("tf2_test_add44.py", "add", 2, 3, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Multiply.java renamed to com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/ElementWiseOperation.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package com.ibm.wala.cast.python.ml.client;
22

3-
import static com.ibm.wala.cast.python.ml.client.Multiply.Parameters.X;
4-
import static com.ibm.wala.cast.python.ml.client.Multiply.Parameters.Y;
3+
import static com.ibm.wala.cast.python.ml.client.ElementWiseOperation.Parameters.X;
4+
import static com.ibm.wala.cast.python.ml.client.ElementWiseOperation.Parameters.Y;
55
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.TYPE_REFERENCE_TO_SIGNATURE;
66
import static com.ibm.wala.cast.python.ml.util.TensorShapeUtil.areBroadcastable;
77
import static com.ibm.wala.cast.python.ml.util.TensorShapeUtil.getBroadcastedShapes;
@@ -18,15 +18,15 @@
1818
import java.util.logging.Logger;
1919

2020
/**
21-
* A representation of a multiply operation in TensorFlow.
21+
* A representation of an element-wise operation in TensorFlow.
2222
*
2323
* @see <a href="https://www.tensorflow.org/api_docs/python/tf/multiply">tf.multiply</a>.
2424
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
2525
*/
26-
public class Multiply extends ZerosLike {
26+
public class ElementWiseOperation extends ZerosLike {
2727

2828
@SuppressWarnings("unused")
29-
private static final Logger logger = getLogger(Multiply.class.getName());
29+
private static final Logger logger = getLogger(ElementWiseOperation.class.getName());
3030

3131
protected enum Parameters {
3232
X,
@@ -35,8 +35,8 @@ protected enum Parameters {
3535
}
3636

3737
/**
38-
* The dtype argument is not explicitly provided to multiply(); rather, the dtype is inferred from
39-
* the `x` argument.
38+
* The dtype argument is not explicitly provided to element-wise operations; rather, the dtype is
39+
* inferred from the `x` argument.
4040
*
4141
* @see <a
4242
* href="https://www.tensorflow.org/api_docs/python/tf/math/multiply#returns">tf.math.multiply
@@ -49,7 +49,7 @@ protected int getDTypeParameterPosition() {
4949
return DTYPE_PARAMETER_POSITION;
5050
}
5151

52-
public Multiply(PointsToSetVariable source) {
52+
public ElementWiseOperation(PointsToSetVariable source) {
5353
super(source);
5454
}
5555

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,14 @@ else if (calledFunction.equals(CONVERT_TO_TENSOR.getDeclaringClass()))
5555
else if (calledFunction.equals(POISSON.getDeclaringClass())) return new Poisson(source);
5656
else if (calledFunction.equals(RAGGED_CONSTANT.getDeclaringClass()))
5757
return new RaggedConstant(source);
58-
else if (calledFunction.equals(MULTIPLY.getDeclaringClass())) return new Multiply(source);
58+
else if (calledFunction.equals(MULTIPLY.getDeclaringClass())
59+
|| calledFunction.equals(
60+
com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ADD.getDeclaringClass())
61+
|| calledFunction.equals(
62+
com.ibm.wala.cast.python.ml.types.TensorFlowTypes.SUBTRACT.getDeclaringClass())
63+
|| calledFunction.equals(
64+
com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DIVIDE.getDeclaringClass()))
65+
return new ElementWiseOperation(source);
5966
else
6067
throw new IllegalArgumentException(
6168
"Unknown call: " + calledFunction + " for source: " + source + ".");

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,43 @@ public boolean canConvertTo(DType other) {
197197

198198
private static final String MULTIPLY_SIGNATURE = "tf.multiply()";
199199

200+
public static final MethodReference ADD =
201+
MethodReference.findOrCreate(
202+
TypeReference.findOrCreate(
203+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/math/add")),
204+
AstMethodReference.fnSelector);
205+
206+
private static final String ADD_SIGNATURE = "tf.add()";
207+
208+
public static final MethodReference SUBTRACT =
209+
MethodReference.findOrCreate(
210+
TypeReference.findOrCreate(
211+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/math/subtract")),
212+
AstMethodReference.fnSelector);
213+
214+
private static final String SUBTRACT_SIGNATURE = "tf.subtract()";
215+
216+
public static final MethodReference DIVIDE =
217+
MethodReference.findOrCreate(
218+
TypeReference.findOrCreate(
219+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/math/divide")),
220+
AstMethodReference.fnSelector);
221+
222+
private static final String DIVIDE_SIGNATURE = "tf.divide()";
223+
200224
/** A mapping from a {@link TypeReference} to its associated TensorFlow signature. */
201225
public static final Map<TypeReference, String> TYPE_REFERENCE_TO_SIGNATURE =
202226
Map.of(
203227
RAGGED_CONSTANT.getDeclaringClass(),
204228
RAGGED_CONSTANT_SIGNATURE,
205229
MULTIPLY.getDeclaringClass(),
206-
MULTIPLY_SIGNATURE);
230+
MULTIPLY_SIGNATURE,
231+
ADD.getDeclaringClass(),
232+
ADD_SIGNATURE,
233+
SUBTRACT.getDeclaringClass(),
234+
SUBTRACT_SIGNATURE,
235+
DIVIDE.getDeclaringClass(),
236+
DIVIDE_SIGNATURE);
207237

208238
/**
209239
* Represents the TensorFlow float32 data type.

0 commit comments

Comments
 (0)