Skip to content

Commit 723fd38

Browse files
committed
Progress on tf.ragged.constant.
1 parent 3daa5c5 commit 723fd38

File tree

1 file changed

+143
-37
lines changed

1 file changed

+143
-37
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/RaggedConstant.java

Lines changed: 143 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -77,56 +77,145 @@ private static Set<Integer> getPossibleOuterListLengths(
7777
return ret;
7878
}
7979

80-
private static Set<Integer> getMaximumDepthOfScalars(
81-
PropagationCallGraphBuilder builder, OrdinalSet<InstanceKey> valuePointsToSet) {
82-
Set<Integer> ret = HashSetFactory.make();
80+
private static Set<InstanceKey> containsScalars(
81+
PropagationCallGraphBuilder builder, OrdinalSet<InstanceKey> pts) {
82+
Set<InstanceKey> ret = HashSetFactory.make();
83+
for (InstanceKey ik : pts) if (containsScalars(builder, ik)) ret.add(ik);
84+
return ret;
85+
}
86+
87+
private static boolean containsScalars(PropagationCallGraphBuilder builder, InstanceKey ik) {
8388
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
8489

85-
for (InstanceKey valueIK : valuePointsToSet) {
86-
int maxDepth = -1;
90+
if (ik instanceof ConstantKey) return true; // Scalar value.
91+
else {
92+
AllocationSiteInNode asin = getAllocationSiteInNode(ik);
93+
TypeReference reference = asin.getConcreteType().getReference();
8794

88-
if (valueIK instanceof ConstantKey) maxDepth = Math.max(maxDepth, 0); // Scalar value.
89-
else {
90-
AllocationSiteInNode asin = getAllocationSiteInNode(valueIK);
91-
TypeReference reference = asin.getConcreteType().getReference();
95+
// A nested `list`, `tuple`, or `np.ndarray`.
96+
if (reference.equals(list) || reference.equals(tuple)) {
97+
OrdinalSet<InstanceKey> objectCatalogPointsToSet =
98+
pointerAnalysis.getPointsToSet(
99+
((AstPointerKeyFactory) builder.getPointerKeyFactory())
100+
.getPointerKeyForObjectCatalog(asin));
92101

93-
// A nested `list`, `tuple`, or `np.ndarray`.
94-
if (reference.equals(list) || reference.equals(tuple)) {
95-
OrdinalSet<InstanceKey> objectCatalogPointsToSet =
96-
pointerAnalysis.getPointsToSet(
97-
((AstPointerKeyFactory) builder.getPointerKeyFactory())
98-
.getPointerKeyForObjectCatalog(asin));
102+
for (InstanceKey catalogIK : objectCatalogPointsToSet) {
103+
ConstantKey<?> constantKey = (ConstantKey<?>) catalogIK;
104+
Object constantKeyValue = constantKey.getValue();
99105

100-
for (InstanceKey catalogIK : objectCatalogPointsToSet) {
101-
ConstantKey<?> constantKey = (ConstantKey<?>) catalogIK;
102-
Object constantKeyValue = constantKey.getValue();
106+
Integer fieldIndex = (Integer) constantKeyValue;
103107

104-
Integer fieldIndex = (Integer) constantKeyValue;
108+
FieldReference subscript =
109+
FieldReference.findOrCreate(Root, findOrCreateAsciiAtom(fieldIndex.toString()), Root);
105110

106-
FieldReference subscript =
107-
FieldReference.findOrCreate(
108-
Root, findOrCreateAsciiAtom(fieldIndex.toString()), Root);
111+
IField f = builder.getClassHierarchy().resolveField(subscript);
109112

110-
IField f = builder.getClassHierarchy().resolveField(subscript);
113+
PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f);
111114

112-
PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f);
115+
OrdinalSet<InstanceKey> instanceFieldPointsToSet =
116+
pointerAnalysis.getPointsToSet(pointerKeyForInstanceField);
113117

114-
OrdinalSet<InstanceKey> instanceFieldPointsToSet =
115-
pointerAnalysis.getPointsToSet(pointerKeyForInstanceField);
118+
for (InstanceKey fieldIK : instanceFieldPointsToSet)
119+
if (containsScalars(builder, fieldIK)) return true;
120+
}
121+
} else
122+
throw new IllegalArgumentException(
123+
"Expected a list or tuple, but found: " + reference + ".");
124+
}
116125

117-
Set<Integer> possibleDepthsOfField =
118-
getMaximumDepthOfScalars(builder, instanceFieldPointsToSet);
126+
return false;
127+
}
119128

120-
for (int depthOfField : possibleDepthsOfField)
121-
maxDepth = Math.max(maxDepth, 1 + depthOfField);
122-
}
129+
private static int getMaximumDepthOfEmptyList(
130+
PropagationCallGraphBuilder builder, InstanceKey valueIK) {
131+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
132+
int maxDepth = 0;
133+
134+
AllocationSiteInNode asin = getAllocationSiteInNode(valueIK);
135+
TypeReference reference = asin.getConcreteType().getReference();
136+
137+
// A nested `list` or `tuple`.
138+
if (reference.equals(list) || reference.equals(tuple)) {
139+
OrdinalSet<InstanceKey> objectCatalogPointsToSet =
140+
pointerAnalysis.getPointsToSet(
141+
((AstPointerKeyFactory) builder.getPointerKeyFactory())
142+
.getPointerKeyForObjectCatalog(asin));
143+
144+
for (InstanceKey catalogIK : objectCatalogPointsToSet) {
145+
ConstantKey<?> constantKey = (ConstantKey<?>) catalogIK;
146+
Object constantKeyValue = constantKey.getValue();
147+
148+
Integer fieldIndex = (Integer) constantKeyValue;
149+
150+
FieldReference subscript =
151+
FieldReference.findOrCreate(Root, findOrCreateAsciiAtom(fieldIndex.toString()), Root);
152+
153+
IField f = builder.getClassHierarchy().resolveField(subscript);
154+
155+
PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f);
156+
157+
OrdinalSet<InstanceKey> instanceFieldPointsToSet =
158+
pointerAnalysis.getPointsToSet(pointerKeyForInstanceField);
159+
160+
if (instanceFieldPointsToSet.isEmpty())
161+
// An empty list at this field.
162+
maxDepth = Math.max(maxDepth, 0);
163+
164+
for (InstanceKey fieldIK : instanceFieldPointsToSet) {
165+
int depthOfField = getMaximumDepthOfEmptyList(builder, fieldIK);
166+
maxDepth = Math.max(maxDepth, 1 + depthOfField);
123167
}
124168
}
169+
} else
170+
throw new IllegalArgumentException("Expected a list or tuple, but found: " + reference + ".");
125171

126-
ret.add(maxDepth);
172+
return maxDepth;
173+
}
174+
175+
private static int getMaximumDepthOfScalars(
176+
PropagationCallGraphBuilder builder, InstanceKey valueIK) {
177+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
178+
int maxDepth = 0;
179+
180+
if (valueIK instanceof ConstantKey) maxDepth = Math.max(maxDepth, 0); // Scalar value.
181+
else {
182+
AllocationSiteInNode asin = getAllocationSiteInNode(valueIK);
183+
TypeReference reference = asin.getConcreteType().getReference();
184+
185+
// A nested `list`, `tuple`, or `np.ndarray`.
186+
if (reference.equals(list) || reference.equals(tuple)) {
187+
OrdinalSet<InstanceKey> objectCatalogPointsToSet =
188+
pointerAnalysis.getPointsToSet(
189+
((AstPointerKeyFactory) builder.getPointerKeyFactory())
190+
.getPointerKeyForObjectCatalog(asin));
191+
192+
for (InstanceKey catalogIK : objectCatalogPointsToSet) {
193+
ConstantKey<?> constantKey = (ConstantKey<?>) catalogIK;
194+
Object constantKeyValue = constantKey.getValue();
195+
196+
Integer fieldIndex = (Integer) constantKeyValue;
197+
198+
FieldReference subscript =
199+
FieldReference.findOrCreate(Root, findOrCreateAsciiAtom(fieldIndex.toString()), Root);
200+
201+
IField f = builder.getClassHierarchy().resolveField(subscript);
202+
203+
PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f);
204+
205+
OrdinalSet<InstanceKey> instanceFieldPointsToSet =
206+
pointerAnalysis.getPointsToSet(pointerKeyForInstanceField);
207+
208+
for (InstanceKey fieldIK : instanceFieldPointsToSet) {
209+
int depthOfField = getMaximumDepthOfScalars(builder, fieldIK);
210+
maxDepth = Math.max(maxDepth, 1 + depthOfField);
211+
}
212+
}
213+
} else
214+
throw new IllegalArgumentException(
215+
"Expected a list or tuple, but found: " + reference + ".");
127216
}
128217

129-
return ret;
218+
return maxDepth;
130219
}
131220

