Skip to content

Commit bea52a4

Browse files
committed
Progress on inner shapes.
1 parent 174ebf7 commit bea52a4

File tree

4 files changed

+79
-0
lines changed

4 files changed

+79
-0
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
120120
private static final TensorType TENSOR_2_NONE_2_INT32 =
121121
new TensorType(INT_32, asList(new NumericDim(2), null, new NumericDim(2)));
122122

123+
private static final TensorType TENSOR_2_NONE_2_3_INT32 =
124+
new TensorType(INT_32, asList(new NumericDim(2), null, new NumericDim(2), new NumericDim(3)));
125+
126+
private static final TensorType TENSOR_2_NONE_2_2_INT32 =
127+
new TensorType(INT_32, asList(new NumericDim(2), null, new NumericDim(2), new NumericDim(2)));
128+
123129
@SuppressWarnings("unused")
124130
private static final TensorType TENSOR_2_NONE_NONE_NONE_INT32 =
125131
new TensorType(INT_32, asList(new NumericDim(2), null));
@@ -4697,6 +4703,22 @@ public void testRaggedConstant16() throws ClassHierarchyException, CancelExcepti
46974703
test("tf2_test_ragged_constant16.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_INT32)));
46984704
}
46994705

4706+
/**
4707+
* Test non-uniform inner dimensions.
4708+
*
4709+
* <p>FIXME: Should not throw an {@link AssertionError}.
4710+
*/
4711+
@Test(expected = AssertionError.class)
4712+
public void testRaggedConstant17() throws ClassHierarchyException, CancelException, IOException {
4713+
test("tf2_test_ragged_constant17.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_3_INT32)));
4714+
}
4715+
4716+
/** This one works because the inner dimensions are uniform. */
4717+
@Test
4718+
public void testRaggedConstant18() throws ClassHierarchyException, CancelException, IOException {
4719+
test("tf2_test_ragged_constant18.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_NONE_2_2_INT32)));
4720+
}
4721+
47004722
private void test(
47014723
String filename,
47024724
String functionName,

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/RaggedConstant.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,19 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
380380
continue;
381381
}
382382

383+
if (!innerShapeArguments.isEmpty()) {
384+
for (List<Dimension<?>> innerShape : innerShapeArguments) {
385+
List<Dimension<?>> newShape = new ArrayList<>(shape.size());
386+
387+
newShape.addAll(shape);
388+
newShape.addAll(innerShape);
389+
390+
ret.add(newShape);
391+
}
392+
393+
continue;
394+
}
395+
383396
ret.add(shape);
384397
}
385398
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
# A list of groups.
11+
# Group 1 has one 2x3 matrix.
12+
# Group 2 has two 2x3 matrices.
13+
data = [
14+
[[[1, 1, 1], [2, 2, 2]]], # Group 1
15+
[[[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]], # Group 2
16+
]
17+
18+
# We set ragged_rank=1.
19+
# This means the OUTER list (the groups) varies in length.
20+
# But everything INSIDE a group is a fixed uniform block.
21+
t = tf.ragged.constant(data, None, 1)
22+
assert t.shape == (2, None, 2, 3)
23+
assert t.dtype == tf.int32
24+
25+
f(t)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient.
2+
3+
import tensorflow as tf
4+
5+
6+
def f(a):
7+
pass
8+
9+
10+
# Two groups. Group 1 has one matrix. Group 2 has two matrices.
11+
data = [[[[1, 1], [2, 2]]], [[[3, 3], [4, 4]], [[5, 5], [6, 6]]]] # Group 1 # Group 2
12+
13+
# We set ragged_rank=1 (The groups are ragged).
14+
# We set inner_shape=(2, 2) (The things inside are 2x2 matrices).
15+
t = tf.ragged.constant(data, None, 1, (2, 2))
16+
assert t.shape == (2, None, 2, 2)
17+
assert t.dtype == tf.int32
18+
19+
f(t)

0 commit comments

Comments
 (0)