|
2 | 2 | import sys |
3 | 3 | import time |
4 | 4 | from pathlib import Path, PurePath |
5 | | -from typing import List |
| 5 | +from typing import Dict, List |
| 6 | +import numpy as np |
6 | 7 |
|
7 | 8 | import pytest |
8 | 9 |
|
|
17 | 18 | RemoveVideo, |
18 | 19 | ReplaceVideo, |
19 | 20 | SaveProjectAs, |
| 21 | + TriangulateSession, |
20 | 22 | get_new_version_filename, |
21 | 23 | ) |
22 | 24 | from sleap.instance import Instance, LabeledFrame |
| 25 | +from sleap.io.cameras import Camcorder |
23 | 26 | from sleap.io.convert import default_analysis_filename |
24 | 27 | from sleap.io.dataset import Labels |
25 | 28 | from sleap.io.format.adaptor import Adaptor |
@@ -952,3 +955,344 @@ def test_AddSession( |
952 | 955 | assert len(labels.sessions) == 2 |
953 | 956 | assert context.state["session"] is session |
954 | 957 | assert labels.sessions[1] is not session |
| 958 | + |
| 959 | + |
| 960 | +def test_triangulate_session_get_all_views_at_frame( |
| 961 | + multiview_min_session_labels: Labels, |
| 962 | +): |
| 963 | + labels = multiview_min_session_labels |
| 964 | + session = labels.sessions[0] |
| 965 | + lf = labels.labeled_frames[0] |
| 966 | + frame_idx = lf.frame_idx |
| 967 | + |
| 968 | + # Test with no cams_to_include, expect views from all linked cameras |
| 969 | + views = TriangulateSession.get_all_views_at_frame(session, frame_idx) |
| 970 | + assert len(views) == len(session.linked_cameras) |
| 971 | + for cam in session.linked_cameras: |
| 972 | + assert views[cam].frame_idx == frame_idx |
| 973 | + assert views[cam].video == session[cam] |
| 974 | + |
| 975 | + # Test with cams_to_include, expect views from only those cameras |
| 976 | + cams_to_include = session.linked_cameras[0:2] |
| 977 | + views = TriangulateSession.get_all_views_at_frame( |
| 978 | + session, frame_idx, cams_to_include=cams_to_include |
| 979 | + ) |
| 980 | + assert len(views) == len(cams_to_include) |
| 981 | + for cam in cams_to_include: |
| 982 | + assert views[cam].frame_idx == frame_idx |
| 983 | + assert views[cam].video == session[cam] |
| 984 | + |
| 985 | + |
| 986 | +def test_triangulate_session_get_instances_across_views( |
| 987 | + multiview_min_session_labels: Labels, |
| 988 | +): |
| 989 | + |
| 990 | + labels = multiview_min_session_labels |
| 991 | + session = labels.sessions[0] |
| 992 | + |
| 993 | + # Test get_instances_across_views |
| 994 | + lf: LabeledFrame = labels[0] |
| 995 | + track = labels.tracks[0] |
| 996 | + instances: Dict[ |
| 997 | + Camcorder, Instance |
| 998 | + ] = TriangulateSession.get_instances_across_views( |
| 999 | + session=session, frame_idx=lf.frame_idx, track=track |
| 1000 | + ) |
| 1001 | + assert len(instances) == len(session.videos) |
| 1002 | + for vid in session.videos: |
| 1003 | + cam = session[vid] |
| 1004 | + inst = instances[cam] |
| 1005 | + assert inst.frame_idx == lf.frame_idx |
| 1006 | + assert inst.track == track |
| 1007 | + assert inst.video == vid |
| 1008 | + |
| 1009 | + # Try with excluding cam views |
| 1010 | + lf: LabeledFrame = labels[2] |
| 1011 | + track = labels.tracks[1] |
| 1012 | + cams_to_include = session.linked_cameras[:4] |
| 1013 | + videos_to_include: Dict[ |
| 1014 | + Camcorder, Video |
| 1015 | + ] = session.get_videos_from_selected_cameras(cams_to_include=cams_to_include) |
| 1016 | + assert len(cams_to_include) == 4 |
| 1017 | + assert len(videos_to_include) == len(cams_to_include) |
| 1018 | + instances: Dict[ |
| 1019 | + Camcorder, Instance |
| 1020 | + ] = TriangulateSession.get_instances_across_views( |
| 1021 | + session=session, |
| 1022 | + frame_idx=lf.frame_idx, |
| 1023 | + track=track, |
| 1024 | + cams_to_include=cams_to_include, |
| 1025 | + ) |
| 1026 | + assert len(instances) == len( |
| 1027 | + videos_to_include |
| 1028 | + ) # May not be true if no instances at that frame |
| 1029 | + for cam, vid in videos_to_include.items(): |
| 1030 | + inst = instances[cam] |
| 1031 | + assert inst.frame_idx == lf.frame_idx |
| 1032 | + assert inst.track == track |
| 1033 | + assert inst.video == vid |
| 1034 | + |
| 1035 | + # Try with only a single view |
| 1036 | + cams_to_include = [session.linked_cameras[0]] |
| 1037 | + with pytest.raises(ValueError): |
| 1038 | + instances = TriangulateSession.get_instances_across_views( |
| 1039 | + session=session, |
| 1040 | + frame_idx=lf.frame_idx, |
| 1041 | + cams_to_include=cams_to_include, |
| 1042 | + track=track, |
| 1043 | + require_multiple_views=True, |
| 1044 | + ) |
| 1045 | + |
| 1046 | + # Try with multiple views, but not enough instances |
| 1047 | + track = labels.tracks[1] |
| 1048 | + cams_to_include = session.linked_cameras[4:6] |
| 1049 | + with pytest.raises(ValueError): |
| 1050 | + instances = TriangulateSession.get_instances_across_views( |
| 1051 | + session=session, |
| 1052 | + frame_idx=lf.frame_idx, |
| 1053 | + cams_to_include=cams_to_include, |
| 1054 | + track=track, |
| 1055 | + require_multiple_views=True, |
| 1056 | + ) |
| 1057 | + |
| 1058 | + |
| 1059 | +def test_triangulate_session_get_and_verify_enough_instances( |
| 1060 | + multiview_min_session_labels: Labels, |
| 1061 | + caplog, |
| 1062 | +): |
| 1063 | + labels = multiview_min_session_labels |
| 1064 | + session = labels.sessions[0] |
| 1065 | + lf = labels.labeled_frames[0] |
| 1066 | + track = labels.tracks[1] |
| 1067 | + |
| 1068 | + # Test with no cams_to_include, expect views from all linked cameras |
| 1069 | + instances = TriangulateSession.get_and_verify_enough_instances( |
| 1070 | + session=session, frame_idx=lf.frame_idx, track=track |
| 1071 | + ) |
| 1072 | + assert len(instances) == 6 # Some views don't have an instance at this track |
| 1073 | + for cam in session.linked_cameras: |
| 1074 | + if cam.name in ["side", "sideL"]: # The views that don't have an instance |
| 1075 | + continue |
| 1076 | + assert instances[cam].frame_idx == lf.frame_idx |
| 1077 | + assert instances[cam].track == track |
| 1078 | + assert instances[cam].video == session[cam] |
| 1079 | + |
| 1080 | + # Test with cams_to_include, expect views from only those cameras |
| 1081 | + cams_to_include = session.linked_cameras[-2:] |
| 1082 | + instances = TriangulateSession.get_and_verify_enough_instances( |
| 1083 | + session=session, |
| 1084 | + frame_idx=lf.frame_idx, |
| 1085 | + cams_to_include=cams_to_include, |
| 1086 | + track=track, |
| 1087 | + ) |
| 1088 | + assert len(instances) == len(cams_to_include) |
| 1089 | + for cam in cams_to_include: |
| 1090 | + assert instances[cam].frame_idx == lf.frame_idx |
| 1091 | + assert instances[cam].track == track |
| 1092 | + assert instances[cam].video == session[cam] |
| 1093 | + |
| 1094 | + # Test with not enough instances, expect views from only those cameras |
| 1095 | + cams_to_include = session.linked_cameras[0:2] |
| 1096 | + instances = TriangulateSession.get_and_verify_enough_instances( |
| 1097 | + session=session, frame_idx=lf.frame_idx, cams_to_include=cams_to_include |
| 1098 | + ) |
| 1099 | + assert isinstance(instances, bool) |
| 1100 | + assert not instances |
| 1101 | + messages = "".join([rec.message for rec in caplog.records]) |
| 1102 | + assert "One or less instances found for frame" in messages |
| 1103 | + |
| 1104 | + |
| 1105 | +def test_triangulate_session_verify_enough_views( |
| 1106 | + multiview_min_session_labels: Labels, caplog |
| 1107 | +): |
| 1108 | + labels = multiview_min_session_labels |
| 1109 | + session = labels.sessions[0] |
| 1110 | + |
| 1111 | + # Test with enough views |
| 1112 | + enough_views = TriangulateSession.verify_enough_views( |
| 1113 | + session=session, show_dialog=False |
| 1114 | + ) |
| 1115 | + assert enough_views |
| 1116 | + messages = "".join([rec.message for rec in caplog.records]) |
| 1117 | + assert len(messages) == 0 |
| 1118 | + caplog.clear() |
| 1119 | + |
| 1120 | + # Test with not enough views |
| 1121 | + cams_to_include = [session.linked_cameras[0]] |
| 1122 | + enough_views = TriangulateSession.verify_enough_views( |
| 1123 | + session=session, cams_to_include=cams_to_include, show_dialog=False |
| 1124 | + ) |
| 1125 | + assert not enough_views |
| 1126 | + messages = "".join([rec.message for rec in caplog.records]) |
| 1127 | + assert "One or less cameras available." in messages |
| 1128 | + |
| 1129 | + |
| 1130 | +def test_triangulate_session_verify_views_and_instances( |
| 1131 | + multiview_min_session_labels: Labels, |
| 1132 | +): |
| 1133 | + labels = multiview_min_session_labels |
| 1134 | + session = labels.sessions[0] |
| 1135 | + |
| 1136 | + # Test with enough views and instances |
| 1137 | + lf = labels.labeled_frames[0] |
| 1138 | + instance = lf.instances[0] |
| 1139 | + |
| 1140 | + context = CommandContext.from_labels(labels) |
| 1141 | + params = { |
| 1142 | + "video": session.videos[0], |
| 1143 | + "session": session, |
| 1144 | + "frame_idx": lf.frame_idx, |
| 1145 | + "instance": instance, |
| 1146 | + "show_dialog": False, |
| 1147 | + } |
| 1148 | + enough_views = TriangulateSession.verify_views_and_instances(context, params) |
| 1149 | + assert enough_views |
| 1150 | + assert "instances" in params |
| 1151 | + |
| 1152 | + # Test with not enough views |
| 1153 | + cams_to_include = [session.linked_cameras[0]] |
| 1154 | + params = { |
| 1155 | + "video": session.videos[0], |
| 1156 | + "session": session, |
| 1157 | + "frame_idx": lf.frame_idx, |
| 1158 | + "instance": instance, |
| 1159 | + "cams_to_include": cams_to_include, |
| 1160 | + "show_dialog": False, |
| 1161 | + } |
| 1162 | + enough_views = TriangulateSession.verify_views_and_instances(context, params) |
| 1163 | + assert not enough_views |
| 1164 | + assert "instances" not in params |
| 1165 | + |
| 1166 | + |
| 1167 | +def test_triangulate_session_calculate_reprojected_points( |
| 1168 | + multiview_min_session_labels: Labels, |
| 1169 | +): |
| 1170 | + """Test `TriangulateSession.calculate_reprojected_points`.""" |
| 1171 | + |
| 1172 | + session = multiview_min_session_labels.sessions[0] |
| 1173 | + lf: LabeledFrame = multiview_min_session_labels[0] |
| 1174 | + track = multiview_min_session_labels.tracks[0] |
| 1175 | + instances: Dict[ |
| 1176 | + Camcorder, Instance |
| 1177 | + ] = TriangulateSession.get_instances_across_views( |
| 1178 | + session=session, frame_idx=lf.frame_idx, track=track |
| 1179 | + ) |
| 1180 | + instances_and_coords = TriangulateSession.calculate_reprojected_points( |
| 1181 | + session=session, instances=instances |
| 1182 | + ) |
| 1183 | + |
| 1184 | + # Check that we get the same number of instances as input |
| 1185 | + assert len(instances) == len(list(instances_and_coords)) |
| 1186 | + |
| 1187 | + # Check that each instance has the same number of points |
| 1188 | + for inst, inst_coords in instances_and_coords: |
| 1189 | + assert inst_coords.shape[1] == len(inst.skeleton) # (1, 15, 2) |
| 1190 | + |
| 1191 | + |
| 1192 | +def test_triangulate_session_get_instances_matrices( |
| 1193 | + multiview_min_session_labels: Labels, |
| 1194 | +): |
| 1195 | + """Test `TriangulateSession.get_instance_matrices`.""" |
| 1196 | + labels = multiview_min_session_labels |
| 1197 | + session = labels.sessions[0] |
| 1198 | + lf: LabeledFrame = labels[0] |
| 1199 | + track = labels.tracks[0] |
| 1200 | + instances: Dict[ |
| 1201 | + Camcorder, Instance |
| 1202 | + ] = TriangulateSession.get_instances_across_views( |
| 1203 | + session=session, frame_idx=lf.frame_idx, track=track |
| 1204 | + ) |
| 1205 | + instances_matrices = TriangulateSession.get_instances_matrices( |
| 1206 | + instances_ordered=instances.values() |
| 1207 | + ) |
| 1208 | + |
| 1209 | + # Verify shape |
| 1210 | + n_views = len(instances) |
| 1211 | + n_frames = 1 |
| 1212 | + n_tracks = 1 |
| 1213 | + n_nodes = len(labels.skeleton) |
| 1214 | + assert instances_matrices.shape == (n_views, n_frames, n_tracks, n_nodes, 2) |
| 1215 | + |
| 1216 | + |
| 1217 | +def test_triangulate_session_update_instances(multiview_min_session_labels: Labels): |
| 1218 | + """Test `RecordingSession.update_instances`.""" |
| 1219 | + |
| 1220 | + # Test update_instances |
| 1221 | + session = multiview_min_session_labels.sessions[0] |
| 1222 | + lf: LabeledFrame = multiview_min_session_labels[0] |
| 1223 | + track = multiview_min_session_labels.tracks[0] |
| 1224 | + instances: Dict[ |
| 1225 | + Camcorder, Instance |
| 1226 | + ] = TriangulateSession.get_instances_across_views( |
| 1227 | + session=session, frame_idx=lf.frame_idx, track=track |
| 1228 | + ) |
| 1229 | + instances_and_coordinates = TriangulateSession.calculate_reprojected_points( |
| 1230 | + session=session, instances=instances |
| 1231 | + ) |
| 1232 | + for inst, inst_coords in instances_and_coordinates: |
| 1233 | + assert inst_coords.shape == (1, len(inst.skeleton), 2) # Tracks, Nodes, 2 |
| 1234 | + # Assert coord are different from original |
| 1235 | + assert not np.array_equal(inst_coords, inst.points_array) |
| 1236 | + |
| 1237 | + # Just run for code coverage testing, do not test output here (race condition) |
| 1238 | + # (see "functional core, imperative shell" pattern) |
| 1239 | + TriangulateSession.update_instances(session=session, instances=instances) |
| 1240 | + |
| 1241 | + |
| 1242 | +def test_triangulate_session_do_action(multiview_min_session_labels: Labels): |
| 1243 | + """Test `TriangulateSession.do_action`.""" |
| 1244 | + |
| 1245 | + labels = multiview_min_session_labels |
| 1246 | + session = labels.sessions[0] |
| 1247 | + |
| 1248 | + # Test with enough views and instances |
| 1249 | + lf = labels.labeled_frames[0] |
| 1250 | + instance = lf.instances[0] |
| 1251 | + |
| 1252 | + context = CommandContext.from_labels(labels) |
| 1253 | + params = { |
| 1254 | + "video": session.videos[0], |
| 1255 | + "session": session, |
| 1256 | + "frame_idx": lf.frame_idx, |
| 1257 | + "instance": instance, |
| 1258 | + "ask_again": True, |
| 1259 | + } |
| 1260 | + TriangulateSession.do_action(context, params) |
| 1261 | + |
| 1262 | + # Test with not enough views |
| 1263 | + cams_to_include = [session.linked_cameras[0]] |
| 1264 | + params = { |
| 1265 | + "video": session.videos[0], |
| 1266 | + "session": session, |
| 1267 | + "frame_idx": lf.frame_idx, |
| 1268 | + "instance": instance, |
| 1269 | + "cams_to_include": cams_to_include, |
| 1270 | + "ask_again": True, |
| 1271 | + } |
| 1272 | + TriangulateSession.do_action(context, params) |
| 1273 | + |
| 1274 | + |
| 1275 | +def test_triangulate_session(multiview_min_session_labels: Labels): |
| 1276 | + """Test `TriangulateSession`, if""" |
| 1277 | + |
| 1278 | + labels = multiview_min_session_labels |
| 1279 | + session = labels.sessions[0] |
| 1280 | + video = session.videos[0] |
| 1281 | + lf = labels.labeled_frames[0] |
| 1282 | + instance = lf.instances[0] |
| 1283 | + context = CommandContext.from_labels(labels) |
| 1284 | + |
| 1285 | + # Test with enough views and instances so we don't get any GUI pop-ups |
| 1286 | + context.triangulateSession( |
| 1287 | + frame_idx=lf.frame_idx, |
| 1288 | + video=video, |
| 1289 | + instance=instance, |
| 1290 | + session=session, |
| 1291 | + ) |
| 1292 | + |
| 1293 | + # Test with using state to gather params |
| 1294 | + context.state["session"] = session |
| 1295 | + context.state["video"] = video |
| 1296 | + context.state["instance"] = instance |
| 1297 | + context.state["frame_idx"] = lf.frame_idx |
| 1298 | + context.triangulateSession() |
0 commit comments