|
1 | 1 | package com.ibm.wala.cast.python.ml.test; |
2 | 2 |
|
| 3 | +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32; |
3 | 4 | import static com.ibm.wala.cast.python.ml.types.TensorType.mnistInput; |
4 | 5 | import static com.ibm.wala.cast.python.util.Util.addPytestEntrypoints; |
5 | 6 | import static java.util.Arrays.asList; |
@@ -58,6 +59,8 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape { |
58 | 59 |
|
59 | 60 | private static final TensorType MNIST_INPUT = mnistInput(); |
60 | 61 |
|
| 62 | + private static final String FLOAT_32 = FLOAT32.name().toLowerCase(); |
| 63 | + |
61 | 64 | @Test |
62 | 65 | public void testValueIndex() |
63 | 66 | throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { |
@@ -880,8 +883,8 @@ public void testAdd7() |
880 | 883 |
|
881 | 884 | List<Dimension<?>> bDimensions = asList(bX, bY); |
882 | 885 |
|
883 | | - TensorType expectedTypeForA = new TensorType("pixel", aDimensions); |
884 | | - TensorType expectedTypeForB = new TensorType("pixel", bDimensions); |
| 886 | + TensorType expectedTypeForA = new TensorType(FLOAT_32, aDimensions); |
| 887 | + TensorType expectedTypeForB = new TensorType(FLOAT_32, bDimensions); |
885 | 888 |
|
886 | 889 | test( |
887 | 890 | "tf2_test_add7.py", |
@@ -1564,8 +1567,8 @@ public void testAdd116() |
1564 | 1567 |
|
1565 | 1568 | List<Dimension<?>> bDimensions = asList(bX, bY); |
1566 | 1569 |
|
1567 | | - TensorType expectedTypeForA = new TensorType("pixel", aDimensions); |
1568 | | - TensorType expectedTypeForB = new TensorType("pixel", bDimensions); |
| 1570 | + TensorType expectedTypeForA = new TensorType(FLOAT_32, aDimensions); |
| 1571 | + TensorType expectedTypeForB = new TensorType(FLOAT_32, bDimensions); |
1569 | 1572 |
|
1570 | 1573 | test( |
1571 | 1574 | "tf2_test_add116.py", |
|
0 commit comments