Skip to content

Commit 0bbab35

Browse files
committed
Progress.
1 parent 44a980c commit 0bbab35

File tree

8 files changed

+306
-23
lines changed

8 files changed

+306
-23
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
105105
private static final TensorType TENSOR_2_NONE_INT32 =
106106
new TensorType(INT_32, asList(new NumericDim(2), null));
107107

108+
private static final TensorType TENSOR_2_NONE_FLOAT32 =
109+
new TensorType(FLOAT_32, asList(new NumericDim(2), null));
110+
111+
private static final TensorType TENSOR_2_NONE_2_FLOAT32 =
112+
new TensorType(FLOAT_32, asList(new NumericDim(2), null, new NumericDim(2)));
113+
108114
@SuppressWarnings("unused")
109115
private static final TensorType TENSOR_2_NONE_NONE_NONE_INT32 =
110116
new TensorType(INT_32, asList(new NumericDim(2), null));
@@ -115,6 +121,15 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
115121
private static final TensorType TENSOR_3_NONE_INT32 =
116122
new TensorType(INT_32, asList(new NumericDim(3), null));
117123

124+
private static final TensorType TENSOR_3_NONE_FLOAT32 =
125+
new TensorType(FLOAT_32, asList(new NumericDim(3), null));
126+
127+
private static final TensorType TENSOR_3_NONE_NONE_FLOAT32 =
128+
new TensorType(FLOAT_32, asList(new NumericDim(3), null, null));
129+
130+
private static final TensorType TENSOR_3_NONE_1_FLOAT32 =
131+
new TensorType(FLOAT_32, asList(new NumericDim(3), null, new NumericDim(1)));
132+
118133
private static final TensorType TENSOR_2_3_INT32 =
119134
new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3)));
120135

@@ -4623,6 +4638,36 @@ public void testRaggedConstant6() throws ClassHierarchyException, CancelExceptio
46234638
test("tf2_test_ragged_constant6.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32)));
46244639
}
46254640

