Skip to content

Commit 319eed7

Browse files
committed
Handle multiple parameters for tf.range().
1 parent b897827 commit 319eed7

File tree

4 files changed

+71
-12
lines changed

4 files changed

+71
-12
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,13 +1328,23 @@ public void testAdd37()
13281328
@Test
13291329
public void testAdd38()
13301330
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1331-
test("tf2_test_add38.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1331+
test(
1332+
"tf2_test_add38.py",
1333+
"add",
1334+
2,
1335+
2,
1336+
Map.of(2, Set.of(TENSOR_5_INT32), 3, Set.of(TENSOR_5_INT32)));
13321337
}
13331338

13341339
@Test
13351340
public void testAdd39()
13361341
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1337-
test("tf2_test_add39.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1342+
test(
1343+
"tf2_test_add39.py",
1344+
"add",
1345+
2,
1346+
2,
1347+
Map.of(2, Set.of(TENSOR_5_INT32), 3, Set.of(TENSOR_5_INT32)));
13381348
}
13391349

13401350
@Test
@@ -1422,7 +1432,12 @@ public void testAdd49()
14221432
@Test
14231433
public void testAdd50()
14241434
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1425-
test("tf2_test_add50.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1435+
test(
1436+
"tf2_test_add50.py",
1437+
"add",
1438+
2,
1439+
2,
1440+
Map.of(2, Set.of(TENSOR_5_INT32), 3, Set.of(TENSOR_5_INT32)));
14261441
}
14271442

14281443
@Test

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import com.ibm.wala.ipa.callgraph.propagation.cfa.CallString;
1919
import com.ibm.wala.ssa.SSAAbstractInvokeInstruction;
2020
import com.ibm.wala.util.collections.HashSetFactory;
21-
import com.ibm.wala.util.debug.UnimplementedError;
2221
import com.ibm.wala.util.intset.OrdinalSet;
2322
import java.util.EnumSet;
2423
import java.util.Iterator;
@@ -64,10 +63,10 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
6463
// 2. `tf.range(start, limit, delta)` - generates a range from start to limit with a step of
6564
// delta.
6665

67-
// First, decide which version of the `range` function is being called based on the number of
68-
// numeric arguments.j
66+
// Decide which version of the `range` function is being called based on the number of numeric
67+
// arguments.
6968
// TODO: Handle keyword arguments.
70-
for (Integer numOfPoisitionArguments : getNumberOfPossiblePositionalArguments(builder)) {
69+
for (Integer numOfPoisitionArguments : getNumberOfPossiblePositionalArguments(builder))
7170
if (numOfPoisitionArguments == 1) {
7271
// it must *just* be `limit`.
7372
PointerKey limitPK =
@@ -81,11 +80,42 @@ protected Set<List<Dimension<?>>> getShapes(PropagationCallGraphBuilder builder)
8180
int shape = (int) Math.ceil((limit - start) / delta);
8281
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
8382
}
83+
} else if (numOfPoisitionArguments == 3) {
84+
// it must be `start`, `limit`, and `delta`.
85+
PointerKey startPK =
86+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 2);
87+
PointerKey limitPK =
88+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 3);
89+
PointerKey deltaPK =
90+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), 4);
91+
92+
OrdinalSet<InstanceKey> startPointsToSet = pointerAnalysis.getPointsToSet(startPK);
93+
OrdinalSet<InstanceKey> limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK);
94+
OrdinalSet<InstanceKey> deltaPointsToSet = pointerAnalysis.getPointsToSet(deltaPK);
95+
96+
assert !startPointsToSet.isEmpty() : "Expected a non-empty points-to set for start.";
97+
assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit.";
98+
assert !deltaPointsToSet.isEmpty() : "Expected a non-empty points-to set for delta.";
99+
100+
for (InstanceKey startIK : startPointsToSet) {
101+
start = ((Number) ((ConstantKey<?>) startIK).getValue()).doubleValue();
102+
103+
for (InstanceKey limitIK : limitPointsToSet) {
104+
limit = ((Number) ((ConstantKey<?>) limitIK).getValue()).doubleValue();
105+
106+
for (InstanceKey deltaIK : deltaPointsToSet) {
107+
delta = ((Number) ((ConstantKey<?>) deltaIK).getValue()).doubleValue();
108+
109+
int shape = (int) Math.ceil((limit - start) / delta);
110+
ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor.
111+
}
112+
}
113+
}
84114
} else
85-
// TODO: Handle more cases.
86-
throw new UnimplementedError(
87-
"Currently cannot handle more than one numeric positional argument for range().");
88-
}
115+
throw new IllegalStateException(
116+
"Expected either 1 or 3 positional arguments for range(), but got: "
117+
+ numOfPoisitionArguments
118+
+ ".");
89119

90120
return ret;
91121
}

com.ibm.wala.cast.python.test/data/tf2_test_add38.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,12 @@ def add(a, b):
55
return a + b
66

77

8-
c = add(tf.range(3, 18, 3), tf.range(5))
8+
arg1 = tf.range(3, 18, 3)
9+
assert arg1.shape == (5,)
10+
assert arg1.dtype == tf.int32
11+
12+
arg2 = tf.range(5)
13+
assert arg2.shape == (5,)
14+
assert arg2.dtype == tf.int32
15+
16+
c = add(arg1, arg2)

com.ibm.wala.cast.python.test/data/tf2_test_tf_range.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ def f(a):
1212
delta = 3
1313

1414
r = tf.range(start, limit, delta)
15+
assert isinstance(r, tf.Tensor)
16+
assert r.shape == (5,)
17+
assert r.dtype == tf.int32
1518

1619
for i in r:
20+
assert isinstance(i, tf.Tensor)
21+
assert i.dtype == tf.int32
22+
assert i.shape == ()
1723
f(i)

0 commit comments

Comments
 (0)