Skip to content

Commit 2755447

Browse files
committed
Handle tf.random.gamma.
1 parent 6576292 commit 2755447

File tree

12 files changed

+375
-5
lines changed

12 files changed

+375
-5
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
101101
private static final TensorType TENSOR_2_1_FLOAT32 =
102102
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(1)));
103103

104+
private static final TensorType TENSOR_10_2_FLOAT32 =
105+
new TensorType(FLOAT_32, asList(new NumericDim(10), new NumericDim(2)));
106+
107+
private static final TensorType TENSOR_10_2_FLOAT64 =
108+
new TensorType(FLOAT_64, asList(new NumericDim(10), new NumericDim(2)));
109+
104110
private static final TensorType TENSOR_2_3_3_INT32 =
105111
new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(3)));
106112

@@ -113,6 +119,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
113119
private static final TensorType TENSOR_3_2_2_FLOAT32 =
114120
new TensorType(FLOAT_32, asList(new NumericDim(3), new NumericDim(2), new NumericDim(2)));
115121

122+
private static final TensorType TENSOR_7_5_2_FLOAT32 =
123+
new TensorType(FLOAT_32, asList(new NumericDim(7), new NumericDim(5), new NumericDim(2)));
124+
125+
private static final TensorType TENSOR_30_3_2_FLOAT32 =
126+
new TensorType(FLOAT_32, asList(new NumericDim(30), new NumericDim(3), new NumericDim(2)));
127+
116128
private static final TensorType TENSOR_3_2_2_3_FLOAT32 =
117129
new TensorType(
118130
FLOAT_32,
@@ -1795,25 +1807,45 @@ public void testAdd99()
17951807
@Test
17961808
public void testAdd100()
17971809
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1798-
test("tf2_test_add100.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1810+
test(
1811+
"tf2_test_add100.py",
1812+
"add",
1813+
2,
1814+
2,
1815+
Map.of(2, Set.of(TENSOR_10_2_FLOAT32), 3, Set.of(TENSOR_10_2_FLOAT32)));
17991816
}
18001817

18011818
@Test
18021819
public void testAdd101()
18031820
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1804-
test("tf2_test_add101.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1821+
test(
1822+
"tf2_test_add101.py",
1823+
"add",
1824+
2,
1825+
2,
1826+
Map.of(2, Set.of(TENSOR_10_2_FLOAT32), 3, Set.of(TENSOR_10_2_FLOAT32)));
18051827
}
18061828

18071829
@Test
18081830
public void testAdd102()
18091831
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1810-
test("tf2_test_add102.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1832+
test(
1833+
"tf2_test_add102.py",
1834+
"add",
1835+
2,
1836+
2,
1837+
Map.of(2, Set.of(TENSOR_10_2_FLOAT32), 3, Set.of(TENSOR_10_2_FLOAT32)));
18111838
}
18121839

18131840
@Test
18141841
public void testAdd103()
18151842
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1816-
test("tf2_test_add103.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1843+
test(
1844+
"tf2_test_add103.py",
1845+
"add",
1846+
2,
1847+
2,
1848+
Map.of(2, Set.of(TENSOR_10_2_FLOAT32), 3, Set.of(TENSOR_10_2_FLOAT32)));
18171849
}
18181850

18191851
@Test
@@ -4398,6 +4430,47 @@ public void testEye6()
43984430
test("tf2_test_eye6.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_2_2_3_FLOAT32)));
43994431
}
44004432

4433+
@Test
4434+
public void testGamma()
4435+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4436+
test("tf2_test_gamma.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_10_2_FLOAT32)));
4437+
}
4438+
4439+
@Test
4440+
public void testGamma2()
4441+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4442+
test("tf2_test_gamma2.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_10_2_FLOAT64)));
4443+
}
4444+
4445+
@Test
4446+
public void testGamma3()
4447+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4448+
test("tf2_test_gamma3.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_7_5_2_FLOAT32)));
4449+
}
4450+
4451+
/** FIXME: Handle keyword arguments properly so that this test passes. */
4452+
@Test(expected = IllegalStateException.class)
4453+
public void testGamma4()
4454+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4455+
test("tf2_test_gamma4.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_30_3_2_FLOAT32)));
4456+
}
4457+
4458+
/**
4459+
* FIXME: Should not throw an {@link IllegalArgumentException} once
4460+
* https://github.com/wala/ML/issues/340 is fixed.
4461+
*/
4462+
@Test(expected = IllegalArgumentException.class)
4463+
public void testGamma5()
4464+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4465+
test("tf2_test_gamma5.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_30_3_2_FLOAT32)));
4466+
}
4467+
4468+
@Test
4469+
public void testGamma6()
4470+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4471+
test("tf2_test_gamma6.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_30_3_2_FLOAT32)));
4472+
}
4473+
44014474
private void test(
44024475
String filename,
44034476
String functionName,
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package com.ibm.wala.cast.python.ml.client;
2+
3+
import static com.ibm.wala.cast.python.ml.client.Gamma.Parameters.ALPHA;
4+
import static com.ibm.wala.cast.python.ml.client.Gamma.Parameters.BETA;
5+
import static com.ibm.wala.cast.python.ml.client.Gamma.Parameters.DTYPE;
6+
7+
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
8+
import com.ibm.wala.ipa.callgraph.CGNode;
9+
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
10+
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
11+
import com.ibm.wala.util.collections.HashSetFactory;
12+
import java.util.ArrayList;
13+
import java.util.List;
14+
import java.util.Set;
15+
16+
/**
17+
* A representation of the `tf.random.gamma` API in TensorFlow.
18+
*
19+
* @see <a href="https://www.tensorflow.org/api_docs/python/tf/random/gamma">tf.random.gamma</a>
20+
* API.
21+
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
22+
*/
23+
public class Gamma extends Ones {
24+
25+
private static final String FUNCTION_NAME = "tf.random.gamma()";
26+
27+
enum Parameters {
28+
SHAPE,
29+
ALPHA,
30+
BETA,
31+
DTYPE,
32+
SEED,
33+
NAME
34+
}
35+
36+
public Gamma(PointsToSetVariable source, CGNode node) {
37+
super(source, node);
38+
}
39+
40+
@Override
41+
protected int getDTypeParameterPosition() {
42+
return DTYPE.ordinal();
43+
}
44+
45+
protected int getAlphaParameterPosition() {
46+
return ALPHA.ordinal();
47+
}
48+
49+
protected int getBetaParameterPosition() {
50+
return BETA.ordinal();
51+
}
52+
53+
protected int getAlphaParameterValueNumber(PropagationCallGraphBuilder builder) {
54+
Set<Integer> numberOfPossiblePositionalArguments =
55+
this.getNumberOfPossiblePositionalArguments(builder);
56+
int alphaParameterPosition = this.getAlphaParameterPosition();
57+
58+
if (!numberOfPossiblePositionalArguments.stream()
59+
.anyMatch(n -> n >= alphaParameterPosition + 1))
60+
throw new IllegalStateException(
61+
"Alpha parameter is mandatory and must be provided explicitly.");
62+
63+
return this.getArgumentValueNumber(alphaParameterPosition);
64+
}
65+
66+
protected int getBetaParameterValueNumber(PropagationCallGraphBuilder builder) {
67+
Set<Integer> numberOfPossiblePositionalArguments =
68+
this.getNumberOfPossiblePositionalArguments(builder);
69+
int betaParameterPosition = this.getBetaParameterPosition();
70+
71+
if (!numberOfPossiblePositionalArguments.stream().anyMatch(n -> n >= betaParameterPosition + 1))
72+
return -1; // Beta parameter is optional.
73+
74+
return this.getArgumentValueNumber(betaParameterPosition);
75+
}
76+
77+
@Override
78+
protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder) {
79+
Set<List<Dimension<?>>> ret = HashSetFactory.make();
80+
Set<List<Dimension<?>>> shapes = super.getShapes(builder);
81+
82+
// Get the shape of the alpha parameter.
83+
Set<List<Dimension<?>>> alphaShapes =
84+
this.getShapes(builder, this.getAlphaParameterValueNumber(builder));
85+
86+
// If there is no beta parameter.
87+
if (this.getBetaParameterValueNumber(builder) < 0)
88+
// return shape `tf.concat([shape, tf.shape(alpha)], axis=0)`.
89+
shapes.forEach(
90+
shape -> {
91+
alphaShapes.forEach(
92+
alphaShape -> {
93+
List<Dimension<?>> newShape = new ArrayList<>(shape);
94+
newShape.addAll(alphaShape);
95+
ret.add(newShape);
96+
});
97+
});
98+
else { // There is a beta parameter.
99+
// return shape `tf.concat([shape, tf.shape(alpha + beta)], axis=0)`.
100+
shapes.forEach(
101+
shape -> {
102+
// Get the shape of the beta parameter, which is optional.
103+
Set<List<Dimension<?>>> betaShapes =
104+
this.getShapes(builder, this.getBetaParameterValueNumber(builder));
105+
106+
alphaShapes.forEach(
107+
aShape -> {
108+
betaShapes.forEach(
109+
bShape -> {
110+
List<Dimension<?>> newShape = new ArrayList<>(shape);
111+
// Here we assume that alphaShape and betaShape are compatible for
112+
// broadcasting.
113+
// In a complete implementation, we would need to handle broadcasting rules
114+
// properly.
115+
int maxLength = Math.max(aShape.size(), bShape.size());
116+
117+
for (int i = 0; i < maxLength; i++) {
118+
Dimension<?> dim;
119+
120+
if (i < aShape.size() && i < bShape.size())
121+
// Both shapes have this dimension, take the maximum.
122+
dim = Dimension.max(aShape.get(i), bShape.get(i));
123+
else if (i < aShape.size())
124+
// Only alpha shape has this dimension.
125+
dim = aShape.get(i);
126+
else
127+
// Only beta shape has this dimension.
128+
dim = bShape.get(i);
129+
130+
newShape.add(dim);
131+
}
132+
133+
ret.add(newShape);
134+
});
135+
});
136+
});
137+
}
138+
139+
return ret;
140+
}
141+
142+
@Override
143+
protected String getSignature() {
144+
return FUNCTION_NAME;
145+
}
146+
}

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.CONVERT_TO_TENSOR;
55
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.EYE;
66
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.FILL;
7+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.GAMMA;
78
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.NORMAL;
89
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONES;
910
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONE_HOT;
@@ -54,6 +55,7 @@ else if (calledFunction.equals(CONVERT_TO_TENSOR.getDeclaringClass()))
5455
return new ConvertToTensor(source, node);
5556
else if (calledFunction.equals(ONE_HOT.getDeclaringClass())) return new OneHot(source, node);
5657
else if (calledFunction.equals(EYE.getDeclaringClass())) return new Eye(source, node);
58+
else if (calledFunction.equals(GAMMA.getDeclaringClass())) return new Gamma(source, node);
5759
else
5860
throw new IllegalArgumentException(
5961
"Unknown call: " + calledFunction + " for source: " + source + ".");

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ public boolean canConvertTo(DType other) {
157157
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/eye")),
158158
AstMethodReference.fnSelector);
159159

160+
/** https://www.tensorflow.org/api_docs/python/tf/gamma. */
161+
public static final MethodReference GAMMA =
162+
MethodReference.findOrCreate(
163+
TypeReference.findOrCreate(
164+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/gamma")),
165+
AstMethodReference.fnSelector);
166+
160167
/**
161168
* Represents the TensorFlow float32 data type.
162169
*

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorType.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ public boolean equals(Object obj) {
110110
} else if (!v.equals(other.v)) return false;
111111
return true;
112112
}
113+
114+
public static Dimension<?> max(Dimension<?> d1, Dimension<?> d2) {
115+
if (d1 instanceof NumericDim && d2 instanceof NumericDim) {
116+
Integer v1 = ((NumericDim) d1).value();
117+
Integer v2 = ((NumericDim) d2).value();
118+
119+
return new NumericDim(Math.max(v1, v2));
120+
} else
121+
throw new IllegalArgumentException(
122+
"Cannot compute max of non-numeric dimensions: " + d1 + ", " + d2);
123+
}
113124
}
114125

115126
public static class SymbolicDim extends Dimension<String> {

com.ibm.wala.cast.python.test/data/tf2_test_add100.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,14 @@ def add(a, b):
55
return a + b
66

77

8-
c = add(tf.random.gamma([10], [0.5, 1.5]), tf.random.gamma([10], [1, 2.5]))
8+
a = tf.random.gamma([10], [0.5, 1.5])
9+
assert isinstance(a, tf.Tensor)
10+
assert a.shape == (10, 2)
11+
assert a.dtype == tf.float32
12+
13+
b = tf.random.gamma([10], [1, 2.5])
14+
assert isinstance(a, tf.Tensor)
15+
assert b.shape == (10, 2)
16+
assert a.dtype == tf.float32
17+
18+
c = add(a, b)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
a = [0.5, 1.5]
9+
assert isinstance(a, list)
10+
assert len(a) == 2
11+
assert all(isinstance(x, float) for x in a)
12+
assert tf.shape(a) == (2,)
13+
14+
samples = tf.random.gamma([10], a)
15+
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
16+
# the samples drawn from each distribution
17+
assert isinstance(samples, tf.Tensor)
18+
assert samples.shape == (10, 2)
19+
assert samples.dtype == tf.float32
20+
21+
f(samples)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
a = [0.5, 1.5]
9+
assert isinstance(a, list)
10+
assert len(a) == 2
11+
assert all(isinstance(x, float) for x in a)
12+
assert tf.shape(a) == (2,)
13+
14+
samples = tf.random.gamma([10], a, None, tf.float64)
15+
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
16+
# the samples drawn from each distribution
17+
assert isinstance(samples, tf.Tensor)
18+
assert samples.shape == (10, 2)
19+
assert samples.dtype == tf.float64
20+
21+
f(samples)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
samples = tf.random.gamma([7, 5], [0.5, 1.5])
9+
# samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
10+
# represents the 7x5 samples drawn from each of the two distributions
11+
12+
assert isinstance(samples, tf.Tensor)
13+
assert samples.shape == (7, 5, 2)
14+
assert samples.dtype == tf.float32
15+
16+
f(samples)

0 commit comments

Comments
 (0)