Skip to content

Commit 6f236c1

Browse files
committed
Progress on tf.ragged.constant() shape inference to handle ragged dimensions.
1 parent 81bac7d commit 6f236c1

File tree

1 file changed

+153
-2
lines changed

1 file changed

+153
-2
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/RaggedConstant.java

Lines changed: 153 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
11
package com.ibm.wala.cast.python.ml.client;
22

33
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.TYPE_REFERENCE_TO_SIGNATURE;
4+
import static com.ibm.wala.cast.python.types.PythonTypes.Root;
5+
import static com.ibm.wala.cast.python.types.PythonTypes.list;
6+
import static com.ibm.wala.cast.python.types.PythonTypes.tuple;
7+
import static com.ibm.wala.cast.python.util.Util.getAllocationSiteInNode;
48
import static com.ibm.wala.cast.python.util.Util.getFunction;
9+
import static com.ibm.wala.core.util.strings.Atom.findOrCreateAsciiAtom;
10+
import static java.util.logging.Logger.getLogger;
511

12+
import com.ibm.wala.cast.ipa.callgraph.AstPointerKeyFactory;
613
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
14+
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
15+
import com.ibm.wala.classLoader.IField;
16+
import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode;
17+
import com.ibm.wala.ipa.callgraph.propagation.ConstantKey;
18+
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
19+
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
20+
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
721
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
822
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
23+
import com.ibm.wala.types.FieldReference;
924
import com.ibm.wala.types.TypeReference;
25+
import com.ibm.wala.util.collections.HashSetFactory;
26+
import com.ibm.wala.util.intset.OrdinalSet;
27+
import java.util.ArrayList;
1028
import java.util.List;
29+
import java.util.Optional;
1130
import java.util.Set;
31+
import java.util.logging.Logger;
1232

