Skip to content

Commit 143e7f7

Browse files
committed
Progress.
1 parent 2dd55f1 commit 143e7f7

File tree

7 files changed

+145
-44
lines changed

7 files changed

+145
-44
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
9595
private static final TensorType TENSOR_3_3_INT32 =
9696
new TensorType(INT_32, asList(new NumericDim(3), new NumericDim(3)));
9797

98+
private static final TensorType TENSOR_0_NONE_FLOAT32 =
99+
new TensorType(FLOAT_32, asList(new NumericDim(0), null));
100+
101+
private static final TensorType TENSOR_0_NONE_3_FLOAT32 =
102+
new TensorType(FLOAT_32, asList(new NumericDim(0), null, new NumericDim(3)));
103+
98104
@SuppressWarnings("unused")
99105
private static final TensorType TENSOR_1_NONE_INT32 =
100106
new TensorType(INT_32, asList(new NumericDim(1), null));
@@ -111,6 +117,9 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
111117
private static final TensorType TENSOR_2_NONE_2_FLOAT32 =
112118
new TensorType(FLOAT_32, asList(new NumericDim(2), null, new NumericDim(2)));
113119

120+
private static final TensorType TENSOR_2_NONE_2_INT32 =
121+
new TensorType(INT_32, asList(new NumericDim(2), null, new NumericDim(2)));
122+
114123
@SuppressWarnings("unused")
115124
private static final TensorType TENSOR_2_NONE_NONE_NONE_INT32 =
116125
new TensorType(INT_32, asList(new NumericDim(2), null));
@@ -4668,6 +4677,21 @@ public void testRaggedConstant12() throws ClassHierarchyException, CancelExcepti
46684677
test("tf2_test_ragged_constant12.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_FLOAT32)));
46694678
}
46704679

