Skip to content

Commit c6683f2

Browse files
committed
Fix tf.range().
- Actually send a literal value into the `constant` op, rather than a value number. - Update tests.
1 parent c8395e6 commit c6683f2

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2033,7 +2033,7 @@ public void testTFRange2()
20332033
@Test
20342034
public void testTFRange3()
20352035
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
2036-
test("test_tf_range.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT)));
2036+
test("test_tf_range.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32)));
20372037
}
20382038

20392039
@Test

com.ibm.wala.cast.python.ml/data/tensorflow.xml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,8 @@
537537
<method name="read_data" descriptor="()LRoot;">
538538
<new def="x" class="Llist" />
539539
<new def="z" class="Ltensorflow/functions/constant" />
540-
<call class="Ltensorflow/functions/constant" name="do" descriptor="()LRoot;" type="virtual" arg0="z" arg1="1" numArgs="2" def="y" />
540+
<constant name="p" type="int" value="1" />
541+
<call class="Ltensorflow/functions/constant" name="do" descriptor="()LRoot;" type="virtual" arg0="z" arg1="p" numArgs="2" def="y" />
541542
<putfield class="LRoot" field="0" fieldType="LRoot" ref="x" value="y" />
542543
<return value="x" />
543544
</method>

0 commit comments

Comments
 (0)