Skip to content

Commit 2344618

Browse files
committed
Handle multiple imports.
Update tests.
1 parent 6dfa1a8 commit 2344618

File tree

2 files changed

+44
-33
lines changed

2 files changed

+44
-33
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1768,7 +1768,12 @@ public void testAdd105()
17681768
@Test
17691769
public void testAdd106()
17701770
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1771-
test("tf2_test_add106.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT)));
1771+
test(
1772+
"tf2_test_add106.py",
1773+
"add",
1774+
2,
1775+
2,
1776+
Map.of(2, Set.of(TENSOR_4_FLOAT32), 3, Set.of(TENSOR_4_FLOAT32)));
17721777
}
17731778

17741779
@Test

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static com.ibm.wala.ipa.callgraph.propagation.cfa.CallStringContextSelector.CALL_STRING;
1212
import static java.util.Arrays.asList;
1313
import static java.util.Collections.emptyList;
14+
import static java.util.stream.Collectors.toSet;
1415

1516
import com.ibm.wala.cast.ipa.callgraph.AstPointerKeyFactory;
1617
import com.ibm.wala.cast.python.ml.types.TensorFlowTypes;
@@ -43,7 +44,6 @@
4344
import java.util.EnumSet;
4445
import java.util.List;
4546
import java.util.Map.Entry;
46-
import java.util.Optional;
4747
import java.util.Set;
4848
import java.util.logging.Logger;
4949

@@ -353,11 +353,11 @@ protected EnumSet<DType> getDTypesFromDTypeArgument(
353353

354354
if (typeReference.equals(TensorFlowTypes.D_TYPE)) {
355355
// we have a dtype.
356-
// let's see if it's float32.
356+
// let's see if it's a dtype.
357357
Set<CGNode> importNodes = builder.getCallGraph().getNodes(IMPORT);
358358

359-
// find the import node from this file.
360-
Optional<CGNode> importNode =
359+
// find the import nodes from this file.
360+
Set<CGNode> importNodesOfInterest =
361361
importNodes.stream()
362362
.filter(
363363
in -> {
@@ -384,38 +384,44 @@ protected EnumSet<DType> getDTypesFromDTypeArgument(
384384

385385
return method.equals(nodeCS.getMethods()[0]);
386386
})
387-
.findFirst();
387+
.collect(toSet());
388388

389-
InstanceKey tensorFlowIK =
390-
pointerAnalysis
391-
.getHeapModel()
392-
.getInstanceKeyForAllocation(
393-
importNode.get(), NewSiteReference.make(0, TENSORFLOW));
389+
if (importNodesOfInterest.isEmpty())
390+
throw new IllegalStateException("No import nodes found for source: " + source + ".");
394391

395-
// Check dtype literals.
396392
boolean found = false;
397393

398-
for (Entry<FieldReference, DType> entry : FIELD_REFERENCE_TO_DTYPE.entrySet()) {
399-
FieldReference fieldRef = entry.getKey();
400-
DType dtype = entry.getValue();
401-
IField field = builder.getClassHierarchy().resolveField(fieldRef);
402-
403-
PointerKey pk =
404-
pointerAnalysis.getHeapModel().getPointerKeyForInstanceField(tensorFlowIK, field);
405-
406-
for (InstanceKey ik : pointerAnalysis.getPointsToSet(pk))
407-
if (ik.equals(instanceKey)) {
408-
ret.add(dtype);
409-
LOGGER.info(
410-
"Found dtype: "
411-
+ dtype
412-
+ " for source: "
413-
+ source
414-
+ " from dType: "
415-
+ instanceKey
416-
+ ".");
417-
found = true;
418-
}
394+
for (CGNode importNode : importNodesOfInterest) {
395+
LOGGER.fine("Found import node of interest: " + importNode + ".");
396+
397+
InstanceKey tensorFlowIK =
398+
pointerAnalysis
399+
.getHeapModel()
400+
.getInstanceKeyForAllocation(importNode, NewSiteReference.make(0, TENSORFLOW));
401+
402+
// Check dtype literals.
403+
for (Entry<FieldReference, DType> entry : FIELD_REFERENCE_TO_DTYPE.entrySet()) {
404+
FieldReference fieldRef = entry.getKey();
405+
DType dtype = entry.getValue();
406+
IField field = builder.getClassHierarchy().resolveField(fieldRef);
407+
408+
PointerKey pk =
409+
pointerAnalysis.getHeapModel().getPointerKeyForInstanceField(tensorFlowIK, field);
410+
411+
for (InstanceKey ik : pointerAnalysis.getPointsToSet(pk))
412+
if (ik.equals(instanceKey)) {
413+
ret.add(dtype);
414+
LOGGER.info(
415+
"Found dtype: "
416+
+ dtype
417+
+ " for source: "
418+
+ source
419+
+ " from dType: "
420+
+ instanceKey
421+
+ ".");
422+
found = true;
423+
}
424+
}
419425
}
420426

421427
if (!found) throw new IllegalStateException("Unknown dtype: " + instanceKey + ".");

0 commit comments

Comments
 (0)