Abstract base class for defining the classification matching strategy.
Inherits From: ABC
TFSimilarity.callbacks.ClassificationMatch(
name: str = , canonical_name: str =
) -> None
name | Name associated with the match object, e.g., match_nearest |
canonical_name | The canonical name associated with match strategy, e.g., match_nearest |
distance_thresholds | The max distance below which a nearest neighbor is considered a valid match. Defaults to None and must be set using ClassificationMatch.compile()``` . |
count | The total number of queries. |
fn |
The count of False negatives matches.
A False negative match is when the query label == the label generated by a matcher and the distance to the query is greater than the distance threshold. |
fp |
The count of False positive matches.
A False positive match is when the query label != the label generated by a matcher and the distance to the query is less than the distance threshold. |
tn |
The count of True negatives matches.
A True negative match is when the query label != the label generated by a matcher and the distance to the query is greater than the distance threshold. |
tp |
The count of True positive matches.
A True positive match is when the query label == the label generated by a matcher and the distance to the query is less than the distance threshold. |
compile(
distance_thresholds: Optional[<a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>] = None
)
Configures the distance thresholds used during matching.
Args | |
---|---|
distance_thresholds | The max distance below which a nearest neighbor is considered a valid match. A threshold of math.inf is used if None is passed. |
compute_count(
query_labels: <a href="../../TFSimilarity/callbacks/IntTensor.md">TFSimilarity.callbacks.IntTensor```
</a>,
lookup_labels: <a href="../../TFSimilarity/callbacks/IntTensor.md">TFSimilarity.callbacks.IntTensor```
</a>,
lookup_distances: <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>
) -> None
Computes the match counts at each of the distance thresholds.
This method computes the following at each distance threshold.
-
True Positive: The query label matches the derived lookup label and the derived lookup distance is <= the current distance threshold.
-
False Positive: The query label does not match the derived lookup label but the derived lookup distance is <= the current distance threshold.
-
False Negative: The query label matches the derived lookup label but the derived lookup distance is > the current distance threshold.
-
True Negative: The query label does not match the derived lookup label and the derived lookup distance is > the current distance threshold.
Note: compile must be called before calling match.
Args | |
---|---|
query_labels | A 1D array of the labels associated with the queries. |
lookup_labels | A 2D array where the jth row is the labels associated with the set of k neighbors for the jth query. |
lookup_distances | A 2D array where the jth row is the distances between the jth query and the set of k neighbors. |
derive_match(
lookup_labels: <a href="../../TFSimilarity/callbacks/IntTensor.md">TFSimilarity.callbacks.IntTensor```
</a>,
lookup_distances: <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>
) -> Tuple[<a href="../../TFSimilarity/callbacks/IntTensor.md">TFSimilarity.callbacks.IntTensor``<b>
</a>, <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor</b>``
</a>]
Derive a match label and distance from a set of K neighbors.
For each query, derive a single match label and distance given the associated set of lookup labels and distances.
Args | |
---|---|
lookup_labels | A 2D array where the jth row is the labels associated with the set of k neighbors for the jth query. |
lookup_distances | A 2D array where the jth row is the distances between the jth query and the set of k neighbors. |
Returns | |
---|---|
A Tuple of FloatTensors:
derived_labels: A FloatTensor of shape
- [len(lookup_labels), 1] where the jth row contains the derived
label for the jth query.
derived_distances: A FloatTensor of shape
|
get_config()