Skip to content

Commit 784fd30

Browse files
committed
Handle tf.random.poisson API in TensorFlow.
1 parent 521fcc4 commit 784fd30

File tree

8 files changed

+191
-2
lines changed

8 files changed

+191
-2
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,13 +1884,23 @@ public void testAdd106()
18841884
@Test
18851885
public void testAdd107()
18861886
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1887-
test("tf2_test_add107.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1887+
test(
1888+
"tf2_test_add107.py",
1889+
"add",
1890+
2,
1891+
2,
1892+
Map.of(2, Set.of(TENSOR_10_2_FLOAT32), 3, Set.of(TENSOR_10_2_FLOAT32)));
18881893
}
18891894

18901895
@Test
18911896
public void testAdd108()
18921897
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1893-
test("tf2_test_add108.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1898+
test(
1899+
"tf2_test_add108.py",
1900+
"add",
1901+
2,
1902+
2,
1903+
Map.of(2, Set.of(TENSOR_10_2_FLOAT32), 3, Set.of(TENSOR_10_2_FLOAT32)));
18941904
}
18951905

18961906
@Test
@@ -4471,6 +4481,30 @@ public void testGamma6()
44714481
test("tf2_test_gamma6.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_30_3_2_FLOAT32)));
44724482
}
44734483

4484+
@Test
4485+
public void testPoisson()
4486+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4487+
test("tf2_test_poisson.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_10_2_FLOAT32)));
4488+
}
4489+
4490+
@Test
4491+
public void testPoisson2()
4492+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4493+
test("tf2_test_poisson2.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_10_2_FLOAT32)));
4494+
}
4495+
4496+
@Test
4497+
public void testPoisson3()
4498+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4499+
test("tf2_test_poisson3.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_10_2_FLOAT64)));
4500+
}
4501+
4502+
@Test
4503+
public void testPoisson4()
4504+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
4505+
test("tf2_test_poisson4.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_7_5_2_FLOAT32)));
4506+
}
4507+
44744508
private void test(
44754509
String filename,
44764510
String functionName,
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package com.ibm.wala.cast.python.ml.client;
2+
3+
import static com.ibm.wala.cast.python.ml.client.Poisson.Parameters.DTYPE;
4+
import static com.ibm.wala.cast.python.ml.client.Poisson.Parameters.LAM;
5+
6+
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
7+
import com.ibm.wala.ipa.callgraph.CGNode;
8+
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
9+
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
10+
import com.ibm.wala.util.collections.HashSetFactory;
11+
import java.util.ArrayList;
12+
import java.util.List;
13+
import java.util.Set;
14+
15+
/**
16+
* A representation of the `tf.random.poisson` API in TensorFlow.
17+
*
18+
* @see <a
19+
* href="https://www.tensorflow.org/api_docs/python/tf/random/poisson">tf.random.poisson</a>.
20+
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
21+
*/
22+
public class Poisson extends Ones {
23+
24+
private static final String FUNCTION_NAME = "tf.random.poisson()";
25+
26+
enum Parameters {
27+
SHAPE,
28+
LAM,
29+
DTYPE,
30+
SEED,
31+
NAME
32+
}
33+
34+
public Poisson(PointsToSetVariable source, CGNode node) {
35+
super(source, node);
36+
}
37+
38+
@Override
39+
protected int getDTypeParameterPosition() {
40+
return DTYPE.ordinal();
41+
}
42+
43+
protected int getLamParameterPosition() {
44+
return LAM.ordinal();
45+
}
46+
47+
protected int getLamParameterValueNumber(PropagationCallGraphBuilder builder) {
48+
Set<Integer> numberOfPossiblePositionalArguments =
49+
this.getNumberOfPossiblePositionalArguments(builder);
50+
int lamParameterPosition = this.getLamParameterPosition();
51+
52+
if (!numberOfPossiblePositionalArguments.stream().anyMatch(n -> n >= lamParameterPosition + 1))
53+
throw new IllegalStateException(
54+
"Cannot determine value number for 'lam' parameter of " + FUNCTION_NAME);
55+
56+
return this.getArgumentValueNumber(lamParameterPosition);
57+
}
58+
59+
@Override
60+
protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder) {
61+
Set<List<Dimension<?>>> ret = HashSetFactory.make();
62+
Set<List<Dimension<?>>> shapes = super.getShapes(builder);
63+
64+
if (shapes.isEmpty())
65+
throw new IllegalStateException(
66+
"Cannot determine shape for " + this.getSignature() + " call.");
67+
68+
// Get the shape of the alpha parameter.
69+
Set<List<Dimension<?>>> lamShapes =
70+
this.getShapes(builder, this.getLamParameterValueNumber(builder));
71+
72+
// return shape `tf.concat([shape, tf.shape(lam)], axis=0)`.
73+
shapes.forEach(
74+
shape -> {
75+
lamShapes.forEach(
76+
lShape -> {
77+
List<Dimension<?>> newShape = new ArrayList<>(shape);
78+
newShape.addAll(lShape);
79+
ret.add(newShape);
80+
});
81+
});
82+
83+
return ret;
84+
}
85+
86+
@Override
87+
protected String getSignature() {
88+
return FUNCTION_NAME;
89+
}
90+
}

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.NORMAL;
99
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONES;
1010
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.ONE_HOT;
11+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.POISSON;
1112
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.RANGE;
1213
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.TRUNCATED_NORMAL;
1314
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.UNIFORM;
@@ -56,6 +57,7 @@ else if (calledFunction.equals(CONVERT_TO_TENSOR.getDeclaringClass()))
5657
else if (calledFunction.equals(ONE_HOT.getDeclaringClass())) return new OneHot(source, node);
5758
else if (calledFunction.equals(EYE.getDeclaringClass())) return new Eye(source, node);
5859
else if (calledFunction.equals(GAMMA.getDeclaringClass())) return new Gamma(source, node);
60+
else if (calledFunction.equals(POISSON.getDeclaringClass())) return new Poisson(source, node);
5961
else
6062
throw new IllegalArgumentException(
6163
"Unknown call: " + calledFunction + " for source: " + source + ".");

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,13 @@ public boolean canConvertTo(DType other) {
164164
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/gamma")),
165165
AstMethodReference.fnSelector);
166166

