Skip to content

Commit 9771ff5

Browse files
committed
Add tests for TriangulateSession
1 parent b53c103 commit 9771ff5

File tree

1 file changed

+350
-6
lines changed

1 file changed

+350
-6
lines changed

tests/gui/test_commands.py

Lines changed: 350 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import sys
33
import time
44
from pathlib import Path, PurePath
5-
from typing import List
5+
from typing import Dict, List
6+
import numpy as np
67

78
import pytest
89

@@ -17,9 +18,11 @@
1718
RemoveVideo,
1819
ReplaceVideo,
1920
SaveProjectAs,
21+
TriangulateSession,
2022
get_new_version_filename,
2123
)
2224
from sleap.instance import Instance, LabeledFrame
25+
from sleap.io.cameras import Camcorder
2326
from sleap.io.convert import default_analysis_filename
2427
from sleap.io.dataset import Labels
2528
from sleap.io.format.adaptor import Adaptor
@@ -28,11 +31,11 @@
2831
from sleap.io.video import Video
2932
from sleap.util import get_package_file
3033

31-
# These imports cause trouble when running `pytest.main()` from within the file
32-
# Comment out to debug tests file via VSCode's "Debug Python File"
33-
from tests.info.test_h5 import extract_meta_hdf5
34-
from tests.io.test_formats import read_nix_meta
35-
from tests.io.test_video import assert_video_params
34+
# # These imports cause trouble when running `pytest.main()` from within the file
35+
# # Comment out to debug tests file via VSCode's "Debug Python File"
36+
# from tests.info.test_h5 import extract_meta_hdf5
37+
# from tests.io.test_formats import read_nix_meta
38+
# from tests.io.test_video import assert_video_params
3639

3740

