Skip to content

Commit c6a9a6a

Browse files
authored
Merge pull request #12 from roboflow/refactor/deepsort
refactor(trackers): decouple DeepSORTKalmanBoxTracker from SORTKalmanBoxTracker
2 parents 86a9bf0 + d004327 commit c6a9a6a

File tree

3 files changed

+153
-13
lines changed

3 files changed

+153
-13
lines changed

trackers/core/deepsort/kalman_box_tracker.py

+135-4
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,154 @@
22

33
import numpy as np
44

5-
from trackers.core.sort.kalman_box_tracker import SORTKalmanBoxTracker
65

7-
8-
class DeepSORTKalmanBoxTracker(SORTKalmanBoxTracker):
6+
class DeepSORTKalmanBoxTracker:
97
"""
108
The `DeepSORTKalmanBoxTracker` class represents the internals of a single
119
tracked object (bounding box), with a Kalman filter to predict and update
1210
its position. It also maintains a feature vector for the object, which is
1311
used to identify the object across frames.
12+
13+
Attributes:
14+
tracker_id (int): Unique identifier for the tracker.
15+
number_of_successful_updates (int): Number of times the object has been
16+
updated successfully.
17+
time_since_update (int): Number of frames since the last update.
18+
state (np.ndarray): State vector of the bounding box.
19+
F (np.ndarray): State transition matrix.
20+
H (np.ndarray): Measurement matrix.
21+
Q (np.ndarray): Process noise covariance matrix.
22+
R (np.ndarray): Measurement noise covariance matrix.
23+
P (np.ndarray): Error covariance matrix.
24+
features (list[np.ndarray]): List of feature vectors.
25+
count_id (int): Class variable to assign unique IDs to each tracker.
26+
27+
Args:
28+
bbox (np.ndarray): Initial bounding box in the form [x1, y1, x2, y2].
29+
feature (Optional[np.ndarray]): Optional initial feature vector.
1430
"""
1531

32+
count_id = 0
33+
34+
@classmethod
35+
def get_next_tracker_id(cls) -> int:
36+
"""
37+
Class method that returns the next available tracker ID.
38+
39+
Returns:
40+
int: The next available tracker ID.
41+
"""
42+
next_id = cls.count_id
43+
cls.count_id += 1
44+
return next_id
45+
1646
def __init__(self, bbox: np.ndarray, feature: Optional[np.ndarray] = None):
17-
super().__init__(bbox)
47+
# Initialize with a temporary ID of -1
48+
# Will be assigned a real ID when the track is considered mature
49+
self.tracker_id = -1
50+
51+
# Number of hits indicates how many times the object has been
52+
# updated successfully
53+
self.number_of_successful_updates = 1
54+
# Number of frames since the last update
55+
self.time_since_update = 0
56+
57+
# For simplicity, we keep a small state vector:
58+
# (x, y, x2, y2, vx, vy, vx2, vy2).
59+
# We'll store the bounding box in "self.state"
60+
self.state = np.zeros((8, 1), dtype=np.float32)
61+
62+
# Initialize state directly from the first detection
63+
self.state[0] = bbox[0]
64+
self.state[1] = bbox[1]
65+
self.state[2] = bbox[2]
66+
self.state[3] = bbox[3]
67+
68+
# Basic constant velocity model
69+
self._initialize_kalman_filter()
70+
71+
# Initialize features list
1872
self.features: list[np.ndarray] = []
1973
if feature is not None:
2074
self.features.append(feature)
2175

76+
def _initialize_kalman_filter(self) -> None:
77+
"""
78+
Sets up the matrices for the Kalman filter.
79+
"""
80+
# State transition matrix (F): 8x8
81+
# We assume a constant velocity model. Positions are incremented by
82+
# velocity each step.
83+
self.F = np.eye(8, dtype=np.float32)
84+
for i in range(4):
85+
self.F[i, i + 4] = 1.0
86+
87+
# Measurement matrix (H): we directly measure x1, y1, x2, y2
88+
self.H = np.eye(4, 8, dtype=np.float32) # 4x8
89+
90+
# Process covariance matrix (Q)
91+
self.Q = np.eye(8, dtype=np.float32) * 0.01
92+
93+
# Measurement covariance (R): noise in detection
94+
self.R = np.eye(4, dtype=np.float32) * 0.1
95+
96+
# Error covariance matrix (P)
97+
self.P = np.eye(8, dtype=np.float32)
98+
99+
def predict(self) -> None:
100+
"""
101+
Predict the next state of the bounding box (applies the state transition).
102+
"""
103+
# Predict state
104+
self.state = self.F @ self.state
105+
# Predict error covariance
106+
self.P = self.F @ self.P @ self.F.T + self.Q
107+
108+
# Increase time since update
109+
self.time_since_update += 1
110+
111+
def update(self, bbox: np.ndarray) -> None:
112+
"""
113+
Updates the state with a new detected bounding box.
114+
115+
Args:
116+
bbox (np.ndarray): Detected bounding box in the form [x1, y1, x2, y2].
117+
"""
118+
self.time_since_update = 0
119+
self.number_of_successful_updates += 1
120+
121+
# Kalman Gain
122+
S = self.H @ self.P @ self.H.T + self.R
123+
K = self.P @ self.H.T @ np.linalg.inv(S)
124+
125+
# Residual
126+
measurement = bbox.reshape((4, 1))
127+
y = measurement - self.H @ self.state
128+
129+
# Update state
130+
self.state = self.state + K @ y
131+
132+
# Update covariance
133+
identity_matrix = np.eye(8, dtype=np.float32)
134+
self.P = (identity_matrix - K @ self.H) @ self.P
135+
136+
def get_state_bbox(self) -> np.ndarray:
137+
"""
138+
Returns the current bounding box estimate from the state vector.
139+
140+
Returns:
141+
np.ndarray: The bounding box [x1, y1, x2, y2].
142+
"""
143+
return np.array(
144+
[
145+
self.state[0], # x1
146+
self.state[1], # y1
147+
self.state[2], # x2
148+
self.state[3], # y2
149+
],
150+
dtype=float,
151+
).reshape(-1)
152+
22153
def update_feature(self, feature: np.ndarray):
23154
self.features.append(feature)
24155

