Skip to content

Commit c40ccba

Browse files
authored
New dataset APIs (#129)
Adding `shuffle()` and `batch()`.
1 parent 6e53776 commit c40ccba

File tree

7 files changed

+111
-29
lines changed

7 files changed

+111
-29
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ public void testTf2()
203203
// treating it as one. But, in the literal case, it should be possible to model it like the list
204204
// tests below.
205205
testTf2("tf2_test_dataset.py", "add", 2, 2, 2, 3);
206+
testTf2("tf2_test_dataset2.py", "add", 2, 2, 2, 3);
207+
testTf2("tf2_test_dataset3.py", "add", 2, 2, 2, 3);
208+
testTf2("tf2_test_dataset4.py", "add", 2, 2, 2, 3);
209+
testTf2("tf2_test_dataset5.py", "add", 2, 2, 2, 3);
206210
testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3);
207211
testTf2("tf2_test_tensor_list2.py", "add", 0, 2);
208212
testTf2("tf2_test_tensor_list3.py", "add", 0, 2);

com.ibm.wala.cast.python.ml/data/tensorflow.xml

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,14 @@
3636
<new def="estimator" class="Lobject" />
3737
<putfield class="LRoot" field="estimator" fieldType="LRoot" ref="x" value="estimator" />
3838

39+
<new def="data" class="Lobject" />
40+
<putfield class="LRoot" field="data" fieldType="LRoot" ref="x" value="data" />
41+
3942
<new def="distribute" class="Lobject" />
4043
<putfield class="LRoot" field="distribute" fieldType="LRoot" ref="x" value="distribute" />
4144

4245
<new def="nn" class="Lobject" />
4346
<putfield class="LRoot" field="nn" fieldType="LRoot" ref="x" value="nn" />
44-
<new def="data" class="Lobject" />
45-
<putfield class="LRoot" field="data" fieldType="LRoot" ref="x" value="data" />
46-
<new def="Dataset" class="Lobject" />
47-
<putfield class="LRoot" field="Dataset" fieldType="LRoot" ref="data" value="Dataset" />
4847
<new def="random" class="Lobject" />
4948
<putfield class="LRoot" field="random" fieldType="LRoot" ref="x" value="random" />
5049
<new def="sparse" class="Lobject" />
@@ -65,6 +64,9 @@
6564
<new def="Estimator" class="Ltensorflow/estimator/Estimator" />
6665
<putfield class="LRoot" field="Estimator" fieldType="LRoot" ref="estimator" value="Estimator" />
6766

67+
<new def="Dataset" class="Ltensorflow/data/Dataset" />
68+
<putfield class="LRoot" field="Dataset" fieldType="LRoot" ref="data" value="Dataset" />
69+
6870
<new def="MirroredStrategy" class="Ltensorflow/distribute/MirroredStrategy" />
6971
<putfield class="LRoot" field="MirroredStrategy" fieldType="LRoot" ref="distribute" value="MirroredStrategy" />
7072

@@ -74,6 +76,9 @@
7476
<new def="numpy_input_fn" class="Ltensorflow/estimator/numpy_input_fn" />
7577
<putfield class="LRoot" field="numpy_input_fn" fieldType="LRoot" ref="inputs" value="numpy_input_fn" />
7678

79+
<new def="from_tensor_slices" class="Ltensorflow/data/Dataset/from_tensor_slices" />
80+
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="Dataset" value="from_tensor_slices" />
81+
7782
<new def="reshape" class="Ltensorflow/functions/reshape" />
7883
<putfield class="LRoot" field="reshape" fieldType="LRoot" ref="x" value="reshape" />
7984

@@ -126,9 +131,6 @@
126131
<new def="array_ops" class="Lobject" />
127132
<putfield class="LRoot" field="array_ops" fieldType="LRoot" ref="ops" value="array_ops" />
128133

129-
<new def="data_ops" class="Lobject" />
130-
<putfield class="LRoot" field="data_ops" fieldType="LRoot" ref="ops" value="data_ops" />
131-
132134
<new def="random_ops" class="Lobject" />
133135
<putfield class="LRoot" field="random_ops" fieldType="LRoot" ref="ops" value="random_ops" />
134136

@@ -174,10 +176,6 @@
174176
<putfield class="LRoot" field="ones" fieldType="LRoot" ref="x" value="ones" />
175177
<putfield class="LRoot" field="ones" fieldType="LRoot" ref="array_ops" value="ones" />
176178

177-
<new def="from_tensor_slices" class="Ltensorflow/functions/from_tensor_slices" />
178-
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="Dataset" value="from_tensor_slices" />
179-
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="data_ops" value="from_tensor_slices" />
180-
181179
<new def="zeros" class="Ltensorflow/functions/zeros" />
182180
<putfield class="LRoot" field="zeros" fieldType="LRoot" ref="x" value="zeros" />
183181
<putfield class="LRoot" field="zeros" fieldType="LRoot" ref="array_ops" value="zeros" />
@@ -410,18 +408,6 @@
410408
</method>
411409
</class>
412410

413-
<class name="from_tensor_slices" allocatable="true">
414-
<!-- "read_dataset" means that this function reads a tensor iterable. -->
415-
<method name="read_dataset" descriptor="()LRoot;">
416-
<new def="x" class="Ltensorflow/python/ops/data_ops/from_tensor_slices" />
417-
<return value="x" />
418-
</method>
419-
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="tensors name">
420-
<call class="LRoot" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="arg0" def="x" />
421-
<return value="x" />
422-
</method>
423-
</class>
424-
425411
<class name="Variable" allocatable="true">
426412
<method name="read_data" descriptor="()LRoot;">
427413
<new def="x" class="Ltensorflow/python/ops/variables/Variable" />
@@ -804,6 +790,53 @@
804790
</class>
805791
</package>
806792