3841
def test_delete_user_dialog(centered_pair_predictions):
@@ -952,3 +955,344 @@ def test_AddSession(
952955
assert len(labels.sessions) == 2
953956
assert context.state["session"] is session
954957
assert labels.sessions[1] is not session
958+
959+
960+
def test_triangulate_session_get_all_views_at_frame(
961+
multiview_min_session_labels: Labels,
962+
):
963+
labels = multiview_min_session_labels
964+
session = labels.sessions[0]
965+
lf = labels.labeled_frames[0]
966+
frame_idx = lf.frame_idx
967+
968+
# Test with no cams_to_include, expect views from all linked cameras
969+
views = TriangulateSession.get_all_views_at_frame(session, frame_idx)
970+
assert len(views) == len(session.linked_cameras)
971+
for cam in session.linked_cameras:
972+
assert views[cam].frame_idx == frame_idx
973+
assert views[cam].video == session[cam]
974+
975+
# Test with cams_to_include, expect views from only those cameras
976+
cams_to_include = session.linked_cameras[0:2]
977+
views = TriangulateSession.get_all_views_at_frame(
978+
session, frame_idx, cams_to_include=cams_to_include
979+
)
980+
assert len(views) == len(cams_to_include)
981+
for cam in cams_to_include:
982+
assert views[cam].frame_idx == frame_idx
983+
assert views[cam].video == session[cam]
984+
985+
986+
def test_triangulate_session_get_instances_across_views(
987+
multiview_min_session_labels: Labels,
988+
):
989+
990+
labels = multiview_min_session_labels
991+
session = labels.sessions[0]
992+
993+
# Test get_instances_across_views
994+
lf: LabeledFrame = labels[0]
995+
track = labels.tracks[0]
996+
instances: Dict[
997+
Camcorder, Instance
998+
] = TriangulateSession.get_instances_across_views(
999+
session=session, frame_idx=lf.frame_idx, track=track
1000+
)
1001+
assert len(instances) == len(session.videos)
1002+
for vid in session.videos:
1003+
cam = session[vid]
1004+
inst = instances[cam]
1005+
assert inst.frame_idx == lf.frame_idx
1006+
assert inst.track == track
1007+
assert inst.video == vid
1008+
1009+
# Try with excluding cam views
1010+
lf: LabeledFrame = labels[2]
1011+
track = labels.tracks[1]
1012+
cams_to_include = session.linked_cameras[:4]
1013+
videos_to_include: Dict[
1014+
Camcorder, Video
1015+
] = session.get_videos_from_selected_cameras(cams_to_include=cams_to_include)
1016+
assert len(cams_to_include) == 4
1017+
assert len(videos_to_include) == len(cams_to_include)
1018+
instances: Dict[
1019+
Camcorder, Instance
1020+
] = TriangulateSession.get_instances_across_views(
1021+
session=session,
1022+
frame_idx=lf.frame_idx,
1023+
track=track,
1024+
cams_to_include=cams_to_include,
1025+
)
1026+
assert len(instances) == len(
1027+
videos_to_include
1028+
) # May not be true if no instances at that frame
1029+
for cam, vid in videos_to_include.items():
1030+
inst = instances[cam]
1031+
assert inst.frame_idx == lf.frame_idx
1032+
assert inst.track == track
1033+
assert inst.video == vid
1034+
1035+
# Try with only a single view
1036+
cams_to_include = [session.linked_cameras[0]]
1037+
with pytest.raises(ValueError):
1038+
instances = TriangulateSession.get_instances_across_views(
1039+
session=session,
1040+
frame_idx=lf.frame_idx,
1041+
cams_to_include=cams_to_include,
1042+
track=track,
1043+
require_multiple_views=True,
1044+
)
1045+
1046+
# Try with multiple views, but not enough instances
1047+
track = labels.tracks[1]
1048+
cams_to_include = session.linked_cameras[4:6]
1049+
with pytest.raises(ValueError):
1050+
instances = TriangulateSession.get_instances_across_views(
1051+
session=session,
1052+
frame_idx=lf.frame_idx,
1053+
cams_to_include=cams_to_include,
1054+
track=track,
1055+
require_multiple_views=True,
1056+
)
1057+
1058+
1059+
def test_triangulate_session_get_and_verify_enough_instances(
1060+
multiview_min_session_labels: Labels,
1061+
caplog,
1062+
):
1063+
labels = multiview_min_session_labels
1064+
session = labels.sessions[0]
1065+
lf = labels.labeled_frames[0]
1066+
track = labels.tracks[1]
1067+
1068+
# Test with no cams_to_include, expect views from all linked cameras
1069+
instances = TriangulateSession.get_and_verify_enough_instances(
1070+
session=session, frame_idx=lf.frame_idx, track=track
1071+
)
1072+
assert len(instances) == 6 # Some views don't have an instance at this track
1073+
for cam in session.linked_cameras:
1074+
if cam.name in ["side", "sideL"]: # The views that don't have an instance
1075+
continue
1076+
assert instances[cam].frame_idx == lf.frame_idx
1077+
assert instances[cam].track == track
1078+
assert instances[cam].video == session[cam]
1079+
1080+
# Test with cams_to_include, expect views from only those cameras
1081+
cams_to_include = session.linked_cameras[-2:]
1082+
instances = TriangulateSession.get_and_verify_enough_instances(
1083+
session=session,
1084+
frame_idx=lf.frame_idx,
1085+
cams_to_include=cams_to_include,
1086+
track=track,
1087+
)
1088+
assert len(instances) == len(cams_to_include)
1089+
for cam in cams_to_include:
1090+
assert instances[cam].frame_idx == lf.frame_idx
1091+
assert instances[cam].track == track
1092+
assert instances[cam].video == session[cam]
1093+
1094+
# Test with not enough instances, expect views from only those cameras
1095+
cams_to_include = session.linked_cameras[0:2]
1096+
instances = TriangulateSession.get_and_verify_enough_instances(
1097+
session=session, frame_idx=lf.frame_idx, cams_to_include=cams_to_include
1098+
)
1099+
assert isinstance(instances, bool)
1100+
assert not instances
1101+
messages = "".join([rec.message for rec in caplog.records])
1102+
assert "One or less instances found for frame" in messages
1103+
1104+
1105+
def test_triangulate_session_verify_enough_views(
1106+
multiview_min_session_labels: Labels, caplog
1107+
):
1108+
labels = multiview_min_session_labels
1109+
session = labels.sessions[0]
1110+
1111+
# Test with enough views
1112+
enough_views = TriangulateSession.verify_enough_views(
1113+
session=session, show_dialog=False
1114+
)
1115+
assert enough_views
1116+
messages = "".join([rec.message for rec in caplog.records])
1117+
assert len(messages) == 0
1118+
caplog.clear()
1119+
1120+
# Test with not enough views
1121+
cams_to_include = [session.linked_cameras[0]]
1122+
enough_views = TriangulateSession.verify_enough_views(
1123+
session=session, cams_to_include=cams_to_include, show_dialog=False
1124+
)
1125+
assert not enough_views
1126+
messages = "".join([rec.message for rec in caplog.records])
1127+
assert "One or less cameras available." in messages
1128+
1129+
1130+
def test_triangulate_session_verify_views_and_instances(
1131+
multiview_min_session_labels: Labels,
1132+
):
1133+
labels = multiview_min_session_labels
1134+
session = labels.sessions[0]
1135+
1136+
# Test with enough views and instances
1137+
lf = labels.labeled_frames[0]
1138+
instance = lf.instances[0]
1139+
1140+
context = CommandContext.from_labels(labels)
1141+
params = {
1142+
"video": session.videos[0],
1143+
"session": session,
1144+
"frame_idx": lf.frame_idx,
1145+
"instance": instance,
1146+
"show_dialog": False,
1147+
}
1148+
enough_views = TriangulateSession.verify_views_and_instances(context, params)
1149+
assert enough_views
1150+
assert "instances" in params
1151+
1152+
# Test with not enough views
1153+
cams_to_include = [session.linked_cameras[0]]
1154+
params = {
1155+
"video": session.videos[0],
1156+
"session": session,
1157+
"frame_idx": lf.frame_idx,
1158+
"instance": instance,
1159+
"cams_to_include": cams_to_include,
1160+
"show_dialog": False,
1161+
}
1162+
enough_views = TriangulateSession.verify_views_and_instances(context, params)
1163+
assert not enough_views
1164+
assert "instances" not in params
1165+
1166+
1167+
def test_triangulate_session_calculate_reprojected_points(
1168+
multiview_min_session_labels: Labels,
1169+
):
1170+
"""Test `TriangulateSession.calculate_reprojected_points`."""
1171+
1172+
session = multiview_min_session_labels.sessions[0]
1173+
lf: LabeledFrame = multiview_min_session_labels[0]
1174+
track = multiview_min_session_labels.tracks[0]
1175+
instances: Dict[
1176+
Camcorder, Instance
1177+
] = TriangulateSession.get_instances_across_views(
1178+
session=session, frame_idx=lf.frame_idx, track=track
1179+
)
1180+
instances_and_coords = TriangulateSession.calculate_reprojected_points(
1181+
session=session, instances=instances
1182+
)
1183+
1184+
# Check that we get the same number of instances as input
1185+
assert len(instances) == len(list(instances_and_coords))
1186+
1187+
# Check that each instance has the same number of points
1188+
for inst, inst_coords in instances_and_coords:
1189+
assert inst_coords.shape[1] == len(inst.skeleton) # (1, 15, 2)
1190+
1191+
1192+
def test_triangulate_session_get_instances_matrices(
1193+
multiview_min_session_labels: Labels,
1194+
):
1195+
"""Test `TriangulateSession.get_instance_matrices`."""
1196+
labels = multiview_min_session_labels
1197+
session = labels.sessions[0]
1198+
lf: LabeledFrame = labels[0]
1199+
track = labels.tracks[0]
1200+
instances: Dict[
1201+
Camcorder, Instance
1202+
] = TriangulateSession.get_instances_across_views(
1203+
session=session, frame_idx=lf.frame_idx, track=track
1204+
)
1205+
instances_matrices = TriangulateSession.get_instances_matrices(
1206+
instances_ordered=instances.values()
1207+
)
1208+
1209+
# Verify shape
1210+
n_views = len(instances)
1211+
n_frames = 1
1212+
n_tracks = 1
1213+
n_nodes = len(labels.skeleton)
1214+
assert instances_matrices.shape == (n_views, n_frames, n_tracks, n_nodes, 2)
1215+
1216+
1217+
def test_triangulate_session_update_instances(multiview_min_session_labels: Labels):
1218+
"""Test `RecordingSession.update_instances`."""
1219+
1220+
# Test update_instances
1221+
session = multiview_min_session_labels.sessions[0]
1222+
lf: LabeledFrame = multiview_min_session_labels[0]
1223+
track = multiview_min_session_labels.tracks[0]
1224+
instances: Dict[
1225+
Camcorder, Instance
1226+
] = TriangulateSession.get_instances_across_views(
1227+
session=session, frame_idx=lf.frame_idx, track=track
1228+
)
1229+
instances_and_coordinates = TriangulateSession.calculate_reprojected_points(
1230+
session=session, instances=instances
1231+
)
1232+
for inst, inst_coords in instances_and_coordinates:
1233+
assert inst_coords.shape == (1, len(inst.skeleton), 2) # Tracks, Nodes, 2
1234+
# Assert coord are different from original
1235+
assert not np.array_equal(inst_coords, inst.points_array)
1236+
1237+
# Just run for code coverage testing, do not test output here (race condition)
1238+
# (see "functional core, imperative shell" pattern)
1239+
TriangulateSession.update_instances(session=session, instances=instances)
1240+
1241+
1242+
def test_triangulate_session_do_action(multiview_min_session_labels: Labels):
1243+
"""Test `TriangulateSession.do_action`."""
1244+
1245+
labels = multiview_min_session_labels
1246+
session = labels.sessions[0]
1247+
1248+
# Test with enough views and instances
1249+
lf = labels.labeled_frames[0]
1250+
instance = lf.instances[0]
1251+
1252+
context = CommandContext.from_labels(labels)
1253+
params = {
1254+
"video": session.videos[0],
1255+
"session": session,
1256+
"frame_idx": lf.frame_idx,
1257+
"instance": instance,
1258+
"ask_again": True,
1259+
}
1260+
TriangulateSession.do_action(context, params)
1261+
1262+
# Test with not enough views
1263+
cams_to_include = [session.linked_cameras[0]]
1264+
params = {
1265+
"video": session.videos[0],
1266+
"session": session,
1267+
"frame_idx": lf.frame_idx,
1268+
"instance": instance,
1269+
"cams_to_include": cams_to_include,
1270+
"ask_again": True,
1271+
}
1272+
TriangulateSession.do_action(context, params)
1273+
1274+
1275+
def test_triangulate_session(multiview_min_session_labels: Labels):
1276+
"""Test `TriangulateSession`, if"""
1277+
1278+
labels = multiview_min_session_labels
1279+
session = labels.sessions[0]
1280+
video = session.videos[0]
1281+
lf = labels.labeled_frames[0]
1282+
instance = lf.instances[0]
1283+
context = CommandContext.from_labels(labels)
1284+
1285+
# Test with enough views and instances so we don't get any GUI pop-ups
1286+
context.triangulateSession(
1287+
frame_idx=lf.frame_idx,
1288+
video=video,
1289+
instance=instance,
1290+
session=session,
1291+
)
1292+
1293+
# Test with using state to gather params
1294+
context.state["session"] = session
1295+
context.state["video"] = video
1296+
context.state["instance"] = instance
1297+
context.state["frame_idx"] = lf.frame_idx
1298+
context.triangulateSession()

0 commit comments

Comments
 (0)