Skip to content

Commit 44a980c

Browse files
committed
Cleanup.
1 parent 19fc100 commit 44a980c

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/RaggedConstant.java

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package com.ibm.wala.cast.python.ml.client;
22

3+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32;
34
import static com.ibm.wala.cast.python.types.PythonTypes.Root;
45
import static com.ibm.wala.cast.python.types.PythonTypes.list;
56
import static com.ibm.wala.cast.python.types.PythonTypes.tuple;
67
import static com.ibm.wala.cast.python.util.Util.getAllocationSiteInNode;
78
import static com.ibm.wala.core.util.strings.Atom.findOrCreateAsciiAtom;
89
import static java.util.logging.Logger.getLogger;
10+
import static java.util.stream.Collectors.toSet;
911

1012
import com.ibm.wala.cast.ipa.callgraph.AstPointerKeyFactory;
1113
import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType;
@@ -29,6 +31,7 @@
2931
import java.util.Optional;
3032
import java.util.Set;
3133
import java.util.logging.Logger;
34+
import java.util.stream.StreamSupport;
3235

3336
/**
3437
* A representation of the `tf.ragged.constant()` API in TensorFlow.
@@ -79,13 +82,6 @@ private static Set<Integer> getPossibleOuterListLengths(
7982
return ret;
8083
}
8184

82-
private static Set<InstanceKey> containsScalars(
83-
PropagationCallGraphBuilder builder, OrdinalSet<InstanceKey> pts) {
84-
Set<InstanceKey> ret = HashSetFactory.make();
85-
for (InstanceKey ik : pts) if (containsScalars(builder, ik)) ret.add(ik);
86-
return ret;
87-
}
88-
8985
private static boolean containsScalars(PropagationCallGraphBuilder builder, InstanceKey ik) {
9086
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
9187

@@ -231,17 +227,19 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
231227
// than the maximum depth of empty lists in `pylist`.
232228

233229
// Step 1: Calculate K, the maximum depth of scalar values in `pylist`.
234-
235230
if (valuePointsToSet == null || valuePointsToSet.isEmpty())
236231
throw new IllegalArgumentException(
237232
"Empty points-to set for value in source: " + this.getSource() + ".");
238233

239234
Set<List<Dimension<?>>> ret = HashSetFactory.make();
240235

241-
Set<InstanceKey> scalars = containsScalars(builder, valuePointsToSet);
236+
Set<InstanceKey> valuesWithScalars =
237+
StreamSupport.stream(valuePointsToSet.spliterator(), false)
238+
.filter(ik -> containsScalars(builder, ik))
239+
.collect(toSet());
242240

243241
for (InstanceKey valueIK : valuePointsToSet) {
244-
int maxDepth = getMaxDepth(builder, scalars, valueIK);
242+
int maxDepth = getMaximumDepthOfInstance(builder, valuesWithScalars, valueIK);
245243
LOGGER.fine("Maximum depth of `pylist`: " + maxDepth);
246244

247245
// Step 2: Determine Ragged Rank (R).
@@ -279,16 +277,15 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
279277
return ret;
280278
}
281279

282-
private static int getMaxDepth(
283-
PropagationCallGraphBuilder builder, Set<InstanceKey> scalars, InstanceKey valueIK) {
284-
int maxDepth;
285-
286-
if (scalars.contains(valueIK)) maxDepth = getMaximumDepthOfScalars(builder, valueIK);
280+
private static int getMaximumDepthOfInstance(
281+
PropagationCallGraphBuilder builder,
282+
Set<InstanceKey> instancesWithScalars,
283+
InstanceKey instance) {
284+
if (instancesWithScalars.contains(instance)) return getMaximumDepthOfScalars(builder, instance);
287285
else
288286
// If `pylist` contains no scalar values, then K is one greater than the maximum depth of
289287
// empty lists in `pylist`.
290-
maxDepth = 1 + getMaximumDepthOfEmptyList(builder, valueIK);
291-
return maxDepth;
288+
return 1 + getMaximumDepthOfEmptyList(builder, instance);
292289
}
293290

294291
private Optional<Integer> getRaggedRankArgumentValue(PropagationCallGraphBuilder builder) {
@@ -314,9 +311,16 @@ protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
314311
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), valueNumber);
315312
OrdinalSet<InstanceKey> valuePointsToSet = pointerAnalysis.getPointsToSet(valuePK);
316313

317-
if (containsScalars(builder, valuePointsToSet).isEmpty()) {
314+
if (valuePointsToSet == null || valuePointsToSet.isEmpty())
315+
throw new IllegalArgumentException(
316+
"Empty points-to set for value in source: " + this.getSource() + ".");
317+
318+
if (StreamSupport.stream(valuePointsToSet.spliterator(), false)
319+
.filter(ik -> containsScalars(builder, ik))
320+
.count()
321+
== 0) {
318322
LOGGER.fine("No scalars found in `pylist`; defaulting to `tf.float32` dtype.");
319-
return EnumSet.of(DType.FLOAT32);
323+
return EnumSet.of(FLOAT32);
320324
}
321325

322326
return super.getDefaultDTypes(builder);

0 commit comments

Comments
 (0)