2121import com .ibm .wala .cast .python .ipa .callgraph .PythonSSAPropagationCallGraphBuilder ;
2222import com .ibm .wala .cast .python .ml .analysis .TensorTypeAnalysis ;
2323import com .ibm .wala .cast .python .ml .analysis .TensorVariable ;
24+ import com .ibm .wala .cast .python .ml .client .NonBroadcastableShapesException ;
2425import com .ibm .wala .cast .python .ml .client .PythonTensorAnalysisEngine ;
2526import com .ibm .wala .cast .python .ml .types .TensorType ;
2627import com .ibm .wala .cast .python .ml .types .TensorType .Dimension ;
@@ -120,6 +121,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
120121 private static final TensorType TENSOR_2_NONE_2_INT32 =
121122 new TensorType (INT_32 , asList (new NumericDim (2 ), null , new NumericDim (2 )));
122123
124+ private static final TensorType TENSOR_2_NONE_2_3_INT32 =
125+ new TensorType (INT_32 , asList (new NumericDim (2 ), null , new NumericDim (2 ), new NumericDim (3 )));
126+
127+ private static final TensorType TENSOR_2_NONE_2_2_INT32 =
128+ new TensorType (INT_32 , asList (new NumericDim (2 ), null , new NumericDim (2 ), new NumericDim (2 )));
129+
123130 @ SuppressWarnings ("unused" )
124131 private static final TensorType TENSOR_2_NONE_NONE_NONE_INT32 =
125132 new TensorType (INT_32 , asList (new NumericDim (2 ), null ));
@@ -186,6 +193,11 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
186193 FLOAT_32 ,
187194 asList (new NumericDim (3 ), new NumericDim (2 ), new NumericDim (2 ), new NumericDim (3 )));
188195
196+ private static final TensorType TENSOR_2_2_2_3_FLOAT32 =
197+ new TensorType (
198+ FLOAT_32 ,
199+ asList (new NumericDim (2 ), new NumericDim (2 ), new NumericDim (2 ), new NumericDim (3 )));
200+
189201 private static final TensorType TENSOR_20_28_28_FLOAT32 =
190202 new TensorType (FLOAT_32 , asList (new NumericDim (20 ), new NumericDim (28 ), new NumericDim (28 )));
191203
@@ -227,7 +239,11 @@ public void testValueIndex()
227239 Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
228240 }
229241
230- @ Test
242+ /**
243+ * Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
244+ * is fixed.
245+ */
246+ @ Test (expected = IllegalArgumentException .class )
231247 public void testValueIndex2 ()
232248 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
233249 test (
@@ -238,7 +254,11 @@ public void testValueIndex2()
238254 Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
239255 }
240256
241- @ Test
257+ /**
258+ * Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
259+ * is fixed.
260+ */
261+ @ Test (expected = IllegalArgumentException .class )
242262 public void testValueIndex3 ()
243263 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
244264 test (
@@ -249,7 +269,11 @@ public void testValueIndex3()
249269 Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
250270 }
251271
252- @ Test
272+ /**
273+ * Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
274+ * is fixed.
275+ */
276+ @ Test (expected = IllegalArgumentException .class )
253277 public void testValueIndex4 ()
254278 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
255279 test (
@@ -260,7 +284,11 @@ public void testValueIndex4()
260284 Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
261285 }
262286
263- @ Test
287+ /**
288+ * Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
289+ * is fixed.
290+ */
291+ @ Test (expected = IllegalArgumentException .class )
264292 public void testFunction ()
265293 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
266294 test ("tf2_test_function.py" , "func2" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
@@ -1617,25 +1645,45 @@ public void testAdd58()
16171645 @ Test
16181646 public void testAdd59 ()
16191647 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1620- test ("tf2_test_add59.py" , "add" , 2 , 2 , Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
1648+ test (
1649+ "tf2_test_add59.py" ,
1650+ "add" ,
1651+ 2 ,
1652+ 2 ,
1653+ Map .of (2 , Set .of (TENSOR_2_INT32 ), 3 , Set .of (TENSOR_2_INT32 )));
16211654 }
16221655
16231656 @ Test
16241657 public void testAdd60 ()
16251658 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1626- test ("tf2_test_add60.py" , "add" , 2 , 2 , Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
1659+ test (
1660+ "tf2_test_add60.py" ,
1661+ "add" ,
1662+ 2 ,
1663+ 2 ,
1664+ Map .of (2 , Set .of (TENSOR_2_INT32 ), 3 , Set .of (TENSOR_2_INT32 )));
16271665 }
16281666
16291667 @ Test
16301668 public void testAdd61 ()
16311669 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1632- test ("tf2_test_add61.py" , "add" , 2 , 2 , Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
1670+ test (
1671+ "tf2_test_add61.py" ,
1672+ "add" ,
1673+ 2 ,
1674+ 2 ,
1675+ Map .of (2 , Set .of (TENSOR_2_INT32 ), 3 , Set .of (TENSOR_2_INT32 )));
16331676 }
16341677
16351678 @ Test
16361679 public void testAdd62 ()
16371680 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1638- test ("tf2_test_add62.py" , "add" , 2 , 2 , Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
1681+ test (
1682+ "tf2_test_add62.py" ,
1683+ "add" ,
1684+ 2 ,
1685+ 2 ,
1686+ Map .of (2 , Set .of (TENSOR_2_INT32 ), 3 , Set .of (TENSOR_2_INT32 )));
16391687 }
16401688
16411689 @ Test
@@ -2151,16 +2199,24 @@ public void testReduceMean3()
21512199 @ Test
21522200 public void testGradient ()
21532201 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
2154- test ("tf2_test_gradient.py" , "f" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
2202+ test ("tf2_test_gradient.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_NONE_FLOAT32 )));
21552203 }
21562204
2157- @ Test
2205+ /**
2206+ * Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
2207+ * is fixed.
2208+ */
2209+ @ Test (expected = IllegalArgumentException .class )
21582210 public void testGradient2 ()
21592211 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
21602212 test ("tf2_test_gradient2.py" , "f" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
21612213 }
21622214
2163- @ Test
2215+ /**
2216+ * Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
2217+ * is fixed.
2218+ */
2219+ @ Test (expected = IllegalArgumentException .class )
21642220 public void testMultiply ()
21652221 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
21662222 test ("tf2_test_multiply.py" , "f" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
@@ -2169,7 +2225,48 @@ public void testMultiply()
21692225 @ Test
21702226 public void testMultiply2 ()
21712227 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
2172- test ("tf2_test_multiply2.py" , "f" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
2228+ test ("tf2_test_multiply2.py" , "f" , 1 , 1 , Map .of (2 , Set .of (SCALAR_TENSOR_OF_INT32 )));
2229+ }
2230+
2231+ @ Test
2232+ public void testMultiply3 ()
2233+ throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
2234+ test ("tf2_test_multiply3.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_3_FLOAT32 )));
2235+ }
2236+
2237+ @ Test
2238+ public void testMultiply4 ()
2239+ throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
2240+ test ("tf2_test_multiply4.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_3_FLOAT32 )));
2241+ }
2242+
2243+ @ Test
2244+ public void testMultiply5 ()
2245+ throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
2246+ test ("tf2_test_multiply5.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_2_2_3_FLOAT32 )));
2247+ }
2248+
2249+ /**
2250+ * This is an invalid case since the inputs have different ranks.
2251+ *
2252+ * <p>For now, we are throwing an exception. But, this is invalid code.
2253+ *
2254+ * <p>TODO: We'll need to come up with a suitable way to handle this in the future.
2255+ */
2256+ @ Test (expected = NonBroadcastableShapesException .class )
2257+ public void testMultiply6 ()
2258+ throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
2259+ test ("tf2_test_multiply6.py" , "f" , 1 , 1 );
2260+ }
2261+
2262+ /**
2263+ * Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340
2264+ * is fixed.
2265+ */
2266+ @ Test (expected = IllegalArgumentException .class )
2267+ public void testMultiply7 ()
2268+ throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
2269+ test ("tf2_test_multiply7.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_3_FLOAT32 )));
21732270 }
21742271
21752272 @ Test
@@ -4697,6 +4794,22 @@ public void testRaggedConstant16() throws ClassHierarchyException, CancelExcepti
46974794 test ("tf2_test_ragged_constant16.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_NONE_2_INT32 )));
46984795 }
46994796
4797+ /**
4798+ * Test non-uniform inner dimensions.
4799+ *
4800+ * <p>TODO: Remove expected assertion error once https://github.com/wala/ML/issues/350 is fixed.
4801+ */
4802+ @ Test (expected = AssertionError .class )
4803+ public void testRaggedConstant17 () throws ClassHierarchyException , CancelException , IOException {
4804+ test ("tf2_test_ragged_constant17.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_NONE_2_3_INT32 )));
4805+ }
4806+
4807+ /** This one works because the inner dimensions are uniform. */
4808+ @ Test
4809+ public void testRaggedConstant18 () throws ClassHierarchyException , CancelException , IOException {
4810+ test ("tf2_test_ragged_constant18.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_NONE_2_2_INT32 )));
4811+ }
4812+
47004813 private void test (
47014814 String filename ,
47024815 String functionName ,
0 commit comments