Skip to content

Commit fe63c88

Browse files
committed
Add ObjectTracker module with DepthAI integration and software tracker
1 parent 9609392 commit fe63c88

File tree

6 files changed

+597
-0
lines changed

6 files changed

+597
-0
lines changed

modules/object_tracker/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
ObjectTracker module for persistent tracking of detected objects across frames.
3+
4+
Two modes available:
5+
1. On-device (DepthAI): Use configure_tracker_node() + parse_tracklets()
6+
2. Software (host): Use SoftwareTracker.update(detections)
7+
"""
8+
9+
from .tracked_object import TrackedObject, TrackingStatus
10+
from .detection import Detection
11+
12+
# On-device DepthAI tracker
13+
from .object_tracker import configure_tracker_node, parse_tracklets
14+
15+
# Software tracker (accepts Detection objects)
16+
from .software_tracker import SoftwareTracker
17+
18+
# Workers
19+
from .object_tracker_worker import object_tracker_run, object_tracker_read_loop
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
Detection input class - interface contract with detection team.
3+
4+
This matches the Detection class from the SpatialDetectionNetwork team.
5+
"""
6+
7+
from dataclasses import dataclass
8+
9+
10+
@dataclass
11+
class Detection:
12+
"""
13+
Standardized detection result from SpatialDetectionNetwork.
14+
15+
This is the input format we receive from the detection team.
16+
"""
17+
18+
label: str
19+
confidence: float
20+
x: float # spatial X (meters, camera frame)
21+
y: float # spatial Y (meters, camera frame)
22+
z: float # spatial Z / depth (meters, camera frame)
23+
xmin: float # bbox left (pixels or normalized)
24+
ymin: float # bbox top (pixels or normalized)
25+
xmax: float # bbox right (pixels or normalized)
26+
ymax: float # bbox bottom (pixels or normalized)
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
ObjectTracker module using DepthAI's built-in ObjectTracker node.
3+
4+
Configures the ObjectTracker node within a DepthAI pipeline and parses
5+
tracklet output into TrackedObject data classes.
6+
7+
The ObjectTracker node is part of the on-device pipeline:
8+
SpatialDetectionNetwork.out ──► ObjectTracker ──► XLinkOut("tracklets")
9+
10+
This module provides:
11+
- configure_tracker_node(): sets up the node in a shared pipeline
12+
- parse_tracklets(): converts raw DepthAI tracklets into TrackedObject list
13+
14+
Reference: https://docs.luxonis.com/software/depthai/depthai-components/nodes/objecttracker/
15+
"""
16+
17+
from typing import List
18+
19+
import depthai as dai
20+
21+
from .tracked_object import TrackedObject, TrackingStatus
22+
23+
24+
# Map DepthAI tracklet status to our TrackingStatus enum
25+
_STATUS_MAP = {
26+
dai.Tracklet.TrackingStatus.NEW: TrackingStatus.NEW,
27+
dai.Tracklet.TrackingStatus.TRACKED: TrackingStatus.TRACKED,
28+
dai.Tracklet.TrackingStatus.LOST: TrackingStatus.LOST,
29+
dai.Tracklet.TrackingStatus.REMOVED: TrackingStatus.LOST,
30+
}
31+
32+
# Available tracker algorithms
33+
TRACKER_TYPES = {
34+
"ZERO_TERM_COLOR_HISTOGRAM": dai.TrackerType.ZERO_TERM_COLOR_HISTOGRAM,
35+
"ZERO_TERM_IMAGELESS": dai.TrackerType.ZERO_TERM_IMAGELESS,
36+
"SHORT_TERM_IMAGELESS": dai.TrackerType.SHORT_TERM_IMAGELESS,
37+
"SHORT_TERM_KCF": dai.TrackerType.SHORT_TERM_KCF,
38+
}
39+
40+
41+
def configure_tracker_node(
42+
pipeline: dai.Pipeline,
43+
spatial_detection_network: dai.node.SpatialDetectionNetwork,
44+
tracker_type: str = "SHORT_TERM_IMAGELESS",
45+
labels_to_track: List[int] = None,
46+
) -> dai.node.ObjectTracker:
47+
"""
48+
Create and configure an ObjectTracker node in the DepthAI pipeline.
49+
50+
This wires the tracker to the SpatialDetectionNetwork outputs.
51+
Teammates provide the pipeline and spatial_detection_network node;
52+
this function adds the tracker on top.
53+
54+
Args:
55+
pipeline: The shared DepthAI pipeline (created by teammates).
56+
spatial_detection_network: The detection network node whose
57+
outputs we consume.
58+
tracker_type: Algorithm name. One of:
59+
ZERO_TERM_COLOR_HISTOGRAM, ZERO_TERM_IMAGELESS,
60+
SHORT_TERM_IMAGELESS, SHORT_TERM_KCF.
61+
labels_to_track: List of class label indices to track.
62+
If None, tracks all detected labels.
63+
64+
Returns:
65+
The configured ObjectTracker node (already linked to inputs
66+
and to an XLinkOut named "tracklets").
67+
"""
68+
if tracker_type not in TRACKER_TYPES:
69+
raise ValueError(
70+
f"Unknown tracker_type '{tracker_type}'. "
71+
f"Options: {list(TRACKER_TYPES.keys())}"
72+
)
73+
74+
# --- create tracker node ---
75+
tracker = pipeline.create(dai.node.ObjectTracker)
76+
tracker.setTrackerType(TRACKER_TYPES[tracker_type])
77+
tracker.setTrackerIdAssignmentPolicy(
78+
dai.TrackerIdAssignmentPolicy.UNIQUE_ID,
79+
)
80+
81+
if labels_to_track is not None:
82+
tracker.setDetectionLabelsToTrack(labels_to_track)
83+
84+
# --- link detection network outputs into tracker inputs ---
85+
# passthrough frame (RGB preview used for detection)
86+
spatial_detection_network.passthrough.link(tracker.inputTrackerFrame)
87+
# detection frame (same frame, used for re-identification)
88+
spatial_detection_network.passthrough.link(tracker.inputDetectionFrame)
89+
# detection results (bounding boxes + spatial coords)
90+
spatial_detection_network.out.link(tracker.inputDetections)
91+
92+
# --- create XLinkOut so host can read tracklets ---
93+
tracker_out = pipeline.create(dai.node.XLinkOut)
94+
tracker_out.setStreamName("tracklets")
95+
tracker.out.link(tracker_out.input)
96+
97+
return tracker
98+
99+
100+
def parse_tracklets(
101+
tracklets_data: dai.Tracklets,
102+
label_map: List[str],
103+
frame_width: int,
104+
frame_height: int,
105+
) -> List[TrackedObject]:
106+
"""
107+
Convert raw DepthAI Tracklets output into a list of TrackedObject.
108+
109+
Called each frame after reading from the device output queue.
110+
111+
Args:
112+
tracklets_data: Raw tracklets from device.getOutputQueue("tracklets").get()
113+
label_map: Ordered list of class names matching model label indices
114+
(e.g. ["person", "car", "landing_pad"]).
115+
frame_width: Original frame width in pixels (for denormalizing bbox).
116+
frame_height: Original frame height in pixels.
117+
118+
Returns:
119+
List of TrackedObject with persistent IDs, status, and smoothed
120+
spatial coordinates.
121+
"""
122+
tracked_objects: List[TrackedObject] = []
123+
124+
for tracklet in tracklets_data.tracklets:
125+
# --- status ---
126+
status = _STATUS_MAP.get(tracklet.status, TrackingStatus.LOST)
127+
128+
# skip objects that have been fully removed
129+
if tracklet.status == dai.Tracklet.TrackingStatus.REMOVED:
130+
continue
131+
132+
# --- label ---
133+
label_index = tracklet.label
134+
label = (
135+
label_map[label_index]
136+
if label_index < len(label_map)
137+
else str(label_index)
138+
)
139+
140+
# --- confidence ---
141+
confidence = tracklet.srcImgDetection.confidence
142+
143+
# --- smoothed spatial coordinates (meters) ---
144+
spatial = tracklet.spatialCoordinates
145+
x = spatial.x / 1000.0 # mm -> m
146+
y = spatial.y / 1000.0
147+
z = spatial.z / 1000.0
148+
149+
# --- bounding box (denormalize from 0-1 to pixels) ---
150+
roi = tracklet.roi.denormalize(frame_width, frame_height)
151+
bbox_x = int(roi.topLeft().x)
152+
bbox_y = int(roi.topLeft().y)
153+
bbox_width = int(roi.bottomRight().x - roi.topLeft().x)
154+
bbox_height = int(roi.bottomRight().y - roi.topLeft().y)
155+
156+
tracked_objects.append(
157+
TrackedObject(
158+
object_id=tracklet.id,
159+
status=status,
160+
label=label,
161+
confidence=confidence,
162+
x=x,
163+
y=y,
164+
z=z,
165+
bbox_x=bbox_x,
166+
bbox_y=bbox_y,
167+
bbox_width=bbox_width,
168+
bbox_height=bbox_height,
169+
)
170+
)
171+
172+
return tracked_objects
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
Worker process for ObjectTracker.
3+
4+
Reads tracklet output from the OAK-D device queue, converts it into
5+
TrackedObject data classes, and pushes them to the next pipeline stage.
6+
7+
Follows the existing worker pattern (producer-consumer via queues).
8+
"""
9+
10+
import logging
11+
from typing import List
12+
13+
import depthai as dai
14+
15+
from .object_tracker import configure_tracker_node, parse_tracklets
16+
from .tracked_object import TrackedObject
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
def object_tracker_run(
22+
pipeline: dai.Pipeline,
23+
spatial_detection_network: dai.node.SpatialDetectionNetwork,
24+
label_map: List[str],
25+
frame_width: int,
26+
frame_height: int,
27+
output_queue, # multiprocessing.Queue[List[TrackedObject]]
28+
tracker_type: str = "SHORT_TERM_IMAGELESS",
29+
labels_to_track: List[int] = None,
30+
) -> None:
31+
"""
32+
Main worker entry point for the ObjectTracker.
33+
34+
Configures the tracker node inside the given pipeline, then
35+
continuously reads tracklet output and pushes TrackedObject lists
36+
to output_queue.
37+
38+
In the full system the pipeline is started externally (because
39+
StereoDepth and SpatialDetectionNetwork share the same device
40+
pipeline). This function is called *before* pipeline start so it
41+
can wire the tracker node, and then enters the read loop *after*
42+
the caller starts the device.
43+
44+
Args:
45+
pipeline: The shared DepthAI pipeline.
46+
spatial_detection_network: Detection node to wire into.
47+
label_map: Ordered class names matching model label indices.
48+
frame_width: Frame width in pixels.
49+
frame_height: Frame height in pixels.
50+
output_queue: Queue for downstream consumers.
51+
tracker_type: Tracker algorithm name.
52+
labels_to_track: Label indices to track (None = all).
53+
"""
54+
configure_tracker_node(
55+
pipeline=pipeline,
56+
spatial_detection_network=spatial_detection_network,
57+
tracker_type=tracker_type,
58+
labels_to_track=labels_to_track,
59+
)
60+
61+
logger.info(
62+
"ObjectTracker node configured (type=%s). "
63+
"Waiting for pipeline to start on device.",
64+
tracker_type,
65+
)
66+
67+
68+
def object_tracker_read_loop(
69+
device: dai.Device,
70+
label_map: List[str],
71+
frame_width: int,
72+
frame_height: int,
73+
output_queue, # multiprocessing.Queue[List[TrackedObject]]
74+
) -> None:
75+
"""
76+
Blocking loop that reads tracklets from the device and pushes
77+
TrackedObject lists to output_queue.
78+
79+
Call this after the device has been started with the pipeline.
80+
81+
Args:
82+
device: Running OAK-D device.
83+
label_map: Ordered class names.
84+
frame_width: Frame width in pixels.
85+
frame_height: Frame height in pixels.
86+
output_queue: Queue for downstream consumers.
87+
"""
88+
tracklet_queue = device.getOutputQueue(
89+
name="tracklets",
90+
maxSize=4,
91+
blocking=False,
92+
)
93+
94+
logger.info("ObjectTracker read loop started.")
95+
96+
while True:
97+
tracklets_data = tracklet_queue.get() # blocks until next frame
98+
99+
tracked_objects = parse_tracklets(
100+
tracklets_data=tracklets_data,
101+
label_map=label_map,
102+
frame_width=frame_width,
103+
frame_height=frame_height,
104+
)
105+
106+
if tracked_objects:
107+
logger.debug(
108+
"Frame produced %d tracked objects: %s",
109+
len(tracked_objects),
110+
[
111+
f"id={t.object_id} status={t.status.value}"
112+
for t in tracked_objects
113+
],
114+
)
115+
116+
output_queue.put(tracked_objects)

0 commit comments

Comments
 (0)