@@ -80,12 +80,30 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
8080 private static final TensorType TENSOR_2_3_3_FLOAT32 =
8181 new TensorType (FLOAT_32 , asList (new NumericDim (2 ), new NumericDim (3 ), new NumericDim (3 )));
8282
83+ private static final TensorType TENSOR_2_3_3_INT32 =
84+ new TensorType (INT_32 , asList (new NumericDim (2 ), new NumericDim (3 ), new NumericDim (3 )));
85+
8386 private static final TensorType TENSOR_2_3_4_FLOAT32 =
8487 new TensorType (FLOAT_32 , asList (new NumericDim (2 ), new NumericDim (3 ), new NumericDim (4 )));
8588
89+ private static final TensorType TENSOR_2_3_4_INT32 =
90+ new TensorType (INT_32 , asList (new NumericDim (2 ), new NumericDim (3 ), new NumericDim (4 )));
91+
8692 private static final TensorType TENSOR_2_FLOAT32 =
8793 new TensorType (FLOAT_32 , asList (new NumericDim (2 )));
8894
95+ private static final TensorType TENSOR_2_INT32 =
96+ new TensorType (INT_32 , asList (new NumericDim (2 )));
97+
98+ private static final TensorType TENSOR_3_INT32 =
99+ new TensorType (INT_32 , asList (new NumericDim (3 )));
100+
101+ private static final TensorType TENSOR_3_FLOAT32 =
102+ new TensorType (FLOAT_32 , asList (new NumericDim (3 )));
103+
104+ private static final TensorType TENSOR_4_FLOAT32 =
105+ new TensorType (FLOAT_32 , asList (new NumericDim (4 )));
106+
89107 private static final TensorType TENSOR_5_FLOAT32 =
90108 new TensorType (FLOAT_32 , asList (new NumericDim (5 )));
91109
@@ -198,13 +216,13 @@ public void testFunction9()
198216 @ Test
199217 public void testFunction10 ()
200218 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
201- test ("tf2_test_function10.py" , "func" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_3_4_FLOAT32 )));
219+ test ("tf2_test_function10.py" , "func" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_3_4_INT32 )));
202220 }
203221
204222 @ Test
205223 public void testFunction11 ()
206224 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
207- test ("tf2_test_function11.py" , "func" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_3_3_FLOAT32 )));
225+ test ("tf2_test_function11.py" , "func" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_3_3_INT32 )));
208226 }
209227
210228 @ Test
@@ -222,7 +240,7 @@ public void testDecorator2()
222240 @ Test
223241 public void testDecorator3 ()
224242 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
225- test ("tf2_test_decorator3.py" , "returned" , 1 , 1 , Map .of (2 , Set .of (TENSOR_5_INT32 )));
243+ test ("tf2_test_decorator3.py" , "returned" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_FLOAT32 )));
226244 }
227245
228246 @ Test
@@ -907,13 +925,13 @@ public void testAutoencoder4()
907925 @ Test
908926 public void testSigmoid ()
909927 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
910- test ("tf2_test_sigmoid.py" , "f" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
928+ test ("tf2_test_sigmoid.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_4_FLOAT32 )));
911929 }
912930
913931 @ Test
914932 public void testSigmoid2 ()
915933 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
916- test ("tf2_test_sigmoid2.py" , "f" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
934+ test ("tf2_test_sigmoid2.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_4_FLOAT32 )));
917935 }
918936
919937 @ Test
@@ -1066,19 +1084,34 @@ public void testAdd22()
10661084 @ Test
10671085 public void testAdd23 ()
10681086 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1069- test ("tf2_test_add23.py" , "add" , 2 , 2 , Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
1087+ test (
1088+ "tf2_test_add23.py" ,
1089+ "add" ,
1090+ 2 ,
1091+ 2 ,
1092+ Map .of (2 , Set .of (TENSOR_2_INT32 ), 3 , Set .of (TENSOR_2_INT32 )));
10701093 }
10711094
10721095 @ Test
10731096 public void testAdd24 ()
10741097 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1075- test ("tf2_test_add24.py" , "add" , 2 , 2 , Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
1098+ test (
1099+ "tf2_test_add24.py" ,
1100+ "add" ,
1101+ 2 ,
1102+ 2 ,
1103+ Map .of (2 , Set .of (TENSOR_2_INT32 ), 3 , Set .of (TENSOR_2_INT32 )));
10761104 }
10771105
10781106 @ Test
10791107 public void testAdd25 ()
10801108 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1081- test ("tf2_test_add25.py" , "add" , 2 , 2 , Map .of (2 , Set .of (MNIST_INPUT ), 3 , Set .of (MNIST_INPUT )));
1109+ test (
1110+ "tf2_test_add25.py" ,
1111+ "add" ,
1112+ 2 ,
1113+ 2 ,
1114+ Map .of (2 , Set .of (TENSOR_2_INT32 ), 3 , Set .of (TENSOR_2_INT32 )));
10821115 }
10831116
10841117 @ Test
@@ -1695,19 +1728,19 @@ public void testMultiGPUTraining2()
16951728 @ Test
16961729 public void testReduceMean ()
16971730 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1698- test ("tf2_test_reduce_mean.py" , "f" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
1731+ test ("tf2_test_reduce_mean.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_2_FLOAT32 )));
16991732 }
17001733
17011734 @ Test
17021735 public void testReduceMean2 ()
17031736 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1704- test ("tf2_test_reduce_mean.py" , "g" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
1737+ test ("tf2_test_reduce_mean.py" , "g" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_2_FLOAT32 )));
17051738 }
17061739
17071740 @ Test
17081741 public void testReduceMean3 ()
17091742 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1710- test ("tf2_test_reduce_mean.py" , "h" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
1743+ test ("tf2_test_reduce_mean.py" , "h" , 1 , 1 , Map .of (2 , Set .of (TENSOR_2_2_FLOAT32 )));
17111744 }
17121745
17131746 @ Test
@@ -1742,13 +1775,13 @@ public void testSparseSoftmaxCrossEntropyWithLogits()
17421775 "f" ,
17431776 1 ,
17441777 1 ,
1745- Map .of (2 , Set .of (MNIST_INPUT )));
1778+ Map .of (2 , Set .of (TENSOR_3_INT32 )));
17461779 }
17471780
17481781 @ Test
17491782 public void testRelu ()
17501783 throws ClassHierarchyException , IllegalArgumentException , CancelException , IOException {
1751- test ("tf2_test_relu.py" , "f" , 1 , 1 , Map .of (2 , Set .of (MNIST_INPUT )));
1784+ test ("tf2_test_relu.py" , "f" , 1 , 1 , Map .of (2 , Set .of (TENSOR_3_FLOAT32 )));
17521785 }
17531786
17541787 @ Test
0 commit comments