4641+
@Test
4642+
public void testRaggedConstant7() throws ClassHierarchyException, CancelException, IOException {
4643+
test("tf2_test_ragged_constant7.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_NONE_INT32)));
4644+
}
4645+
4646+
@Test
4647+
public void testRaggedConstant8() throws ClassHierarchyException, CancelException, IOException {
4648+
test("tf2_test_ragged_constant8.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_NONE_FLOAT32)));
4649+
}
4650+
4651+
@Test
4652+
public void testRaggedConstant9() throws ClassHierarchyException, CancelException, IOException {
4653+
test("tf2_test_ragged_constant9.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_NONE_NONE_FLOAT32)));
4654+
}
4655+
4656+
@Test
4657+
public void testRaggedConstant10() throws ClassHierarchyException, CancelException, IOException {
4658+
test("tf2_test_ragged_constant10.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_NONE_1_FLOAT32)));
4659+
}
4660+
4661+
@Test
4662+
public void testRaggedConstant11() throws ClassHierarchyException, CancelException, IOException {
4663+
test("tf2_test_ragged_constant11.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_FLOAT32)));
4664+
}
4665+
4666+
@Test
4667+
public void testRaggedConstant12() throws ClassHierarchyException, CancelException, IOException {
4668+
test("tf2_test_ragged_constant12.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_FLOAT32)));
4669+
}
4670+
46264671
private void test(
46274672
String filename,
46284673
String functionName,

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/RaggedConstant.java

Lines changed: 147 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.ibm.wala.cast.python.ml.client;
22

3+
import static com.ibm.wala.cast.python.ml.client.RaggedConstant.Parameters.RAGGED_RANK;
34
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32;
45
import static com.ibm.wala.cast.python.types.PythonTypes.Root;
56
import static com.ibm.wala.cast.python.types.PythonTypes.list;
@@ -28,7 +29,6 @@
2829
import java.util.ArrayList;
2930
import java.util.EnumSet;
3031
import java.util.List;
31-
import java.util.Optional;
3232
import java.util.Set;
3333
import java.util.logging.Logger;
3434
import java.util.stream.StreamSupport;
@@ -57,6 +57,65 @@ public RaggedConstant(PointsToSetVariable source) {
5757
super(source);
5858
}
5959

60+
private static Set<Integer> getPossibleInnerListLengths(
61+
PropagationCallGraphBuilder builder, OrdinalSet<InstanceKey> pts) {
62+
Set<Integer> ret = HashSetFactory.make();
63+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
64+
65+
for (InstanceKey ik : pts) {
66+
AllocationSiteInNode asin = getAllocationSiteInNode(ik);
67+
TypeReference reference = asin.getConcreteType().getReference();
68+
69+
// A `list` or `tuple`.
70+
if (reference.equals(list) || reference.equals(tuple)) {
71+
OrdinalSet<InstanceKey> objectCatalogPointsToSet =
72+
pointerAnalysis.getPointsToSet(
73+
((AstPointerKeyFactory) builder.getPointerKeyFactory())
74+
.getPointerKeyForObjectCatalog(asin));
75+
76+
assert objectCatalogPointsToSet.iterator().hasNext();
77+
78+
InstanceKey catalogIK =
79+
objectCatalogPointsToSet
80+
.iterator()
81+
.next(); // Just need one element to check inner length.
82+
83+
ConstantKey<?> constantKey = (ConstantKey<?>) catalogIK;
84+
Object constantKeyValue = constantKey.getValue();
85+
86+
Integer fieldIndex = (Integer) constantKeyValue;
87+
88+
FieldReference subscript =
89+
FieldReference.findOrCreate(Root, findOrCreateAsciiAtom(fieldIndex.toString()), Root);
90+
91+
IField f = builder.getClassHierarchy().resolveField(subscript);
92+
93+
PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f);
94+
95+
OrdinalSet<InstanceKey> instanceFieldPointsToSet =
96+
pointerAnalysis.getPointsToSet(pointerKeyForInstanceField);
97+
98+
boolean containsAllListsOrTuples =
99+
StreamSupport.stream(instanceFieldPointsToSet.spliterator(), false)
100+
.allMatch(
101+
ik -> {
102+
AllocationSiteInNode innerAsin = getAllocationSiteInNode(ik);
103+
104+
if (innerAsin == null) return false;
105+
106+
TypeReference innerReference = innerAsin.getConcreteType().getReference();
107+
return innerReference.equals(list) || innerReference.equals(tuple);
108+
});
109+
110+
if (!containsAllListsOrTuples) ret.add(objectCatalogPointsToSet.size());
111+
else ret.addAll(getPossibleInnerListLengths(builder, instanceFieldPointsToSet));
112+
} else
113+
throw new IllegalStateException("Expected a list or tuple, but found: " + reference + ".");
114+
}
115+
116+
return ret;
117+
}
118+
60119
private static Set<Integer> getPossibleOuterListLengths(
61120
PropagationCallGraphBuilder builder, OrdinalSet<InstanceKey> valuePointsToSet) {
62121
Set<Integer> ret = HashSetFactory.make();
@@ -246,31 +305,58 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
246305
int K = maxDepth;
247306
LOGGER.fine("Tensor rank: " + K);
248307

249-
Optional<Integer> raggedRank = this.getRaggedRankArgumentValue(builder);
250-
int R = raggedRank.orElse(K - 1);
251-
LOGGER.fine("Ragged rank: " + R);
308+
Set<Long> rankArguments = this.getPossibleRaggedRankArguments(builder);
252309

253-
// Step 3: Construct shape with rank K and ragged rank R.
310+
if (rankArguments.isEmpty()) rankArguments.add(K - 1L); // Default ragged rank.
254311

255-
// Get the length of the outer list.
256-
Set<Integer> possibleOuterListLengths =
257-
getPossibleOuterListLengths(builder, valuePointsToSet);
312+
for (Long R : rankArguments) {
313+
LOGGER.fine("Ragged rank: " + R);
258314

259-
for (int outerListLength : possibleOuterListLengths) {
260-
List<Dimension<?>> shape = new ArrayList<>();
261-
shape.add(new NumericDim(outerListLength));
315+
// Step 3: Construct shape with rank K and ragged rank R.
316+
// The final shape is constructed by concatenating the Ragged Portion and the Uniform
317+
// Portion.
262318

263-
// The first R dimensions are ragged.
264-
for (int i = 0; i < R; i++) shape.add(null); // Unknown size for ragged dimensions.
319+
// Part A: The Ragged Portion (Dimensions 0 to R)
265320

266-
/*
267-
// The remaining K - R dimensions are dense.
268-
for (int i = R; i < K; i++) {
269-
shape.add(new NumericDim(-1)); // Unknown size for dense dimensions.
270-
}
271-
*/
321+
// For the ragged dimensions, TensorFlow does not look for a uniform length. It assigns the
322+
// shape based on the row_splits.
272323

273-
ret.add(shape);
324+
// Get the length of the outer list.
325+
Set<Integer> possibleOuterListLengths =
326+
getPossibleOuterListLengths(builder, valuePointsToSet);
327+
328+
for (int outerListLength : possibleOuterListLengths) {
329+
List<Dimension<?>> shape = new ArrayList<>();
330+
331+
// Dim 0 (Batch): Always fixed. It is simply len(input_list).
332+
shape.add(new NumericDim(outerListLength));
333+
334+
// The first R dimensions are ragged.
335+
// Dim 1 to R: These are assigned None (or ? in older outputs) in the static shape,
336+
// indicating they can vary.
337+
for (Long i = 0L; i < R; i++) shape.add(null); // Unknown size for ragged dimensions.
338+
339+
// Part B: The Uniform Portion (Dimensions R + 1 to K)
340+
// If R < K - 1 (meaning you requested fewer ragged dimensions than the total depth),
341+
// TensorFlow enforces uniformity on the remaining inner dimensions.
342+
343+
// 1. It checks the length of every sub-list at these levels.
344+
// 2. If any lengths differ, it throws a ValueError.
345+
// 3. If they match, that length becomes the fixed size for that dimension.
346+
347+
if (R < K - 1) {
348+
Set<Integer> possibleInnerListLengths =
349+
getPossibleInnerListLengths(builder, valuePointsToSet);
350+
351+
// Determine the uniform lengths for dimensions R + 1 to K - 1.
352+
for (long i = R + 1; i < K; i++) {
353+
for (int innerListLength : possibleInnerListLengths)
354+
shape.add(new NumericDim(innerListLength));
355+
}
356+
}
357+
358+
ret.add(shape);
359+
}
274360
}
275361
}
276362

@@ -288,9 +374,46 @@ private static int getMaximumDepthOfInstance(
288374
return 1 + getMaximumDepthOfEmptyList(builder, instance);
289375
}
290376

291-
private Optional<Integer> getRaggedRankArgumentValue(PropagationCallGraphBuilder builder) {
292-
// TODO Auto-generated method stub
293-
return Optional.empty();
377+
protected Set<Long> getPossibleRaggedRankArguments(PropagationCallGraphBuilder builder) {
378+
Set<Long> ret = HashSetFactory.make();
379+
int valueNumber = this.getRaggedRankArgumentValueNumber(builder);
380+
381+
if (valueNumber >= 0) {
382+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
383+
PointerKey raggedRankPK =
384+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), valueNumber);
385+
OrdinalSet<InstanceKey> raggedRankPointsToSet = pointerAnalysis.getPointsToSet(raggedRankPK);
386+
387+
if (raggedRankPointsToSet == null || raggedRankPointsToSet.isEmpty())
388+
throw new IllegalArgumentException(
389+
"Empty points-to set for ragged_rank in source: " + this.getSource() + ".");
390+
391+
for (InstanceKey raggedRankIK : raggedRankPointsToSet)
392+
if (raggedRankIK instanceof ConstantKey) {
393+
ConstantKey<?> constantKey = (ConstantKey<?>) raggedRankIK;
394+
Object constantKeyValue = constantKey.getValue();
395+
396+
if (constantKeyValue instanceof Long) {
397+
Long raggedRankValue = (Long) constantKeyValue;
398+
ret.add(raggedRankValue);
399+
} else
400+
throw new IllegalArgumentException(
401+
"Expected an integer for ragged_rank, but found: " + constantKeyValue + ".");
402+
} else
403+
throw new IllegalArgumentException(
404+
"Expected a constant key for ragged_rank, but found: " + raggedRankIK + ".");
405+
}
406+
407+
return ret;
408+
}
409+
410+
protected int getRaggedRankParameterPosition() {
411+
return RAGGED_RANK.ordinal();
412+
}
413+
414+
protected int getRaggedRankArgumentValueNumber(PropagationCallGraphBuilder builder) {
415+
// TODO: Handle keyword arguments.
416+
return this.getArgumentValueNumber(builder, this.getRaggedRankParameterPosition(), true);
294417
}
295418

296419
/**
@@ -323,6 +446,7 @@ protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
323446
return EnumSet.of(FLOAT32);
324447
}
325448

449+
// Otherwise, there are values available to infer the dtype from.
326450
return super.getDefaultDTypes(builder);
327451
}
328452
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
arg = [[[1], [2]], [[3]], [[4], [5], [6]]]
11+
12+
x = tf.ragged.constant(arg, tf.float32, 1)
13+
assert x.shape == (3, None, 1)
14+
assert x.dtype == tf.float32
15+
16+
f(x)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
arg = [[1, 2], [3]]
11+
12+
x = tf.ragged.constant(arg, tf.float32, 1)
13+
14+
assert x.shape == (2, None)
15+
assert x.dtype == tf.float32
16+
17+
f(x)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
arg = [[[1, 2]], [[3, 4]]]
11+
12+
x = tf.ragged.constant(arg, tf.float32, 1)
13+
14+
assert x.shape == (2, None, 2)
15+
assert x.dtype == tf.float32
16+
17+
f(x)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
arg = [[1, 2], [3], [4, 5, 6]]
11+
assert isinstance(arg, list)
12+
assert len(arg) == 3
13+
assert all(isinstance(row, list) for row in arg)
14+
assert all(isinstance(x, int) for row in arg for x in row)
15+
assert len(arg[0]) == 2
16+
assert len(arg[1]) == 1
17+
assert len(arg[2]) == 3
18+
19+
x = tf.ragged.constant(arg, tf.int32)
20+
assert isinstance(x, tf.RaggedTensor)
21+
assert x.shape == (3, None)
22+
assert x.dtype == tf.int32
23+
24+
f(x)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
arg = [[1, 2], [3], [4, 5, 6]]
11+
assert isinstance(arg, list)
12+
assert len(arg) == 3
13+
assert all(isinstance(row, list) for row in arg)
14+
assert all(isinstance(x, int) for row in arg for x in row)
15+
assert len(arg[0]) == 2
16+
assert len(arg[1]) == 1
17+
assert len(arg[2]) == 3
18+
19+
x = tf.ragged.constant(arg, tf.float32)
20+
assert isinstance(x, tf.RaggedTensor)
21+
assert x.shape == (3, None)
22+
assert x.dtype == tf.float32
23+
24+
f(x)

0 commit comments

Comments
 (0)