167+
/** https://www.tensorflow.org/api_docs/python/tf/poisson. */
168+
public static final MethodReference POISSON =
169+
MethodReference.findOrCreate(
170+
TypeReference.findOrCreate(
171+
PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/poisson")),
172+
AstMethodReference.fnSelector);
173+
167174
/**
168175
* Represents the TensorFlow float32 data type.
169176
*
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
samples = tf.random.poisson([10], [0.5, 1.5])
9+
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
10+
# the samples drawn from each distribution
11+
assert samples.shape == (10, 2)
12+
assert samples.dtype == tf.float32
13+
14+
f(samples)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
samples = tf.random.poisson([10], [0.5, 1.5], tf.float32)
9+
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
10+
# the samples drawn from each distribution
11+
assert samples.shape == (10, 2)
12+
assert samples.dtype == tf.float32
13+
14+
f(samples)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
samples = tf.random.poisson([10], [0.5, 1.5], tf.float64)
9+
# samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
10+
# the samples drawn from each distribution
11+
assert samples.shape == (10, 2)
12+
assert samples.dtype == tf.float32
13+
14+
f(samples)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import tensorflow as tf
2+
3+
4+
def f(a):
5+
pass
6+
7+
8+
samples = tf.random.poisson([7, 5], [12.2, 3.3])
9+
# samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
10+
# represents the 7x5 samples drawn from each of the two distributions
11+
assert samples.shape == (7, 5, 2)
12+
assert samples.dtype == tf.float32
13+
14+
f(samples)

0 commit comments

Comments
 (0)