trackers/core/deepsort/tracker.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@ class DeepSORTTracker(BaseTrackerWithFeatures):
2929
from rfdetr import RFDETRBase
3030
from rfdetr.util.coco_classes import COCO_CLASSES
3131
32-
from trackers.core.deepsort.tracker import DeepSORTTracker
32+
from trackers import DeepSORTFeatureExtractor, DeepSORTTracker
3333
3434
model = RFDETRBase(device="mps")
35-
tracker = DeepSORTTracker()
35+
feature_extractor = DeepSORTFeatureExtractor.from_timm(
36+
model_name="mobilenetv4_conv_small.e1200_r224_in1k"
37+
)
38+
tracker = DeepSORTTracker(feature_extractor=feature_extractor)
3639
box_annotator = sv.BoxAnnotator()
3740
label_annotator = sv.LabelAnnotator()
3841
@@ -292,7 +295,8 @@ def _get_associated_indices(
292295
if combined_dist.size > 0:
293296
row_indices, col_indices = np.where(combined_dist < 1.0)
294297
sorted_pairs = sorted(
295-
zip(row_indices, col_indices), key=lambda x: combined_dist[x[0], x[1]]
298+
zip(map(int, row_indices), map(int, col_indices)),
299+
key=lambda x: combined_dist[x[0], x[1]],
296300
)
297301

298302
used_rows = set()
@@ -303,8 +307,10 @@ def _get_associated_indices(
303307
used_cols.add(col)
304308
matched_indices.append((row, col))
305309

306-
unmatched_trackers = unmatched_trackers - used_rows
307-
unmatched_detections = unmatched_detections - used_cols
310+
unmatched_trackers = unmatched_trackers - {int(row) for row in used_rows}
311+
unmatched_detections = unmatched_detections - {
312+
int(col) for col in used_cols
313+
}
308314

309315
return matched_indices, unmatched_trackers, unmatched_detections
310316

trackers/utils/sort_utils.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from copy import deepcopy
2-
from typing import List, Sequence, TypeVar
2+
from typing import List, Sequence, TypeVar, Union
33

44
import numpy as np
55
import supervision as sv
66
from supervision.detection.utils import box_iou_batch
77

8+
from trackers.core.deepsort.kalman_box_tracker import DeepSORTKalmanBoxTracker
89
from trackers.core.sort.kalman_box_tracker import SORTKalmanBoxTracker
910

10-
KalmanBoxTrackerType = TypeVar("KalmanBoxTrackerType", bound=SORTKalmanBoxTracker)
11+
KalmanBoxTrackerType = TypeVar(
12+
"KalmanBoxTrackerType", bound=Union[SORTKalmanBoxTracker, DeepSORTKalmanBoxTracker]
13+
)
1114

1215

1316
def get_alive_trackers(
@@ -42,7 +45,7 @@ def get_alive_trackers(
4245

4346

4447
def get_iou_matrix(
45-
trackers: Sequence[SORTKalmanBoxTracker], detection_boxes: np.ndarray
48+
trackers: Sequence[KalmanBoxTrackerType], detection_boxes: np.ndarray
4649
) -> np.ndarray:
4750
"""
4851
Build IOU cost matrix between detections and predicted bounding boxes
@@ -68,7 +71,7 @@ def get_iou_matrix(
6871

6972

7073
def update_detections_with_track_ids(
71-
trackers: Sequence[SORTKalmanBoxTracker],
74+
trackers: Sequence[KalmanBoxTrackerType],
7275
detections: sv.Detections,
7376
detection_boxes: np.ndarray,
7477
minimum_iou_threshold: float,

0 commit comments

Comments
 (0)