1333
/**
1434
* A representation of the `tf.ragged.constant()` API in TensorFlow.
@@ -19,6 +39,8 @@
1939
*/
2040
public class RaggedConstant extends ZerosLike {
2141

42+
private static final Logger LOGGER = getLogger(RaggedConstant.class.getName());
43+
2244
protected enum Parameters {
2345
PYLIST,
2446
DTYPE,
@@ -38,9 +60,138 @@ protected String getSignature() {
3860
return TYPE_REFERENCE_TO_SIGNATURE.get(function);
3961
}
4062

63+
private static Set<Integer> getPossibleListLengths(
64+
PropagationCallGraphBuilder builder, OrdinalSet<InstanceKey> valuePointsToSet) {
65+
Set<Integer> ret = HashSetFactory.make();
66+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
67+
68+
for (InstanceKey valueIK : valuePointsToSet) {
69+
AllocationSiteInNode asin = getAllocationSiteInNode(valueIK);
70+
TypeReference reference = asin.getConcreteType().getReference();
71+
72+
// A `list` or `tuple`.
73+
if (reference.equals(list) || reference.equals(tuple)) {
74+
OrdinalSet<InstanceKey> objectCatalogPointsToSet =
75+
pointerAnalysis.getPointsToSet(
76+
((AstPointerKeyFactory) builder.getPointerKeyFactory())
77+
.getPointerKeyForObjectCatalog(asin));
78+
79+
ret.add(objectCatalogPointsToSet.size());
80+
} else
81+
throw new IllegalArgumentException(
82+
"Expected a list or tuple, but found: " + reference + ".");
83+
}
84+
85+
return ret;
86+
}
87+
88+
private static Set<Integer> getMaximumDepthOfScalars(
89+
PropagationCallGraphBuilder builder, OrdinalSet<InstanceKey> valuePointsToSet) {
90+
Set<Integer> ret = HashSetFactory.make();
91+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
92+
93+
for (InstanceKey valueIK : valuePointsToSet) {
94+
int maxDepth = -1;
95+
96+
if (valueIK instanceof ConstantKey) maxDepth = Math.max(maxDepth, 0); // Scalar value.
97+
else {
98+
AllocationSiteInNode asin = getAllocationSiteInNode(valueIK);
99+
TypeReference reference = asin.getConcreteType().getReference();
100+
101+
// A nested `list`, `tuple`, or `np.ndarray`.
102+
if (reference.equals(list) || reference.equals(tuple)) {
103+
OrdinalSet<InstanceKey> objectCatalogPointsToSet =
104+
pointerAnalysis.getPointsToSet(
105+
((AstPointerKeyFactory) builder.getPointerKeyFactory())
106+
.getPointerKeyForObjectCatalog(asin));
107+
108+
for (InstanceKey catalogIK : objectCatalogPointsToSet) {
109+
ConstantKey<?> constantKey = (ConstantKey<?>) catalogIK;
110+
Object constantKeyValue = constantKey.getValue();
111+
112+
Integer fieldIndex = (Integer) constantKeyValue;
113+
114+
FieldReference subscript =
115+
FieldReference.findOrCreate(
116+
Root, findOrCreateAsciiAtom(fieldIndex.toString()), Root);
117+
118+
IField f = builder.getClassHierarchy().resolveField(subscript);
119+
120+
PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f);
121+
122+
OrdinalSet<InstanceKey> instanceFieldPointsToSet =
123+
pointerAnalysis.getPointsToSet(pointerKeyForInstanceField);
124+
125+
Set<Integer> possibleDepthsOfField =
126+
getMaximumDepthOfScalars(builder, instanceFieldPointsToSet);
127+
128+
for (int depthOfField : possibleDepthsOfField)
129+
maxDepth = Math.max(maxDepth, 1 + depthOfField);
130+
}
131+
}
132+
}
133+
134+
ret.add(maxDepth);
135+
}
136+
137+
return ret;
138+
}
139+
41140
@Override
42-
protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder builder) {
141+
protected Set<List<Dimension<?>>> getShapesOfValue(
142+
PropagationCallGraphBuilder builder, OrdinalSet<InstanceKey> valuePointsToSet) {
143+
// Returns a potentially ragged tensor with rank K and the specified `ragged_rank`, containing
144+
// the values from `pylist`.
145+
146+
// All scalar values in `pylist` must have the same nesting depth K, and the returned
147+
// `RaggedTensor` will have rank K. If `pylist` contains no scalar values, then K is one greater
148+
// than the maximum depth of empty lists in `pylist`.
149+
150+
// Step 1: Calculate K, the maximum depth of scalar values in `pylist`.
151+
152+
if (valuePointsToSet == null || valuePointsToSet.isEmpty())
153+
throw new IllegalArgumentException(
154+
"Empty points-to set for value in source: " + this.getSource() + ".");
155+
156+
Set<List<Dimension<?>>> ret = HashSetFactory.make();
157+
158+
Set<Integer> maxDepthOfScalars = getMaximumDepthOfScalars(builder, valuePointsToSet);
159+
LOGGER.fine("Maximum depth of scalars in pylist: " + maxDepthOfScalars);
160+
161+
// Step 2: Determine Ragged Rank (R).
162+
for (int K : maxDepthOfScalars) {
163+
Optional<Integer> raggedRank = this.getRaggedRankArgumentValue(builder);
164+
int R = raggedRank.orElse(K - 1);
165+
LOGGER.fine("Ragged rank: " + R);
166+
167+
// Step 3: Construct shape with rank K and ragged rank R.
168+
169+
// Get the length of the outer list.
170+
Set<Integer> possibleOuterListLengths = getPossibleListLengths(builder, valuePointsToSet);
171+
172+
for (int outerListLength : possibleOuterListLengths) {
173+
List<Dimension<?>> shape = new ArrayList<>();
174+
shape.add(new NumericDim(outerListLength));
175+
176+
// The first R dimensions are ragged.
177+
for (int i = 0; i < R; i++) shape.add(null); // Unknown size for ragged dimensions.
178+
179+
/*
180+
// The remaining K - R dimensions are dense.
181+
for (int i = R; i < K; i++) {
182+
shape.add(new NumericDim(-1)); // Unknown size for dense dimensions.
183+
}
184+
*/
185+
186+
ret.add(shape);
187+
}
188+
}
189+
190+
return ret;
191+
}
192+
193+
private Optional<Integer> getRaggedRankArgumentValue(PropagationCallGraphBuilder builder) {
43194
// TODO Auto-generated method stub
44-
return super.getDefaultShapes(builder);
195+
return Optional.empty();
45196
}
46197
}

0 commit comments

Comments
 (0)