Skip to content

Commit b08fe50

Browse files
committed
Progress.
1 parent dbe3c78 commit b08fe50

File tree

4 files changed

+51
-4
lines changed

4 files changed

+51
-4
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4692,6 +4692,11 @@ public void testRaggedConstant15() throws ClassHierarchyException, CancelExcepti
46924692
test("tf2_test_ragged_constant15.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_INT32)));
46934693
}
46944694

4695+
@Test
4696+
public void testRaggedConstant16() throws ClassHierarchyException, CancelException, IOException {
4697+
test("tf2_test_ragged_constant16.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_INT32)));
4698+
}
4699+
46954700
private void test(
46964701
String filename,
46974702
String functionName,

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/RaggedConstant.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.util.ArrayList;
3333
import java.util.EnumSet;
3434
import java.util.List;
35+
import java.util.Objects;
3536
import java.util.Set;
3637
import java.util.logging.Logger;
3738
import java.util.stream.StreamSupport;
@@ -308,8 +309,15 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
308309
int K = maxDepth;
309310
LOGGER.fine("Tensor rank: " + K);
310311

311-
Set<Long> rankArguments = this.getPossibleRaggedRankArguments(builder);
312-
Set<List<Dimension<?>>> innerShapeArguments = this.getPossibleInnerShapeArguments(builder);
312+
Set<Long> rankArguments =
313+
this.getPossibleRaggedRankArguments(builder).stream()
314+
.filter(Objects::nonNull)
315+
.collect(toSet());
316+
317+
Set<List<Dimension<?>>> innerShapeArguments =
318+
this.getPossibleInnerShapeArguments(builder).stream()
319+
.filter(Objects::nonNull)
320+
.collect(toSet());
313321

314322
if (rankArguments.isEmpty())
315323
// Default ragged rank.

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,15 @@ protected Set<Integer> getNumberOfPossiblePositionalArguments(
738738
return ret;
739739
}
740740

741+
/**
742+
* Returns the possible long arguments for the given value number. If the argument is `None`, then
743+
* a null value will be contained within the returned set.
744+
*
745+
* @param builder The {@link PropagationCallGraphBuilder} used for the analysis.
746+
* @param valueNumber The value number of the argument.
747+
* @return A set of possible long arguments. If the argument is `None`, then a null value will be
748+
* contained within the returned set.
749+
*/
741750
protected Set<Long> getPossibleLongArguments(
742751
PropagationCallGraphBuilder builder, int valueNumber) {
743752
Set<Long> ret = HashSetFactory.make();
@@ -753,14 +762,17 @@ protected Set<Long> getPossibleLongArguments(
753762
"Empty points-to set in source: " + this.getSource() + ".");
754763

755764
for (InstanceKey instanceKey : pointsToSet)
756-
if (instanceKey instanceof com.ibm.wala.ipa.callgraph.propagation.ConstantKey) {
765+
if (instanceKey instanceof ConstantKey) {
757766
ConstantKey<?> constantKey = (ConstantKey<?>) instanceKey;
758767
Object constantKeyValue = constantKey.getValue();
759768

760769
if (constantKeyValue instanceof Long) {
761770
Long value = (Long) constantKeyValue;
762771
ret.add(value);
763-
} else
772+
} else if (constantKeyValue == null)
773+
// The argument may be `None`.
774+
ret.add(null);
775+
else
764776
throw new IllegalStateException(
765777
"Expected a long, but found: " + constantKeyValue + ".");
766778
} else
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
data = [[[1, 2], [3, 4]], [[5, 6]]]
11+
12+
# Success: The data matches the inner shape (2,)
13+
t1 = tf.ragged.constant(data, None, None, (2,))
14+
assert t1.shape == (2, None, 2)
15+
assert t1.dtype == tf.int32
16+
# Output: (2, None, 2)
17+
18+
# Failure: You claim inner shape is (3,), but data is length 2
19+
# t2 = tf.ragged.constant(data, ragged_rank=1, inner_shape=(3,))
20+
# Raises ValueError: Inner shape mismatch
21+
22+
f(t1)

0 commit comments

Comments
 (0)