Skip to content

Commit f87cec7

Browse files
committed
move shared utilities to common_child.py
1 parent 4c2253a commit f87cec7

File tree

2 files changed

+451
-505
lines changed

2 files changed

+451
-505
lines changed

tests/functional/common_child.py

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
#!/usr/bin/env python3
2+
3+
# SPDX-FileCopyrightText: (C) 2026 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
import json
7+
import threading
8+
import time
9+
10+
import numpy as np
11+
12+
from scene_common import log
13+
from scene_common.mqtt import PubSub
14+
from scene_common.rest_client import RESTClient
15+
from scene_common.timestamp import get_iso_time
16+
17+
FRAME_RATE = 10
18+
MAX_WAIT = 3
19+
NUM_PUBLISH_ITERATIONS = 3
20+
PERSON = "person"
21+
REGION = "region"
22+
TRIPWIRE = "tripwire"
23+
24+
# Object bounding-box y-sweep that produces world-coordinate trajectories
25+
# crossing both the ROI and the tripwire defined in setup_scenes().
26+
# Range matches FunctionalTest.getLocations() used by tc_tripwire_mqtt.py.
27+
_step = 0.02
28+
_opposite = np.arange(-0.5, 0.6, _step)
29+
_across = np.flip(_opposite)[2:]
30+
OBJ_Y_LOCATIONS = np.concatenate((_opposite, _across))
31+
32+
33+
class ChildSceneTest:
34+
"""Manages shared state and common operations for child-scene event tests.
35+
36+
Attributes hold scene/region IDs, accumulated MQTT event messages, and the
37+
connection flag so that tests interact with instance state rather than
38+
module-level globals.
39+
"""
40+
41+
def __init__(self, params):
42+
"""Initialise helper with test parameters dict (from conftest ``params`` fixture)."""
43+
self.params = params
44+
45+
# Scene / region IDs populated by setup_scenes()
46+
self.parent_id = None
47+
self.child_id = None
48+
self.roi_uid = None
49+
self.tripwire_uid = None
50+
self.sensor_uid = None
51+
52+
# Tracks whether the child has already been unlinked (so teardown skips it)
53+
self.child_unlinked = False
54+
55+
# MQTT connection flag
56+
self.connected = False
57+
58+
# Accumulated event messages keyed by category
59+
self.parent_roi_events = []
60+
self.parent_tripwire_events = []
61+
self.parent_sensor_events = []
62+
self.child_roi_events = []
63+
self.child_tripwire_events = []
64+
self.child_sensor_events = []
65+
66+
def make_rest_client(self):
67+
"""Return an authenticated :class:`RESTClient` instance."""
68+
rest_client = RESTClient(self.params["resturl"], rootcert=self.params["rootcert"])
69+
assert rest_client.authenticate(self.params["user"], self.params["password"])
70+
return rest_client
71+
72+
def setup_scenes(self, rest_client, link=True):
73+
"""Create parent scene, optionally link Demo as child, create ROI, tripwire, sensor.
74+
75+
Populates instance attributes with all created UIDs. When *link* is
76+
``False`` the child is located but not linked to the parent, and
77+
:attr:`child_unlinked` is set so :meth:`teardown_scenes` skips the
78+
unlink step.
79+
80+
@param rest_client An authenticated :class:`RESTClient`.
81+
@param link Whether to link the child scene to the parent (default ``True``).
82+
"""
83+
# Create parent scene
84+
parent_scene = rest_client.createScene({"name": "parent_event_test"})
85+
assert parent_scene.statusCode == 201, (
86+
f"Expected 201 creating parent scene, got {parent_scene.statusCode}: {parent_scene.errors}")
87+
self.parent_id = parent_scene["uid"]
88+
log.info(f"[SETUP] Parent scene uid={self.parent_id}")
89+
90+
# Locate the Demo child scene (it has a registered camera)
91+
scenes = rest_client.getScenes({"name": "Demo"})
92+
assert scenes["count"] > 0, "Demo scene not found – required for child camera"
93+
self.child_id = scenes["results"][0]["uid"]
94+
log.info(f"[SETUP] Child scene uid={self.child_id}")
95+
96+
# Link Demo as child of parent (skipped when link=False)
97+
if link:
98+
res = rest_client.updateScene(self.child_id, {"parent": self.parent_id})
99+
assert res.statusCode == 200, (
100+
f"Expected 200 linking child to parent, got {res.statusCode}: {res.errors}")
101+
log.info("[SETUP] Linked child to parent")
102+
103+
# Verify link
104+
res = rest_client.getChildScene({"parent": self.parent_id})
105+
assert res.statusCode == 200, (
106+
f"Expected 200 fetching child scenes, got {res.statusCode}: {res.errors}")
107+
else:
108+
self.child_unlinked = True
109+
log.info("[SETUP] Child NOT linked to parent (link=False)")
110+
111+
# Create ROI in child scene – spans most of the floor plan
112+
roi_points = ((1.38, 5.94), (1.17, 0.8), (7.41, 0.83), (7.35, 6.01))
113+
roi_res = rest_client.createRegion({
114+
"scene": self.child_id,
115+
"name": "TestROI_child",
116+
"points": roi_points,
117+
})
118+
assert roi_res.statusCode == 201, (
119+
f"Expected 201 creating ROI, got {roi_res.statusCode}: {roi_res.errors}")
120+
self.roi_uid = roi_res["uid"]
121+
log.info(f"[SETUP] ROI uid={self.roi_uid}")
122+
123+
# Create tripwire in child scene using the same centre-horizontal geometry
124+
# as tc_tripwire_mqtt.py (create_tripwire_by_ratio with x_ratio=0.8).
125+
# Demo scene: width=900 px, height=643 px, scale=100 px/m → cx=4.5, cy=3.215
126+
_demo_cx = 900 / (2 * 100) # 4.5 m
127+
_demo_cy = 643 / (2 * 100) # 3.215 m
128+
_demo_dx = _demo_cx * 0.8 # 3.6 m
129+
tw_res = rest_client.createTripwire({
130+
"scene": self.child_id,
131+
"name": "TestTripwire_child",
132+
"points": ((_demo_cx - _demo_dx, _demo_cy), (_demo_cx + _demo_dx, _demo_cy)),
133+
})
134+
assert tw_res.statusCode == 201, (
135+
f"Expected 201 creating tripwire, got {tw_res.statusCode}: {tw_res.errors}")
136+
self.tripwire_uid = tw_res["uid"]
137+
log.info(f"[SETUP] Tripwire uid={self.tripwire_uid}")
138+
139+
# Create sensor in child scene
140+
sensor_res = rest_client.createSensor({
141+
"scene": self.child_id,
142+
"name": "TestSensor_child",
143+
"area": "circle",
144+
"radius": 3.21,
145+
"center": (4.5, 3.22),
146+
})
147+
assert sensor_res.statusCode == 201, (
148+
f"Expected 201 creating sensor, got {sensor_res.statusCode}: {sensor_res.errors}")
149+
self.sensor_uid = sensor_res["uid"]
150+
log.info(f"[SETUP] Sensor uid={self.sensor_uid}")
151+
152+
def teardown_scenes(self, rest_client):
153+
"""Remove created analytics objects, unlink child, and delete parent scene.
154+
155+
@param rest_client An authenticated :class:`RESTClient`.
156+
"""
157+
for uid, label, fn in [
158+
(self.roi_uid, "ROI", rest_client.deleteRegion),
159+
(self.tripwire_uid, "Tripwire", rest_client.deleteTripwire),
160+
(self.sensor_uid, "Sensor", rest_client.deleteSensor),
161+
]:
162+
if uid:
163+
res = fn(uid)
164+
log.info(f"[TEARDOWN] Deleted {label} uid={uid}: {res.statusCode}")
165+
166+
if self.child_id and not self.child_unlinked:
167+
res = rest_client.deleteChildSceneLink(self.child_id)
168+
log.info(f"[TEARDOWN] Unlinked child uid={self.child_id}: {res.statusCode}")
169+
170+
if self.parent_id:
171+
res = rest_client.deleteScene(self.parent_id)
172+
log.info(f"[TEARDOWN] Deleted parent scene uid={self.parent_id}: {res.statusCode}")
173+
174+
def unlink_child(self, rest_client):
175+
"""Unlink the child scene from its parent and record that it has been done.
176+
177+
Calling this mid-test means :meth:`teardown_scenes` will skip the unlink
178+
step, avoiding a double-unlink error.
179+
180+
@param rest_client An authenticated :class:`RESTClient`.
181+
"""
182+
res = rest_client.deleteChildSceneLink(self.child_id)
183+
assert res.statusCode == 200, (
184+
f"Expected 200 deleting child link, got {res.statusCode}: {res.errors}")
185+
self.child_unlinked = True
186+
log.info(f"Unlinked child uid={self.child_id}: {res.statusCode}")
187+
188+
def _subscribe_event(self, mqttc, label, region_type, scene_id, region_id):
189+
"""Format an EVENT topic, subscribe, and log the subscription."""
190+
t = PubSub.formatTopic(PubSub.EVENT, region_type=region_type, event_type="+",
191+
scene_id=scene_id, region_id=region_id)
192+
mqttc.subscribe(t)
193+
log.info(f"Subscribed {label}: {t}")
194+
195+
def _on_connect(self, mqttc, obj, flags, rc):
196+
"""Subscribe to all relevant event topics once connected."""
197+
if rc != 0:
198+
log.error(f"MQTT connect failed with rc={rc}")
199+
return
200+
201+
log.info("MQTT connected")
202+
self.connected = True
203+
204+
self._subscribe_event(mqttc, "child ROI events", REGION, self.child_id, self.roi_uid)
205+
self._subscribe_event(mqttc, "child tripwire events", TRIPWIRE, self.child_id, self.tripwire_uid)
206+
if self.sensor_uid:
207+
self._subscribe_event(mqttc, "child sensor events", REGION, self.child_id, self.sensor_uid)
208+
209+
# Parent equivalents (republished by controller)
210+
self._subscribe_event(mqttc, "parent ROI events", REGION, self.parent_id, self.roi_uid)
211+
self._subscribe_event(mqttc, "parent tripwire events", TRIPWIRE, self.parent_id, self.tripwire_uid)
212+
if self.sensor_uid:
213+
self._subscribe_event(mqttc, "parent sensor events", REGION, self.parent_id, self.sensor_uid)
214+
215+
def _on_message(self, mqttc, obj, msg):
216+
"""Route incoming MQTT messages to the correct accumulator list."""
217+
topic = PubSub.parseTopic(msg.topic)
218+
if topic is None:
219+
return
220+
221+
try:
222+
data = json.loads(msg.payload.decode("utf-8"))
223+
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
224+
log.warning(f"Failed to decode MQTT payload on {msg.topic}: {exc}")
225+
return
226+
227+
scene_id = topic.get("scene_id")
228+
region_id = topic.get("region_id")
229+
region_type = topic.get("region_type")
230+
231+
if topic.get("_topic_id") != PubSub.EVENT:
232+
return
233+
234+
if scene_id == self.child_id and region_id == self.roi_uid and region_type == REGION:
235+
self.child_roi_events.append(data)
236+
log.info(f"Child ROI event received: {len(self.child_roi_events)} total")
237+
238+
elif scene_id == self.child_id and region_id == self.tripwire_uid and region_type == TRIPWIRE:
239+
self.child_tripwire_events.append(data)
240+
log.info(f"Child tripwire event received: {len(self.child_tripwire_events)} total")
241+
242+
elif (scene_id == self.child_id and self.sensor_uid
243+
and region_id == self.sensor_uid and region_type == REGION):
244+
self.child_sensor_events.append(data)
245+
log.info(f"Child sensor event received: {len(self.child_sensor_events)} total")
246+
247+
elif scene_id == self.parent_id and region_id == self.roi_uid and region_type == REGION:
248+
self.parent_roi_events.append(data)
249+
log.info(f"Parent ROI event received: {len(self.parent_roi_events)} total")
250+
251+
elif scene_id == self.parent_id and region_id == self.tripwire_uid and region_type == TRIPWIRE:
252+
self.parent_tripwire_events.append(data)
253+
log.info(f"Parent tripwire event received: {len(self.parent_tripwire_events)} total")
254+
255+
elif (scene_id == self.parent_id and self.sensor_uid
256+
and region_id == self.sensor_uid and region_type == REGION):
257+
self.parent_sensor_events.append(data)
258+
log.info(f"Parent sensor event received: {len(self.parent_sensor_events)} total")
259+
260+
def connect_mqtt(self):
261+
"""Create a :class:`PubSub` client, attach callbacks, connect, and wait.
262+
263+
@return The connected :class:`PubSub` client.
264+
"""
265+
client = PubSub(self.params["auth"], None, self.params["rootcert"],
266+
self.params["broker_url"], self.params["broker_port"])
267+
client.onConnect = self._on_connect
268+
client.onMessage = self._on_message
269+
client.connect()
270+
client.loopStart()
271+
272+
start = time.time()
273+
while not self.connected and time.time() - start < MAX_WAIT:
274+
time.sleep(0.5)
275+
assert self.connected, "MQTT client failed to connect within timeout"
276+
return client
277+
278+
def _send_detections(self, client, obj_data, y_locations, stop_event):
279+
"""Publish person detections through a y-sweep to trigger enter/exit events.
280+
281+
Stops early if *stop_event* is set. Called internally by
282+
:meth:`start_detection_thread`.
283+
284+
@param client Connected :class:`PubSub` client.
285+
@param obj_data Detection payload dict (modified in-place).
286+
@param y_locations Iterable of bounding-box y values.
287+
@param stop_event :class:`threading.Event`, publishing stops when set.
288+
"""
289+
cam_id = obj_data["id"]
290+
topic = PubSub.formatTopic(PubSub.DATA_CAMERA, camera_id=cam_id)
291+
for _ in range(NUM_PUBLISH_ITERATIONS):
292+
for y in y_locations:
293+
if stop_event.is_set():
294+
return
295+
obj_data["timestamp"] = get_iso_time()
296+
obj_data["objects"][PERSON][0]["bounding_box"]["y"] = float(y)
297+
obj_data["objects"][PERSON][0]["category"] = PERSON
298+
client.publish(topic, json.dumps(obj_data))
299+
time.sleep(1.0 / FRAME_RATE)
300+
301+
def send_sensor_value(self, client, sensor_name, value):
302+
"""Publish a singleton sensor reading to DATA_SENSOR topic.
303+
304+
@param client Connected :class:`PubSub` client.
305+
@param sensor_name Sensor identifier string.
306+
@param value Sensor reading value.
307+
"""
308+
message = {
309+
"timestamp": get_iso_time(),
310+
"id": sensor_name,
311+
"value": value,
312+
}
313+
topic = PubSub.formatTopic(PubSub.DATA_SENSOR, sensor_id=sensor_name)
314+
client.publish(topic, json.dumps(message))
315+
log.info(f"Published sensor value: id={sensor_name}, value={value}")
316+
317+
def start_detection_thread(self, client, obj_data, stop_event, y_locations=None):
318+
"""Spawn and start a daemon thread that publishes detections.
319+
320+
@param client Connected :class:`PubSub` client.
321+
@param obj_data Detection payload dict.
322+
@param stop_event :class:`threading.Event` to stop publishing.
323+
@param y_locations Y-sweep values, defaults to :data:`OBJ_Y_LOCATIONS`.
324+
@return The started :class:`threading.Thread`.
325+
"""
326+
if y_locations is None:
327+
y_locations = OBJ_Y_LOCATIONS
328+
thread = threading.Thread(
329+
target=self._send_detections,
330+
args=(client, obj_data, y_locations, stop_event),
331+
daemon=True,
332+
)
333+
thread.start()
334+
return thread
335+
336+
def wait_for_events(self, attr, timeout=MAX_WAIT):
337+
"""Block until at least one event is present in the named attribute.
338+
339+
@param attr Name of the list attribute to poll (e.g. ``"parent_roi_events"``).
340+
@param timeout Maximum seconds to wait.
341+
@return ``True`` if events arrived within *timeout*, ``False`` otherwise.
342+
"""
343+
start = time.time()
344+
while time.time() - start < timeout:
345+
if getattr(self, attr):
346+
return True
347+
time.sleep(0.5)
348+
return False

0 commit comments

Comments
 (0)