Skip to content

Commit 216bb9a

Browse files
committed
Add test for varying dtypes.
1 parent 13c7ec0 commit 216bb9a

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,17 @@ public void testDecorator11()
291291
test("tf2_test_decorator11.py", "C.returned", 1, 1, Map.of(3, Set.of(TENSOR_5_INT32)));
292292
}
293293

294+
@Test
295+
public void testDecorator12()
296+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
297+
test(
298+
"tf2_test_decorator12.py",
299+
"returned",
300+
1,
301+
1,
302+
Map.of(2, Set.of(TENSOR_2_FLOAT32, TENSOR_2_INT32)));
303+
}
304+
294305
@Test
295306
public void testDataset()
296307
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
130130
.distinct()
131131
.collect(Collectors.toCollection(() -> EnumSet.noneOf(DType.class)));
132132

133+
// FIXME: We can't tell the difference here between varying dtypes in a single call and that of
134+
// possible varying dtypes values from the points-to graph. Below, we are treating it as these
135+
// values lie in a single call, but that may not be the case.
136+
133137
if (types.contains(DType.FLOAT64)) return EnumSet.of(DType.FLOAT64);
134138
else if (types.contains(DType.FLOAT32)) return EnumSet.of(DType.FLOAT32);
135139
else if (types.contains(DType.INT64)) return EnumSet.of(DType.INT64);
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import tensorflow as tf
2+
3+
4+
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),))
5+
@tf.function(reduce_retracing=True)
6+
def returned(a):
7+
return a
8+
9+
10+
a = tf.constant([1, 1.0])
11+
b = returned(a)
12+
13+
assert a.shape == (2,)
14+
assert a.dtype == tf.float32

0 commit comments

Comments
 (0)