Skip to content

Commit a2ba663

Browse files
committed
Fix getSignature in RaggedConstant to return the correct function signature.
1 parent 4951e69 commit a2ba663

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/RaggedConstant.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import static com.ibm.wala.cast.python.ml.client.RaggedConstant.Parameters.INNER_SHAPE;
44
import static com.ibm.wala.cast.python.ml.client.RaggedConstant.Parameters.RAGGED_RANK;
55
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32;
6+
import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.TYPE_REFERENCE_TO_SIGNATURE;
67
import static com.ibm.wala.cast.python.types.PythonTypes.Root;
78
import static com.ibm.wala.cast.python.types.PythonTypes.list;
89
import static com.ibm.wala.cast.python.types.PythonTypes.tuple;
910
import static com.ibm.wala.cast.python.util.Util.getAllocationSiteInNode;
11+
import static com.ibm.wala.cast.python.util.Util.getFunction;
1012
import static com.ibm.wala.core.util.strings.Atom.findOrCreateAsciiAtom;
1113
import static java.lang.Math.max;
1214
import static java.util.Collections.emptySet;
@@ -478,4 +480,15 @@ protected EnumSet<DType> getDefaultDTypes(PropagationCallGraphBuilder builder) {
478480
// Otherwise, there are values available to infer the dtype from.
479481
return super.getDefaultDTypes(builder);
480482
}
483+
484+
/**
485+
* Returns the TensorFlow function signature represented by this generator.
486+
*
487+
* @return The TensorFlow function signature represented by this generator.
488+
*/
489+
@Override
490+
protected String getSignature() {
491+
TypeReference function = getFunction(this.getSource());
492+
return TYPE_REFERENCE_TO_SIGNATURE.get(function);
493+
}
481494
}

0 commit comments

Comments
 (0)