Skip to content

Commit fda8d8b

Browse files
committed
Move shared retrack test helpers to common_retrack.py
1 parent 0b5c001 commit fda8d8b

2 files changed

Lines changed: 658 additions & 344 deletions

File tree

tests/functional/common_retrack.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
#!/usr/bin/env python3
2+
3+
# SPDX-FileCopyrightText: (C) 2026 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
"""Shared infrastructure for scene retrack functional tests.
7+
8+
RetrackTest encapsulates the MQTT/REST helpers and per-test scene
9+
state (parent_id, child_id, received message buffers) so that test
10+
functions in tc_scene_retrack.py remain free of module-level globals.
11+
"""
12+
13+
import json
14+
import math
15+
import threading
16+
import time
17+
18+
from scene_common.mqtt import PubSub
19+
from scene_common import log
20+
from scene_common.timestamp import get_iso_time
21+
22+
23+
class RetrackTest:
24+
"""! Infrastructure helpers for scene retrack functional tests.
25+
26+
Each test function should instantiate one RetrackTest and use it
27+
for scene setup, MQTT client management, and result collection. All
28+
per-test state (scene IDs, received message buffers) is stored on the
29+
instance, eliminating module-level globals.
30+
"""
31+
32+
FRAME_RATE = 10
33+
MAX_WAIT = 3
34+
NUM_PUBLISH_ITERATIONS = 5
35+
36+
def __init__(self, params):
37+
"""! Initialise an empty helper bound to the given connection parameters.
38+
39+
@param params Dict of functional-test connection parameters from
40+
the conftest fixture.
41+
"""
42+
self.params = params
43+
self.parent_id = None
44+
self.child_id = None
45+
self.parent_received = []
46+
self.child_received = []
47+
48+
def on_message(self, mqttc, obj, msg):
49+
"""! Default onMessage callback; routes regulated messages into
50+
parent_received or child_received based on scene_id.
51+
52+
@param mqttc MQTT client object.
53+
@param obj Private user data (unused).
54+
@param msg MQTTMessage instance.
55+
"""
56+
topic = PubSub.parseTopic(msg.topic)
57+
if topic is None:
58+
return
59+
data = json.loads(msg.payload.decode("utf-8"))
60+
obj_count = len(data.get('objects', []))
61+
if obj_count == 0:
62+
return
63+
if topic.get('scene_id') == self.parent_id:
64+
log.info(f"Parent regulated: {obj_count} objects")
65+
self.parent_received.append(data)
66+
elif topic.get('scene_id') == self.child_id:
67+
log.info(f"Child regulated: {obj_count} objects")
68+
self.child_received.append(data)
69+
70+
def setup_scenes(self, rest_client):
71+
"""! Create a fresh parent scene and link the existing Demo scene as
72+
child with retrack=True (default).
73+
74+
@param rest_client An authenticated RESTClient instance.
75+
"""
76+
parent_scene = rest_client.createScene({'name': "retrack_parent"})
77+
assert parent_scene.statusCode == 201, \
78+
f"Failed to create parent scene: {parent_scene.statusCode}"
79+
self.parent_id = parent_scene['uid']
80+
log.info(f"Created parent scene: {self.parent_id}")
81+
82+
scenes = rest_client.getScenes({'name': 'Demo'})
83+
assert scenes['count'] > 0, "Demo scene not found – required for retrack tests"
84+
child_scene = scenes['results'][0]
85+
self.child_id = child_scene['uid']
86+
log.info(f"Using Demo as child scene: {self.child_id}")
87+
88+
res = rest_client.updateScene(self.child_id, {'parent': self.parent_id})
89+
assert res.statusCode == 200, \
90+
f"Failed to link child to parent: {res.statusCode}"
91+
92+
child_links = rest_client.getChildScene({'parent': self.parent_id})
93+
assert child_links.statusCode == 200 and child_links['count'] == 1, \
94+
"Child-parent link not found after linking"
95+
96+
def teardown_scenes(self, rest_client):
97+
"""! Unlink the child scene and delete the parent scene created for
98+
the test. The Demo child scene is a fixture and is never deleted.
99+
100+
@param rest_client An authenticated RESTClient instance.
101+
"""
102+
if self.child_id and self.parent_id:
103+
res = rest_client.deleteChildSceneLink(self.child_id)
104+
log.info(f"[TEARDOWN] Unlinked child uid={self.child_id}: {res.statusCode}")
105+
if self.parent_id:
106+
res = rest_client.deleteScene(self.parent_id)
107+
log.info(f"[TEARDOWN] Deleted parent scene uid={self.parent_id}: {res.statusCode}")
108+
109+
def _await_db_notification(self, rest_fn):
110+
"""! Subscribe to CMD_DATABASE, call rest_fn(), then assert the
111+
notification arrives confirming the controller loaded the change.
112+
113+
@param rest_fn Zero-argument callable that performs the REST update.
114+
"""
115+
db_received = threading.Event()
116+
subscribed = threading.Event()
117+
db_topic = PubSub.formatTopic(PubSub.CMD_DATABASE)
118+
119+
def _on_db(mqttc, obj, msg):
120+
db_received.set()
121+
122+
def _on_connected(mqttc, obj, flags, rc):
123+
if rc == 0:
124+
mqttc.addCallback(db_topic, _on_db)
125+
subscribed.set()
126+
127+
tmp = PubSub(self.params["auth"], None, self.params["rootcert"],
128+
self.params["broker_url"], self.params["broker_port"])
129+
tmp.onConnect = _on_connected
130+
tmp.connect()
131+
tmp.loopStart()
132+
assert subscribed.wait(self.MAX_WAIT), \
133+
"Temporary MQTT client failed to subscribe to CMD_DATABASE within timeout"
134+
try:
135+
rest_fn()
136+
assert db_received.wait(self.MAX_WAIT), \
137+
"Timed out waiting for CMD_DATABASE notification"
138+
finally:
139+
tmp.loopStop()
140+
141+
def set_retrack(self, rest_client, value):
142+
"""! Update the retrack flag on the child scene link and wait for the
143+
CMD_DATABASE notification confirming the controller has loaded the change.
144+
145+
@param rest_client An authenticated RESTClient instance.
146+
@param value Boolean value for the retrack field.
147+
"""
148+
def _update():
149+
res = rest_client.updateChildScene(self.child_id, {'retrack': value})
150+
assert res.statusCode == 200, \
151+
f"Failed to set retrack={value}: {res.statusCode}"
152+
log.info(f"Set retrack={value} on child scene {self.child_id}")
153+
verify = rest_client.getChildScene({'parent': self.parent_id})
154+
assert verify.statusCode == 200, \
155+
f"Failed to read back child scene link after setting retrack={value}"
156+
actual = verify['results'][0]['retrack']
157+
log.info(f"Verify child link retrack value: {actual}")
158+
assert actual == value, \
159+
f"retrack mismatch: expected {value}, got {actual}"
160+
self._await_db_notification(_update)
161+
162+
def set_external_rate(self, rest_client, rate):
163+
"""! Update external_update_rate on the child scene and wait for the
164+
CMD_DATABASE notification confirming the controller has loaded the change.
165+
166+
@param rest_client An authenticated RESTClient instance.
167+
@param rate Float Hz value for external_update_rate.
168+
"""
169+
def _update():
170+
res = rest_client.updateScene(self.child_id, {'external_update_rate': rate})
171+
assert res.statusCode == 200, \
172+
f"Failed to set external_update_rate={rate}: {res.statusCode}"
173+
log.info(f"Set external_update_rate={rate} on scene {self.child_id}")
174+
self._await_db_notification(_update)
175+
176+
def make_client(self, topics=None, on_msg=None):
177+
"""! Create and start an MQTT PubSub client, subscribe to *topics* on
178+
connect, and block until the broker confirms connection.
179+
180+
Defaults to subscribing to DATA_REGULATED for both parent and child
181+
scenes with self.on_message as the callback when omitted.
182+
183+
@param topics List of MQTT topic strings. Defaults to
184+
DATA_REGULATED for parent_id and child_id.
185+
@param on_msg onMessage callback. Defaults to self.on_message.
186+
@return Connected PubSub instance.
187+
"""
188+
if topics is None:
189+
topics = [
190+
PubSub.formatTopic(PubSub.DATA_REGULATED, scene_id=self.parent_id),
191+
PubSub.formatTopic(PubSub.DATA_REGULATED, scene_id=self.child_id),
192+
]
193+
if on_msg is None:
194+
on_msg = self.on_message
195+
connected_event = threading.Event()
196+
197+
def _on_connect(mqttc, obj, flags, rc):
198+
if rc == 0:
199+
for t in topics:
200+
mqttc.subscribe(t)
201+
log.info(f"Subscribed: {t}")
202+
connected_event.set()
203+
204+
client = PubSub(self.params["auth"], None, self.params["rootcert"],
205+
self.params["broker_url"], self.params["broker_port"])
206+
client.onConnect = _on_connect
207+
client.onMessage = on_msg
208+
client.connect()
209+
client.loopStart()
210+
assert connected_event.wait(self.MAX_WAIT), \
211+
"MQTT client failed to connect within timeout"
212+
return client
213+
214+
def wait_for_messages(self, timeout=None, require_parent=True, require_child=True):
215+
"""! Block until at least one message with objects has arrived on the
216+
expected topics, or timeout expires.
217+
218+
@param timeout Maximum seconds to wait. Defaults to MAX_WAIT.
219+
@param require_parent Assert that parent received objects if True.
220+
@param require_child Assert that child received objects if True.
221+
"""
222+
if timeout is None:
223+
timeout = self.MAX_WAIT
224+
start = time.time()
225+
while time.time() - start < timeout:
226+
parent_ok = (not require_parent) or len(self.parent_received) > 0
227+
child_ok = (not require_child) or len(self.child_received) > 0
228+
if parent_ok and child_ok:
229+
return
230+
time.sleep(0.5)
231+
if require_parent:
232+
assert len(self.parent_received) > 0, \
233+
f"Timed out after {timeout}s: no objects on parent regulated topic"
234+
if require_child:
235+
assert len(self.child_received) > 0, \
236+
f"Timed out after {timeout}s: no objects on child regulated topic"
237+
238+
@staticmethod
239+
def collect_object_ids(messages):
240+
"""! Return the set of object id values from a list of regulated messages.
241+
242+
@param messages List of decoded regulated-data message dicts.
243+
@return Set of id strings found in 'objects' lists.
244+
"""
245+
ids = set()
246+
for msg in messages:
247+
for obj in msg.get('objects', []):
248+
if 'id' in obj:
249+
ids.add(obj['id'])
250+
return ids
251+
252+
@staticmethod
253+
def publish_data(obj_data, client, obj_category="person"):
254+
"""! Publish simulated object detection data to a camera's MQTT topic.
255+
256+
@param obj_data The object data fixture containing camera id and objects.
257+
@param client The MQTT PubSub client.
258+
@param obj_category The object category to publish (default: "person").
259+
"""
260+
cam_id = obj_data["id"]
261+
topic = PubSub.formatTopic(PubSub.DATA_CAMERA, camera_id=cam_id)
262+
for iteration in range(RetrackTest.NUM_PUBLISH_ITERATIONS):
263+
for i in range(5):
264+
obj_data["timestamp"] = get_iso_time()
265+
obj_data["objects"][obj_category][0]["bounding_box"]["y"] = 100 + (i * 20)
266+
obj_data["objects"][obj_category][0]["category"] = obj_category
267+
client.publish(topic, json.dumps(obj_data))
268+
log.info(
269+
f"Published object via camera {cam_id}: y={100 + (i * 20)} "
270+
f"(iter {iteration})")
271+
time.sleep(1.0 / RetrackTest.FRAME_RATE)
272+
273+
@staticmethod
274+
def publish_timed(obj_data, client, rate, duration):
275+
"""! Publish camera detections at *rate* Hz for *duration* seconds.
276+
277+
@param obj_data The object data fixture containing camera id and objects.
278+
@param client The MQTT PubSub client.
279+
@param rate Publish rate in Hz.
280+
@param duration Duration in seconds.
281+
"""
282+
cam_id = obj_data["id"]
283+
topic = PubSub.formatTopic(PubSub.DATA_CAMERA, camera_id=cam_id)
284+
end = time.time() + duration
285+
i = 0
286+
while time.time() < end:
287+
obj_data["timestamp"] = get_iso_time()
288+
obj_data["objects"]["person"][0]["bounding_box"]["y"] = 100 + (i % 5) * 20
289+
obj_data["objects"]["person"][0]["category"] = "person"
290+
client.publish(topic, json.dumps(obj_data))
291+
time.sleep(1.0 / rate)
292+
i += 1
293+
294+
@staticmethod
295+
def assert_valid_translation(tr, label):
296+
"""! Assert that *tr* is a list of exactly three finite numeric values.
297+
298+
@param tr The translation value to validate.
299+
@param label Human-readable label used in assertion messages.
300+
"""
301+
assert len(tr) == 3, \
302+
f"{label} 'translation' must have 3 elements, got {len(tr)}"
303+
for v in tr:
304+
assert isinstance(v, (int, float)), \
305+
f"{label} translation element not numeric: {v}"
306+
assert math.isfinite(v), \
307+
f"{label} translation element is not finite (NaN/Inf): {v}"

0 commit comments

Comments
 (0)