Skip to content

Commit 48488d8

Browse files
committed
ITEP-90808: Add cosine similarity as a configurable option
1 parent 64efa7f commit 48488d8

18 files changed

Lines changed: 1478 additions & 116 deletions

controller/config/reid-config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
{
2+
"similarity_metric": "L2",
23
"stale_feature_timeout_secs": 5.0,
34
"stale_feature_check_interval_secs": 1.0,
45
"feature_accumulation_threshold": 12,
56
"minimum_bbox_area": 5000,
67
"feature_slice_size": 10,
7-
"similarity_threshold": 30.0
8+
"similarity_threshold": 30
89
}

controller/src/controller/data_source.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from abc import ABC, abstractmethod
55
from pathlib import Path
66
import json
7+
78
from scene_common import log
89
from scene_common.rest_client import RESTClient
910

controller/src/controller/reid.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,61 @@
33

44
from abc import ABC, abstractmethod
55

6+
import numpy as np
7+
8+
from scene_common import log
9+
610
class ReIDDatabase(ABC):
11+
def prepare_reid_dict(self, embedding_vector, dimensions=None,
12+
caller_name="prepare_reid_dict",
13+
normalize_embeddings=False):
14+
"""Prepare a normalized/validated ReID payload from arbitrary vector shapes.
15+
16+
Supports vectors shaped as (N,), (1, N), or any array-like object by
17+
flattening to 1D. If dimensions is None, dimensions are inferred from the
18+
flattened vector length.
19+
"""
20+
if embedding_vector is None:
21+
log.warning(f"{caller_name}: Empty embedding vector, skipping this vector")
22+
return None
23+
24+
vec_array = np.asarray(embedding_vector, dtype="float32").reshape(-1)
25+
inferred_dimensions = int(vec_array.shape[0])
26+
expected_dimensions = inferred_dimensions if dimensions is None else int(dimensions)
27+
28+
if inferred_dimensions != expected_dimensions:
29+
log.warning(
30+
f"{caller_name}: Expected vector shape ({expected_dimensions},) but got {vec_array.shape}, skipping this vector")
31+
return None
32+
33+
if not np.all(np.isfinite(vec_array)):
34+
log.warning(f"{caller_name}: Vector contains non-finite values, skipping this vector")
35+
return None
36+
37+
if normalize_embeddings:
38+
norm = np.linalg.norm(vec_array)
39+
if not np.isfinite(norm) or norm == 0.0:
40+
log.warning(f"{caller_name}: Invalid vector norm ({norm}), skipping this vector")
41+
return None
42+
vec_array = vec_array / norm
43+
44+
return {
45+
"embedded_vector": vec_array.astype("float32", copy=False),
46+
"dimensions": expected_dimensions,
47+
}
48+
49+
def _prepare_reid_vector(self, reid_vector, dimensions, caller_name,
50+
normalize_embeddings=False):
51+
"""Backward-compatible wrapper returning only the prepared vector."""
52+
prepared_reid = self.prepare_reid_dict(
53+
reid_vector,
54+
dimensions,
55+
caller_name,
56+
normalize_embeddings=normalize_embeddings)
57+
if prepared_reid is None:
58+
return None
59+
return prepared_reid["embedded_vector"]
60+
761
@abstractmethod
862
def connect(self, hostname):
963
"""

controller/src/controller/scene.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,21 @@
33

44
from types import SimpleNamespace
55
from typing import Optional
6+
67
import numpy as np
8+
79
import robot_vision as rv
8-
from controller.controller_mode import ControllerMode
9-
from controller.moving_object import ChainData
1010
from scene_common import log
1111
from scene_common.camera import Camera
1212
from scene_common.earth_lla import convertLLAToECEF, calculateTRSLocal2LLAFromSurfacePoints
13-
from scene_common.geometry import Line, Point, Region, Tripwire, getRegionEvents, getTripwireEvents
13+
from scene_common.geometry import Point, Region, Tripwire, getRegionEvents, getTripwireEvents
1414
from scene_common.scene_model import SceneModel
1515
from scene_common.timestamp import get_epoch_time, get_iso_time
1616
from scene_common.transform import CameraPose
1717
from scene_common.mesh_util import getMeshAxisAlignedProjectionToXY, createRegionMesh, createObjectMesh
1818

19+
from controller.controller_mode import ControllerMode
20+
from controller.moving_object import ChainData
1921
from controller.ilabs_tracking import IntelLabsTracking
2022
from controller.time_chunking import TimeChunkedIntelLabsTracking, DEFAULT_CHUNKING_RATE_FPS
2123
from controller.tracking import (MAX_UNRELIABLE_TIME,

controller/src/controller/uuid_manager.py

Lines changed: 98 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import collections
55
import concurrent.futures
66
import threading
7+
import math
78

89
import numpy as np
910

@@ -13,19 +14,38 @@
1314
from scene_common.timestamp import get_epoch_time
1415

1516
DEFAULT_DATABASE = "VDMS"
16-
DEFAULT_SIMILARITY_THRESHOLD = 40
17+
DEFAULT_SIMILARITY_THRESHOLD = 0.5
1718
DEFAULT_MINIMUM_BBOX_AREA = 5000
1819
DEFAULT_MINIMUM_FEATURE_COUNT = 12
1920
DEFAULT_FEATURE_SLICE_SIZE = 10
2021
DEFAULT_MAX_QUERY_TIME = 4
2122
DEFAULT_MAX_SIMILARITY_QUERIES_TRACKED = 10
2223
DEFAULT_STALE_FEATURE_TIMEOUT_SECS = 5.0
2324
DEFAULT_STALE_FEATURE_CHECK_INTERVAL_SECS = 1.0
25+
DEFAULT_SIMILARITY_METRIC = "L2"
26+
SUPPORTED_SIMILARITY_METRICS = {"COSINE", "L2"}
2427
available_databases = {
2528
"VDMS": VDMSDatabase,
2629
}
2730

2831
class UUIDManager:
32+
def _normalizeSimilarityMetric(self, metric):
33+
normalized_metric = str(metric).strip().upper()
34+
if normalized_metric not in SUPPORTED_SIMILARITY_METRICS:
35+
log.warning(
36+
f"Unsupported similarity_metric '{metric}', "
37+
f"supported values are {sorted(SUPPORTED_SIMILARITY_METRICS)}; "
38+
f"falling back to {DEFAULT_SIMILARITY_METRIC}")
39+
return DEFAULT_SIMILARITY_METRIC
40+
return normalized_metric
41+
42+
def _resolveDatabaseSimilarityMetric(self, configured_metric):
43+
"""Translate controller-facing similarity metric to the VDMS descriptor metric."""
44+
metric = self._normalizeSimilarityMetric(configured_metric)
45+
if metric == "COSINE":
46+
return "IP"
47+
return metric
48+
2949
def __init__(self, database=DEFAULT_DATABASE, reid_config_data=None):
3050
self.active_ids = {}
3151
self.active_ids_lock = threading.Lock()
@@ -52,10 +72,13 @@ def __init__(self, database=DEFAULT_DATABASE, reid_config_data=None):
5272
self.stale_feature_timeout_secs = reid_config_data.get('stale_feature_timeout_secs', DEFAULT_STALE_FEATURE_TIMEOUT_SECS)
5373
self.stale_feature_check_interval_secs = reid_config_data.get('stale_feature_check_interval_secs', DEFAULT_STALE_FEATURE_CHECK_INTERVAL_SECS)
5474
self.minimum_feature_count = reid_config_data.get('feature_accumulation_threshold', DEFAULT_MINIMUM_FEATURE_COUNT)
55-
self.similarity_threshold = reid_config_data.get(
56-
'similarity_threshold', DEFAULT_SIMILARITY_THRESHOLD)
75+
self.similarity_threshold = reid_config_data.get('similarity_threshold', DEFAULT_SIMILARITY_THRESHOLD)
76+
self.similarity_metric = self._normalizeSimilarityMetric(
77+
reid_config_data.get('similarity_metric', DEFAULT_SIMILARITY_METRIC))
5778
self.minimum_bbox_area = reid_config_data.get('minimum_bbox_area', DEFAULT_MINIMUM_BBOX_AREA)
5879
self.feature_slice_size = reid_config_data.get('feature_slice_size', DEFAULT_FEATURE_SLICE_SIZE)
80+
self.reid_database = available_databases[database](
81+
similarity_metric=self._resolveDatabaseSimilarityMetric(self.similarity_metric))
5982
self.stale_feature_timer = None
6083
self._start_stale_feature_timer()
6184
return
@@ -82,10 +105,15 @@ def updateReidConfig(self, reid_config_data=None):
82105
'feature_accumulation_threshold', DEFAULT_MINIMUM_FEATURE_COUNT)
83106
self.similarity_threshold = reid_config_data.get(
84107
'similarity_threshold', DEFAULT_SIMILARITY_THRESHOLD)
108+
self.similarity_metric = self._normalizeSimilarityMetric(reid_config_data.get(
109+
'similarity_metric', DEFAULT_SIMILARITY_METRIC))
85110
self.minimum_bbox_area = reid_config_data.get(
86111
'minimum_bbox_area', DEFAULT_MINIMUM_BBOX_AREA)
87112
self.feature_slice_size = reid_config_data.get(
88113
'feature_slice_size', DEFAULT_FEATURE_SLICE_SIZE)
114+
if hasattr(self, 'reid_database') and self.reid_database is not None:
115+
self.reid_database.similarity_metric = self._resolveDatabaseSimilarityMetric(
116+
self.similarity_metric)
89117

90118
# Timer cadence changes require rescheduling the stale feature timer.
91119
if self.stale_feature_timer is not None and old_interval != self.stale_feature_check_interval_secs:
@@ -450,38 +478,67 @@ def parseQueryResults(self, similarity_scores, threshold=None, rv_id=None):
450478
The threshold value is used as the deciding criteria for close matches.
451479
452480
@param similarity_scores The similarity scores obtained from the database query
453-
@param threshold The maximum difference between the Re-ID vectors which would
454-
still be considered a valid match
481+
@param threshold Similarity threshold interpreted according to metric semantics:
482+
- L2-style distance: lower is better, candidate must be < threshold
483+
- IP-style score: higher is better, candidate must be > threshold
455484
@return database_id Returns the ID of the matched entry from the database if one
456485
is found; otherwise, returns None
457-
@return similarity Distance between the Re-ID vectors for the object and the
458-
matched entry if it is found; otherwise, return None
486+
@return similarity Similarity value returned by VDMS (`_distance` field) for
487+
the matched entry if one is found; otherwise, return None
459488
"""
460489
if threshold is None:
461490
threshold = self.similarity_threshold
462491