4680+
@Test
4681+
public void testRaggedConstant13() throws ClassHierarchyException, CancelException, IOException {
4682+
test("tf2_test_ragged_constant13.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_0_NONE_FLOAT32)));
4683+
}
4684+
4685+
@Test
4686+
public void testRaggedConstant14() throws ClassHierarchyException, CancelException, IOException {
4687+
test("tf2_test_ragged_constant14.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_0_NONE_3_FLOAT32)));
4688+
}
4689+
4690+
@Test
4691+
public void testRaggedConstant15() throws ClassHierarchyException, CancelException, IOException {
4692+
test("tf2_test_ragged_constant15.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_INT32)));
4693+
}
4694+
46714695
private void test(
46724696
String filename,
46734697
String functionName,

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/RaggedConstant.java

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package com.ibm.wala.cast.python.ml.client;
22

3+
import static com.ibm.wala.cast.python.ml.client.RaggedConstant.Parameters.INNER_SHAPE;
34
import static com.ibm.wala.cast.python.ml.client.RaggedConstant.Parameters.RAGGED_RANK;
45
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32;
56
import static com.ibm.wala.cast.python.types.PythonTypes.Root;
67
import static com.ibm.wala.cast.python.types.PythonTypes.list;
78
import static com.ibm.wala.cast.python.types.PythonTypes.tuple;
89
import static com.ibm.wala.cast.python.util.Util.getAllocationSiteInNode;
910
import static com.ibm.wala.core.util.strings.Atom.findOrCreateAsciiAtom;
11+
import static java.lang.Math.max;
12+
import static java.util.Collections.emptySet;
1013
import static java.util.logging.Logger.getLogger;
1114
import static java.util.stream.Collectors.toSet;
1215

@@ -306,8 +309,14 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
306309
LOGGER.fine("Tensor rank: " + K);
307310

308311
Set<Long> rankArguments = this.getPossibleRaggedRankArguments(builder);
312+
Set<List<Dimension<?>>> innerShapeArguments = this.getPossibleInnerShapeArguments(builder);
309313

310-
if (rankArguments.isEmpty()) rankArguments.add(K - 1L); // Default ragged rank.
314+
if (rankArguments.isEmpty())
315+
// Default ragged rank.
316+
if (innerShapeArguments.isEmpty()) rankArguments.add(max(0, K - 1L));
317+
else
318+
for (List<Dimension<?>> innerShape : innerShapeArguments)
319+
rankArguments.add(max(0, K - 1L - innerShape.size()));
311320

312321
for (Long R : rankArguments) {
313322
LOGGER.fine("Ragged rank: " + R);
@@ -375,36 +384,7 @@ private static int getMaximumDepthOfInstance(
375384
}
376385

377386
protected Set<Long> getPossibleRaggedRankArguments(PropagationCallGraphBuilder builder) {
378-
Set<Long> ret = HashSetFactory.make();
379-
int valueNumber = this.getRaggedRankArgumentValueNumber(builder);
380-
381-
if (valueNumber >= 0) {
382-
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
383-
PointerKey raggedRankPK =
384-
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), valueNumber);
385-
OrdinalSet<InstanceKey> raggedRankPointsToSet = pointerAnalysis.getPointsToSet(raggedRankPK);
386-
387-
if (raggedRankPointsToSet == null || raggedRankPointsToSet.isEmpty())
388-
throw new IllegalArgumentException(
389-
"Empty points-to set for ragged_rank in source: " + this.getSource() + ".");
390-
391-
for (InstanceKey raggedRankIK : raggedRankPointsToSet)
392-
if (raggedRankIK instanceof ConstantKey) {
393-
ConstantKey<?> constantKey = (ConstantKey<?>) raggedRankIK;
394-
Object constantKeyValue = constantKey.getValue();
395-
396-
if (constantKeyValue instanceof Long) {
397-
Long raggedRankValue = (Long) constantKeyValue;
398-
ret.add(raggedRankValue);
399-
} else
400-
throw new IllegalArgumentException(
401-
"Expected an integer for ragged_rank, but found: " + constantKeyValue + ".");
402-
} else
403-
throw new IllegalArgumentException(
404-
"Expected a constant key for ragged_rank, but found: " + raggedRankIK + ".");
405-
}
406-
407-
return ret;
387+
return this.getPossibleLongArguments(builder, this.getRaggedRankArgumentValueNumber(builder));
408388
}
409389

410390
protected int getRaggedRankParameterPosition() {
@@ -416,6 +396,26 @@ protected int getRaggedRankArgumentValueNumber(PropagationCallGraphBuilder build
416396
return this.getArgumentValueNumber(builder, this.getRaggedRankParameterPosition(), true);
417397
}
418398

399+
protected int getInnerShapeParameterPosition() {
400+
return INNER_SHAPE.ordinal();
401+
}
402+
403+
protected int getInnerShapeArgumentValueNumber(PropagationCallGraphBuilder builder) {
404+
// TODO: Handle keyword arguments.
405+
return this.getArgumentValueNumber(builder, this.getInnerShapeParameterPosition(), true);
406+
}
407+
408+
protected Set<List<Dimension<?>>> getPossibleInnerShapeArguments(
409+
PropagationCallGraphBuilder builder) {
410+
int valueNumber = this.getInnerShapeArgumentValueNumber(builder);
411+
412+
if (valueNumber >= 0) {
413+
PointerKey pointerKey = builder.getPointerKeyForLocal(this.getNode(), valueNumber);
414+
OrdinalSet<InstanceKey> pointsToSet = builder.getPointerAnalysis().getPointsToSet(pointerKey);
415+
return this.getShapesFromShapeArgument(builder, pointsToSet);
416+
} else return emptySet();
417+
}
418+
419419
/**
420420
* {@inheritDoc}
421421
*

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public Set<TensorType> getTensorTypes(PropagationCallGraphBuilder builder) {
8888
* Returns the possible shapes of the tensor returned by this generator.
8989
*
9090
* @param builder The {@link PropagationCallGraphBuilder} used to build the call graph.
91-
* @param pointsToSet The points-to set of the shape argument.
91+
* @param pointsToSet The points-to set of the shape argument. FIXME: Why not take a value number?
9292
* @return A set of possible shapes of the tensor returned by this generator.
9393
*/
9494
protected Set<List<Dimension<?>>> getShapesFromShapeArgument(
@@ -290,6 +290,7 @@ protected Set<List<Dimension<?>>> getShapes(
290290
throw new IllegalArgumentException(
291291
"Empty points-to set for value number: " + valueNumber + " in: " + this.getNode() + ".");
292292

293+
// FIXME: Just use the value number directly?
293294
return getShapesOfValue(builder, valuePointsToSet);
294295
}
295296

@@ -736,4 +737,37 @@ protected Set<Integer> getNumberOfPossiblePositionalArguments(
736737

737738
return ret;
738739
}
740+
741+
protected Set<Long> getPossibleLongArguments(
742+
PropagationCallGraphBuilder builder, int valueNumber) {
743+
Set<Long> ret = HashSetFactory.make();
744+
745+
if (valueNumber >= 0) {
746+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
747+
PointerKey pointerKey =
748+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), valueNumber);
749+
OrdinalSet<InstanceKey> pointsToSet = pointerAnalysis.getPointsToSet(pointerKey);
750+
751+
if (pointsToSet == null || pointsToSet.isEmpty())
752+
throw new IllegalArgumentException(
753+
"Empty points-to set in source: " + this.getSource() + ".");
754+
755+
for (InstanceKey instanceKey : pointsToSet)
756+
if (instanceKey instanceof com.ibm.wala.ipa.callgraph.propagation.ConstantKey) {
757+
ConstantKey<?> constantKey = (ConstantKey<?>) instanceKey;
758+
Object constantKeyValue = constantKey.getValue();
759+
760+
if (constantKeyValue instanceof Long) {
761+
Long value = (Long) constantKeyValue;
762+
ret.add(value);
763+
} else
764+
throw new IllegalStateException(
765+
"Expected a long, but found: " + constantKeyValue + ".");
766+
} else
767+
throw new IllegalStateException(
768+
"Expected a constant key, but found: " + instanceKey + ".");
769+
}
770+
771+
return ret;
772+
}
739773
}

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/ZerosLike.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
package com.ibm.wala.cast.python.ml.client;
22

3-
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
4-
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
53
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
6-
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
7-
import java.util.List;
8-
import java.util.Set;
94

105
/**
116
* A generator for tensors created by the `zeros_like()` function in TensorFlow.
@@ -28,13 +23,6 @@ public ZerosLike(PointsToSetVariable source) {
2823
super(source);
2924
}
3025

31-
@Override
32-
protected Set<List<Dimension<?>>> getShapesFromShapeArgument(
33-
PropagationCallGraphBuilder builder, Iterable<InstanceKey> pointsToSet) {
34-
throw new UnsupportedOperationException(
35-
"Shapes are derived from the `input` argument and cannot be provided explicitly.");
36-
}
37-
3826
@Override
3927
protected int getShapeParameterPosition() {
4028
return SHAPE_PARAMETER_POSITION;
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
# TensorFlow sees an empty list and gives up on the inner dimensions.
11+
t = tf.ragged.constant([], None, 1)
12+
assert t.shape == (0, None)
13+
assert t.dtype == tf.float32
14+
# Output: (0, None) -> It lost the inner dimension info!
15+
16+
f(t)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
# You tell TensorFlow: "Even though it's empty, if there WERE data,
11+
# it would be shape (3,)."
12+
t = tf.ragged.constant([], None, 1, (3,))
13+
assert t.shape == (0, None, 3)
14+
assert t.dtype == tf.float32
15+
# Output: (0, None, 3) -> The inner structure is preserved.
16+
17+
f(t)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
data = [[[1, 2], [3, 4]], [[5, 6]]]
11+
12+
# Success: The data matches the inner shape (2,)
13+
t1 = tf.ragged.constant(data, None, 1, (2,))
14+
assert t1.shape == (2, None, 2)
15+
assert t1.dtype == tf.int32
16+
# Output: (2, None, 2)
17+
18+
# Failure: You claim inner shape is (3,), but data is length 2
19+
# t2 = tf.ragged.constant(data, ragged_rank=1, inner_shape=(3,))
20+
# Raises ValueError: Inner shape mismatch
21+
22+
f(t1)

0 commit comments

Comments
 (0)