132221
@Override
@@ -147,11 +236,16 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
147236

148237
Set<List<Dimension<?>>> ret = HashSetFactory.make();
149238

150-
Set<Integer> maxDepthOfScalars = getMaximumDepthOfScalars(builder, valuePointsToSet);
151-
LOGGER.fine("Maximum depth of scalars in `pylist`: " + maxDepthOfScalars);
239+
Set<InstanceKey> scalars = containsScalars(builder, valuePointsToSet);
240+
241+
for (InstanceKey valueIK : valuePointsToSet) {
242+
int maxDepth = getMaxDepth(builder, scalars, valueIK);
243+
LOGGER.fine("Maximum depth of `pylist`: " + maxDepth);
244+
245+
// Step 2: Determine Ragged Rank (R).
246+
int K = maxDepth;
247+
LOGGER.fine("Tensor rank: " + K);
152248

153-
// Step 2: Determine Ragged Rank (R).
154-
for (int K : maxDepthOfScalars) {
155249
Optional<Integer> raggedRank = this.getRaggedRankArgumentValue(builder);
156250
int R = raggedRank.orElse(K - 1);
157251
LOGGER.fine("Ragged rank: " + R);
@@ -183,6 +277,18 @@ protected Set<List<Dimension<?>>> getShapesOfValue(
183277
return ret;
184278
}
185279

280+
private static int getMaxDepth(
281+
PropagationCallGraphBuilder builder, Set<InstanceKey> scalars, InstanceKey valueIK) {
282+
int maxDepth;
283+
284+
if (scalars.contains(valueIK)) maxDepth = getMaximumDepthOfScalars(builder, valueIK);
285+
else
286+
// If `pylist` contains no scalar values, then K is one greater than the maximum depth of
287+
// empty lists in `pylist`.
288+
maxDepth = 1 + getMaximumDepthOfEmptyList(builder, valueIK);
289+
return maxDepth;
290+
}
291+
186292
private Optional<Integer> getRaggedRankArgumentValue(PropagationCallGraphBuilder builder) {
187293
// TODO Auto-generated method stub
188294
return Optional.empty();

0 commit comments

Comments
 (0)