463492
if similarity_scores:
464-
# VDMS FindDescriptor returns entities sorted ascending by _distance (closest first),
465-
# so each per-vector best match is always entities[0].
466-
minimum_distances = [self._findMinimumDistance(entities)
493+
metric_candidates = [self._findBestMetricCandidate(entities)
467494
for entities in similarity_scores]
468-
distances_below_threshold = [(uuid, distance) for (uuid, distance) in
469-
minimum_distances if
470-
distance is not None and distance < threshold]
471-
472-
if distances_below_threshold:
473-
counter = collections.Counter(item[0] for item in distances_below_threshold)
495+
qualifying_candidates = [(uuid, metric_value) for (uuid, metric_value) in
496+
metric_candidates if
497+
metric_value is not None and
498+
self._isSimilarityMatch(metric_value, threshold)]
499+
if qualifying_candidates:
500+
counter = collections.Counter(item[0] for item in qualifying_candidates)
474501
most_common_uuid, count = counter.most_common(1)[0]
475-
if count >= (len(minimum_distances) / 2):
476-
similarity = min(item[1] for item in distances_below_threshold
477-
if item[0] == most_common_uuid)
502+
if count >= (len(metric_candidates) / 2):
503+
similarity = self._pickBestMetricValue(
504+
[item[1] for item in qualifying_candidates if item[0] == most_common_uuid])
478505
return most_common_uuid, similarity
479506

480507
return None, None
481508

482-
def _findMinimumDistance(self, entities):
509+
def _isHigherBetterMetric(self):
510+
"""Return True when the configured descriptor metric uses higher-is-better semantics."""
511+
metric = getattr(self.reid_database, 'similarity_metric', None)
512+
if metric is None:
513+
return False
514+
return str(metric).strip().upper() in {"IP", "INNER_PRODUCT"}
515+
516+
def _isSimilarityMatch(self, metric_value, threshold):
517+
"""Evaluate threshold semantics according to the active descriptor metric."""
518+
if metric_value is None:
519+
return False
520+
521+
if not math.isfinite(metric_value):
522+
return False
523+
524+
if self._isHigherBetterMetric() and (metric_value < -1.0 or metric_value > 1.0):
525+
return False
526+
527+
if self._isHigherBetterMetric():
528+
return metric_value > threshold
529+
return metric_value < threshold
530+
531+
def _pickBestMetricValue(self, metric_values):
532+
"""Pick best metric value according to descriptor metric semantics."""
533+
if not metric_values:
534+
return None
535+
if self._isHigherBetterMetric():
536+
return max(metric_values)
537+
return min(metric_values)
538+
539+
def _findBestMetricCandidate(self, entities):
483540
"""
484-
Find the uuid with the minimum distance and the corresponding distance value.
541+
Find the best candidate uuid and metric value according to descriptor semantics.
485542
486543
VDMS returns entities sorted ascending by _distance (closest first), so entities[0]
487544
is always the best match.
@@ -490,8 +547,26 @@ def _findMinimumDistance(self, entities):
490547
[{'uuid': <UUID>, 'rvid': <TRACKER_ID>, '_distance': <SIMILARITY_SCORE>}, ...]
491548
"""
492549
if entities:
493-
minimum_distance_entity = entities[0]
494-
return (minimum_distance_entity['uuid'], minimum_distance_entity['_distance'])
550+
filtered_entities = []
551+
for entity in entities:
552+
metric_value = entity.get('_distance')
553+
if metric_value is None or not math.isfinite(metric_value):
554+
continue
555+
if self._isHigherBetterMetric() and (metric_value < -1.0 or metric_value > 1.0):
556+
log.warning(
557+
f"Ignoring out-of-range IP similarity score {metric_value} "
558+
f"for uuid={entity.get('uuid')}")
559+
continue
560+
filtered_entities.append(entity)
561+
562+
if not filtered_entities:
563+
return (None, None)
564+
565+
if self._isHigherBetterMetric():
566+
best_entity = max(filtered_entities, key=lambda x: x['_distance'])
567+
else:
568+
best_entity = min(filtered_entities, key=lambda x: x['_distance'])
569+
return (best_entity['uuid'], best_entity['_distance'])
495570
return (None, None)
496571

497572
def _active_gid_index_locked(self):

0 commit comments

Comments
 (0)