Skip to content

Commit debb869

Browse files
committed
Initial work on dtype argument for tf.range().
1 parent 8dda029 commit debb869

File tree

4 files changed

+39
-3
lines changed

4 files changed

+39
-3
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2036,6 +2036,12 @@ public void testTFRange3()
20362036
test("tf2_test_tf_range3.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32)));
20372037
}
20382038

2039+
@Test
2040+
public void testTFRange4()
2041+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2042+
test("tf2_test_tf_range4.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_FLOAT32)));
2043+
}
2044+
20392045
@Test
20402046
public void testImport()
20412047
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {

com.ibm.wala.cast.python.ml/data/tensorflow.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@
537537
<method name="read_data" descriptor="()LRoot;">
538538
<new def="x" class="Llist" />
539539
<new def="z" class="Ltensorflow/functions/constant" />
540-
<constant name="p" type="int" value="1" />
540+
<constant name="p" type="int" value="1" /> <!-- FIXME: Unfortunately, if a dtype argument is supplied to `tf.range()`, the value may not be an integer. -->
541541
<call class="Ltensorflow/functions/constant" name="do" descriptor="()LRoot;" type="virtual" arg0="z" arg1="p" numArgs="2" def="y" />
542542
<putfield class="LRoot" field="0" fieldType="LRoot" ref="x" value="y" />
543543
<return value="x" />

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ public class Range extends TensorGenerator {
4343

4444
private static final String FUNCTION_NAME = "tf.range()";
4545

46+
private static final int DTYPE_POSITIONAL_PARAMETER_INDEX = 3;
47+
4648
public Range(PointsToSetVariable source, CGNode node) {
4749
super(source, node);
4850
}
@@ -232,10 +234,13 @@ protected int getValueNumberForShapeArgument() {
232234

233235
@Override
234236
protected int getValueNumberForDTypeArgument() {
237+
// TODO: Handle keyword arguments.
238+
235239
// TODO: We need a value number for the dtype argument. Also, that value number can differ
236240
// depending on the version of the `range` function being called.
237-
238-
return -1; // Positional dtype argument for range() is not yet implemented.
241+
return this.getNode().getIR().getMethod().isStatic()
242+
? this.getNode().getIR().getParameter(DTYPE_POSITIONAL_PARAMETER_INDEX)
243+
: this.getNode().getIR().getParameter(DTYPE_POSITIONAL_PARAMETER_INDEX + 1);
239244
}
240245

241246
@Override
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# From: https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/range#for_example
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
start = 3
11+
limit = 18
12+
delta = 3
13+
14+
r = tf.range(start, limit, delta, tf.float32)
15+
assert isinstance(r, tf.Tensor)
16+
assert r.shape == (5,)
17+
assert r.dtype == tf.float32
18+
19+
for i in r:
20+
assert isinstance(i, tf.Tensor)
21+
assert (
22+
i.dtype == tf.float32
23+
) # NOTE: This is getting cast here from the original input.
24+
assert i.shape == ()
25+
f(i)

0 commit comments

Comments
 (0)