Skip to content

Commit 95ce3d3

Browse files
committed
Default to tf.float32 dtype in RaggedConstant when there are no
scalars.
1 parent 723fd38 commit 95ce3d3

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Constant.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,14 @@ protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder b
3535
return getShapes(builder, this.getValueArgumentValueNumber());
3636
}
3737

38+
/**
39+
* {@inheritDoc}
40+
*
41+
* <p>If the <code>dtype</code> argument is not specified, then the type is inferred from the type
42+
* of value.
43+
*/
3844
@Override
3945
protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
40-
// If the dtype argument is not specified, then the type is inferred from the type of value.
4146
// TODO: Handle keyword arguments.
4247
return getDTypes(builder, this.getValueArgumentValueNumber());
4348
}

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/RaggedConstant.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import static java.util.logging.Logger.getLogger;
99

1010
import com.ibm.wala.cast.ipa.callgraph.AstPointerKeyFactory;
11+
import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType;
1112
import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
1213
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
1314
import com.ibm.wala.classLoader.IField;
@@ -23,6 +24,7 @@
2324
import com.ibm.wala.util.collections.HashSetFactory;
2425
import com.ibm.wala.util.intset.OrdinalSet;
2526
import java.util.ArrayList;
27+
import java.util.EnumSet;
2628
import java.util.List;
2729
import java.util.Optional;
2830
import java.util.Set;
@@ -293,4 +295,30 @@ private Optional<Integer> getRaggedRankArgumentValue(PropagationCallGraphBuilder
293295
// TODO Auto-generated method stub
294296
return Optional.empty();
295297
}
298+
299+
/**
300+
* {@inheritDoc}
301+
*
302+
* <p>If there no scalars, we default to <code>tf.float32</code>. This isn't in the documentation,
303+
* but it seems to be the case.
304+
*
305+
* @see The <a href="https://github.com/tensorflow/tensorflow/issues/105858">"Update default dtype
306+
* description in ragged_factory_ops.py" GitHub issue</a>.
307+
*/
308+
@Override
309+
protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
310+
PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis();
311+
312+
int valueNumber = this.getValueArgumentValueNumber();
313+
PointerKey valuePK =
314+
pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), valueNumber);
315+
OrdinalSet<InstanceKey> valuePointsToSet = pointerAnalysis.getPointsToSet(valuePK);
316+
317+
if (containsScalars(builder, valuePointsToSet).isEmpty()) {
318+
LOGGER.fine("No scalars found in `pylist`; defaulting to `tf.float32` dtype.");
319+
return EnumSet.of(DType.FLOAT32);
320+
}
321+
322+
return super.getDefaultDTypes(builder);
323+
}
296324
}

0 commit comments

Comments
 (0)