Skip to content

Commit 669a775

Browse files
committed
If tensor shapes are coming from a variable, make sure we know about it.
1 parent 9913316 commit 669a775

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4128,13 +4128,19 @@ public void testConvertToTensor4()
41284128
test("tf2_test_convert_to_tensor4.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32)));
41294129
}
41304130

4131-
@Test
4131+
/**
4132+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340 is fixed.
4133+
*/
4134+
@Test(expected = IllegalArgumentException.class)
41324135
public void testConvertToTensor5()
41334136
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
41344137
test("tf2_test_convert_to_tensor5.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32)));
41354138
}
41364139

4137-
@Test
4140+
/**
4141+
* Should not throw an {@link IllegalArgumentException} once https://github.com/wala/ML/issues/340 is fixed.
4142+
*/
4143+
@Test(expected = IllegalArgumentException.class)
41384144
public void testConvertToTensor6()
41394145
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
41404146
test("tf2_test_convert_to_tensor6.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32)));

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ public Set<TensorType> getTensorTypes(PropagationCallGraphBuilder builder) {
8888
*/
8989
protected Set<List<Dimension<?>>> getShapesFromShapeArgument(
9090
PropagationCallGraphBuilder builder, Iterable<InstanceKey> pointsToSet) {
91+
if (pointsToSet == null || !pointsToSet.iterator().hasNext())
92+
throw new IllegalArgumentException(
93+
"Empty points-to set for shape argument in source: " + source + ".");
94+
9195
Set<List<Dimension<?>>> ret = HashSetFactory.make();
9296
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
9397

@@ -266,6 +270,11 @@ protected Set<List<Dimension<?>>> getShapes(
266270
PointerKey valuePK =
267271
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), valueNumber);
268272
OrdinalSet<InstanceKey> valuePointsToSet = pointerAnalysis.getPointsToSet(valuePK);
273+
274+
if (valuePointsToSet.isEmpty())
275+
throw new IllegalArgumentException(
276+
"Empty points-to set for value number: " + valueNumber + " in: " + this.getNode() + ".");
277+
269278
return getShapesOfValue(builder, valuePointsToSet);
270279
}
271280

@@ -278,6 +287,10 @@ protected Set<List<Dimension<?>>> getShapes(
278287
*/
279288
private Set<List<Dimension<?>>> getShapesOfValue(
280289
PropagationCallGraphBuilder builder, OrdinalSet<InstanceKey> valuePointsToSet) {
290+
if (valuePointsToSet == null || valuePointsToSet.isEmpty())
291+
throw new IllegalArgumentException(
292+
"Empty points-to set for value in source: " + source + ".");
293+
281294
Set<List<Dimension<?>>> ret = HashSetFactory.make();
282295
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
283296

@@ -349,6 +362,10 @@ else if (valueIK instanceof AllocationSiteInNode) {
349362
*/
350363
protected EnumSet<DType> getDTypesFromDTypeArgument(
351364
PropagationCallGraphBuilder builder, Iterable<InstanceKey> pointsToSet) {
365+
if (pointsToSet == null || !pointsToSet.iterator().hasNext())
366+
throw new IllegalArgumentException(
367+
"Empty points-to set for dtype argument in source: " + source + ".");
368+
352369
EnumSet<DType> ret = EnumSet.noneOf(DType.class);
353370
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
354371

@@ -509,6 +526,11 @@ protected EnumSet<DType> getDTypes(PropagationCallGraphBuilder builder, int valu
509526
PointerKey valuePK =
510527
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), valueNumber);
511528
OrdinalSet<InstanceKey> valuePointsToSet = pointerAnalysis.getPointsToSet(valuePK);
529+
530+
if (valuePointsToSet == null || valuePointsToSet.isEmpty())
531+
throw new IllegalArgumentException(
532+
"Empty points-to set for value number: " + valueNumber + " in: " + this.getNode() + ".");
533+
512534
return getDTypesOfValue(builder, valuePointsToSet);
513535
}
514536

@@ -522,6 +544,10 @@ protected EnumSet<DType> getDTypes(PropagationCallGraphBuilder builder, int valu
522544
*/
523545
private EnumSet<DType> getDTypesOfValue(
524546
PropagationCallGraphBuilder builder, OrdinalSet<InstanceKey> valuePointsToSet) {
547+
if (valuePointsToSet == null || valuePointsToSet.isEmpty())
548+
throw new IllegalArgumentException(
549+
"Empty points-to set for value in source: " + source + ".");
550+
525551
EnumSet<DType> ret = EnumSet.noneOf(DType.class);
526552
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
527553

0 commit comments

Comments
 (0)