793+
<package name="tensorflow/data">
794+
<class name="Dataset" allocatable="true">
795+
<!-- "read_dataset" means that this function reads a tensor iterable. -->
796+
<method name="read_dataset" descriptor="()LRoot;">
797+
<new def="shuffle" class="Ltensorflow/data/shuffle" />
798+
<putfield class="LRoot" field="shuffle" fieldType="LRoot" ref="arg0" value="shuffle" />
799+
<new def="batch" class="Ltensorflow/data/batch" />
800+
<putfield class="LRoot" field="batch" fieldType="LRoot" ref="arg0" value="batch" />
801+
<return value="arg0" />
802+
</method>
803+
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="self variant_tensor">
804+
<call class="LRoot" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="arg0" def="x" />
805+
<return value="x" />
806+
</method>
807+
</class>
808+
809+
<class name="shuffle" allocatable="true">
810+
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/data/Dataset#shuffle -->
811+
<method name="do" descriptor="()LRoot;" numArgs="5" paramNames="self buffer_size seed reshuffle_each_iteration name">
812+
<!-- FIXME: Workaround for https://github.com/wala/ML/issues/127. This method (shuffle) doesn't really return a "new" dataset but rather a modified version of the receiver. But, the receiver isn't available without a trampoline AFAIK. -->
813+
<new def="x" class="Ltensorflow/data/Dataset" />
814+
<call class="Ltensorflow/data/Dataset" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="x" def="xx" />
815+
<return value="xx" />
816+
</method>
817+
</class>
818+
819+
<class name="batch" allocatable="true">
820+
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/data/Dataset#batch -->
821+
<method name="do" descriptor="()LRoot;" numArgs="6" paramNames="self batch_size drop_remainder num_parallel_calls deterministic name">
822+
<!-- FIXME: Workaround for https://github.com/wala/ML/issues/127. -->
823+
<new def="x" class="Ltensorflow/data/Dataset" />
824+
<call class="Ltensorflow/data/Dataset" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="x" def="xx" />
825+
<return value="xx" />
826+
</method>
827+
</class>
828+
</package>
829+
830+
<package name="tensorflow/data/Dataset">
831+
<class name="from_tensor_slices" allocatable="true">
832+
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="tensors name">
833+
<new def="x" class="Ltensorflow/data/Dataset" />
834+
<call class="Ltensorflow/data/Dataset" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="x" def="xx" />
835+
<return value="xx" />
836+
</method>
837+
</class>
838+
</package>
839+
807840
<package name="tensorflow/estimator/train">
808841
<class name="train" allocatable="true">
809842
<method name="do" descriptor="()LRoot;" numArgs="3">

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,11 @@ private static Set<PointsToSetVariable> getDataflowSources(
119119
// We are potentially pulling a tensor out of a tensor iterable.
120120
EachElementGetInstruction eachElementGetInstruction = (EachElementGetInstruction) inst;
121121

122-
// Find the potential tensor iterable creation site.
123-
SSAInstruction def = du.getDef(eachElementGetInstruction.getUse(0));
122+
// Find the potential tensor iterable definition.
123+
int use = eachElementGetInstruction.getUse(0);
124+
SSAInstruction def = du.getDef(use);
124125

125-
if (createsTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) {
126+
if (definesTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) {
126127
sources.add(src);
127128
logger.info("Added dataflow source from tensor iterable: " + src + ".");
128129
}
@@ -133,16 +134,16 @@ private static Set<PointsToSetVariable> getDataflowSources(
133134
}
134135

135136
/**
136-
* Returns true iff the given {@link SSAInstruction} creates an iterable of tensors.
137+
* Returns true iff the given {@link SSAInstruction} defines an iterable of tensors.
137138
*
138139
* @param instruction The {@link SSAInstruction} in question.
139140
* @param node The {@link CGNode} of the function containing the given {@link SSAInstruction}.
140141
* @param callGraph The {@link CallGraph} that includes a node corresponding to the given {@link
141142
* SSAInstruction}.
142143
* @param pointerAnalysis The {@link PointerAnalysis} built from the given {@link CallGraph}.
143-
* @return True iff the given {@link SSAInstruction} creates an iterable over tensors.
144+
* @return True iff the given {@link SSAInstruction} defines an iterable over tensors.
144145
*/
145-
private static boolean createsTensorIterable(
146+
private static boolean definesTensorIterable(
146147
SSAInstruction instruction,
147148
CGNode node,
148149
CallGraph callGraph,
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import tensorflow as tf
2+
3+
4+
def add(a, b):
5+
return a + b
6+
7+
8+
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3)
9+
10+
for element in dataset:
11+
c = add(element, element)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import tensorflow as tf
2+
3+
4+
def add(a, b):
5+
return a + b
6+
7+
8+
dataset = tf.data.Dataset(None) # This is actually illegal since this ctor is not publicly visible.
9+
10+
for element in dataset:
11+
c = add(element, element)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import tensorflow as tf
2+
3+
4+
def add(a, b):
5+
return a + b
6+
7+
8+
dataset = tf.data.Dataset(None).shuffle(3) # This is actually illegal since this ctor is not publicly visible.
9+
10+
for element in dataset:
11+
c = add(element, element)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import tensorflow as tf
2+
3+
4+
def add(a, b):
5+
return a + b
6+
7+
8+
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).batch(2)
9+
10+
for element in dataset:
11+
c = add(element, element)

0 commit comments

Comments
 (0)