|
8 | 8 | import static java.util.logging.Logger.getLogger; |
9 | 9 |
|
10 | 10 | import com.ibm.wala.cast.ipa.callgraph.AstPointerKeyFactory; |
| 11 | +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType; |
11 | 12 | import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; |
12 | 13 | import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim; |
13 | 14 | import com.ibm.wala.classLoader.IField; |
|
23 | 24 | import com.ibm.wala.util.collections.HashSetFactory; |
24 | 25 | import com.ibm.wala.util.intset.OrdinalSet; |
25 | 26 | import java.util.ArrayList; |
| 27 | +import java.util.EnumSet; |
26 | 28 | import java.util.List; |
27 | 29 | import java.util.Optional; |
28 | 30 | import java.util.Set; |
@@ -293,4 +295,30 @@ private Optional<Integer> getRaggedRankArgumentValue(PropagationCallGraphBuilder |
293 | 295 | // TODO Auto-generated method stub |
294 | 296 | return Optional.empty(); |
295 | 297 | } |
| 298 | + |
| 299 | + /** |
| 300 | + * {@inheritDoc} |
| 301 | + * |
| 302 | + * <p>If there no scalars, we default to <code>tf.float32</code>. This isn't in the documentation, |
| 303 | + * but it seems to be the case. |
| 304 | + * |
| 305 | + * @see The <a href="https://github.com/tensorflow/tensorflow/issues/105858">"Update default dtype |
| 306 | + * description in ragged_factory_ops.py" GitHub issue</a>. |
| 307 | + */ |
| 308 | + @Override |
| 309 | + protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) { |
| 310 | + PointerAnalysis<InstanceKey> pointerAnalysis = builder.getPointerAnalysis(); |
| 311 | + |
| 312 | + int valueNumber = this.getValueArgumentValueNumber(); |
| 313 | + PointerKey valuePK = |
| 314 | + pointerAnalysis.getHeapModel().getPointerKeyForLocal(this.getNode(), valueNumber); |
| 315 | + OrdinalSet<InstanceKey> valuePointsToSet = pointerAnalysis.getPointsToSet(valuePK); |
| 316 | + |
| 317 | + if (containsScalars(builder, valuePointsToSet).isEmpty()) { |
| 318 | + LOGGER.fine("No scalars found in `pylist`; defaulting to `tf.float32` dtype."); |
| 319 | + return EnumSet.of(DType.FLOAT32); |
| 320 | + } |
| 321 | + |
| 322 | + return super.getDefaultDTypes(builder); |
| 323 | + } |
296 | 324 | } |
0 commit comments