diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2417e418e..9768dbbfe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,6 +62,7 @@ jobs: tests: name: Tests (${{ matrix.os }}) runs-on: ${{ matrix.os }} + timeout-minutes: 50 strategy: fail-fast: false matrix: diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 9e07c0a25..b223157f9 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -194,6 +194,8 @@ optional arguments: --tracking.pre_cull_iou_threshold TRACKING.PRE_CULL_IOU_THRESHOLD If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0) + --tracking.pre_cull_general_iou_threshold TRACKING.PRE_CULL_GENERAL_IOU_THRESHOLD + If non-zero, then use IOU threshold to remove overlapping instances regardless of the target count *before* tracking. (default: 0) --tracking.post_connect_single_breaks TRACKING.POST_CONNECT_SINGLE_BREAKS If non-zero and target_instance_count is also non-zero, then connect track breaks when exactly one track is lost and exactly one track is spawned in frame. (default: 0) diff --git a/sleap/config/shortcuts.yaml b/sleap/config/shortcuts.yaml index e4eccea40..0d1277c30 100644 --- a/sleap/config/shortcuts.yaml +++ b/sleap/config/shortcuts.yaml @@ -16,8 +16,10 @@ goto next suggestion: Space goto next track spawn: Ctrl+E goto next user: Ctrl+U goto next labeled: Alt+Right +goto next view: V goto prev suggestion: Shift+Space goto prev labeled: Alt+Left +goto prev view: Shift+V learning: Ctrl+L mark frame: Ctrl+M new: Ctrl+N diff --git a/sleap/gui/app.py b/sleap/gui/app.py index f53506fae..184a13b96 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -44,7 +44,6 @@ frame and instances listed in data view table. """ - import os import platform import random @@ -72,9 +71,11 @@ from sleap.gui.web import ReleaseChecker, ping_analytics from sleap.gui.widgets.docks import ( InstancesDock, + SessionsDock, SkeletonDock, SuggestionsDock, VideosDock, + InstanceGroupDock, ) from sleap.gui.widgets.slider import set_slider_marks_from_labels from sleap.gui.widgets.video import QtVideoPlayer @@ -86,7 +87,6 @@ from sleap.skeleton import Skeleton from sleap.util import parse_uri_path - logger = getLogger(__name__) @@ -144,6 +144,7 @@ def __init__( self.state["labeled_frame"] = None self.state["last_interacted_frame"] = None self.state["filename"] = None + self.state["session"] = None self.state["show non-visible nodes"] = prefs["show non-visible nodes"] self.state["show instances"] = True self.state["show labels"] = True @@ -223,6 +224,7 @@ def closeEvent(self, event): prefs["color predicted"] = self.state["color predicted"] prefs["trail shade"] = self.state["trail_shade"] prefs["share usage data"] = self.state["share usage data"] + prefs["distinctly_color"] = self.state["distinctly_color"] # Save preferences. prefs.save() @@ -297,6 +299,8 @@ def labels(self, value): def _initialize_gui(self): """Creates menus, dock windows, starts timers to update gui state.""" + self.state["distinctly_color"] = prefs["distinctly_color"] + self._create_color_manager() self._create_video_player() self.statusBar() @@ -326,7 +330,18 @@ def _create_video_player(self): self.setCentralWidget(self.player) def switch_frame(video): - """Jump to last labeled frame""" + """Maintain the same frame index if available. + + If the video is shorter than the current frame index, find the last labeled + frame. If no labeled frame is found, set the frame index to 0. + """ + + # If the new video is long enough, stay on the current frame index + current_frame_idx = self.state["frame_idx"] + if video.num_frames > current_frame_idx: + return + + # If the new video is not long enough, find last labeled frame or set to 0 last_label = self.labels.find_last(video) if last_label is not None: self.state["frame_idx"] = last_label.frame_idx @@ -347,9 +362,16 @@ def update_frame_chunk_suggestions(video): frame_to_spinbox.setMaximum(video.num_frames) frame_from_spinbox.setMaximum(video.num_frames) + def update_session(video): + """Update session state for current video.""" + if video is not None and len(self.labels.sessions) > 0: + session = self.labels.get_session(video=video) + self.state["session"] = session + self.state.connect( "video", callbacks=[ + update_session, # Important to update session before other callbacks switch_frame, lambda x: self._update_seekbar_marks(), update_frame_chunk_suggestions, @@ -570,6 +592,18 @@ def add_submenu_choices(menu, title, options, key): "Next Track Spawn Frame", self.commands.nextTrackFrame, ) + add_menu_item( + goMenu, + "goto next view", + "Next View", + self.commands.nextView, + ) + add_menu_item( + goMenu, + "goto prev view", + "Prev View", + self.commands.prevView, + ) goMenu.addSeparator() @@ -624,7 +658,7 @@ def prev_vid(): key="palette", ) - distinctly_color_options = ("instances", "nodes", "edges") + distinctly_color_options = ("instance groups", "instances", "nodes", "edges") add_submenu_choices( menu=viewMenu, @@ -634,7 +668,7 @@ def prev_vid(): ) self.state["palette"] = prefs["palette"] - self.state["distinctly_color"] = "instances" + self.state["distinctly_color"] = prefs["distinctly_color"] viewMenu.addSeparator() @@ -797,9 +831,18 @@ def new_instance_menu_action(): self.commands.deleteFrameLimitPredictions, ) + ### Sessions Menu ### + + sessionsMenu = self.menuBar().addMenu("Sessions") + + self.inst_groups_menu = sessionsMenu.addMenu("Set Instance Group") + self.inst_groups_delete_menu = sessionsMenu.addMenu("Delete Instance Group") + self.state.connect("frame_idx", self._update_sessions_menu) + ### Tracks Menu ### tracksMenu = self.menuBar().addMenu("Tracks") + self.track_menu = tracksMenu.addMenu("Set Instance Track") add_menu_check_item( tracksMenu, "propagate track labels", "Propagate Track Labels" @@ -1017,9 +1060,11 @@ def _create_dock_windows(self): """Create dock windows and connect them to GUI.""" self.videos_dock = VideosDock(self) + self.sessions_dock = SessionsDock(self, tab_with=self.videos_dock) self.skeleton_dock = SkeletonDock(self, tab_with=self.videos_dock) self.suggestions_dock = SuggestionsDock(self, tab_with=self.videos_dock) self.instances_dock = InstancesDock(self, tab_with=self.videos_dock) + self.instance_groups_dock = InstanceGroupDock(self, tab_with=self.videos_dock) # Bring videos tab forward. self.videos_dock.wgt_layout.parent().parent().raise_() @@ -1079,7 +1124,10 @@ def _update_gui_state(self): has_selected_node = self.state["selected_node"] is not None has_selected_edge = self.state["selected_edge"] is not None has_selected_video = self.state["selected_video"] is not None + has_selected_session = self.state["selected_session"] is not None has_video = self.state["video"] is not None + has_selected_camcorder = self.state["selected_camera"] is not None + has_selected_unlinked_video = self.state["selected_unlinked_video"] is not None has_frame_range = bool(self.state["has_frame_range"]) has_unsaved_changes = bool(self.state["has_changes"]) @@ -1103,6 +1151,7 @@ def _update_gui_state(self): # Update menus + self.inst_groups_menu.setEnabled(has_selected_instance) self.track_menu.setEnabled(has_selected_instance) self.delete_tracks_menu.setEnabled(has_tracks) self._menu_actions["clear selection"].setEnabled(has_selected_instance) @@ -1134,9 +1183,16 @@ def _update_gui_state(self): self._buttons["show video"].setEnabled(has_selected_video) self._buttons["remove video"].setEnabled(has_video) self._buttons["delete instance"].setEnabled(has_selected_instance) + self._buttons["unlink video"].setEnabled(has_selected_camcorder) self.suggestions_dock.suggestions_form_widget.buttons[ "generate_button" ].setEnabled(has_videos) + self._buttons["remove session"].setEnabled(has_selected_session) + self._buttons["link video"].setEnabled( + has_selected_unlinked_video + and has_selected_camcorder + and has_selected_session + ) # Update overlays self.overlays["track_labels"].visible = ( @@ -1162,6 +1218,10 @@ def _has_topic(topic_list): ): self.plotFrame() + if _has_topic([UpdateTopic.sessions]): + self.sessions_dock.sessions_table.model().items = self.labels.sessions + self.labels._cache.update() + if _has_topic( [ UpdateTopic.frame, @@ -1179,6 +1239,10 @@ def _has_topic(topic_list): if _has_topic([UpdateTopic.video]): self.videos_dock.table.model().items = self.labels.videos + self.labels._cache.update() + self.sessions_dock.unlinked_videos_table.model().items = ( + self.labels._cache._linkage_of_videos["unlinked"] + ) if _has_topic([UpdateTopic.skeleton]): self.skeleton_dock.nodes_table.model().items = self.state["skeleton"] @@ -1197,6 +1261,7 @@ def _has_topic(topic_list): if _has_topic([UpdateTopic.project, UpdateTopic.on_frame]): self.instances_dock.table.model().items = self.state["labeled_frame"] + self._update_instance_group_model() if _has_topic([UpdateTopic.suggestions]): self.suggestions_dock.table.model().items = self.labels.suggestions @@ -1221,6 +1286,38 @@ def _has_topic(topic_list): if _has_topic([UpdateTopic.frame, UpdateTopic.project_instances]): self.state["last_interacted_frame"] = self.state["labeled_frame"] + self._update_sessions_menu() + + if _has_topic([UpdateTopic.sessions]): + self.update_cameras_model() + self.update_unlinked_videos_model() + self._update_sessions_menu() + self._update_instance_group_model() + + def update_unlinked_videos_model(self): + """Update the unlinked videos model with the selected session.""" + self.sessions_dock.unlinked_videos_table.model().items = ( + self.labels._cache._linkage_of_videos["unlinked"] + ) + + def _update_instance_group_model(self): + """Update the instance group model with the `InstanceGroup`s in current frame.""" + + session = self.state["session"] + if session is not None: + frame_idx: int = self.state["frame_idx"] + frame_group = session.frame_groups.get(frame_idx, None) + if frame_group is not None: + self.instance_groups_dock.table.model().items = ( + frame_group.instance_groups + ) + return + + self.instance_groups_dock.table.model().items = [] + + def update_cameras_model(self): + """Update the cameras model with the selected session.""" + self.sessions_dock.camera_table.model().items = self.state["selected_session"] def plotFrame(self, *args, **kwargs): """Plots (or replots) current frame.""" @@ -1241,7 +1338,8 @@ def _after_plot_update(self, frame_idx): # Replot connected views for multi-camera projects # TODO(LM): Use context.state["session"] in command instead (when implemented) session = self.labels.get_session(video) - self.commands.triangulateSession(session=session) + if self.state.get("auto_triangulate", False): + self.commands.triangulateSession(session=session) def _after_plot_change(self, player, frame_idx, selected_inst): """Called each time a new frame is drawn.""" @@ -1333,6 +1431,11 @@ def updateStatusMessage(self, message: Optional[str] = None): else: self.statusBar().setStyleSheet("color: black") + if self.state["session"] is not None and current_video is not None: + camera = self.state["session"].get_camera(video=self.state["video"]) + if camera is not None: + message += f"{spacer}Camera: {camera.name}" + self.statusBar().showMessage(message) def resetPrefs(self): @@ -1364,6 +1467,51 @@ def _update_track_menu(self): "New Track", self.commands.addTrack, Qt.CTRL + Qt.Key_0 ) + def _update_sessions_menu(self): + """Update the instance groups menu based on the frame index.""" + + # Clear menus before adding more items + self.inst_groups_menu.clear() + self.inst_groups_delete_menu.clear() + + # Get the session + session = self.state.get("session") + if session is None: + return + + # Get the frame group for the current frame + frame_idx = self.state["frame_idx"] + frame_group = session.frame_groups.get(frame_idx, None) + if frame_group is not None: + for inst_group_ind, instance_group in enumerate( + frame_group.instance_groups + ): + # Create shortcut key for first 9 groups + key_command = "" + if inst_group_ind < 9: + key_command = Qt.SHIFT + Qt.Key_0 + inst_group_ind + 1 + + # Update the Set Instance Group menu + self.inst_groups_menu.addAction( + instance_group.name, + lambda x=instance_group: self.commands.setInstanceGroup(x), + key_command, + ) + + # Update the Delete Instance Group menu + self.inst_groups_delete_menu.addAction( + instance_group.name, + lambda x=instance_group: self.commands.deleteInstanceGroup( + instance_group=x + ), + ) + + self.inst_groups_menu.addAction( + "New Instance Group", + self.commands.addInstanceGroup, + Qt.SHIFT + Qt.Key_0, + ) + def _update_seekbar_marks(self): """Updates marks on seekbar.""" set_slider_marks_from_labels( diff --git a/sleap/gui/color.py b/sleap/gui/color.py index 6172d236d..889381ca6 100644 --- a/sleap/gui/color.py +++ b/sleap/gui/color.py @@ -11,12 +11,14 @@ Initial color palette (and other settings, like default line width) is read from user preferences but can be changed after object is created. """ + from typing import Any, Iterable, Optional, Union, Text, Tuple import yaml from sleap.util import get_config_file from sleap.instance import Instance, Track, Node +from sleap.io.cameras import RecordingSession from sleap.io.dataset import Labels from sleap.prefs import prefs @@ -26,7 +28,7 @@ class ColorManager: - """Class to determine color to use for track. + """Class to determine color to use for track and instance groups. The color depends on the order of the tracks in `Labels` object, so we need to initialize with `Labels`. @@ -37,7 +39,11 @@ class ColorManager: palette: String with the color palette name to use. """ - def __init__(self, labels: Labels = None, palette: str = "standard"): + def __init__( + self, + labels: Labels = None, + palette: str = "standard", + ): self.labels = labels with open(get_config_file("colors.yaml"), "r") as f: @@ -55,6 +61,7 @@ def __init__(self, labels: Labels = None, palette: str = "standard"): self.set_palette(palette) self.uncolored_prediction_color = (250, 250, 10) + self.ungrouped_instance_color = (250, 250, 10) if prefs["bold lines"]: self.thick_pen_width = 6 @@ -178,6 +185,22 @@ def get_track_color(self, track: Union[Track, int]) -> ColorTupleType: return self.get_color_by_idx(track_idx) + def get_instance_group_color(self, instance_group, frame_group): + """Returns the color to use for a given instance group. + + Args: + instance_group: `InstanceGroup` object + frame_group: `FrameGroup` object + Returns: + (r, g, b)-tuple + """ + if frame_group is None or instance_group is None: + return self.ungrouped_instance_color + + if instance_group in frame_group.instance_groups: + instance_group_idx = frame_group.instance_groups.index(instance_group) + return self.get_color_by_idx(instance_group_idx) + @classmethod def is_sequence(cls, item) -> bool: """Returns whether item is a tuple or list.""" @@ -238,6 +261,9 @@ def get_item_color( item: Any, parent_instance: Optional[Instance] = None, parent_skeleton: Optional["Skeleton"] = None, + parent_session: Optional[RecordingSession] = None, + parent_frame: Optional["LabeledFrame"] = None, + parent_frame_idx: Optional[int] = None, ) -> ColorTupleType: """Gets (r, g, b) tuple of color to use for drawing item.""" @@ -257,6 +283,37 @@ def get_item_color( return (128, 128, 128) + if self.distinctly_color == "instance groups" and parent_instance: + + if parent_frame is None and parent_instance: + parent_frame = parent_instance.frame + + if parent_frame_idx is None and parent_frame: + parent_frame_idx = parent_frame.frame_idx + + if parent_session is None and self.labels and parent_frame: + parent_session = self.labels.get_session(video=parent_frame.video) + + can_retrieve_instance_group = ( + parent_instance is not None + and parent_frame_idx is not None + and parent_session is not None + ) + + if can_retrieve_instance_group: + instance_group = None + frame_group = parent_session.frame_groups.get(parent_frame_idx, None) + if frame_group is not None: + instance_group = frame_group.get_instance_group( + instance=parent_instance + ) + + return self.get_instance_group_color( + instance_group=instance_group, frame_group=frame_group + ) + + return self.uncolored_prediction_color + if self.distinctly_color == "instances" or hasattr(item, "track"): track = None if hasattr(item, "track"): diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 41dbb29a4..7da588c6b 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -33,17 +33,17 @@ class which inherits from `AppCommand` (or a more specialized class such as import subprocess import sys import traceback +import toml from enum import Enum from glob import glob -from itertools import permutations, product from pathlib import Path, PurePath -from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union import attr import cv2 import numpy as np -from sleap_anipose import triangulate, reproject from qtpy import QtCore, QtGui, QtWidgets +from sleap_anipose import reproject, triangulate from sleap.gui.dialogs.delete import DeleteDialog from sleap.gui.dialogs.filedialog import FileDialog @@ -54,13 +54,13 @@ class which inherits from `AppCommand` (or a more specialized class such as from sleap.gui.state import GuiState from sleap.gui.suggestions import VideoFrameSuggestions from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track -from sleap.io.cameras import Camcorder, InstanceGroup, FrameGroup, RecordingSession +from sleap.io.cameras import Camcorder, FrameGroup, InstanceGroup, RecordingSession from sleap.io.convert import default_analysis_filename from sleap.io.dataset import Labels from sleap.io.format.adaptor import Adaptor from sleap.io.format.csv import CSVAdaptor from sleap.io.format.ndx_pose import NDXPoseAdaptor -from sleap.io.video import Video +from sleap.io.video import Video, MediaVideo from sleap.skeleton import Node, Skeleton from sleap.util import get_package_file @@ -83,6 +83,7 @@ class UpdateTopic(Enum): frame = 8 project = 9 project_instances = 10 + sessions = 11 class AppCommand: @@ -411,6 +412,14 @@ def gotoVideoAndFrame(self, video: Video, frame_idx: int): """Activates video and goes to frame.""" NavCommand.go_to(self, frame_idx, video) + def nextView(self): + """Goes to next view.""" + self.execute(GoAdjacentView, prev_or_next="next") + + def prevView(self): + """Goes to previous view.""" + self.execute(GoAdjacentView, prev_or_next="prev") + # Editing Commands def toggleGrayscale(self): @@ -437,6 +446,19 @@ def addSession(self): """Shows gui for adding `RecordingSession`s to the project.""" self.execute(AddSession) + def removeSelectedSession(self): + """Removes a session from the project and the sessions dock.""" + self.execute(RemoveSession) + + def linkVideoToSession( + self, + video: Optional[Video] = None, + session: Optional[RecordingSession] = None, + camera: Optional[Camcorder] = None, + ): + """Links a video to a `RecordingSession`.""" + self.execute(LinkVideoToSession, video=video, session=session, camera=camera) + def openSkeletonTemplate(self): """Shows gui for loading saved skeleton into project.""" self.execute(OpenSkeleton, template=True) @@ -577,6 +599,17 @@ def setInstanceTrack(self, new_track: "Track"): """Sets track for selected instance.""" self.execute(SetSelectedInstanceTrack, new_track=new_track) + def addInstanceGroup(self): + """Sets the instance group for selected instance.""" + self.execute(AddInstanceGroup) + + def setInstanceGroup(self, instance_group: Optional["InstanceGroup"]): + """Sets the instance group for selected instance.""" + self.execute(SetSelectedInstanceGroup, instance_group=instance_group) + + def setInstanceGroupName(self, instance_group: InstanceGroup, name: str): + self.execute(SetInstanceGroupName, instance_group=instance_group, name=name) + def deleteTrack(self, track: "Track"): """Delete a track and remove from all instances.""" self.execute(DeleteTrack, track=track) @@ -585,6 +618,10 @@ def deleteMultipleTracks(self, delete_all: bool = False): """Delete all tracks.""" self.execute(DeleteMultipleTracks, delete_all=delete_all) + def deleteInstanceGroup(self, instance_group: "InstanceGroup"): + """Delete an instance group.""" + self.execute(DeleteInstanceGroup, instance_group=instance_group) + def copyInstanceTrack(self): """Copies the selected instance's track to the track clipboard.""" self.execute(CopyInstanceTrack) @@ -619,6 +656,7 @@ def triangulateSession( session: Optional[RecordingSession] = None, frame_idx: Optional[int] = None, instance: Optional[Instance] = None, + triangulate_predictions: bool = False, ): """Triangulates `Instance`s for selected views in a `RecordingSession`.""" self.execute( @@ -626,6 +664,7 @@ def triangulateSession( session=session, frame_idx=frame_idx, instance=instance, + triangulate_predictions=triangulate_predictions, ) def openWebsite(self, url): @@ -644,6 +683,10 @@ def openPrereleaseVersion(self): """Open the current prerelease version.""" self.execute(OpenPrereleaseVersion) + def unlink_video_from_camera(self): + """Unlinks video from a camera""" + self.execute(UnlinkVideo) + # File Commands @@ -686,6 +729,8 @@ def do_action(context: "CommandContext", params: dict): if len(labels.videos): context.state["video"] = labels.videos[0] + context.state["session"] = labels.sessions[0] if len(labels.sessions) else None + context.state["project_loaded"] = True context.state["has_changes"] = params.get("changed_on_load", False) or ( filename is None @@ -1686,6 +1731,25 @@ def ask(cls, context: "CommandContext", params: dict) -> bool: return okay +class GoAdjacentView(NavCommand): + @classmethod + def do_action(cls, context: CommandContext, params: dict): + operator = -1 if params["prev_or_next"] == "prev" else 1 + + labels = context.labels + frame_idx = context.state["frame_idx"] + video = context.state["video"] + session = labels.get_session(video) + + # Get the next view + current_video_idx = session.videos.index(video) + new_video_idx = (current_video_idx + operator) % len(session.videos) + new_video = session.videos[new_video_idx] + + context.state["video"] = new_video + context.state["frame_idx"] = frame_idx + + # Editing Commands @@ -1763,10 +1827,22 @@ class ShowImportVideos(EditCommand): topics = [UpdateTopic.video] @staticmethod - def do_action(context: CommandContext, params: dict): + def ask(context: CommandContext, params: dict) -> bool: filenames = params["filenames"] import_list = ImportVideos().ask(filenames=filenames) + + if len(import_list) > 0: + params["import_list"] = import_list + return True + + return False + + @staticmethod + def do_action(context: CommandContext, params: dict): + import_list = params["import_list"] new_videos = ImportVideos.create_videos(import_list) + params["new_videos"] = new_videos + video = None for video in new_videos: # Add to labels @@ -1939,21 +2015,97 @@ def ask(context: CommandContext, params: dict) -> bool: return True +class RemoveSession(EditCommand): + topics = [UpdateTopic.sessions] + + @staticmethod + def do_action(context: CommandContext, params: dict): + current_session = context.state["selected_session"] + try: + context.labels.remove_recording_session(current_session) + except Exception as e: + raise e + finally: + # Always set the selected session to None, even if it wasn't removed + context.state["selected_session"] = None + + class AddSession(EditCommand): - # topics = [UpdateTopic.session] + topics = [UpdateTopic.sessions, UpdateTopic.video] @staticmethod def do_action(context: CommandContext, params: dict): camera_calibration = params["camera_calibration"] + + # Create session from camera calibration file session = RecordingSession.load(filename=camera_calibration) # Add session context.labels.add_session(session) + # Import the new videos and link them to a camera + if "import_list" in params: + ShowImportVideos().do_action(context=context, params=params) + + # Create a lookup for cameras and videos + camera_by_video_paths = params.get("camera_by_video_paths", {}) + new_videos = params.get("new_videos", []) + new_video_by_filename = { + video.backend.filename: video for video in new_videos + } + camera_by_name = {cam.name: cam for cam in session.cameras} + + # Link videos to cameras + for video_path, camera_name in camera_by_video_paths.items(): + # Get video and camcorder + video = new_video_by_filename.get(video_path, None) + if video is None: + continue + camera = camera_by_name.get(camera_name, None) + if camera is None: + continue + context.linkVideoToSession(video=video, session=session, camera=camera) + # Load if no video currently loaded if context.state["session"] is None: context.state["session"] = session + # Reset since this action is also linked to a button in the SessionsDock and it + # is not visually apparent which session is selected after clicking the button + context.state["selected_session"] = None + + @staticmethod + def find_video_paths(camera_calibration: str) -> Dict[str, str]: + + # Find parent of calibration file + calibration_path = Path(camera_calibration) + parent_dir = calibration_path.parent + + # Use camcorder names in session to find camera folders + calibration_data = toml.load(camera_calibration) + camera_names = [ + value["name"] for value in calibration_data.values() if "name" in value + ] + + # Find videos inside camera folders + camera_by_video_paths = {} + for camera_name in camera_names: + camera_folder = parent_dir / camera_name + + # Skip if camera folder does not exist + if not camera_folder.exists(): + continue + + # Find and append all videos in camera folder + video_path = None + video_extensions = MediaVideo.EXTS + for file in camera_folder.iterdir(): + if file.suffix[1:] in video_extensions: + video_path = camera_folder / file + camera_by_video_paths[video_path.as_posix()] = camera_name + + return camera_by_video_paths + @staticmethod def ask(context: CommandContext, params: dict) -> bool: """Shows gui for adding video to project.""" @@ -1966,8 +2118,41 @@ def ask(context: CommandContext, params: dict) -> bool: ) params["camera_calibration"] = filename + if len(filename) == 0: + return False + + camera_by_video_paths = AddSession.find_video_paths(camera_calibration=filename) + params["camera_by_video_paths"] = camera_by_video_paths + + # Show import video dialog if any videos are found + if len(camera_by_video_paths) > 0: + params["filenames"] = list(camera_by_video_paths.keys()) + ShowImportVideos().ask(context=context, params=params) + + return True + + +class LinkVideoToSession(EditCommand): + topics = [UpdateTopic.sessions] + + @staticmethod + def do_action(context: CommandContext, params: dict): + video = params["video"] or context.state["selected_unlinked_video"] + recording_session = params["session"] or context.state["selected_session"] + camcorder = params["camera"] or context.state["selected_camera"] + + if camcorder is None: + raise ValueError("No camera selected.") + + if recording_session is None: + raise ValueError("No session selected.") + + if camcorder.get_video(recording_session) is None: + recording_session.add_video(video=video, camcorder=camcorder) - return len(filename) > 0 + # Reset the selected camera and video + context.state["selected_camera"] = None + context.state["selected_unlinked_video"] = None class OpenSkeleton(EditCommand): @@ -2613,6 +2798,44 @@ def ask_and_do(context: CommandContext, params: dict): context.signal_update([UpdateTopic.project_instances]) +class AddInstanceGroup(EditCommand): + topics = [UpdateTopic.sessions] + + @staticmethod + def do_action(context, params): + + # Get session and frame index + frame_idx = context.state["frame_idx"] + session: RecordingSession = context.state["session"] + if session is None: + raise ValueError("Cannot add instance group without session.") + + # Get or create frame group + frame_group = session.frame_groups.get(frame_idx, None) + if frame_group is None: + frame_group = session.new_frame_group(frame_idx=frame_idx) + + # Create and add instance group + instance_group = frame_group.add_instance_group(instance_group=None) + + # Now add the selected instance to the `InstanceGroup` + context.execute(SetSelectedInstanceGroup, instance_group=instance_group) + + +class SetInstanceGroupName(EditCommand): + + topics = [UpdateTopic.sessions] + + @staticmethod + def do_action(context: CommandContext, params: dict): + instance_group = params["instance_group"] + name = params["name"] + + FrameGroup = context.state["session"].frame_groups[instance_group.frame_idx] + + FrameGroup.set_instance_group_name(instance_group=instance_group, name=name) + + class AddTrack(EditCommand): topics = [UpdateTopic.tracks] @@ -2629,6 +2852,63 @@ def do_action(context: CommandContext, params: dict): context.execute(SetSelectedInstanceTrack, new_track=new_track) +class SetSelectedInstanceGroup(EditCommand): + topics = [UpdateTopic.project_instances] + + @staticmethod + def do_action(context, params): + """Set the `selected_instance` to the `instance_group`. + + Args: + context: The command context. + state: The context state. + instance: The selected instance. + frame_idx: The frame index. + video: The video. + session: The recording session. + + params: The command parameters. + instance_group: The `InstanceGroup` to set the selected instance to. + + Raises: + ValueError: If the `RecordingSession` is None. + ValueError: If the `FrameGroup` does not exist for the frame index. + ValueError: If the `Video` is not linked to a `Camcorder`. + """ + + selected_instance = context.state["instance"] + frame_idx = context.state["frame_idx"] + video = context.state["video"] + + base_message = ( + f"Cannot set instance group for selected instance [{selected_instance}]." + ) + + # `RecordingSession` should not be None + session: RecordingSession = context.state["session"] + if session is None: + raise ValueError(f"{base_message} No session for video [{video}]") + + # `FrameGroup` should already exist + frame_group = session.frame_groups.get(frame_idx, None) + if frame_group is None: + raise ValueError( + f"{base_message} Frame group does not exist for frame [{frame_idx}] in " + f"{session}." + ) + + # We need the camera and instance group to set the instance group + camera = session.get_camera(video=video) + if camera is None: + raise ValueError(f"{base_message} No camera linked to video [{video}]") + instance_group = params["instance_group"] + + # Set the instance group + frame_group.add_instance( + instance=selected_instance, camera=camera, instance_group=instance_group + ) + + class SetSelectedInstanceTrack(EditCommand): topics = [UpdateTopic.tracks] @@ -2718,6 +2998,31 @@ def do_action(context: CommandContext, params: dict): context.labels.remove_unused_tracks() +class DeleteInstanceGroup(EditCommand): + topics = [UpdateTopic.sessions] + + @staticmethod + def do_action(context, params): + + instance_group = params["instance_group"] + frame_idx = context.state["frame_idx"] + + base_message = f"Cannot delete instance group [{instance_group}]." + + # `RecordingSession` should not be None + session: RecordingSession = context.state["session"] + if session is None: + raise ValueError(f"{base_message} No session in context state.") + + # `FrameGroup` should already exist + frame_group = session.frame_groups.get(frame_idx, None) + if frame_group is None: + raise ValueError(f"{base_message} No frame group for frame {frame_idx}.") + + # Remove the instance group + frame_group.remove_instance_group(instance_group=instance_group) + + class CopyInstanceTrack(EditCommand): @staticmethod def do_action(context: CommandContext, params: dict): @@ -3299,6 +3604,7 @@ def add_nodes_from_template( def add_force_directed_nodes( cls, context, instance, visible, center_point: QtCore.QPoint = None ): + import networkx as nx center_point = center_point or context.app.player.getVisibleRect().center() @@ -3407,7 +3713,11 @@ class TriangulateSession(EditCommand): topics = [UpdateTopic.frame, UpdateTopic.project_instances] @classmethod - def do_action(cls, context: CommandContext, params: dict): + def do_action( + cls, + context: CommandContext, + params: dict, + ): """Triangulate, reproject, and update instances in a session at a frame index. Args: @@ -3417,8 +3727,9 @@ def do_action(cls, context: CommandContext, params: dict): video's session. frame_idx: The frame index to use. Default is current frame index. instance: The `Instance` object to use. Default is current instance. - show_dialog: If True, then show a warning dialog. Default is True. - ask_again: If True, then ask for views/instances again. Default is False. + triangulate_predictions: If True, then include predicted instances in + triangulation. Otherwise, only use user labeled instances. Default + is False. """ session: RecordingSession = ( @@ -3439,6 +3750,8 @@ def do_action(cls, context: CommandContext, params: dict): instance = params.get("instance", None) or context.state["instance"] instance_group = frame_group.get_instance_group(instance) + triangulate_predictions = params.get("triangulate_predictions", False) + # If instance_group is None, then we will try to triangulate entire frame_group instance_groups = ( [instance_group] @@ -3457,7 +3770,10 @@ def do_action(cls, context: CommandContext, params: dict): return # Not enough instances for triangulation # Get the `FrameGroup` of shape M=include x T x N x 2 - fg_tensor = frame_group.numpy(instance_groups=instance_groups, pred_as_nan=True) + pred_as_nan = not triangulate_predictions + fg_tensor = frame_group.numpy( + instance_groups=instance_groups, pred_as_nan=pred_as_nan + ) # Add extra dimension for number of frames frame_group_tensor = np.expand_dims(fg_tensor, axis=1) # M=include x F=1 xTxNx2 @@ -3597,3 +3913,20 @@ def copy_to_clipboard(text: str): clipboard = QtWidgets.QApplication.clipboard() clipboard.clear(mode=clipboard.Clipboard) clipboard.setText(text, mode=clipboard.Clipboard) + + +class UnlinkVideo(EditCommand): + topics = [UpdateTopic.sessions] + + @staticmethod + def do_action(context: CommandContext, params: dict): + camcorder = context.state["selected_camera"] + recording_session = context.state["selected_session"] + + video = camcorder.get_video(recording_session) + + if video is not None and recording_session is not None: + recording_session.remove_video(video) + + # Reset the selected camera + context.state["selected_camera"] = None diff --git a/sleap/gui/dataviews.py b/sleap/gui/dataviews.py index 0a008bea7..d3651bc96 100644 --- a/sleap/gui/dataviews.py +++ b/sleap/gui/dataviews.py @@ -30,6 +30,7 @@ from sleap.io.dataset import Labels from sleap.instance import LabeledFrame, Instance from sleap.skeleton import Skeleton +from sleap.io.cameras import Camcorder, RecordingSession, InstanceGroup class GenericTableModel(QtCore.QAbstractTableModel): @@ -385,6 +386,17 @@ def getSelectedRowItem(self) -> Any: return self.model().original_items[idx.row()] +class SessionsTableModel(GenericTableModel): + properties = ("index", "videos", "cameras") + + def item_to_data(self, obj, item: RecordingSession): + res = {} + res["index"] = item.id + res["cameras"] = len(getattr(item, "cameras")) + res["videos"] = len(getattr(item, "videos")) + return res + + class VideosTableModel(GenericTableModel): properties = ("filename", "frames", "height", "width", "channels") @@ -538,7 +550,6 @@ def sort(self, column_idx: int, order: QtCore.Qt.SortOrder): if prop != "group": super(SuggestionsTableModel, self).sort(column_idx, order) else: - if not reverse: # Use group_int (int) instead of group (str). self.beginResetModel() @@ -652,3 +663,70 @@ def columnCount(self, parent): def flags(self, index: QtCore.QModelIndex): """Overrides Qt method, returns flags (editable etc).""" return QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable + + +class CamerasTableModel(GenericTableModel): + """Table model for unlinking `Camcorder`s and `Video`s within a `RecordingSession`. + + Args: + obj: 'RecordingSession' which has information of cameras + and paired video + """ + + properties = ("camera", "video") + + def object_to_items(self, obj: RecordingSession): + return obj.camera_cluster.cameras + + def item_to_data(self, obj: RecordingSession, item: Camcorder): + + video = obj.get_video(item) + return {"camera": item.name, "video": video.filename if video else ""} + + +class InstanceGroupTableModel(GenericTableModel): + """Table model for displaying all instance groups in a given frame. + + Args: + item: 'InstanceGroup' which has information about the instance group + """ + + properties = ("name", "score", "frame index", "cameras", "instances") + + def item_to_data(self, obj, item: InstanceGroup): + + data = { + "name": item.name, + "score": "" if item.score is None else str(round(item.score, 2)), + "frame index": item.frame_idx, + "cameras": len(item.camera_cluster.cameras), + "instances": len(item.instances), + } + return data + + def get_item_color(self, instance_group: InstanceGroup, key: str): + color_manager = self.context.app.color_manager + if color_manager.distinctly_color == "instance groups" and key == "name": + + # Get the RecordingSession + state = self.context.state + session = state["session"] + if session is None: + return + + # Get the FrameGroup + frame_idx = state["frame_idx"] + frame_group = session.frame_groups.get(frame_idx, None) + if frame_group is None: + return + + # Get the InstanceGroup and color + color = color_manager.get_instance_group_color(instance_group, frame_group) + return QtGui.QColor(*color) + + def can_set(self, item, key): + return True + + def set_item(self, item, key, value): + if key == "name" and value: + self.context.setInstanceGroupName(instance_group=item, name=value) diff --git a/sleap/gui/shortcuts.py b/sleap/gui/shortcuts.py index 37db5fb51..83e51b364 100644 --- a/sleap/gui/shortcuts.py +++ b/sleap/gui/shortcuts.py @@ -41,6 +41,8 @@ class Shortcuts(object): "goto next suggestion", "goto prev suggestion", "goto next track spawn", + "goto next view", + "goto prev view", "show instances", "show labels", "show edges", diff --git a/sleap/gui/widgets/docks.py b/sleap/gui/widgets/docks.py index 43e218adb..5af4f72b6 100644 --- a/sleap/gui/widgets/docks.py +++ b/sleap/gui/widgets/docks.py @@ -1,6 +1,6 @@ """Module for creating dock widgets for the `MainWindow`.""" -from typing import Callable, Iterable, List, Optional, Type, Union +from typing import Callable, Dict, Iterable, List, Optional, Type, Union from qtpy import QtGui from qtpy.QtCore import Qt @@ -12,6 +12,10 @@ QLabel, QLayout, QMainWindow, + QLabel, + QComboBox, + QCheckBox, + QGroupBox, QPushButton, QTabWidget, QVBoxLayout, @@ -27,7 +31,11 @@ SkeletonNodesTableModel, SuggestionsTableModel, VideosTableModel, + CamerasTableModel, + SessionsTableModel, + InstanceGroupTableModel, ) +from sleap.io.cameras import RecordingSession, FrameGroup, InstanceGroup from sleap.gui.dialogs.formbuilder import YamlFormWidget from sleap.gui.widgets.views import CollapsibleWidget from sleap.skeleton import Skeleton @@ -363,7 +371,6 @@ def create_templates_groupbox(self) -> QGroupBox: vb.addWidget(hbw) def updatePreviewImage(preview_image_bytes: bytes): - # Decode the preview image preview_image = decode_preview_image(preview_image_bytes) @@ -566,3 +573,210 @@ def create_table_edit_buttons(self) -> QWidget: hbw = QWidget() hbw.setLayout(hb) return hbw + + +class SessionsDock(DockWidget): + def __init__( + self, + main_window: Optional[QMainWindow], + tab_with: Optional[QLayout] = None, + ): + self.sessions_model_type = SessionsTableModel + self.camera_model_type = CamerasTableModel + self.unlinked_videos_model_type = VideosTableModel + super().__init__( + name="Sessions", + main_window=main_window, + model_type=[ + self.sessions_model_type, + self.camera_model_type, + self.unlinked_videos_model_type, + ], + tab_with=tab_with, + ) + + def create_triangulation_options(self) -> QWidget: + main_window = self.main_window + hb = QHBoxLayout() + + # Add button to triangulate on demand + self.add_button( + hb, + "Triangulate", + main_window.process_events_then(main_window.commands.triangulateSession), + ) + + # Add checkbox and button for "Auto-triangulate" + self.auto_align_checkbox = QCheckBox("Auto-Triangulate") + self.auto_align_checkbox.stateChanged.connect( + lambda x: main_window.state.set("auto_triangulate", x == Qt.Checked) + ) + hb.addWidget(self.auto_align_checkbox) + + hbw = QWidget() + hbw.setLayout(hb) + return hbw + + def create_video_unlink_button(self) -> QWidget: + main_window = self.main_window + + hb = QHBoxLayout() + self.add_button( + hb, "Unlink Video", main_window.commands.unlink_video_from_camera + ) + + hbw = QWidget() + hbw.setLayout(hb) + return hbw + + def create_video_link_button(self) -> QWidget: + main_window = self.main_window + + hb = QHBoxLayout() + self.add_button(hb, "Link Video", main_window.commands.linkVideoToSession) + + hbw = QWidget() + hbw.setLayout(hb) + return hbw + + def create_models(self) -> Union[GenericTableModel, Dict[str, GenericTableModel]]: + main_window = self.main_window + self.sessions_model = self.sessions_model_type( + items=main_window.state["labels"].sessions, context=main_window.commands + ) + self.camera_model = self.camera_model_type( + items=main_window.state["selected_session"], context=main_window.commands + ) + self.unlinked_videos_model = self.unlinked_videos_model_type( + items=main_window.state["selected_session"], context=main_window.commands + ) + + self.model = { + "sessions_model": self.sessions_model, + "camera_model": self.camera_model, + "unlink_videos_model": self.unlinked_videos_model, + } + return self.model + + def create_tables(self) -> Union[GenericTableView, Dict[str, GenericTableView]]: + if self.sessions_model is None: + self.create_models() + + main_window = self.main_window + self.sessions_table = GenericTableView( + state=main_window.state, row_name="session", model=self.sessions_model + ) + self.camera_table = GenericTableView( + is_activatable=True, + state=main_window.state, + row_name="camera", + model=self.camera_model, + ellipsis_left=True, + ) + self.unlinked_videos_table = GenericTableView( + is_activatable=True, + state=main_window.state, + row_name="unlinked_video", + model=self.unlinked_videos_model, + ellipsis_left=True, + ) + + self.main_window.state.connect( + "selected_session", self.main_window.update_cameras_model + ) + + self.main_window.state.connect( + "selected_session", self.main_window.update_unlinked_videos_model + ) + + self.table = { + "sessions_table": self.sessions_table, + "camera_table": self.camera_table, + "unlinked_videos_table": self.unlinked_videos_table, + } + return self.table + + def create_table_edit_buttons(self) -> QWidget: + main_window = self.main_window + + hb = QHBoxLayout() + self.add_button(hb, "Add Session", lambda x: main_window.commands.addSession()) + self.add_button( + hb, "Remove Session", main_window.commands.removeSelectedSession + ) + + hbw = QWidget() + hbw.setLayout(hb) + return hbw + + def lay_everything_out(self) -> None: + if self.table is None: + self.create_tables() + + # TODO(LM): Add this to a create method + # Add the sessions table to the dock + self.wgt_layout.addWidget(self.sessions_table) + + table_edit_buttons = self.create_table_edit_buttons() + self.wgt_layout.addWidget(table_edit_buttons) + + # TODO(LM): Add this to a create method + # Add the cameras table to the dock + self.wgt_layout.addWidget(self.camera_table) + + video_unlink_button = self.create_video_unlink_button() + self.wgt_layout.addWidget(video_unlink_button) + + # Add the triangulation options to the dock + triangulation_options = self.create_triangulation_options() + self.wgt_layout.addWidget(triangulation_options) + + # Add the unlinked videos table to the dock + self.wgt_layout.addWidget(self.unlinked_videos_table) + video_link_button = self.create_video_link_button() + self.wgt_layout.addWidget(video_link_button) + + +class InstanceGroupDock(DockWidget): + """Dock widget for displaying instance groups.""" + + def __init__(self, main_window: QMainWindow, tab_with: Optional[QLayout] = None): + super().__init__( + name="Instance Groups", + main_window=main_window, + model_type=InstanceGroupTableModel, + tab_with=tab_with, + ) + + def create_models(self) -> InstanceGroupTableModel: + session: RecordingSession = self.main_window.state["session"] + instance_groups = [] + if session is not None: + frame_idx: int = self.main_window.state["frame_idx"] + frame_group: FrameGroup = session.frame_groups.get(frame_idx, None) + if frame_group is not None: + instance_groups: List[InstanceGroup] = frame_group.instance_groups + + self.model = self.model_type( + items=instance_groups, + context=self.main_window.commands, + ) + return self.model + + def create_tables(self) -> GenericTableView: + if self.model is None: + self.create_models() + + self.table = GenericTableView( + state=self.main_window.state, + row_name="instance_group", + name_prefix="", + model=self.model, + ) + return self.table + + def lay_everything_out(self) -> None: + if self.table is None: + self.create_tables() + + self.wgt_layout.addWidget(self.table) diff --git a/sleap/gui/widgets/video.py b/sleap/gui/widgets/video.py index 502ea388e..9391dcab8 100644 --- a/sleap/gui/widgets/video.py +++ b/sleap/gui/widgets/video.py @@ -12,6 +12,7 @@ >>> vp.addInstance(instance=my_instance, color=(r, g, b)) """ + from collections import deque # FORCE_REQUESTS controls whether we emit a signal to process frame requests @@ -73,6 +74,8 @@ from sleap.io.video import Video from sleap.prefs import prefs from sleap.skeleton import Node +from sleap.io.cameras import Camcorder +from sleap.io.cameras import InstanceGroup class LoadImageWorker(QtCore.QObject): @@ -268,14 +271,44 @@ def update_selection_state(a, b): self.seekbar.selectionChanged.connect(update_selection_state) - self.state.connect("frame_idx", lambda idx: self.plot()) - self.state.connect("frame_idx", lambda idx: self.seekbar.setValue(idx)) + def frame_idx_callback(frame_idx: int): + """All callbacks that need to be called when frame_idx changes.""" + self.plot() + self.seekbar.setValue(frame_idx) + self.state["instance"] = None + + self.state.connect("frame_idx", lambda idx: frame_idx_callback(idx)) self.state.connect("instance", self.view.selectInstance) + self.state.connect("instance_group", self.view.selectInstance) self.state.connect("show instances", self.plot) self.state.connect("show labels", self.plot) self.state.connect("show edges", self.plot) self.state.connect("video", self.load_video) + + def set_video(video: Video): + self.state["video"] = video + + self.state.connect("unlinked_video", lambda video: set_video(video)) + + def set_video_from_camera(camera: Camcorder): + """Updates the video state when camera state changes. + + Args: + camera: The camera object + """ + # If either the camera or the session is None, we can't get the linked video + session = self.state["session"] + if camera is None or session is None: + return + + # Get the linked video from the camera + video: Optional[Video] = camera.get_video(session=session) + if video is not None: + self.state["video"] = video + + self.state.connect("camera", lambda camera: set_video_from_camera(camera)) + self.state.connect("fit", self.setFitZoom) self.view.show() @@ -313,6 +346,7 @@ def _load_and_show_requested_image(self, frame_idx): # Display image self.view.setImage(qimage) + # TODO: Delegate to command context def _register_shortcuts(self): self._shortcut_triggers = dict() @@ -935,18 +969,42 @@ def all_instances(self) -> List["QtInstance"]: scene_items = self.scene.items(Qt.SortOrder.AscendingOrder) return list(filter(lambda x: isinstance(x, QtInstance), scene_items)) - def selectInstance(self, select: Union[Instance, int]): - """ - Select a particular instance in view. + def selectInstance(self, select: Optional[Union[Instance, int, InstanceGroup]]): + """Select a particular instance in view. Args: - select: Either `Instance` or index of instance in view. + select: Either `None` or `Instance`, index, or `InstanceGroup` of instance + in view. Returns: None """ + + # Decide which function to use to determine if instance is selected + if isinstance(select, int): + + def determine_selected(idx: int, instance: QtInstance): + return idx == select + + elif isinstance(select, Instance): + + def determine_selected(idx: int, instance: QtInstance): + return instance.instance == select + + elif isinstance(select, InstanceGroup): + + def determine_selected(idx: int, instance: QtInstance): + return instance.instance in select.instances + + else: + + def determine_selected(idx: int, instance: QtInstance): + return False + + # Set selected state for each instance for idx, instance in enumerate(self.all_instances): - instance.selected = select == idx or select == instance.instance + instance.selected = determine_selected(idx, instance) + self.updatedSelection.emit() def getSelectionIndex(self) -> Optional[int]: @@ -1528,22 +1586,29 @@ def mousePressEvent(self, event): # Select instance this nodes belong to. self.parentObject().player.state["instance"] = self.parentObject().instance + # TODO(LM): Document this behavior + # Ctrl-click to toggle node(s) as incomplete + complete = False if (event.modifiers() & Qt.ControlModifier) else True + # Alt-click to drag instance - if event.modifiers() == Qt.AltModifier: + if event.modifiers() & Qt.AltModifier: self.dragParent = True self.parentObject().setFlag(QGraphicsItem.ItemIsMovable) # set origin to point clicked so that we can rotate around this point self.parentObject().setTransformOriginPoint(self.scenePos()) self.parentObject().mousePressEvent(event) - # Shift-click to mark all points as complete - elif event.modifiers() == Qt.ShiftModifier: - self.parentObject().updatePoints(complete=True, user_change=True) + # Shift-click to mark all points as complete (or incomplete if ctrl is held) + elif event.modifiers() & Qt.ShiftModifier: + self.parentObject().updatePoints( + complete=complete, incomplete=(not complete), user_change=True + ) else: self.dragParent = False super(QtNode, self).mousePressEvent(event) self.updatePoint() - self.point.complete = True # FIXME: move to command + self.point.complete = complete # FIXME: move to command + elif event.button() == Qt.RightButton: # Select instance this nodes belong to. self.parentObject().player.state["instance"] = self.parentObject().instance @@ -1850,7 +1915,7 @@ def __init__( self.track_label.setHtml(instance_label_text) # Add nodes - for (node, point) in self.instance.nodes_points: + for node, point in self.instance.nodes_points: if point.visible or self.show_non_visible: node_item = QtNode( parent=self, @@ -1865,7 +1930,7 @@ def __init__( self.nodes[node.name] = node_item # Add edges - for (src, dst) in self.skeleton.edge_names: + for src, dst in self.skeleton.edge_names: # Make sure that both nodes are present in this instance before drawing edge if src in self.nodes and dst in self.nodes: edge_item = QtEdge( @@ -1904,14 +1969,21 @@ def __init__( def __repr__(self) -> str: return f"QtInstance(pos()={self.pos()},instance={self.instance})" - def updatePoints(self, complete: bool = False, user_change: bool = False): + def updatePoints( + self, + complete: bool = False, + incomplete: bool = False, + user_change: bool = False, + ): """Update data and display for all points in skeleton. This is called any time the skeleton is manipulated as a whole. Args: - complete: Whether to update all nodes by setting "completed" - attribute. + complete: Whether to update all nodes to complete by setting "completed" + attribute. Overrides `incomplete`. + incomplete: Whether to update all nodes to incomplete by setting "complete" + attribute. Overridden by `complete`. user_change: Whether method is called because of change made by user. @@ -1936,6 +2008,8 @@ def updatePoints(self, complete: bool = False, user_change: bool = False): if complete: # FIXME: move to command node_item.point.complete = True + elif incomplete: + node_item.point.complete = False # Wait to run callbacks until all nodes are updated # Otherwise the label positions aren't correct since # they depend on the edge vectors to old node positions. diff --git a/sleap/instance.py b/sleap/instance.py index ed1fa3d07..9026bb115 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -722,26 +722,34 @@ def update_points(self, points: np.ndarray, exclude_complete: bool = False): Args: points: The new points to update to. - exclude_complete: Whether to update points where Point.complete is True + exclude_complete: Whether to update visible points where Point.complete + and Point.visible is True. This only applies to user-labeled instances. """ + + # Determine if Instance is a PredictedInstance + is_predicted = True if isinstance(self._points, PredictedPointArray) else False + points_dict = dict() for point_new, points_old, node_name in zip( points, self._points, self.skeleton.node_names ): - # Skip if new point is nan or old point is complete - if np.isnan(point_new).any() or (exclude_complete and points_old.complete): + visible = points_old.visible + complete = points_old.complete + + # Skip if new point is nan or old is user-labeled, visible and complete + skip_if_complete = (not is_predicted) and exclude_complete and visible + if np.isnan(point_new).any() or (skip_if_complete and complete): continue # Grab the x, y from the new point and visible, complete from the old point x, y = point_new - visible = points_old.visible - complete = points_old.complete + visible = visible + complete = complete # Create a new point and add to the dict - if type(self._points) == PredictedPointArray: - # TODO(LM): The point score is meant to rate the confidence of the - # prediction, but this method updates from triangulation. + if is_predicted: + # This method does not update the points score. score = points_old.score point_obj = PredictedPoint( x=x, y=y, visible=visible, complete=complete, score=score @@ -1546,33 +1554,18 @@ def unused_predictions(self) -> List[Instance]: a corresponding :class:`Instance` in the same track in frame. """ unused_predictions = [] - any_tracks = [inst.track for inst in self._instances if inst.track is not None] - if len(any_tracks): - # use tracks to determine which predicted instances have been used - used_tracks = [ - inst.track - for inst in self._instances - if type(inst) == Instance and inst.track is not None - ] - unused_predictions = [ - inst - for inst in self._instances - if inst.track not in used_tracks and type(inst) == PredictedInstance - ] - else: - # use from_predicted to determine which predicted instances have been used - # TODO: should we always do this instead of using tracks? - used_instances = [ - inst.from_predicted - for inst in self._instances - if inst.from_predicted is not None - ] - unused_predictions = [ - inst - for inst in self._instances - if type(inst) == PredictedInstance and inst not in used_instances - ] + # Use from_predicted to determine which predicted instances have been used + used_instances = [ + inst.from_predicted + for inst in self._instances + if inst.from_predicted is not None + ] + unused_predictions = [ + inst + for inst in self._instances + if type(inst) == PredictedInstance and inst not in used_instances + ] return unused_predictions @@ -1593,9 +1586,9 @@ def instances_to_show(self) -> List[Instance]: if type(inst) == Instance or inst in unused_predictions ] inst_to_show.sort( - key=lambda inst: inst.track.spawned_on - if inst.track is not None - else math.inf + key=lambda inst: ( + inst.track.spawned_on if inst.track is not None else math.inf + ) ) return inst_to_show diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 4ffd9adf4..e2ed6c4b5 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union +import cv2 import cattr import numpy as np import toml @@ -15,7 +16,7 @@ # from sleap.io.dataset import Labels # TODO(LM): Circular import, implement Observer from sleap.instance import Instance, LabeledFrame, PredictedInstance from sleap.io.video import Video -from sleap.util import deep_iterable_converter +from sleap.util import compute_oks, deep_iterable_converter logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ def sessions(self) -> List["RecordingSession"]: def get_video(self, session: "RecordingSession") -> Optional[Video]: if session not in self._video_by_session: - logger.warning(f"{session} not found in {self}.") + logger.debug(f"{session} not found in {self}.") return None return self._video_by_session[session] @@ -345,8 +346,11 @@ def load(cls, filename) -> "CameraCluster": Returns: `CameraCluster` object. """ - cgroup: CameraGroup = super().load(filename) - return cls(cameras=cgroup.cameras, metadata=cgroup.metadata) + + calibration_dict = toml.load(filename) + camera_cluster = cls.from_calibration_dict(calibration_dict) + + return camera_cluster @classmethod def from_calibration_dict(cls, calibration_dict: Dict[str, str]) -> "CameraCluster": @@ -366,12 +370,12 @@ def from_calibration_dict(cls, calibration_dict: Dict[str, str]) -> "CameraClust `CameraCluster` object. """ - # Save the calibration dictionary to a temp file and load as `CameraGroup` - with tempfile.TemporaryDirectory() as temp_dir: - temp_file = str(Path(temp_dir, "calibration.toml")) - with open(temp_file, "w") as f: - toml.dump(calibration_dict, f) - cgroup: CameraGroup = super().load(temp_file) + # Taken from aniposelib.cameras.CameraGroup.load, but without sorting keys + keys = calibration_dict.keys() + items = [calibration_dict[k] for k in keys if k != "metadata"] + cgroup = CameraGroup.from_dicts(items) + if "metadata" in calibration_dict: + cgroup.metadata = calibration_dict["metadata"] return cls(cameras=cgroup.cameras, metadata=cgroup.metadata) @@ -394,6 +398,35 @@ def to_calibration_dict(self) -> Dict[str, str]: return calibration_dict + # TODO(LM): Remove this function once aniposelib is updated. + def optim_points( + self, + points, + p3ds, + constraints=[], + constraints_weak=[], + scale_smooth=4, + scale_length=2, + scale_length_weak=0.5, + reproj_error_threshold=15, + reproj_loss="soft_l1", + n_deriv_smooth=1, + scores=None, + verbose=False, + ): + """Overwrite parent function which does not handle nan values. + + This function is called when we triangulate. The triangulated points are stored + in p3ds, but the parent function optimizes the triangulated p3ds to better fit + the 2D points. The parent function does not handle nan values (yet), so we need + to overwrite it here. + + Reutrns: + p3ds: np.ndarray of shape (n_points, 3) + """ + + return p3ds + @define class InstanceGroup: @@ -408,6 +441,9 @@ class InstanceGroup: cameras: List of `Camcorder` objects that have an `Instance` associated. instances: List of `Instance` objects. instance_by_camcorder: Dictionary of `Instance` objects by `Camcorder`. + score: Optional score for the `InstanceGroup`. Setting the score will also + update the score for all `instances` already in the `InstanceGroup`. The + score for `instances` will not be updated upon initialization. """ _name: str = field() @@ -416,6 +452,7 @@ class InstanceGroup: _camcorder_by_instance: Dict[Instance, Camcorder] = field(factory=dict) _dummy_instance: Optional[Instance] = field(default=None) camera_cluster: Optional[CameraCluster] = field(default=None) + _score: Optional[float] = field(default=None) def __attrs_post_init__(self): """Initialize `InstanceGroup` object.""" @@ -537,7 +574,7 @@ def return_unique_name(cls, name_registry: Set[str]) -> str: return new_name @property - def instances(self) -> List[Instance]: + def instances(self) -> List[Union[Instance, PredictedInstance]]: """List of `Instance` objects.""" return list(self._instance_by_camcorder.values()) @@ -551,7 +588,35 @@ def instance_by_camcorder(self) -> Dict[Camcorder, Instance]: """Dictionary of `Instance` objects by `Camcorder`.""" return self._instance_by_camcorder - def numpy(self, pred_as_nan: bool = False, invisible_as_nan=True) -> np.ndarray: + @property + def score(self) -> Optional[float]: + """Score for the `InstanceGroup`.""" + return self._score + + @score.setter + def score(self, score: Optional[float]): + """Set the score for the `InstanceGroup`. + + Also sets the score for all instances in the `InstanceGroup` if they have a + `score` attribute. + + Args: + score: Score to set for the `InstanceGroup`. + """ + + for instance in self.instances: + if hasattr(instance, "score"): + instance.score = score + + self._score = score + + def numpy( + self, + pred_as_nan: bool = False, + invisible_as_nan=True, + cams_to_include: Optional[List[Camcorder]] = None, + undistort: bool = False, + ) -> np.ndarray: """Return instances as a numpy array of shape (n_views, n_nodes, 2). The ordering of views is based on the ordering of `Camcorder`s in the @@ -565,13 +630,21 @@ def numpy(self, pred_as_nan: bool = False, invisible_as_nan=True) -> np.ndarray: self.dummy_instance. Default is False. invisible_as_nan: If True, then replaces invisible points with nan. Default is True. + cams_to_include: List of `Camcorder`s to include in the numpy array. If + None, then all `Camcorder`s in the `CameraCluster` are included. Default + is None. + undistort: If True, then undistort the points using cv2.undistortPoints. + Default is False. Returns: Numpy array of shape (n_views, n_nodes, 2). """ instance_numpys: List[np.ndarray] = [] # len(M) x N x 2 - for cam in self.camera_cluster.cameras: + if cams_to_include is None: + cams_to_include = self.camera_cluster.cameras + + for cam in cams_to_include: instance = self.get_instance(cam) # Determine whether to use a dummy (all nan) instance @@ -586,6 +659,16 @@ def numpy(self, pred_as_nan: bool = False, invisible_as_nan=True) -> np.ndarray: instance_numpy: np.ndarray = instance.get_points_array( invisible_as_nan=invisible_as_nan ) # N x 2 + + if undistort: + instance_numpy_shape = instance_numpy.shape + instance_numpy = instance_numpy.reshape(-1, 2) + instance_numpy = cv2.undistortPoints( + instance_numpy.astype("float64"), + cameraMatrix=cam.camera.matrix, + distCoeffs=cam.camera.dist, + ).reshape(instance_numpy_shape) + instance_numpys.append(instance_numpy) return np.stack(instance_numpys, axis=0) # M x N x 2 @@ -645,7 +728,7 @@ def add_instance(self, cam: Camcorder, instance: Instance): ) # Add the instance to the `InstanceGroup` - self.replace_instance(cam, instance) + self.replace_instance(cam=cam, instance=instance) def replace_instance(self, cam: Camcorder, instance: Instance): """Replace an `Instance` in the `InstanceGroup`. @@ -668,6 +751,9 @@ def replace_instance(self, cam: Camcorder, instance: Instance): # Remove the instance if it already exists self.remove_instance(instance_or_cam=instance) + # Remove the instance currently at the cam (if any) + self.remove_instance(instance_or_cam=cam) + # Replace the instance in the `InstanceGroup` self._instance_by_camcorder[cam] = instance self._camcorder_by_instance[instance] = cam @@ -793,6 +879,11 @@ def update_points( f"Camcorders in `cams_to_include` ({len(cams_to_include)})." ) + # Calculate OKS scores for the points + gt_points = self.numpy( + pred_as_nan=True, invisible_as_nan=True, cams_to_include=cams_to_include + ) # M x N x 2 + oks_scores = np.full((n_views, n_nodes), np.nan) for cam_idx, cam in enumerate(cams_to_include): # Get the instance for the cam instance: Optional[Instance] = self.get_instance(cam) @@ -802,11 +893,22 @@ def update_points( ) continue - # Update the points (and scores) for the (predicted) instance + # Compute the OKS score for the instance if it is a ground truth instance + if not isinstance(instance, PredictedInstance): + instance_oks = compute_oks( + gt_points[cam_idx, :, :], + points[cam_idx, :, :], + ) + oks_scores[cam_idx] = instance_oks + + # Update the points for the instance instance.update_points( points=points[cam_idx, :, :], exclude_complete=exclude_complete ) + # Update the score for the InstanceGroup to be the average OKS score + self.score = np.nanmean(oks_scores) # scalar + def __getitem__( self, idx_or_key: Union[int, Camcorder, Instance] ) -> Union[Camcorder, Instance]: @@ -841,7 +943,8 @@ def __len__(self): def __repr__(self): return ( f"{self.__class__.__name__}(name={self.name}, frame_idx={self.frame_idx}, " - f"instances:{len(self)}, camera_cluster={self.camera_cluster})" + f"score={self.score}, instances:{len(self)}, camera_cluster=" + f"{self.camera_cluster})" ) def __hash__(self) -> int: @@ -853,6 +956,7 @@ def from_instance_by_camcorder_dict( instance_by_camcorder: Dict[Camcorder, Instance], name: str, name_registry: Set[str], + score: Optional[float] = None, ) -> Optional["InstanceGroup"]: """Creates an `InstanceGroup` object from a dictionary. @@ -860,6 +964,8 @@ def from_instance_by_camcorder_dict( instance_by_camcorder: Dictionary with `Camcorder` keys and `Instance` values. name: Name to use for the `InstanceGroup`. name_registry: Set of names to check for uniqueness. + score: Optional score for the `InstanceGroup`. This will NOT update the + score of the `Instance`s within the `InstanceGroup`. Default is None. Raises: ValueError: If the `InstanceGroup` name is already in use. @@ -903,6 +1009,7 @@ def from_instance_by_camcorder_dict( frame_idx=frame_idx, camera_cluster=camera_cluster, instance_by_camcorder=instance_by_camcorder_copy, + score=score, ) def to_dict( @@ -930,10 +1037,14 @@ def to_dict( for cam, instance in self._instance_by_camcorder.items() } - return { + instance_group_dict = { "name": self.name, "camcorder_to_lf_and_inst_idx_map": camcorder_to_lf_and_inst_idx_map, } + if self.score is not None: + instance_group_dict["score"] = str(round(self.score, 4)) + + return instance_group_dict @classmethod def from_dict( @@ -957,6 +1068,13 @@ def from_dict( `InstanceGroup` object. """ + # Get the score (if available) + score = ( + float(instance_group_dict["score"]) + if "score" in instance_group_dict + else None + ) + # Get the `Instance` objects camcorder_to_lf_and_inst_idx_map: Dict[ str, Tuple[str, str] @@ -978,6 +1096,7 @@ def from_dict( instance_by_camcorder=instance_by_camcorder, name=instance_group_dict["name"], name_registry=name_registry, + score=score, ) @@ -999,6 +1118,7 @@ class RecordingSession: frame_inds: List of frame indices. cams_to_include: List of `Camcorder`s to include in this `FrameGroup`. excluded_views: List of excluded views (names of `Camcorder`s). + projection_bounds: Projection bounds for `Camcorder`s in `self.cams_to_include`. """ # TODO(LM): Consider implementing Observer pattern for `camera_cluster` and `labels` @@ -1009,11 +1129,28 @@ class RecordingSession: _frame_group_by_frame_idx: Dict[int, "FrameGroup"] = field(factory=dict) _cams_to_include: Optional[List[Camcorder]] = field(default=None) _excluded_views: Optional[Tuple[str]] = field(default=None) + _projection_bounds: Optional[np.ndarray] = field(default=None) + + @property + def id(self) -> str: + """Unique identifier for the `RecordingSession`.""" + if self.labels is not None and self in self.labels.sessions: + return self.labels.sessions.index(self) + else: + return hash(self) @property def videos(self) -> List[Video]: """List of `Video`s.""" + # TODO(LM): Should these be in the same order as `self.labels.videos`? + # e.g. switching between views in GUI should keep the same order, but not enforced. + # We COULD implicitly enforce this by adding videos in the same order as + # `self.labels.videos`, but "explicit is better than implicit". + # Instead, we could sort the videos by their index in labels.videos. This might + # bottleneck switching between views for sessions with lots of cameras/videos. + # Unless! We do this (each time) when adding the videos to the session instead + # of when accessing the videos. This would be a good compromise. return self.camera_cluster._videos_by_session[self] @property @@ -1085,6 +1222,81 @@ def excluded_views(self) -> Optional[Tuple[str]]: return self._excluded_views + def _recalculate_projection_bounds(self): + """Calculate the projection bounds for `Camcorder`s in `self.cams_to_include`. + + This method recreates the `_projection_bounds` attribute based on the linked + `Video`'s height and width. The `_projection_bounds` are updated one by one for + each `Video` added to the `RecordingSession` through the `add_video` method. + However, the `_projection_bounds` will need to be recalculated if the + `Video.height` or `Video.width` attribute is changed after the `Video` is added + to the `RecordingSession`. + + Currently, this method is only called on initialization/deserialization of the + `RecordingSession` and yields an all nan array. + """ + + # Get the projection bounds for all `Camcorder`s in the `RecordingSession` + n_cameras = len(self.camera_cluster.cameras) + bounds = np.full((n_cameras, 2), np.nan) + for cam in self.linked_cameras: + video: Video = self.get_video(cam) + + # Get the video's width and height + x_max = video.width + y_max = video.height + + # Allow triangulating even if no video information is available + if x_max is None or y_max is None: + continue + + # Update the bounds + bounds[self.camera_cluster.cameras.index(cam)] = (x_max, y_max) + self._projection_bounds = bounds.copy() + + @property + def projection_bounds(self) -> Tuple[Tuple[float, float], Tuple[float, float]]: + """Projection bounds for `Camcorder`s in the `RecordingSession.cams_to_include`. + + The projection bounds are based off the linked `Video`'s height and width. + + To recalculate the projection bounds, set the `_projection_bounds` attribute to + None, then access the `projection_bounds` property. + + Returns: + NumPy array of shape (n_cameras, 2) where the first column is the width and + the second column is the height of the `Video` linked to the associated + `Camcorder`. + """ + + # If the projection bounds have not been set, then calculate them + if self._projection_bounds is None: + # Reconstruct the projection bounds for all `Camcorder`s in the session + self._recalculate_projection_bounds() + + # Ensure we don't accidentally modify the underlying projection bounds + bounds = self._projection_bounds.copy() + + # Only return the bounds for cams to include + bounds = bounds[[cam in self.cams_to_include for cam in self.camera_cluster]] + return bounds + + @projection_bounds.setter + def projection_bounds(self, bounds: np.ndarray): + """Raises error if trying to set projection bounds directly. + + The underlying self._projection_bounds is updated automatically when calling + `RecordingSession.add_video`. + + Raises: + ValueError: If trying to set projection bounds directly. + """ + + raise ValueError( + "Cannot set projection bounds directly. Projection bounds are updated " + "automatically when calling `RecordingSession.add_video`." + ) + def get_video(self, camcorder: Camcorder) -> Optional[Video]: """Retrieve `Video` linked to `Camcorder`. @@ -1106,7 +1318,7 @@ def get_video(self, camcorder: Camcorder) -> Optional[Video]: ) if camcorder not in self._video_by_camcorder: - logger.warning( + logger.debug( f"Camcorder {camcorder.name} is not linked to a video in this " f"RecordingSession." ) @@ -1148,6 +1360,12 @@ def add_video(self, video: Video, camcorder: Camcorder): f"{self.camera_cluster}." ) + # Ensure the `Video` is a `Video` object + if not isinstance(video, Video): + raise ValueError( + f"Expected a `Video` object, but got {type(video)} instead." + ) + # Add session-to-videos (1-to-many) map to `CameraCluster` if self not in self.camera_cluster._videos_by_session: self.camera_cluster.add_session(self) @@ -1178,6 +1396,14 @@ def add_video(self, video: Video, camcorder: Camcorder): if self.labels is not None: self.labels.update_session(self, video) + # TODO(LM): Use observer pattern to update bounds when `Video.shape` changes + # Update projection bounds + x_max = video.width + y_max = video.height + if not (x_max is None or y_max is None): + cam_idx = self.camera_cluster.cameras.index(camcorder) + self._projection_bounds[cam_idx] = (x_max, y_max) + def remove_video(self, video: Video): """Removes a `Video` from the `RecordingSession`. @@ -1202,6 +1428,10 @@ def remove_video(self, video: Video): if self.labels is not None and self.labels.get_session(video) is not None: self.labels.remove_session_video(video=video) + # Update projection bounds + cam_idx = self.camera_cluster.cameras.index(camcorder) + self._projection_bounds[cam_idx] = (np.nan, np.nan) + def new_frame_group(self, frame_idx: int): """Creates and adds an empty `FrameGroup` to the `RecordingSession`. @@ -1242,6 +1472,9 @@ def get_videos_from_selected_cameras( return videos + def __bool__(self): + return True + def __attrs_post_init__(self): self.camera_cluster.add_session(self) @@ -1249,6 +1482,9 @@ def __attrs_post_init__(self): if self._cams_to_include is not None: self.cams_to_include = self._cams_to_include + # Initialize `_projection_bounds` by calling the property + self.projection_bounds + def __iter__(self) -> Iterator[List[Camcorder]]: return iter(self.camera_cluster) @@ -1365,7 +1601,7 @@ def to_session_dict( # Store camcorder-to-video indices map where key is camcorder index # and value is video index from `Labels.videos` camcorder_to_video_idx_map = {} - for cam_idx, camcorder in enumerate(self.camera_cluster): + for cam_idx, camcorder in enumerate(self.camera_cluster.cameras): # Skip if Camcorder is not linked to any Video if camcorder not in self._video_by_camcorder: continue @@ -1560,7 +1796,7 @@ def __attrs_post_init__(self): for camera in self.session.camera_cluster.cameras: self._instances_by_cam[camera] = set() for instance_group in self.instance_groups: - self.add_instance_group(instance_group) + self.add_instance_group(instance_group=instance_group) @property def instance_groups(self) -> List[InstanceGroup]: @@ -1621,6 +1857,8 @@ def numpy( self, instance_groups: Optional[List[InstanceGroup]] = None, pred_as_nan: bool = False, + invisible_as_nan: bool = True, + undistort: bool = False, ) -> np.ndarray: """Numpy array of all `InstanceGroup`s in `FrameGroup.cams_to_include`. @@ -1629,10 +1867,13 @@ def numpy( self.instance_groups. pred_as_nan: If True, then replaces `PredictedInstance`s with all nan self.dummy_instance. Default is False. + invisible_as_nan: If True, then replaces invisible points with nan. Default + is True. + undistort: If True, then undistort the points. Default is False. Returns: Numpy array of shape (M, T, N, 2) where M is the number of views (determined - by self.cames_to_include), T is the number of `InstanceGroup`s, N is the + by self.cams_to_include), T is the number of `InstanceGroup`s, N is the number of Nodes, and 2 is for x, y. """ @@ -1648,19 +1889,19 @@ def numpy( f"{self.instance_groups}" ) - instance_group_numpys: List[np.ndarray] = [] # len(T) M=all x N x 2 + instance_group_numpys: List[np.ndarray] = [] # len(T) M=include x N x 2 for instance_group in instance_groups: instance_group_numpy = instance_group.numpy( - pred_as_nan=pred_as_nan - ) # M=all x N x 2 + pred_as_nan=pred_as_nan, + invisible_as_nan=invisible_as_nan, + cams_to_include=self.cams_to_include, + undistort=undistort, + ) # M=include x N x 2 instance_group_numpys.append(instance_group_numpy) - frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=all x T x N x 2 - cams_to_include_mask = np.array( - [cam in self.cams_to_include for cam in self.session.cameras] - ) # M=all x 1 + frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=include x TxNx2 - return frame_group_numpy[cams_to_include_mask] # M=include x T x N x 2 + return frame_group_numpy # M=include x T x N x 2 def add_instance( self, @@ -1702,6 +1943,15 @@ def add_instance( # Add the `Instance` to the `InstanceGroup` if instance_group is not None: + # Remove any existing `Instance` in given `InstanceGroup` at same `Camcorder` + preexisting_instance = instance_group.get_instance(camera) + if preexisting_instance is not None: + self.remove_instance(instance=preexisting_instance) + + # Remove the `Instance` from the `FrameGroup` if it is already exists + self.remove_instance(instance=instance, remove_empty_instance_group=True) + + # Add the `Instance` to the `InstanceGroup` instance_group.add_instance(cam=camera, instance=instance) else: self._raise_if_instance_not_in_instance_group(instance=instance) @@ -1715,11 +1965,15 @@ def add_instance( labeled_frame = instance.frame self.add_labeled_frame(labeled_frame=labeled_frame, camera=camera) - def remove_instance(self, instance: Instance): + def remove_instance( + self, instance: Instance, remove_empty_instance_group: bool = False + ): """Removes an `Instance` from the `FrameGroup`. Args: instance: `Instance` to remove from the `FrameGroup`. + remove_empty_instance_group: If True, then remove the `InstanceGroup` if it + is empty. Default is False. """ instance_group = self.get_instance_group(instance=instance) @@ -1736,12 +1990,22 @@ def remove_instance(self, instance: Instance): instance_group.remove_instance(instance_or_cam=instance) # Remove the `Instance` from the `FrameGroup` - self._instances_by_cam[camera].remove(instance) + if instance in self._instances_by_cam[camera]: + self._instances_by_cam[camera].remove(instance) + else: + logger.debug( + f"Instance {instance} not found in this FrameGroup: " + f"{self._instances_by_cam[camera]}." + ) # Remove "empty" `LabeledFrame`s from the `FrameGroup` if len(self._instances_by_cam[camera]) < 1: self.remove_labeled_frame(labeled_frame_or_camera=camera) + # Remove the `InstanceGroup` if it is empty + if remove_empty_instance_group and len(instance_group.instances) < 1: + self.remove_instance_group(instance_group=instance_group) + def add_instance_group( self, instance_group: Optional[InstanceGroup] = None ) -> InstanceGroup: @@ -1808,11 +2072,9 @@ def remove_instance_group(self, instance_group: InstanceGroup): # Remove the `Instance`s from the `FrameGroup` for camera, instance in instance_group.instance_by_camcorder.items(): self._instances_by_cam[camera].remove(instance) - - # Remove the `LabeledFrame` from the `FrameGroup` - labeled_frame = self.get_labeled_frame(camera=camera) - if labeled_frame is not None: - self.remove_labeled_frame(labeled_frame_or_camera=camera) + # Remove the `LabeledFrame` if no more grouped instances + if len(self._instances_by_cam[camera]) < 1: + self.remove_labeled_frame(labeled_frame_or_camera=camera) # TODO(LM): maintain this as a dictionary for quick lookups def get_instance_group(self, instance: Instance) -> Optional[InstanceGroup]: @@ -2032,7 +2294,6 @@ def upsert_points( This will update the points for existing `Instance`s in the `InstanceGroup`s and also add new `Instance`s if they do not exist. - Included cams are specified by `FrameGroup.cams_to_include`. The ordering of the `InstanceGroup`s in `instance_groups` should match the @@ -2056,6 +2317,22 @@ def upsert_points( ), f"Expected {len(instance_groups)} instances, got {n_instances}." assert n_coords == 2, f"Expected 2 coordinates, got {n_coords}." + # Ensure we are working with a float array + points = points.astype(float) + + # Get projection bounds (based on video height/width) + bounds = self.session.projection_bounds + bounds_expanded_x = bounds[:, None, None, 0] + bounds_expanded_y = bounds[:, None, None, 1] + + # Create masks for out-of-bounds x and y coordinates + out_of_bounds_x = (points[..., 0] < 0) | (points[..., 0] > bounds_expanded_x) + out_of_bounds_y = (points[..., 1] < 0) | (points[..., 1] > bounds_expanded_y) + + # Replace out-of-bounds x and y coordinates with nan + points[out_of_bounds_x, 0] = np.nan + points[out_of_bounds_y, 1] = np.nan + # Update points for each `InstanceGroup` for ig_idx, instance_group in enumerate(instance_groups): # Ensure that `InstanceGroup`s is in this `FrameGroup` diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 17f1555ce..43c094c50 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1,5 +1,5 @@ """ -A SLEAP dataset collects labeled video frames, together with required metadata. +A SLEAP dataset collects labeled video frames, together with required metadata. This contains labeled frame data (user annotations and/or predictions), together with all the other data that is saved for a SLEAP project @@ -37,6 +37,7 @@ the file will be saved in the corresponding format. You can also specify the default extension to use if none is provided in the filename. """ + import itertools import os from collections.abc import MutableSequence @@ -111,6 +112,7 @@ def rebuild_cache(self): self._track_occupancy = dict() self._frame_count_cache = dict() self._session_by_video: Dict[Video, RecordingSession] = dict() + self._linkage_of_videos = {"linked": [], "unlinked": []} # Loop through labeled frames only once for lf in self.labels: @@ -128,6 +130,16 @@ def rebuild_cache(self): for video in session.videos: self._session_by_video[video] = session + # Build linkage of videos by session + for video in self.labels.videos: + if ( + video not in self._session_by_video + or self._session_by_video[video] is None + ): + self._linkage_of_videos["unlinked"].append(video) + else: + self._linkage_of_videos["linked"].append(video) + def add_labeled_frame(self, new_frame: LabeledFrame): """Add a new labeled frame to the cache. @@ -163,6 +175,20 @@ def add_video_to_session(self, session: RecordingSession, new_video: Video): self._session_by_video[new_video] = session + def update_linkage_of_videos(self): + """Updates a dictionary of linked and unlinked videos.""" + temp = {"linked": [], "unlinked": []} + for video in self.labels.videos: + if ( + video not in self._session_by_video + or self._session_by_video[video] is None + ): + temp["unlinked"].append(video) + else: + temp["linked"].append(video) + + self._linkage_of_videos = temp + def update( self, new_item: Optional[ @@ -183,6 +209,8 @@ def update( elif isinstance(new_item, tuple): self.add_video_to_session(*new_item) + self.update_linkage_of_videos() + def find_frames( self, video: Video, frame_idx: Optional[Union[int, Iterable[int]]] = None ) -> Optional[List[LabeledFrame]]: @@ -1293,9 +1321,11 @@ def get_template_instance_points(self, skeleton: Skeleton): template_points = np.stack( [ - node_positions[node] - if node in node_positions - else np.random.randint(0, 50, size=2) + ( + node_positions[node] + if node in node_positions + else np.random.randint(0, 50, size=2) + ) for node in skeleton.nodes ] ) @@ -1411,6 +1441,14 @@ def remove_instance( if not in_transaction: self._cache.remove_instance(frame, instance) + # TODO: Do NOT merge into develop, this next line is handled by InstancesList + # Check that if a `PredictedInstance` is removed, that it is not referenced + from_predicted_instances = [inst.from_predicted for inst in frame.instances] + if instance in from_predicted_instances: + containing_inst_idx: int = from_predicted_instances.index(instance) + containing_inst: Instance = frame.instances[containing_inst_idx] + containing_inst.from_predicted = None + # Also remove instance from `InstanceGroup` if any session = self.get_session(frame.video) if session is None: @@ -1453,6 +1491,10 @@ def add_instance(self, frame: LabeledFrame, instance: Instance): # Add instance and track to labels frame.instances.append(instance) + + # TODO: Do NOT merge into develop, this next line is handled by InstancesList + instance.frame = frame # Needed to add instance to instance group + if (instance.track is not None) and (instance.track not in self.tracks): self.add_track(video=frame.video, track=instance.track) @@ -1762,6 +1804,15 @@ def remove_session_video(self, video: Video): if session.get_camera(video) is not None: session.remove_video(video) + def remove_recording_session(self, session: RecordingSession): + """Remove a session from self.sessions. + + Args: + session: `RecordingSession` instance + """ + if session in self._sessions: + self._sessions.remove(session) + @classmethod def from_json(cls, *args, **kwargs): from sleap.io.format.labels_json import LabelsJsonAdaptor @@ -2696,9 +2747,11 @@ def set_track( # whether they're tracked. n_insts = max( [ - lf.n_user_instances - if lf.n_user_instances > 0 # take user instances over predicted - else lf.n_predicted_instances + ( + lf.n_user_instances + if lf.n_user_instances > 0 # take user instances over predicted + else lf.n_predicted_instances + ) for lf in lfs ] ) @@ -2911,7 +2964,6 @@ def find_path_using_paths(missing_path: Text, search_paths: List[Text]) -> Text: # Look for file with that name in each of the search path directories for search_path in search_paths: - if os.path.isfile(search_path): path_dir = os.path.dirname(search_path) else: diff --git a/sleap/io/visuals.py b/sleap/io/visuals.py index f2dde0be3..adb51f602 100644 --- a/sleap/io/visuals.py +++ b/sleap/io/visuals.py @@ -456,7 +456,7 @@ def _plot_instance_cv( ) if self.show_edges: - for (src, dst) in instance.skeleton.edge_inds: + for src, dst in instance.skeleton.edge_inds: # Get points for the nodes connected by this edge src_x, src_y = points_array[src] dst_x, dst_y = points_array[dst] @@ -542,8 +542,9 @@ def save_labeled_video( for what instance/node/edge palette: SLEAP color palette to use. Options include: "alphabet", "five+", "solarized", or "standard". Only used if `color_manager` is None. - distinctly_color: Specify how to color instances. Options include: "instances", - "edges", and "nodes". Only used if `color_manager` is None. + distinctly_color: Specify how to color instances. Options include: + "instance groups", "instances", "edges", and "nodes". Only used if + `color_manager` is None. gui_progress: Whether to show Qt GUI progress dialog. Returns: diff --git a/sleap/nn/evals.py b/sleap/nn/evals.py index 002f8a143..936d0b4de 100644 --- a/sleap/nn/evals.py +++ b/sleap/nn/evals.py @@ -45,6 +45,7 @@ TopDownMultiClassPredictor, SingleInstancePredictor, ) +from sleap.util import compute_oks logger = logging.getLogger(__name__) @@ -113,143 +114,6 @@ def find_frame_pairs( return frame_pairs -def compute_instance_area(points: np.ndarray) -> np.ndarray: - """Compute the area of the bounding box of a set of keypoints. - - Args: - points: A numpy array of coordinates. - - Returns: - The area of the bounding box of the points. - """ - if points.ndim == 2: - points = np.expand_dims(points, axis=0) - - min_pt = np.nanmin(points, axis=-2) - max_pt = np.nanmax(points, axis=-2) - - return np.prod(max_pt - min_pt, axis=-1) - - -def compute_oks( - points_gt: np.ndarray, - points_pr: np.ndarray, - scale: Optional[float] = None, - stddev: float = 0.025, - use_cocoeval: bool = True, -) -> np.ndarray: - """Compute the object keypoints similarity between sets of points. - - Args: - points_gt: Ground truth instances of shape (n_gt, n_nodes, n_ed), - where n_nodes is the number of body parts/keypoint types, and n_ed - is the number of Euclidean dimensions (typically 2 or 3). Keypoints - that are missing/not visible should be represented as NaNs. - points_pr: Predicted instance of shape (n_pr, n_nodes, n_ed). - use_cocoeval: Indicates whether the OKS score is calculated like cocoeval - method or not. True indicating the score is calculated using the - cocoeval method (widely used and the code can be found here at - https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L192C5-L233C20) - and False indicating the score is calculated using the method exactly - as given in the paper referenced in the Notes below. - scale: Size scaling factor to use when weighing the scores, typically - the area of the bounding box of the instance (in pixels). This - should be of the length n_gt. If a scalar is provided, the same - number is used for all ground truth instances. If set to None, the - bounding box area of the ground truth instances will be calculated. - stddev: The standard deviation associated with the spread in the - localization accuracy of each node/keypoint type. This should be of - the length n_nodes. "Easier" keypoint types will have lower values - to reflect the smaller spread expected in localizing it. - - Returns: - The object keypoints similarity between every pair of ground truth and - predicted instance, a numpy array of of shape (n_gt, n_pr) in the range - of [0, 1.0], with 1.0 denoting a perfect match. - - Notes: - It's important to set the stddev appropriately when accounting for the - difficulty of each keypoint type. For reference, the median value for - all keypoint types in COCO is 0.072. The "easiest" keypoint is the left - eye, with stddev of 0.025, since it is easy to precisely locate the - eyes when labeling. The "hardest" keypoint is the left hip, with stddev - of 0.107, since it's hard to locate the left hip bone without external - anatomical features and since it is often occluded by clothing. - - The implementation here is based off of the descriptions in: - Ronch & Perona. "Benchmarking and Error Diagnosis in Multi-Instance Pose - Estimation." ICCV (2017). - """ - if points_gt.ndim == 2: - points_gt = np.expand_dims(points_gt, axis=0) - if points_pr.ndim == 2: - points_pr = np.expand_dims(points_pr, axis=0) - - if scale is None: - scale = compute_instance_area(points_gt) - - n_gt, n_nodes, n_ed = points_gt.shape # n_ed = 2 or 3 (euclidean dimensions) - n_pr = points_pr.shape[0] - - # If scalar scale was provided, use the same for each ground truth instance. - if np.isscalar(scale): - scale = np.full(n_gt, scale) - - # If scalar standard deviation was provided, use the same for each node. - if np.isscalar(stddev): - stddev = np.full(n_nodes, stddev) - - # Compute displacement between each pair. - displacement = np.reshape(points_gt, (n_gt, 1, n_nodes, n_ed)) - np.reshape( - points_pr, (1, n_pr, n_nodes, n_ed) - ) - assert displacement.shape == (n_gt, n_pr, n_nodes, n_ed) - - # Convert to pairwise Euclidean distances. - distance = (displacement ** 2).sum(axis=-1) # (n_gt, n_pr, n_nodes) - assert distance.shape == (n_gt, n_pr, n_nodes) - - # Compute the normalization factor per keypoint. - if use_cocoeval: - # If use_cocoeval is True, then compute normalization factor according to cocoeval. - spread_factor = (2 * stddev) ** 2 - scale_factor = 2 * (scale + np.spacing(1)) - else: - # If use_cocoeval is False, then compute normalization factor according to the paper. - spread_factor = stddev ** 2 - scale_factor = 2 * ((scale + np.spacing(1)) ** 2) - normalization_factor = np.reshape(spread_factor, (1, 1, n_nodes)) * np.reshape( - scale_factor, (n_gt, 1, 1) - ) - assert normalization_factor.shape == (n_gt, 1, n_nodes) - - # Since a "miss" is considered as KS < 0.5, we'll set the - # distances for predicted points that are missing to inf. - missing_pr = np.any(np.isnan(points_pr), axis=-1) # (n_pr, n_nodes) - assert missing_pr.shape == (n_pr, n_nodes) - distance[:, missing_pr] = np.inf - - # Compute the keypoint similarity as per the top of Eq. 1. - ks = np.exp(-(distance / normalization_factor)) # (n_gt, n_pr, n_nodes) - assert ks.shape == (n_gt, n_pr, n_nodes) - - # Set the KS for missing ground truth points to 0. - # This is equivalent to the visibility delta function of the bottom - # of Eq. 1. - missing_gt = np.any(np.isnan(points_gt), axis=-1) # (n_gt, n_nodes) - assert missing_gt.shape == (n_gt, n_nodes) - ks[np.expand_dims(missing_gt, axis=1)] = 0 - - # Compute the OKS. - n_visible_gt = np.sum( - (~missing_gt).astype("float64"), axis=-1, keepdims=True - ) # (n_gt, 1) - oks = np.sum(ks, axis=-1) / n_visible_gt - assert oks.shape == (n_gt, n_pr) - - return oks - - def match_instances( frame_gt: LabeledFrame, frame_pr: LabeledFrame, diff --git a/sleap/nn/tracker/components.py b/sleap/nn/tracker/components.py index 10b2953b7..06ce82812 100644 --- a/sleap/nn/tracker/components.py +++ b/sleap/nn/tracker/components.py @@ -12,15 +12,18 @@ """ + +from __future__ import annotations + import operator from collections import defaultdict -from typing import List, Tuple, Optional, TypeVar, Callable +from typing import Callable, List, Optional, Tuple, TypeVar import attr import numpy as np from scipy.optimize import linear_sum_assignment -from sleap import PredictedInstance, Instance, Track +from sleap import Instance, LabeledFrame, PredictedInstance, Track from sleap.nn import utils InstanceType = TypeVar("InstanceType", Instance, PredictedInstance) @@ -110,11 +113,27 @@ def greedy_matching(cost_matrix: np.ndarray) -> List[Tuple[int, int]]: def nms_instances( - instances, iou_threshold, target_count=None -) -> Tuple[List[PredictedInstance], List[PredictedInstance]]: + instances: list[PredictedInstance], + iou_threshold: float | None, + target_count: int | None = None, +) -> tuple[list[PredictedInstance], list[PredictedInstance]]: + """Finds `Instance`s to keep using non-maximum suppression. + + Args: + instances: The list of `PredictedInstance` objects to filter. + iou_threshold: The IOU threshold for suppression. If None, then no suppression + is applied and `PredictedInstance.score` is used to determine which + instances to keep. + target_count: The maximum number of instances to keep. Default is None, + which means all instances are kept. + + Returns: + A tuple of two lists: the first list contains the instances to keep, and + the second list contains the instances to remove. + """ boxes = np.array([inst.bounding_box for inst in instances]) scores = np.array([inst.score for inst in instances]) - picks = nms_fast(boxes, scores, iou_threshold, target_count) + picks: list[int] = nms_fast(boxes, scores, iou_threshold, target_count) to_keep = [inst for i, inst in enumerate(instances) if i in picks] to_remove = [inst for i, inst in enumerate(instances) if i not in picks] @@ -122,90 +141,100 @@ def nms_instances( return to_keep, to_remove -def nms_fast(boxes, scores, iou_threshold, target_count=None) -> List[int]: - """https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/""" +def nms_fast( + boxes: np.ndarray, + scores: np.ndarray, + iou_threshold: float, + target_count: int | None = None, +) -> list[int]: + """Finds indices of boxes to keep using non-maximum suppression. + + https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/ - # if there are no boxes, return an empty list + Args: + boxes: The bounding boxes to filter. Each box is represented by its coordinates + in the format (x1, y1, x2, y2) where (x1, y1) is the corner closest to + (0, 0) and (x2, y2) is the corner farthest from (0, 0). Shape is (N, 4) + where N is the number of boxes. + scores: The scores for each bounding box (e.g., confidence scores associated + with `PredictedInstance`). Scores are used to pick which boxes to evaluate + first (i.e. the highest scoring box is never removed). Shape is (N,) where + N is the number of boxes. + iou_threshold: The IOU threshold for suppression. This is a soft threshold since + boxes are added back if there are too few boxes picked (determined by + `target_count`). + target_count: The maximum number of boxes to keep. Default is None, which + means only boxes with IOU less than the threshold are kept. + + Returns: + A list of indices of the boxes that have an IOU less than the IOU threshold. + """ + # Return an empty list if no boxes. if len(boxes) == 0: return [] - # if we already have fewer boxes than the target count, return all boxes + # Return all boxes (if target_count is None or greater than the number of boxes). if target_count and len(boxes) < target_count: return list(range(len(boxes))) - # if the bounding boxes coordinates are integers, convert them to floats -- - # this is important since we'll be doing a bunch of divisions + # Convert boxes to float if they are integers (for higher precision division). if boxes.dtype.kind == "i": boxes = boxes.astype("float") - # initialize the list of picked indexes - picked_idxs = [] - - # init list of boxes removed by nms - nms_idxs = [] - - # grab the coordinates of the bounding boxes - x1 = boxes[:, 0] - y1 = boxes[:, 1] + # Grab the coordinates of all the bounding boxes. + x1 = boxes[:, 0] # x1 <= x2 + y1 = boxes[:, 1] # y1 <= y2 x2 = boxes[:, 2] y2 = boxes[:, 3] - # compute the area of the bounding boxes and sort the bounding - # boxes by their scores - area = (x2 - x1 + 1) * (y2 - y1 + 1) - idxs = np.argsort(scores) - - # keep looping while some indexes still remain in the indexes list - while len(idxs) > 0: + # Compute the area of all the bounding boxes. + areas = (x2 - x1 + 1) * (y2 - y1 + 1) - # we want to add the best box which is the last box in sorted list + picked_idxs = [] + nms_idxs = [] + idxs = np.argsort(scores) # The higher-scoring boxes are at the end of the list. + while len(idxs) > 0: # Each iteration, we remove the last box in `idxs`. + # Get highest score box (last in list) and add to picked boxes. picked_box_idx = idxs[-1] - - # last = len(idxs) - 1 - # i = idxs[last] picked_idxs.append(picked_box_idx) - # find the largest (x, y) coordinates for the start of - # the bounding box and the smallest (x, y) coordinates - # for the end of the bounding box + # Find the smallest (x, y) coordinates for corner 1 of the bounding box. xx1 = np.maximum(x1[picked_box_idx], x1[idxs[:-1]]) yy1 = np.maximum(y1[picked_box_idx], y1[idxs[:-1]]) + + # Find the largest (x, y) coordinates for corner 2 of the bounding box. xx2 = np.minimum(x2[picked_box_idx], x2[idxs[:-1]]) yy2 = np.minimum(y2[picked_box_idx], y2[idxs[:-1]]) - # compute the width and height of the bounding box + # Compute the ratio of overlap. w = np.maximum(0, xx2 - xx1 + 1) h = np.maximum(0, yy2 - yy1 + 1) + overlap = (w * h) / areas[idxs[:-1]] - # compute the ratio of overlap - overlap = (w * h) / area[idxs[:-1]] - - # find boxes with iou over threshold + # Find and remove boxes with iou over threshold. nms_for_new_box = np.where(overlap > iou_threshold)[0] - nms_idxs.extend(list(idxs[nms_for_new_box])) + nms_idxs.extend(list(idxs[nms_for_new_box])) # In case we need to add back. + idxs = np.delete(idxs, nms_for_new_box) - # delete new box (last in list) plus nms boxes - idxs = np.delete(idxs, nms_for_new_box)[:-1] + # Remove the last box (the one we just picked). + idxs = idxs[:-1] - # if we're below the target number of boxes, add some back + # Add some boxes back if we have too few picked boxes. if target_count and nms_idxs and len(picked_idxs) < target_count: - # sort by descending score + # Add back boxes with the highest scores. nms_idxs.sort(key=lambda idx: -scores[idx]) - add_back_count = min(len(nms_idxs), len(picked_idxs) - target_count) picked_idxs.extend(nms_idxs[:add_back_count]) - # return the list of picked boxes return picked_idxs def cull_instances( - frames: List["LabeledFrame"], + frames: List[LabeledFrame], instance_count: int, iou_threshold: Optional[float] = None, -): - """ - Removes instances from frames over instance per frame threshold. +) -> None: + """Removes instances from frames over instance per frame threshold. Args: frames: The list of `LabeledFrame` objects with predictions. @@ -256,52 +285,70 @@ def cull_instances( def cull_frame_instances( - instances_list: List[InstanceType], - instance_count: int, - iou_threshold: Optional[float] = None, -) -> List["LabeledFrame"]: - """ - Removes instances (for single frame) over instance per frame threshold. + instances_list: list[InstanceType], + instance_count: int | None = None, + iou_threshold: float | None = None, + general_iou_threshold: float | None = None, +) -> list[InstanceType]: + """Removes instances (for single frame) over instance per frame threshold. Args: instances_list: The list of instances for a single frame. - instance_count: The maximum number of instances we want per frame. - iou_threshold: Intersection over Union (IOU) threshold to use when - removing overlapping instances over target count; if None, then - only use score to determine which instances to remove. + instance_count: The maximum number of instances we want per frame. If None, then + no limit is applied. Default is None. + iou_threshold: Intersection over Union (IOU) threshold to use when removing + overlapping instances over `instance_count`. If None, then only use score to + determine which instances to remove over `instance_count`. Default is None. + general_iou_threshold: Intersection over Union (IOU) threshold to use when + removing overlapping instances - regardless of `instance_count`. If None, + then no general IOU threshold is applied. Default is None. Returns: - Updated list of frames, also modifies frames in place. + Updated `instances_list` (modified in-place). """ if not instances_list: return - if len(instances_list) > instance_count: - # List of instances which we'll pare down - keep_instances = instances_list + # List of instances which we'll pare down + keep_instances = instances_list + + # First, let's remove instances over the general IOU threshold + if general_iou_threshold is not None: # Use NMS to remove overlapping instances over target count - if iou_threshold: - keep_instances, extra_instances = nms_instances( - keep_instances, - iou_threshold=iou_threshold, - target_count=instance_count, - ) - # Remove the extra instances - for inst in extra_instances: - instances_list.remove(inst) - - # Use lower score to remove instances over target count - if len(keep_instances) > instance_count: - # Sort by ascending score, get target number of instances - # from the end of list (i.e., with highest score) - extra_instances = sorted(keep_instances, key=operator.attrgetter("score"))[ - :-instance_count - ] - - # Remove the extra instances - for inst in extra_instances: - instances_list.remove(inst) + keep_instances, extra_instances = nms_instances( + keep_instances, + iou_threshold=general_iou_threshold, + ) + + # Remove the extra instances + for inst in extra_instances: + instances_list.remove(inst) + + # If we have no restrictions on instance count, return the list. + if instance_count is None or len(instances_list) <= instance_count: + return instances_list + + # Otherwise, let's determine instances to remove over the target count... + extra_instances = [] + + # ...using NMS to remove overlapping instances over target count. + if iou_threshold is not None: + keep_instances, extra_instances = nms_instances( + keep_instances, + iou_threshold=iou_threshold, + target_count=instance_count, + ) + + # ...using lower score to remove instances over target count. + elif len(keep_instances) > instance_count: # Only true if no iou threshold. + extra_instances = sorted(keep_instances, key=operator.attrgetter("score"))[ + :-instance_count + ] + + # Remove the extra instances. + for inst in extra_instances: + instances_list.remove(inst) return instances_list diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 9865b7db5..a1701c9f0 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -827,6 +827,7 @@ def make_tracker_by_name( target_instance_count: int = 0, pre_cull_to_target: bool = False, pre_cull_iou_threshold: Optional[float] = None, + pre_cull_general_iou_threshold: Optional[float] = None, # Post-tracking options to connect broken tracks post_connect_single_breaks: bool = False, # TODO: deprecate these post-tracking cleaning options @@ -878,13 +879,18 @@ def make_tracker_by_name( ) pre_cull_function = None - if target_instance_count and pre_cull_to_target: + if ( + target_instance_count + and pre_cull_to_target + or pre_cull_general_iou_threshold + ): def pre_cull_function(inst_list): cull_frame_instances( inst_list, instance_count=target_instance_count, iou_threshold=pre_cull_iou_threshold, + general_iou_threshold=pre_cull_general_iou_threshold, ) tracker_obj = cls( @@ -961,6 +967,14 @@ def get_by_name_factory_options(cls): ) options.append(option) + option = dict(name="pre_cull_general_iou_threshold", default=0) + option["type"] = float + option["help"] = ( + "If non-zero, then use IOU threshold to remove overlapping instances " + "regardless of the target count *before* tracking." + ) + options.append(option) + option = dict(name="post_connect_single_breaks", default=0) option["type"] = int option["help"] = ( diff --git a/sleap/prefs.py b/sleap/prefs.py index 3d5a2113e..fc1efaf1d 100644 --- a/sleap/prefs.py +++ b/sleap/prefs.py @@ -15,6 +15,7 @@ class Preferences(object): "medium step size": 10, "large step size": 100, "color predicted": False, + "distinctly_color": "instances", "propagate track labels": True, "palette": "standard", "bold lines": False, diff --git a/sleap/util.py b/sleap/util.py index c27cb6c09..75e24b423 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -82,6 +82,143 @@ def deep_iterable_converter(member_converter, iterable_converter=None): return _DeepIterableConverter(member_converter, iterable_converter) +def compute_instance_area(points: np.ndarray) -> np.ndarray: + """Compute the area of the bounding box of a set of keypoints. + + Args: + points: A numpy array of coordinates. + + Returns: + The area of the bounding box of the points. + """ + if points.ndim == 2: + points = np.expand_dims(points, axis=0) + + min_pt = np.nanmin(points, axis=-2) + max_pt = np.nanmax(points, axis=-2) + + return np.prod(max_pt - min_pt, axis=-1) + + +def compute_oks( + points_gt: np.ndarray, + points_pr: np.ndarray, + scale: Optional[float] = None, + stddev: float = 0.025, + use_cocoeval: bool = True, +) -> np.ndarray: + """Compute the object keypoints similarity between sets of points. + + Args: + points_gt: Ground truth instances of shape (n_gt, n_nodes, n_ed), + where n_nodes is the number of body parts/keypoint types, and n_ed + is the number of Euclidean dimensions (typically 2 or 3). Keypoints + that are missing/not visible should be represented as NaNs. + points_pr: Predicted instance of shape (n_pr, n_nodes, n_ed). + use_cocoeval: Indicates whether the OKS score is calculated like cocoeval + method or not. True indicating the score is calculated using the + cocoeval method (widely used and the code can be found here at + https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L192C5-L233C20) + and False indicating the score is calculated using the method exactly + as given in the paper referenced in the Notes below. + scale: Size scaling factor to use when weighing the scores, typically + the area of the bounding box of the instance (in pixels). This + should be of the length n_gt. If a scalar is provided, the same + number is used for all ground truth instances. If set to None, the + bounding box area of the ground truth instances will be calculated. + stddev: The standard deviation associated with the spread in the + localization accuracy of each node/keypoint type. This should be of + the length n_nodes. "Easier" keypoint types will have lower values + to reflect the smaller spread expected in localizing it. + + Returns: + The object keypoints similarity between every pair of ground truth and + predicted instance, a numpy array of of shape (n_gt, n_pr) in the range + of [0, 1.0], with 1.0 denoting a perfect match. + + Notes: + It's important to set the stddev appropriately when accounting for the + difficulty of each keypoint type. For reference, the median value for + all keypoint types in COCO is 0.072. The "easiest" keypoint is the left + eye, with stddev of 0.025, since it is easy to precisely locate the + eyes when labeling. The "hardest" keypoint is the left hip, with stddev + of 0.107, since it's hard to locate the left hip bone without external + anatomical features and since it is often occluded by clothing. + + The implementation here is based off of the descriptions in: + Ronch & Perona. "Benchmarking and Error Diagnosis in Multi-Instance Pose + Estimation." ICCV (2017). + """ + if points_gt.ndim == 2: + points_gt = np.expand_dims(points_gt, axis=0) + if points_pr.ndim == 2: + points_pr = np.expand_dims(points_pr, axis=0) + + if scale is None: + scale = compute_instance_area(points_gt) + + n_gt, n_nodes, n_ed = points_gt.shape # n_ed = 2 or 3 (euclidean dimensions) + n_pr = points_pr.shape[0] + + # If scalar scale was provided, use the same for each ground truth instance. + if np.isscalar(scale): + scale = np.full(n_gt, scale) + + # If scalar standard deviation was provided, use the same for each node. + if np.isscalar(stddev): + stddev = np.full(n_nodes, stddev) + + # Compute displacement between each pair. + displacement = np.reshape(points_gt, (n_gt, 1, n_nodes, n_ed)) - np.reshape( + points_pr, (1, n_pr, n_nodes, n_ed) + ) + assert displacement.shape == (n_gt, n_pr, n_nodes, n_ed) + + # Convert to pairwise Euclidean distances. + distance = (displacement ** 2).sum(axis=-1) # (n_gt, n_pr, n_nodes) + assert distance.shape == (n_gt, n_pr, n_nodes) + + # Compute the normalization factor per keypoint. + if use_cocoeval: + # If use_cocoeval is True, then compute normalization factor according to cocoeval. + spread_factor = (2 * stddev) ** 2 + scale_factor = 2 * (scale + np.spacing(1)) + else: + # If use_cocoeval is False, then compute normalization factor according to the paper. + spread_factor = stddev ** 2 + scale_factor = 2 * ((scale + np.spacing(1)) ** 2) + normalization_factor = np.reshape(spread_factor, (1, 1, n_nodes)) * np.reshape( + scale_factor, (n_gt, 1, 1) + ) + assert normalization_factor.shape == (n_gt, 1, n_nodes) + + # Since a "miss" is considered as KS < 0.5, we'll set the + # distances for predicted points that are missing to inf. + missing_pr = np.any(np.isnan(points_pr), axis=-1) # (n_pr, n_nodes) + assert missing_pr.shape == (n_pr, n_nodes) + distance[:, missing_pr] = np.inf + + # Compute the keypoint similarity as per the top of Eq. 1. + ks = np.exp(-(distance / normalization_factor)) # (n_gt, n_pr, n_nodes) + assert ks.shape == (n_gt, n_pr, n_nodes) + + # Set the KS for missing ground truth points to 0. + # This is equivalent to the visibility delta function of the bottom + # of Eq. 1. + missing_gt = np.any(np.isnan(points_gt), axis=-1) # (n_gt, n_nodes) + assert missing_gt.shape == (n_gt, n_nodes) + ks[np.expand_dims(missing_gt, axis=1)] = 0 + + # Compute the OKS. + n_visible_gt = np.sum( + (~missing_gt).astype("float64"), axis=-1, keepdims=True + ) # (n_gt, 1) + oks = np.sum(ks, axis=-1) / n_visible_gt + assert oks.shape == (n_gt, n_pr) + + return oks + + def json_loads(json_str: str) -> Dict: """A simple wrapper around the JSON decoder we are using. diff --git a/tests/data/videos/min_session_back.mp4 b/tests/data/videos/min_session_back.mp4 index ae6fb2582..0e925d00f 100644 Binary files a/tests/data/videos/min_session_back.mp4 and b/tests/data/videos/min_session_back.mp4 differ diff --git a/tests/data/videos/min_session_backL.mp4 b/tests/data/videos/min_session_backL.mp4 index 3108e2d47..1edbaf6d4 100644 Binary files a/tests/data/videos/min_session_backL.mp4 and b/tests/data/videos/min_session_backL.mp4 differ diff --git a/tests/data/videos/min_session_mid.mp4 b/tests/data/videos/min_session_mid.mp4 index ab11795f1..e845f8b12 100644 Binary files a/tests/data/videos/min_session_mid.mp4 and b/tests/data/videos/min_session_mid.mp4 differ diff --git a/tests/data/videos/min_session_midL.mp4 b/tests/data/videos/min_session_midL.mp4 index 80de34f9d..ee9699bdb 100644 Binary files a/tests/data/videos/min_session_midL.mp4 and b/tests/data/videos/min_session_midL.mp4 differ diff --git a/tests/data/videos/min_session_side.mp4 b/tests/data/videos/min_session_side.mp4 index 6c6db58a7..f7977d7f5 100644 Binary files a/tests/data/videos/min_session_side.mp4 and b/tests/data/videos/min_session_side.mp4 differ diff --git a/tests/data/videos/min_session_sideL.mp4 b/tests/data/videos/min_session_sideL.mp4 index 6834448d6..639d010cb 100644 Binary files a/tests/data/videos/min_session_sideL.mp4 and b/tests/data/videos/min_session_sideL.mp4 differ diff --git a/tests/data/videos/min_session_top.mp4 b/tests/data/videos/min_session_top.mp4 index 0f817b4a6..b3422b523 100644 Binary files a/tests/data/videos/min_session_top.mp4 and b/tests/data/videos/min_session_top.mp4 differ diff --git a/tests/data/videos/min_session_topL.mp4 b/tests/data/videos/min_session_topL.mp4 index ce59a45ff..d5614a402 100644 Binary files a/tests/data/videos/min_session_topL.mp4 and b/tests/data/videos/min_session_topL.mp4 differ diff --git a/tests/fixtures/cameras.py b/tests/fixtures/cameras.py index 6f30511e8..554dd6480 100644 --- a/tests/fixtures/cameras.py +++ b/tests/fixtures/cameras.py @@ -1,5 +1,9 @@ """Camera fixtures for pytest.""" +import shutil +import toml +from pathlib import Path + import pytest from sleap.io.cameras import CameraCluster, RecordingSession @@ -18,3 +22,35 @@ def min_session_camera_cluster(min_session_calibration_toml_path): @pytest.fixture def min_session_session(min_session_calibration_toml_path): return RecordingSession.load(min_session_calibration_toml_path) + + +@pytest.fixture +def min_session_directory(tmpdir, min_session_calibration_toml_path): + # Create a new RecordingSession object + camera_calibration_path = min_session_calibration_toml_path + + # Create temporary directory with the structured video files + temp_dir = tmpdir.mkdir("recording_session") + + # Copy and paste the calibration toml + shutil.copy(camera_calibration_path, temp_dir) + + # Create directories for each camera + calibration_data = toml.load(camera_calibration_path) + camera_dnames = [ + value["name"] for value in calibration_data.values() if "name" in value + ] + + for cam_name in camera_dnames: + cam_dir = Path(temp_dir, cam_name) + cam_dir.mkdir() + + # Copy and paste the videos in the directories (only min_session_[camera_name].mp4) + videos_dir = Path("tests/data/videos") + for file in videos_dir.iterdir(): + if file.suffix == ".mp4" and "min_session" in file.stem: + camera_fname = file.stem.split("_")[2] + if camera_fname in camera_dnames: + shutil.copy(file, Path(temp_dir, camera_fname)) + + return temp_dir diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 68c8fb578..0c86a838d 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -3,6 +3,7 @@ import time from pathlib import Path, PurePath from typing import List +import tempfile import numpy as np import pytest @@ -1102,4 +1103,561 @@ def test_TriangulateSession_do_action(multiview_min_session_frame_groups): else: assert np.allclose(inst_group_np, inst_group_np_post_tri, equal_nan=True) - # TODO(LM): Test with `PredictedInstance`s + # Test triangulate session with `triangulate_predictions` set to `True` and `False` + + # Looping through labeled frames and calling remove_instance on each instance to + # ensure only predicted instances in FrameGroup + for labeled_frame in frame_group.labeled_frames: + for instance in labeled_frame.user_instances: + labels.remove_instance(labeled_frame, instance) + for instance_group in frame_group.instance_groups: + for instance in instance_group.instances: + assert isinstance(instance, PredictedInstance) + frame_group_np = frame_group.numpy() + + # Test triangulate session with `triangulate_predictions` set to default + TriangulateSession.do_action(context, params) + assert np.allclose(frame_group_np, frame_group.numpy(), equal_nan=True) + + # Test triangulate session with `triangulate_predictions` set to `False` + params = { + "session": session, + "frame_idx": frame_idx, + "frame_group": frame_group, + "triangulate_predictions": False, + } + TriangulateSession.do_action(context, params) + assert np.allclose(frame_group_np, frame_group.numpy(), equal_nan=True) + + # Test triangulate session with `triangulate_predictions` set to `True` + params = { + "session": session, + "frame_idx": frame_idx, + "frame_group": frame_group, + "triangulate_predictions": True, + } + TriangulateSession.do_action(context, params) + assert not np.allclose(frame_group_np, frame_group.numpy(), equal_nan=True) + + +def test_triangulateSession_with_predictions(multiview_min_session_frame_groups): + """Test that `triangulateSession` command works with triangulate_predictions""" + + labels: Labels = multiview_min_session_frame_groups + session: RecordingSession = labels.sessions[0] + frame_idx: int = 0 + frame_group: FrameGroup = session.frame_groups[frame_idx] + + # Test triangulateSession command with triangulate_predictions set to True and False + + # Looping through labeled frames and calling remove_instance on each instance to + # ensure only predicted instances in FrameGroup + for labeled_frame in frame_group.labeled_frames: + for instance in labeled_frame.user_instances: + labels.remove_instance(labeled_frame, instance) + for instance_group in frame_group.instance_groups: + for instance in instance_group.instances: + assert isinstance(instance, PredictedInstance) + frame_group_np = frame_group.numpy() + + context = CommandContext.from_labels(labels) + + # Test triangulate session with triangulate_predictions set to default (False) + context.triangulateSession(session=session, frame_idx=frame_idx) + assert np.allclose(frame_group_np, frame_group.numpy(), equal_nan=True) + + # Test triangulate session with triangulate_predictions set to False + context.triangulateSession( + session=session, frame_idx=frame_idx, triangulate_predictions=False + ) + assert np.allclose(frame_group_np, frame_group.numpy(), equal_nan=True) + + # Test triangulate session with triangulate_predictions set to True + context.triangulateSession( + session=session, frame_idx=frame_idx, triangulate_predictions=True + ) + assert not np.allclose(frame_group_np, frame_group.numpy(), equal_nan=True) + + +def test_SetSelectedInstanceGroup(multiview_min_session_frame_groups: Labels): + """Test that setting a new instance group works.""" + + labels = multiview_min_session_frame_groups + session: RecordingSession = labels.sessions[0] + frame_idx = 0 + frame_group: FrameGroup = session.frame_groups[frame_idx] + labeled_frame: LabeledFrame = frame_group.labeled_frames[0] + video = labeled_frame.video + camera = session.get_camera(video=video) + + # We want to replace `instance_0` with `instance_1` in the `InstanceGroup` + instance_0 = labeled_frame.user_instances[0] + instance_group_0 = frame_group.get_instance_group(instance=instance_0) + instance_1 = labeled_frame.user_instances[1] + instance_group_1 = frame_group.get_instance_group(instance=instance_1) + + # Set-up CommandContext + context: CommandContext = CommandContext.from_labels(labels) + context.state["instance"] = instance_1 + context.state["video"] = video + + # No session + with pytest.raises(ValueError): + context.setInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup._instances_by_camcorder + assert instance_0 in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # No frame_idx + context.state["session"] = session + with pytest.raises(ValueError): + context.setInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup._instances_by_camcorder + assert instance_0 in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # With session and frame_idx + context.state["frame_idx"] = frame_idx + context.setInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + assert instance_group_0 in frame_group.instance_groups + assert instance_group_1 in frame_group.instance_groups + # Check FrameGroup._instances_by_camcorder + assert instance_0 not in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 5 + assert instance_0 not in instance_group_0.instances + assert instance_0 not in instance_group_1.instances + assert instance_1 in instance_group_0.instances + assert instance_1 not in instance_group_1.instances + # Check InstanceGroup._camcorder_by_instance + assert instance_0 not in instance_group_0._camcorder_by_instance + assert instance_0 not in instance_group_1._camcorder_by_instance + assert instance_1 in instance_group_0._camcorder_by_instance + assert instance_1 not in instance_group_1._camcorder_by_instance + # Check InstanceGroup._instance_by_camcorder + assert instance_0 not in instance_group_0._instance_by_camcorder.values() + assert instance_0 not in instance_group_1._instance_by_camcorder.values() + assert instance_1 in instance_group_0._instance_by_camcorder.values() + assert instance_1 not in instance_group_1._instance_by_camcorder.values() + + # Let's move the instance to the other `InstanceGroup` + context.setInstanceGroup(instance_group=instance_group_1) + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + assert instance_group_0 in frame_group.instance_groups + assert instance_group_1 in frame_group.instance_groups + # Check FrameGroup._instances_by_camcorder + assert instance_0 not in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 7 + assert len(instance_group_1.instances) == 6 + assert instance_0 not in instance_group_0.instances + assert instance_0 not in instance_group_1.instances + assert instance_1 not in instance_group_0.instances + assert instance_1 in instance_group_1.instances + # Check InstanceGroup._camcorder_by_instance + assert instance_0 not in instance_group_0._camcorder_by_instance + assert instance_0 not in instance_group_1._camcorder_by_instance + assert instance_1 not in instance_group_0._camcorder_by_instance + assert instance_1 in instance_group_1._camcorder_by_instance + # Check InstanceGroup._instance_by_camcorder + assert instance_0 not in instance_group_0._instance_by_camcorder.values() + assert instance_0 not in instance_group_1._instance_by_camcorder.values() + assert instance_1 not in instance_group_0._instance_by_camcorder.values() + assert instance_1 in instance_group_1._instance_by_camcorder.values() + + # Let's move the other instance back to its original `InstanceGroup` + context.state["instance"] = instance_0 + context.setInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + assert instance_group_0 in frame_group.instance_groups + assert instance_group_1 in frame_group.instance_groups + # Check FrameGroup._instances_by_camcorder + assert instance_0 in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + assert instance_0 in instance_group_0.instances + assert instance_0 not in instance_group_1.instances + assert instance_1 not in instance_group_0.instances + assert instance_1 in instance_group_1.instances + # Check InstanceGroup._camcorder_by_instance + assert instance_0 in instance_group_0._camcorder_by_instance + assert instance_0 not in instance_group_1._camcorder_by_instance + assert instance_1 not in instance_group_0._camcorder_by_instance + assert instance_1 in instance_group_1._camcorder_by_instance + # Check InstanceGroup._instance_by_camcorder + assert instance_0 in instance_group_0._instance_by_camcorder.values() + assert instance_0 not in instance_group_1._instance_by_camcorder.values() + assert instance_1 not in instance_group_0._instance_by_camcorder.values() + assert instance_1 in instance_group_1._instance_by_camcorder.values() + + # Let's remove all but one instance from an `InstanceGroup` + for instance in instance_group_0.instances: + if instance == instance_0: + continue + frame_group.remove_instance(instance=instance) + assert len(instance_group_0.instances) == 1 + context.setInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + assert instance_group_0 in frame_group.instance_groups + assert instance_group_1 in frame_group.instance_groups + # Check FrameGroup._instances_by_camcorder + assert instance_0 in frame_group._instances_by_cam[camera] + assert instance_1 in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 1 + assert len(instance_group_1.instances) == 6 + assert instance_0 in instance_group_0.instances + assert instance_0 not in instance_group_1.instances + assert instance_1 not in instance_group_0.instances + assert instance_1 in instance_group_1.instances + # Check InstanceGroup._camcorder_by_instance + assert instance_0 in instance_group_0._camcorder_by_instance + assert instance_0 not in instance_group_1._camcorder_by_instance + assert instance_1 not in instance_group_0._camcorder_by_instance + assert instance_1 in instance_group_1._camcorder_by_instance + # Check InstanceGroup._instance_by_camcorder + assert instance_0 in instance_group_0._instance_by_camcorder.values() + assert instance_0 not in instance_group_1._instance_by_camcorder.values() + assert instance_1 not in instance_group_0._instance_by_camcorder.values() + assert instance_1 in instance_group_1._instance_by_camcorder.values() + + # Let's switch the last instance to a different `InstanceGroup` + context.setInstanceGroup(instance_group=instance_group_1) + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 1 + assert instance_group_1 in frame_group.instance_groups + # Check FrameGroup._instances_by_camcorder + assert instance_0 in frame_group._instances_by_cam[camera] + assert instance_1 not in frame_group._instances_by_cam[camera] + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 0 + assert len(instance_group_1.instances) == 6 + assert instance_0 in instance_group_1.instances + assert instance_1 not in instance_group_1.instances + # Check InstanceGroup._camcorder_by_instance + assert instance_0 in instance_group_1._camcorder_by_instance + assert instance_1 not in instance_group_1._camcorder_by_instance + # Check InstanceGroup._instance_by_camcorder + assert instance_0 in instance_group_1._instance_by_camcorder.values() + assert instance_1 not in instance_group_1._instance_by_camcorder.values() + + +def test_AddInstanceGroup(multiview_min_session_frame_groups: Labels): + """Test that adding an instance group works.""" + + labels = multiview_min_session_frame_groups + session: RecordingSession = labels.sessions[0] + frame_idx = 1 + frame_group: FrameGroup = session.frame_groups[frame_idx] + instance_group_0: InstanceGroup = frame_group.instance_groups[0] + instance_group_1: InstanceGroup = frame_group.instance_groups[1] + labeled_frame: LabeledFrame = frame_group.labeled_frames[0] + video = labeled_frame.video + camera = session.get_camera(video=video) + + # Set-up CommandContext + context: CommandContext = CommandContext.from_labels(labels) + + # No session + with pytest.raises(ValueError): + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # No frame_idx + context.state["session"] = session + with pytest.raises(TypeError): + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # No instance + context.state["frame_idx"] = frame_idx + with pytest.raises(ValueError): + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + instance_group_2 = frame_group.instance_groups[-1] + assert len(frame_group.instance_groups) == 3 + assert instance_group_2 in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + assert len(instance_group_2.instances) == 0 + + # No video + context.state["instance"] = instance_group_0.get_instance(cam=camera) + with pytest.raises(ValueError): + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + instance_group_3 = frame_group.instance_groups[-1] + assert len(frame_group.instance_groups) == 4 + assert instance_group_3 in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + assert len(instance_group_2.instances) == 0 + assert len(instance_group_3.instances) == 0 + + # Everything, let's add an `InstanceGroup` and set the `Instance` to it + context.state["video"] = video + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + instance_group_4 = frame_group.instance_groups[-1] + assert len(frame_group.instance_groups) == 5 + assert instance_group_4 in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 7 + assert len(instance_group_1.instances) == 6 + assert len(instance_group_2.instances) == 0 + assert len(instance_group_3.instances) == 0 + assert len(instance_group_4.instances) == 1 + + # Everything, let's add an `InstanceGroup` and set the last `Instance` to it + context.state["video"] = video + context.addInstanceGroup() + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + instance_group_5 = frame_group.instance_groups[-1] + assert len(frame_group.instance_groups) == 5 + assert instance_group_4 not in frame_group.instance_groups + assert instance_group_5 in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 7 + assert len(instance_group_1.instances) == 6 + assert len(instance_group_2.instances) == 0 + assert len(instance_group_3.instances) == 0 + assert len(instance_group_4.instances) == 0 + assert len(instance_group_5.instances) == 1 + + +def test_DeleteInstanceGroup(multiview_min_session_frame_groups: Labels): + """Test that deleting an instance group works.""" + + labels = multiview_min_session_frame_groups + session: RecordingSession = labels.sessions[0] + frame_idx = 2 + frame_group: FrameGroup = session.frame_groups[frame_idx] + instance_group_0: InstanceGroup = frame_group.instance_groups[0] + instance_group_1: InstanceGroup = frame_group.instance_groups[1] + labeled_frame: LabeledFrame = frame_group.labeled_frames[0] + video = labeled_frame.video + camera = session.get_camera(video=video) + + # Set-up CommandContext + context: CommandContext = CommandContext.from_labels(labels) + + # No session + with pytest.raises(ValueError): + context.deleteInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # No frame_idx + context.state["session"] = session + with pytest.raises(ValueError): + context.deleteInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 2 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 2 + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # Everything, let's delete an `InstanceGroup` + context.state["frame_idx"] = frame_idx + context.deleteInstanceGroup(instance_group=instance_group_0) + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 1 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 1 + assert instance_group_0 not in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + # Everything, let's delete the last `InstanceGroup` + context.state["frame_idx"] = frame_idx + context.deleteInstanceGroup(instance_group=instance_group_1) + # Check FrameGroup._instances_by_camcorder + assert len(frame_group._instances_by_cam[camera]) == 0 + # Check FrameGroup.instance_groups + assert len(frame_group.instance_groups) == 0 + assert instance_group_1 not in frame_group.instance_groups + # Check InstanceGroup.instances + assert len(instance_group_0.instances) == 8 + assert len(instance_group_1.instances) == 6 + + +def test_automatic_addition_and_linkage_videos(min_session_directory): + """Test if the automatic addition of videos works.""" + # Create a new RecordingSession object + session_dir = Path(min_session_directory) + session_dir_video_paths = [ + video_path.as_posix() for video_path in session_dir.rglob("*.mp4") + ] + calibration_path = Path(session_dir, "calibration.toml") + + # Test find_video_paths + camera_by_video_paths = AddSession.find_video_paths( + camera_calibration=calibration_path + ) + assert len(camera_by_video_paths) == 8 + assert all([p in session_dir_video_paths for p in camera_by_video_paths]) + + # Create a new Label() object + labels = Labels() + context = CommandContext.from_labels(labels) + + # Case 1: No videos imported + params = {"camera_calibration": calibration_path} + AddSession.do_action(context, params) + + # Check if the session was added to the Label object + assert len(labels.sessions) == 1 + assert isinstance(context.state["session"], RecordingSession) + + # Check that no videos were added + assert len(labels.videos) == 0 + + # Case 2: Videos imported + template_import_params = { + "filename": "path/to/video.mp4", + "grayscale": True, + } + template_import_item = { + "params": template_import_params, + "video_type": "mp4", + "video_class": Video.from_media, + } + import_list = [] + cam_names_to_exclude = ["topL", "sideL"] + for video_path, cam_name in camera_by_video_paths.items(): + + # Only link videos for certain cameras + if cam_name in cam_names_to_exclude: + continue + + import_params = dict(template_import_params) + import_params["filename"] = video_path + template_import_item["params"] = import_params + import_list.append(dict(template_import_item)) + + params = { + "camera_calibration": calibration_path, + "import_list": import_list, + "camera_by_video_paths": camera_by_video_paths, + } + AddSession.do_action(context, params) + + # Check if the session was added to the Label object + assert len(labels.sessions) == 2 + assert isinstance(context.state["session"], RecordingSession) + session: RecordingSession = labels.sessions[-1] + + # Check that videos were added + assert len(labels.videos) == len(import_list) + + # Check that videos were linked + assert len(session.videos) == len(import_list) + assert len(session.videos) == len(session.cameras) - len(cam_names_to_exclude) + for cam in session.cameras: + if cam.name in cam_names_to_exclude: + assert session.get_video(camcorder=cam) is None + else: + video = session.get_video(camcorder=cam) + assert video in session.videos + assert session.get_camera(video=video) is cam + + +def test_link_video_to_session(min_session_session, centered_pair_vid): + """Test if the linkage of videos to a session works.""" + + # Create a new Label() object + session: RecordingSession = min_session_session + video: Video = centered_pair_vid + labels = Labels() + labels.add_session(session) + labels.add_video(video) + + # Create command context + context = CommandContext.from_labels(labels) + + # Call the function without a camera selected + with pytest.raises(ValueError): + context.linkVideoToSession() + + # Call the function without a recording session selected + camera = session.cameras[0] + with pytest.raises(ValueError): + context.linkVideoToSession(camera=camera) + context.state["selected_camera"] = camera + with pytest.raises(ValueError): + context.linkVideoToSession() + + # Call the function without a video selected + with pytest.raises(ValueError): + context.linkVideoToSession(session=session) + context.state["selected_session"] = session + with pytest.raises(ValueError): + context.linkVideoToSession() + + # Call the function with all parameters + context.linkVideoToSession(video=video) + assert video in session.videos + assert camera is session.get_camera(video=video) + assert video is session.get_video(camcorder=camera) + + +def test_setInstanceGroupName(multiview_min_session_frame_groups): + labels: Labels = multiview_min_session_frame_groups + session: RecordingSession = labels.sessions[0] + + # Set-up CommandContext + context: CommandContext = CommandContext.from_labels(labels) + context.state["session"] = session + + # Start test + frame_group = session.frame_groups[0] + instance_group: InstanceGroup = frame_group.instance_groups[0] + new_name = "New Name" + context.setInstanceGroupName(instance_group=instance_group, name=new_name) + assert instance_group.name == new_name diff --git a/tests/gui/test_dataviews.py b/tests/gui/test_dataviews.py index 9c62daf88..593bb2758 100644 --- a/tests/gui/test_dataviews.py +++ b/tests/gui/test_dataviews.py @@ -1,11 +1,7 @@ -import pytest -import pytestqt - from sleap.gui.dataviews import * def test_skeleton_nodes(qtbot, centered_pair_predictions): - table = GenericTableView( model=SkeletonNodesTableModel(items=centered_pair_predictions.skeletons[0]) ) @@ -74,6 +70,33 @@ def test_table_sort(qtbot, centered_pair_predictions): assert table.getSelectedRowItem().score == inst.score +def test_sessions_table(qtbot, min_session_session, hdf5_vid): + sessions = [] + sessions.append(min_session_session) + table = GenericTableView( + row_name="session", + is_sortable=True, + name_prefix="", + model=SessionsTableModel(items=sessions), + ) + table.selectRow(0) + assert len(table.getSelectedRowItem().videos) == 0 + assert len(table.getSelectedRowItem().camera_cluster.cameras) == 8 + assert len(table.getSelectedRowItem().camera_cluster.sessions) == 1 + + video = hdf5_vid + min_session_session.add_video( + video, + table.getSelectedRowItem().camera_cluster.cameras[0], + ) + + # Verify that modification of the recording session is reflected in the recording session stored in the table + assert len(table.getSelectedRowItem().videos) == 1 + + min_session_session.remove_video(video) + assert len(table.getSelectedRowItem().videos) == 0 + + def test_table_sort_string(qtbot): table_model = GenericTableModel( items=[dict(a=1, b=2), dict(a=2, b="")], properties=["a", "b"] @@ -84,3 +107,57 @@ def test_table_sort_string(qtbot): # Make sure we can sort with both numbers and strings (i.e., "") table.model().sort(0) table.model().sort(1) + + +def test_camera_table(qtbot, multiview_min_session_labels): + + session = multiview_min_session_labels.sessions[0] + camcorders = session.camera_cluster.cameras + + table_model = CamerasTableModel(items=session) + num_rows = table_model.rowCount() + + assert table_model.columnCount() == 2 + assert num_rows == len(camcorders) + + table = GenericTableView( + row_name="camera", + model=table_model, + ) + + # Test if all comcorders are presented in the correct row + for i in range(num_rows): + table.selectRow(i) + + # Check first column + assert table.getSelectedRowItem() == camcorders[i] + assert table.model().data(table.currentIndex()) == camcorders[i].name + + # Check second column + index = table.model().index(i, 1) + linked_video_filename = camcorders[i].get_video(session).filename + assert table.model().data(index) == linked_video_filename + + # Test if a comcorder change is reflected + idxs_to_remove = [1, 2, 7] + for idx in idxs_to_remove: + multiview_min_session_labels.sessions[0].remove_video( + camcorders[idx].get_video(multiview_min_session_labels.sessions[0]) + ) + table.model().items = session + + for i in range(num_rows): + table.selectRow(i) + + # Check first column + assert table.getSelectedRowItem() == camcorders[i] + assert table.model().data(table.currentIndex()) == camcorders[i].name + + # Check second column + index = table.model().index(i, 1) + linked_video = camcorders[i].get_video(session) + if i in idxs_to_remove: + assert table.model().data(index) == "" + else: + linked_video_filename = linked_video.filename + assert table.model().data(index) == linked_video_filename diff --git a/tests/gui/widgets/test_docks.py b/tests/gui/widgets/test_docks.py index 69fe56a56..0caa0278c 100644 --- a/tests/gui/widgets/test_docks.py +++ b/tests/gui/widgets/test_docks.py @@ -1,7 +1,9 @@ """Module for testing dock widgets for the `MainWindow`.""" from pathlib import Path + import pytest + from sleap import Labels, Video from sleap.gui.app import MainWindow from sleap.gui.commands import OpenSkeleton @@ -10,6 +12,7 @@ SuggestionsDock, VideosDock, SkeletonDock, + SessionsDock, ) @@ -107,3 +110,145 @@ def test_instances_dock(qtbot): assert dock.name == "Instances" assert dock.main_window is main_window assert dock.wgt_layout is dock.widget().layout() + + +def test_sessions_dock(qtbot): + """Test the `SessionsDock` class.""" + main_window = MainWindow() + dock = SessionsDock(main_window) + + assert dock.name == "Sessions" + assert dock.main_window is main_window + assert dock.wgt_layout is dock.widget().layout() + + +def test_sessions_dock_cameras_table(qtbot, multiview_min_session_labels): + labels = multiview_min_session_labels + session = labels.sessions[0] + camcorders = session.camera_cluster.cameras + main_window = MainWindow(labels=labels) + assert main_window.state["session"] == session + + dock = main_window.sessions_dock + table = dock.camera_table + + # Testing if cameras_table is loaded correctly + + # Test if all comcorders are presented in the correct row + for i, cam in enumerate(camcorders): + main_window.state["selected_session"] = session + table.selectRow(i) + + # Check first column + assert table.getSelectedRowItem() == cam + assert table.model().data(table.currentIndex()) == cam.name + + # Check second column + index = table.model().index(i, 1) + linked_video_filename = cam.get_video(session).filename + assert table.model().data(index) == linked_video_filename + + # Test if a comcorder change is reflected + idxs_to_remove = [1, 2, 7] + for idx in idxs_to_remove: + main_window.state["selected_camera"] = camcorders[idx] + main_window._buttons["unlink video"].click() + + for i, cam in enumerate(camcorders): + table.selectRow(i) + + # Check first column + assert table.getSelectedRowItem() == camcorders[i] + assert table.model().data(table.currentIndex()) == camcorders[i].name + + # Check second column + index = table.model().index(i, 1) + linked_video = camcorders[i].get_video(session) + if i in idxs_to_remove: + assert table.model().data(index) == "" + else: + linked_video_filename = linked_video.filename + assert table.model().data(index) == linked_video_filename + + +def test_sessions_dock_session_table(qtbot, multiview_min_session_labels): + """Test the SessionsDock.sessions_table.""" + + # Create dock + labels = multiview_min_session_labels + main_window = MainWindow(labels=labels) + + # Testing if sessions table is loaded correctly + sessions = multiview_min_session_labels.sessions + main_window.sessions_dock.sessions_table.selectRow(0) + assert main_window.sessions_dock.sessions_table.getSelectedRowItem() == sessions[0] + + # Testing if removal of selected session is reflected in sessions dock + main_window.state["selected_session"] = sessions[0] + main_window._buttons["remove session"].click() + + with pytest.raises(IndexError): + # There are no longer any sessions in the table + main_window.sessions_dock.sessions_table.selectRow(0) + + +def test_sessions_dock_unlinked_videos_table(qtbot, multiview_min_session_labels): + """Test the SessionsDock.unlinked_videos_table.""" + # Create dock + label = multiview_min_session_labels + main_window = MainWindow(labels=label) + dock = main_window.sessions_dock + assert main_window.state["session"] == label.sessions[0] + label_cache = label._cache + + # Selected Session + main_window.state["selected_session"] = label.sessions[0] + + # Testing if the unlinked videos table and its cache are loaded correctly + assert dock.unlinked_videos_table.model().rowCount() == 0 + assert label_cache._linkage_of_videos["unlinked"] == [] + assert label_cache._linkage_of_videos["linked"] == label.videos + + # Testing if the unlinked videos table and its cache are updated correctly + main_window.state["selected_camera"] = label.sessions[0].camera_cluster.cameras[0] + camera = main_window.state["selected_camera"] + video = camera.get_video(label.sessions[0]) + main_window._buttons["unlink video"].click() + + # Check unlinked videos tables + assert dock.unlinked_videos_table.model().rowCount() == 1 + + # Check cache + assert len(label_cache._linkage_of_videos["unlinked"]) == 1 + assert camera.get_video(label.sessions[0]) is None + assert video in label_cache._linkage_of_videos["unlinked"] + + # Test if the "Link" button functions correctly + main_window.state["selected_camera"] = label.sessions[0].camera_cluster.cameras[0] + main_window.state["selected_unlinked_video"] = video + main_window._buttons["link video"].click() + + # Check unlinked videos tables + assert dock.unlinked_videos_table.model().rowCount() == 0 + + # Check cache + assert len(label_cache._linkage_of_videos["unlinked"]) == 0 + assert video not in label_cache._linkage_of_videos["unlinked"] + + # Test multiple unlinked videos + indxs = [1, 3, 5] + original_length = len(label_cache._linkage_of_videos["linked"]) + for indx in indxs: + main_window.state["selected_camera"] = label.sessions[0].camera_cluster.cameras[ + indx + ] + camera = main_window.state["selected_camera"] + video = camera.get_video(label.sessions[0]) + main_window._buttons["unlink video"].click() + + # Check unlinked videos tables + assert dock.unlinked_videos_table.model().rowCount() == len(indxs) + + # Check cache + assert len(label_cache._linkage_of_videos["unlinked"]) == len(indxs) + assert len(label_cache._linkage_of_videos["linked"]) == original_length - len(indxs) diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index f5437389a..87b109134 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -4,6 +4,7 @@ import numpy as np import pytest +import toml from sleap.io.cameras import ( Camcorder, @@ -178,9 +179,21 @@ def test_recording_session( assert frame_group.frame_idx == 0 assert frame_group == session.frame_groups[0] - # Test add_video - camcorder = session.camera_cluster.cameras[0] + # Test add_video (and _projection_bounds) + cam_idx = 0 + camcorder = session.camera_cluster.cameras[cam_idx] + prev_cams_to_include = session.cams_to_include + n_prev_cams_to_include = len(prev_cams_to_include) + assert session.projection_bounds.shape == (n_prev_cams_to_include, 2) + assert np.all(np.isnan(session._projection_bounds[cam_idx])) session.add_video(centered_pair_vid, camcorder) + n_cams_to_include = len(session.cams_to_include) + assert n_cams_to_include == n_prev_cams_to_include + 1 + assert session.projection_bounds.shape == (n_cams_to_include, 2) + assert np.all( + session._projection_bounds[cam_idx] + == [centered_pair_vid.width, centered_pair_vid.height] + ) assert centered_pair_vid is session.camera_cluster._videos_by_session[session][0] assert centered_pair_vid is camcorder._video_by_session[session] assert session is session.camera_cluster._session_by_video[centered_pair_vid] @@ -246,8 +259,18 @@ def compare_sessions(session_1: RecordingSession, session_2: RecordingSession): ) compare_sessions(session, session_2) - # Test remove_video + # Test remove_video (and _projection_bounds) + assert session.projection_bounds.shape == (len(session.cams_to_include), 2) + cam_idx = session.camera_cluster.cameras.index( + session.get_camera(centered_pair_vid) + ) + assert np.all( + session._projection_bounds[cam_idx] + == [centered_pair_vid.width, centered_pair_vid.height] + ) session.remove_video(centered_pair_vid) + assert session.projection_bounds.shape == (len(session.cams_to_include), 2) + assert np.all(np.isnan(session._projection_bounds[cam_idx])) assert centered_pair_vid not in session.videos assert camcorder not in session.linked_cameras assert camcorder in session.unlinked_cameras @@ -263,6 +286,45 @@ def compare_sessions(session_1: RecordingSession, session_2: RecordingSession): # Test __getitem__ with `Camcorder` key assert session[camcorder] is None + # Test _projection_bounds + _projection_bounds = session._projection_bounds + n_cameras, n_coords = _projection_bounds.shape + assert n_cameras == len(session.camera_cluster.cameras) + assert n_coords == 2 + n_linked_cameras = len(session.linked_cameras) + assert n_linked_cameras < n_cameras + assert _projection_bounds[np.isnan(_projection_bounds)].shape == ( + n_coords * (n_cameras - n_linked_cameras), + ) + for cam_idx, cam in enumerate(session.camera_cluster.cameras): + if cam in session.linked_cameras: + linked_video = session.get_video(cam) + assert _projection_bounds[cam_idx, 0] == linked_video.width + assert _projection_bounds[cam_idx, 1] == linked_video.height + else: + assert np.all(np.isnan(_projection_bounds[cam_idx]) == True) + + # Test projection_bounds property + session.cams_to_include = session.camera_cluster.cameras[:6] + projection_bounds = session.projection_bounds + n_cams_to_include, n_coords = projection_bounds.shape + assert n_cams_to_include == len( + set(session.linked_cameras) & set(session._cams_to_include) + ) + assert n_coords == 2 + assert projection_bounds[np.isnan(projection_bounds)].shape == ( + n_coords * (n_cams_to_include - n_linked_cameras), + ) + pb_idx = 0 + for cam_idx, cam in enumerate(session.camera_cluster.cameras): + if cam in session.linked_cameras: + linked_video = session.get_video(cam) + assert projection_bounds[pb_idx, 0] == linked_video.width + assert projection_bounds[pb_idx, 1] == linked_video.height + pb_idx += 1 + else: + assert np.all(np.isnan(session._projection_bounds[cam_idx]) == True) + # Test make_cattr labeled_frame_to_idx = {lf: idx for idx, lf in enumerate(labels.labeled_frames)} sessions_cattr = RecordingSession.make_cattr( @@ -275,6 +337,11 @@ def compare_sessions(session_1: RecordingSession, session_2: RecordingSession): session_3 = sessions_cattr.structure(session_dict_2, RecordingSession) compare_sessions(session_2, session_3) + # Test id + assert session.id == hash(session) + labels.add_session(session) + assert session.id == labels.sessions.index(session) + def test_recording_session_get_videos_from_selected_cameras( multiview_min_session_labels: Labels, @@ -399,6 +466,7 @@ def create_instance_group( frame_idx: int, add_dummy: bool = False, name: Optional[str] = None, + score: Optional[float] = None, ) -> Union[ InstanceGroup, Tuple[InstanceGroup, Dict[Camcorder, Instance], Instance, Camcorder] ]: @@ -444,6 +512,7 @@ def create_instance_group( instance_by_camcorder=instance_by_camera, name="test_instance_group", name_registry={}, + score=score, ) return ( (instance_group, instance_by_camera, dummy_instance, cam) @@ -534,6 +603,20 @@ def test_instance_group( assert isinstance(instance_group_dict, dict) assert instance_group_dict["name"] == instance_group.name assert "camcorder_to_lf_and_inst_idx_map" in instance_group_dict + assert "score" not in instance_group_dict + + # Test `score` property (and `to_dict`) + assert instance_group.score is None + instance_group.score = 0.5 + for instance in instance_group.instances: + assert instance.score == 0.5 + instance_group.score = 0.75 + for instance in instance_group.instances: + assert instance.score == 0.75 + instance_group_dict = instance_group.to_dict( + instance_to_lf_and_inst_idx=instance_to_lf_and_inst_idx + ) + assert instance_group_dict["score"] == str(0.75) # Test `from_dict` instance_group_2 = InstanceGroup.from_dict( @@ -585,6 +668,20 @@ def test_instance_group( name="test_instance_group", name_registry={}, ) + instance_by_camera = { + cam: instance_group.get_instance(cam) for cam in instance_group.cameras + } + instance_group_from_dict = InstanceGroup.from_instance_by_camcorder_dict( + instance_by_camcorder=instance_by_camera, + name="test_instance_group", + name_registry={}, + score=0.5, + ) + assert instance_group_from_dict.score == 0.5 + # The score of instances will NOT be updated on initialization. + for instance in instance_group_from_dict.instances: + if isinstance(instance, PredictedInstance): + assert instance.score != instance_group_from_dict.score # Test `__repr__` print(instance_group) @@ -596,32 +693,27 @@ def test_instance_group( frame_group = session.frame_groups[frame_idx] instance_group = frame_group.instance_groups[0] - # Test `numpy` method - instance_group_numpy = instance_group.numpy() - n_views, n_nodes, n_coords = instance_group_numpy.shape - assert n_views == len(instance_group.camera_cluster.cameras) - assert n_nodes == len(instance_group.dummy_instance.skeleton.nodes) - assert n_coords == 2 - # Different instance groups should have different coordinates - for inst_idx, _ in enumerate(instance_group.instances[:-1]): - assert not np.allclose( - instance_group_numpy[:, inst_idx], - instance_group_numpy[:, inst_idx + 1], - equal_nan=True, - ) - # Different views should have different coordinates - for view_idx, _ in enumerate(instance_group.camera_cluster.cameras[:-1]): - assert not np.allclose( - instance_group_numpy[view_idx], - instance_group_numpy[view_idx + 1], - equal_nan=True, - ) - # Test `update_points` method + n_views = len(session.cams_to_include) + n_nodes = len(instance_group.dummy_instance.skeleton.nodes) + n_coords = 2 assert not np.all(instance_group.numpy(invisible_as_nan=False) == 72317) - instance_group.update_points(np.full((n_views, n_nodes, n_coords), 72317)) + # Remove some Instances to "expose" underlying PredictedInstances + for inst in instance_group.instances[:2]: + lf = inst.frame + labels.remove_instance(lf, inst) + instance_group.update_points(points=np.full((n_views, n_nodes, n_coords), 72317)) + for inst in instance_group.instances: + if isinstance(inst, PredictedInstance): + assert inst.score == instance_group.score + prev_score = instance_group.score + instance_group.update_points(points=np.full((n_views, n_nodes, n_coords), 72317)) + for inst in instance_group.instances: + if isinstance(inst, PredictedInstance): + assert inst.score == instance_group.score instance_group_numpy = instance_group.numpy(invisible_as_nan=False) assert np.all(instance_group_numpy == 72317) + assert instance_group.score == 1.0 # Score should be 1.0 because same points # Test `add_instance`, `replace_instance`, and `remove_instance` cam = instance_group.cameras[0] @@ -642,6 +734,30 @@ def test_instance_group( assert cam not in instance_group.cameras +def test_instance_group_numpy(multiview_min_session_frame_groups: Labels): + """Test `InstanceGroup.numpy` method.""" + labels = multiview_min_session_frame_groups + session = labels.sessions[0] + frame_group = session.frame_groups[0] + instance_group = frame_group.instance_groups[0] + + instance_group_numpy = instance_group.numpy() + n_views, n_nodes, n_coords = instance_group_numpy.shape + assert n_views == len(instance_group.camera_cluster.cameras) + assert n_nodes == len(instance_group.dummy_instance.skeleton.nodes) + assert n_coords == 2 + + # Test for undisorted points + instance_group_numpy_0 = instance_group.numpy(undistort=False) + instance_group_numpy_undistorted = instance_group.numpy(undistort=True) + assert np.allclose( + instance_group_numpy_0, instance_group_numpy, atol=1e-3, equal_nan=True + ) + assert not np.allclose( + instance_group_numpy, instance_group_numpy_undistorted, equal_nan=True + ) + + def test_frame_group( multiview_min_session_labels: Labels, multiview_min_session_frame_groups: Labels ): @@ -747,18 +863,6 @@ def test_frame_group( with pytest.raises(ValueError): frame_group.cams_to_include = session.linked_cameras - # Test `numpy` method - frame_group_np = frame_group.numpy() - n_views, n_inst_groups, n_nodes, n_coords = frame_group_np.shape - assert n_views == len(frame_group.cams_to_include) - assert n_inst_groups == len(frame_group.instance_groups) - assert n_nodes == len(labels.skeleton.nodes) - assert n_coords == 2 - # Different instance groups should have different coordinates - assert not np.allclose(frame_group_np[:, 0], frame_group_np[:, 1], equal_nan=True) - # Different views should have different coordinates - assert not np.allclose(frame_group_np[0], frame_group_np[1], equal_nan=True) - # Test `get_instance_group` instance_group = frame_group.instance_groups[0] camera = session.cameras[0] @@ -853,3 +957,151 @@ def test_frame_group( assert camera in frame_group.cameras assert labeled_frame_created in frame_group.labeled_frames assert labeled_frame in frame_group.session.labels.labeled_frames + + # Test `upsert_points` (all in bounds, all updated) + n_cameras = len(frame_group.cams_to_include) + n_instance_groups = len(frame_group.instance_groups) + n_nodes = len(frame_group.session.labels.skeleton.nodes) + n_coords = 2 + value = 100 + points = np.full((n_cameras, n_instance_groups, n_nodes, n_coords), value) + frame_group.upsert_points( + points=points, instance_groups=frame_group.instance_groups + ) + assert np.all(frame_group.numpy(invisible_as_nan=False) == value) + + # Test `upsert_points` (all out of bound, none updated) + projection_bounds = frame_group.session.projection_bounds + min_bound = projection_bounds.min() + prev_value = value + oob_value = 5000 + assert oob_value > min_bound + points = np.full((n_cameras, n_instance_groups, n_nodes, n_coords), oob_value) + frame_group.upsert_points( + points=points, instance_groups=frame_group.instance_groups + ) + assert np.any(frame_group.numpy(invisible_as_nan=False) == oob_value) == False + assert np.all(frame_group.numpy(invisible_as_nan=False) == prev_value) + + # Test `upsert_points` (some out of bound, some updated) + value = 200 + oob_value = 5000 + assert oob_value > min_bound + oob_mask = np.random.choice( + [True, False], size=(n_cameras, n_instance_groups, n_nodes, n_coords) + ) + points = np.full((n_cameras, n_instance_groups, n_nodes, n_coords), value) + points[oob_mask] = oob_value + frame_group.upsert_points( + points=points, instance_groups=frame_group.instance_groups + ) + # Get the logical or for either x or y being out of bounds + oob_mask_1d = np.any(oob_mask, axis=-1) # Collapse last axis + oob_mask_1d_expanded = np.expand_dims(oob_mask_1d, axis=-1) + oob_mask_1d_expanded = np.broadcast_to(oob_mask_1d_expanded, oob_mask.shape) + frame_group_numpy = frame_group.numpy(invisible_as_nan=False) + assert np.any(frame_group_numpy > min_bound) == False + assert np.all(frame_group_numpy[oob_mask_1d_expanded] == prev_value) # Not updated + assert np.all(frame_group_numpy[~oob_mask_1d_expanded] == value) # Updated + + # Test `upsert_points` (between x,y bounds, some out of bound, some updated) + value = 300 + points = np.full((n_cameras, n_instance_groups, n_nodes, n_coords), value) + # Reset the points to all in bounds + frame_group.upsert_points( + points=points, instance_groups=frame_group.instance_groups + ) + assert np.all(frame_group.numpy(invisible_as_nan=False) == value) + # Add some out of bounds points + prev_value = value + value = 400 + points = np.full((n_cameras, n_instance_groups, n_nodes, n_coords), value) + max_bound = projection_bounds.max() + oob_value = max_bound - 1 + assert oob_value < max_bound and oob_value > min_bound + oob_mask = np.random.choice( + [True, False], size=(n_cameras, n_instance_groups, n_nodes, n_coords) + ) + points[oob_mask] = oob_value + frame_group.upsert_points( + points=points, instance_groups=frame_group.instance_groups + ) + # Get the logical or for either x or y being out of bounds + bound_x, bound_y = projection_bounds[:, 0].min(), projection_bounds[:, 1].min() + oob_mask_x = np.where(points[:, :, :, 0] > bound_x, True, False) + oob_mask_y = np.where(points[:, :, :, 1] > bound_y, True, False) + oob_mask_1d = np.logical_or(oob_mask_x, oob_mask_y) + oob_mask_1d_expanded = np.expand_dims(oob_mask_1d, axis=-1) + oob_mask_1d_expanded = np.broadcast_to(oob_mask_1d_expanded, oob_mask.shape) + frame_group_numpy = frame_group.numpy(invisible_as_nan=False) + assert np.any(frame_group_numpy[:, :, :, 0] > bound_x) == False + assert np.any(frame_group_numpy[:, :, :, 1] > bound_y) == False + assert np.all(frame_group_numpy[oob_mask_1d_expanded] == prev_value) # Not updated + oob_value_mask = np.logical_and(~oob_mask_1d_expanded, oob_mask) + value_mask = np.logical_and(~oob_mask_1d_expanded, ~oob_mask) + assert np.all(frame_group_numpy[oob_value_mask] == oob_value) # Updated to oob + assert np.all(frame_group_numpy[value_mask] == value) # Updated to value + + +def test_frame_group_numpy(multiview_min_session_frame_groups: Labels): + """Test `FrameGroup.numpy` method.""" + labels = multiview_min_session_frame_groups + session = labels.sessions[0] + frame_group = session.frame_groups[0] + + # Test `numpy` method + frame_group_np = frame_group.numpy() + n_views, n_inst_groups, n_nodes, n_coords = frame_group_np.shape + assert n_views == len(frame_group.cams_to_include) + assert n_inst_groups == len(frame_group.instance_groups) + assert n_nodes == len(labels.skeleton.nodes) + assert n_coords == 2 + + # Undistored points should be different from distorted + frame_group_np_0 = frame_group.numpy(undistort=False) + frame_group_np_undistorted = frame_group.numpy(undistort=True) + assert np.allclose(frame_group_np_0, frame_group_np, atol=1e-3, equal_nan=True) + assert not np.allclose(frame_group_np, frame_group_np_undistorted, equal_nan=True) + + +def test_cameras_are_not_sorted(): + """Test that cameras are not sorted in `RecordingSession`. + + Sorting will invalidate the correspondence between camera index and video index when + re-opening project. + + [cam_0] + name = "back" + size = [ 1280, 1024,] + matrix = [ [ 762.513822135494, 0.0, 639.5,], [ 0.0, 762.513822135494, 511.5,], [ 0.0, 0.0, 1.0,],] + distortions = [ -0.2868458380166852, 0.0, 0.0, 0.0, 0.0,] + rotation = [ 0.3571857188780474, 0.8879473292757126, 1.6832001677006176,] + translation = [ -555.4577842902744, -294.43494957092884, -190.82196458369515,] + """ + + # Make a calibration file with more than 10 cameras + num_cameras = 20 + calibration_dict = {} + for camera_idx in range(num_cameras): + cam_name = f"cam_{camera_idx}" + calibration_dict[cam_name] = { + "name": cam_name, + "size": (1024, 1024), + "matrix": np.eye(3).tolist(), + "distortions": [0, 0, 0, 0, 0], + "rotation": [0, 0, 0], + "translation": [10 * camera_idx, 0, 0], + } + + # Save the dict to a toml file + calibration_file = "calibration.toml" + with open(calibration_file, "w") as f: + toml.dump(calibration_dict, f) + + # Load the calibration file + camera_cluster = CameraCluster.load(calibration_file) + assert len(camera_cluster.cameras) == num_cameras + + # Ensure that cameras are still in correct order + for camera_idx, camera in enumerate(camera_cluster.cameras): + assert camera.name == f"cam_{camera_idx}" diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index a544b7703..766b56ab5 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -1633,3 +1633,43 @@ def test_export_nwb(centered_pair_predictions: Labels, tmpdir): # Read from NWB file read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename)) assert_read_labels_match(centered_pair_predictions, read_labels) + + +def test_remove_instance_removes_from_predicted_reference(): + + # Create skeleton + skeleton = Skeleton() + skeleton.add_node("nodeA") + + # Dummy video + video = Video(backend=None) + + # Create predicted instance + predicted_instance = PredictedInstance( + skeleton=skeleton, points={"nodeA": Point(1, 1)}, score=0.95 + ) + user_instance = Instance(skeleton=skeleton, points={"nodeA": Point(2, 2)}) + user_instance.from_predicted = predicted_instance + + # Create labeled_frame and add to labels + labeled_frame = LabeledFrame( + video=video, frame_idx=0, instances=[user_instance, predicted_instance] + ) + labels = Labels(labeled_frames=[labeled_frame]) + + # Ensure the `from_predicted` reference exists + assert user_instance.from_predicted is predicted_instance + + # Remove predicted instance + labels.remove_instance(labeled_frame, predicted_instance) + + # Ensure the predicted instance is removed from the frame + assert predicted_instance not in labeled_frame.instances + assert len(labeled_frame.instances) == 1 + + # Ensure the `from_predicted` reference is removed + assert user_instance.from_predicted is None + + # Ensure no lingering references to the predicted instance + assert predicted_instance not in labels.all_instances + assert all(predicted_instance not in lf.instances for lf in labels.labeled_frames) diff --git a/tests/nn/test_evals.py b/tests/nn/test_evals.py index 265994056..a60398623 100644 --- a/tests/nn/test_evals.py +++ b/tests/nn/test_evals.py @@ -13,7 +13,6 @@ from sleap.nn.evals import ( compute_dists, compute_dist_metrics, - compute_oks, load_metrics, evaluate_model, ) @@ -23,48 +22,6 @@ sleap.use_cpu_only() -def test_compute_oks(): - # Test compute_oks function with the cocoutils implementation - inst_gt = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr) - np.testing.assert_allclose(oks, 1) - - inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr) - np.testing.assert_allclose(oks, 2 / 3) - - inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr) - np.testing.assert_allclose(oks, 1) - - inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr) - np.testing.assert_allclose(oks, 1) - - # Test compute_oks function with the implementation from the paper - inst_gt = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr, False) - np.testing.assert_allclose(oks, 1) - - inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr, False) - np.testing.assert_allclose(oks, 2 / 3) - - inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr, False) - np.testing.assert_allclose(oks, 1) - - inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") - oks = compute_oks(inst_gt, inst_pr, False) - np.testing.assert_allclose(oks, 1) - - def test_compute_dists(instances, predicted_instances): # Make some changes to the instances error_start = 10 diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index dedf0d324..078ae1409 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -1408,12 +1408,19 @@ def load_instance(labels_in: Labels): assert new_inst.track != old_inst.track -@pytest.mark.parametrize("cmd", ["--max_instances 1", "-n 1"]) -def test_valid_cli_command(cmd): +def test_valid_cli_command(): """Test that sleap-track CLI command is valid.""" parser = _make_cli_parser() + + for cmd in ["--max_instances 1", "-n 1"]: + args = parser.parse_args(cmd.split()) + assert args.max_instances == 1 + + cmd = "--tracking.pre_cull_general_iou_threshold 0.5" args = parser.parse_args(cmd.split()) - assert args.max_instances == 1 + assert getattr(args, "tracking.pre_cull_general_iou_threshold") == 0.5 + tracker = _make_tracker_from_cli(args) + assert tracker.pre_cull_function is not None def test_make_predictor_from_cli( diff --git a/tests/nn/test_tracker_components.py b/tests/nn/test_tracker_components.py index f861241ee..824e02f3c 100644 --- a/tests/nn/test_tracker_components.py +++ b/tests/nn/test_tracker_components.py @@ -1,16 +1,17 @@ -import pytest import numpy as np +import pytest -from sleap.nn.tracking import Tracker +from sleap.instance import LabeledFrame, PredictedInstance +from sleap.io.dataset import Labels from sleap.nn.tracker.components import ( - nms_instances, - nms_fast, - cull_instances, FrameMatches, + cull_instances, + cull_frame_instances, greedy_matching, + nms_fast, + nms_instances, ) - -from sleap.instance import PredictedInstance +from sleap.nn.tracking import Tracker from sleap.skeleton import Skeleton @@ -22,7 +23,7 @@ @pytest.mark.parametrize("count", [0, 2]) def test_tracker_by_name(tracker, similarity, match, count): t = Tracker.make_tracker_by_name( - "flow", "instance", "greedy", clean_instance_count=2 + tracker=tracker, similarity=similarity, match=match, clean_instance_count=count ) t.track([]) t.final_pass([]) @@ -42,6 +43,144 @@ def test_cull_instances(centered_pair_predictions): assert len(frame.instances) == 1 +def test_cull_frame_instances_no_target(centered_pair_predictions: Labels): + labels = centered_pair_predictions + video = labels.video + labeled_frame: LabeledFrame = labels.find_last(video=video, frame_idx=1098) + + # There will never be an IOU greater than 1, so expect all instances back. + assert len(labeled_frame.instances) == 3 + cull_frame_instances( + instances_list=labeled_frame.instances, general_iou_threshold=1 + ) + assert len(labeled_frame.instances) == 3 + + # There is an instance with an IOU of 1 though, so expect 2 instances back. + assert len(labeled_frame.instances) == 3 + cull_frame_instances( + instances_list=labeled_frame.instances, general_iou_threshold=0.999999999999999 + ) + assert len(labeled_frame.instances) == 2 + + # Test with Tracker + + tracker: Tracker = Tracker.make_tracker_by_name( + pre_cull_general_iou_threshold=0.999999999999999, + ) + assert tracker.pre_cull_function is not None + + # There is also an instance with an IOU of 0.67, so expect 1 instance back. + assert len(labeled_frame.instances) == 2 + tracker: Tracker = Tracker.make_tracker_by_name( + pre_cull_general_iou_threshold=0.6, + ) + assert tracker.pre_cull_function is not None + tracker.pre_cull_function(inst_list=labeled_frame.instances) + assert len(labeled_frame.instances) == 1 + + +def test_cull_frame_instances_with_target(centered_pair_predictions: Labels): + labels = centered_pair_predictions + video = labels.video + labeled_frame: LabeledFrame = labels.find_last(video=video, frame_idx=1098) + + # Target count equal to the number of instances. Expect all instances back. + target_count = 3 + + # No IOU threshold. + assert len(labeled_frame.instances) == target_count + cull_frame_instances(instances_list=labeled_frame.instances, instance_count=3) + assert len(labeled_frame.instances) == target_count + + # With IOU threshold. + assert len(labeled_frame.instances) == target_count + cull_frame_instances( + instances_list=labeled_frame.instances, + instance_count=target_count, + iou_threshold=0.0, + ) + assert len(labeled_frame.instances) == target_count + + # Target count less than the number of instances. Expect target count instances back + + # Without IOU. + target_count = 2 + assert len(labeled_frame.instances) == 3 + cull_frame_instances( + instances_list=labeled_frame.instances, instance_count=target_count + ) + assert len(labeled_frame.instances) == target_count + + # With IOU. + target_count = 1 + assert len(labeled_frame.instances) == 2 + cull_frame_instances( + instances_list=labeled_frame.instances, + instance_count=target_count, + iou_threshold=0.0, + ) + assert len(labeled_frame.instances) == target_count + + # Test with both target count and general IOU threshold. Switching frames and using + # Tracker. + + labeled_frame: LabeledFrame = labels.find_last(video=video, frame_idx=1095) + tracker: Tracker = Tracker.make_tracker_by_name(target_instance_count=target_count) + assert tracker.pre_cull_function is None + + # No instances removed. + + target_count = 4 + general_iou_threshold = 1 + tracker: Tracker = Tracker.make_tracker_by_name( + target_instance_count=target_count, + pre_cull_general_iou_threshold=general_iou_threshold, + ) + assert tracker.pre_cull_function is not None + + # Without non-general IOU. + assert len(labeled_frame.instances) == target_count + tracker.pre_cull_function(inst_list=labeled_frame.instances) + assert len(labeled_frame.instances) == target_count + + # With non-general IOU. + iou_threshold = 0.0 + assert len(labeled_frame.instances) == target_count + tracker: Tracker = Tracker.make_tracker_by_name( + target_instance_count=target_count, + pre_cull_iou_threshold=iou_threshold, + pre_cull_general_iou_threshold=general_iou_threshold, + ) + assert tracker.pre_cull_function is not None + tracker.pre_cull_function(inst_list=labeled_frame.instances) + assert len(labeled_frame.instances) == target_count + + # Instance removed via general IOU. + target_count = 4 + general_iou_threshold = 0.999999999999999 + assert len(labeled_frame.instances) == 4 + tracker: Tracker = Tracker.make_tracker_by_name( + target_instance_count=target_count, + pre_cull_general_iou_threshold=general_iou_threshold, + ) + assert tracker.pre_cull_function is not None + tracker.pre_cull_function(inst_list=labeled_frame.instances) + assert len(labeled_frame.instances) == target_count - 1 + + # Instance removed via non-general IOU. + target_count = 2 + iou_threshold = 0.0 + assert len(labeled_frame.instances) == 3 + tracker: Tracker = Tracker.make_tracker_by_name( + target_instance_count=target_count, + pre_cull_to_target=True, + pre_cull_iou_threshold=iou_threshold, + ) + assert tracker.pre_cull_function is not None + tracker.pre_cull_function(inst_list=labeled_frame.instances) + assert len(labeled_frame.instances) == target_count + + def test_nms(): boxes = np.array( [[10, 10, 20, 20], [10, 10, 15, 15], [30, 30, 40, 40], [32, 32, 42, 42]] diff --git a/tests/test_instance.py b/tests/test_instance.py index 74a8b192e..141f19468 100644 --- a/tests/test_instance.py +++ b/tests/test_instance.py @@ -529,3 +529,79 @@ def test_instance_structuring_from_predicted(centered_pair_predictions): # Unstructure -> structure labels_copy = labels.copy() + + +def test_instance_update_points(multiview_min_session_frame_groups): + """Test updating points of an instance.""" + + labels: Labels = multiview_min_session_frame_groups + lf: LabeledFrame = labels.labeled_frames[0] + instance = lf.user_instances[0] + pred_instance = lf.predicted_instances[0] + n_nodes = len(labels.skeleton.nodes) + + # Case 0. User instance with incomplete and invisible points. + point_val = 0 + for inst in [instance, pred_instance]: + for points in inst._points: + points.visible = False + points.complete = False + inst.update_points(points=np.full((n_nodes, 2), point_val)) + + # All points should be updated. + assert np.all(inst.get_points_array(invisible_as_nan=False) == point_val) + + # Case 1. User instance with incomplete and visible points. + point_val = 1 + for inst in [instance, pred_instance]: + for points in inst._points: + points.visible = True + points.complete = False + inst.update_points(points=np.full((n_nodes, 2), point_val)) + + # All points should be updated. + assert np.all(inst.get_points_array(invisible_as_nan=False) == point_val) + + # Case 2. User instance with complete and visible points. + old_point_val = point_val + point_val = 2 + for inst in [instance, pred_instance]: + for points in inst._points: + points.visible = True + points.complete = True + inst.update_points( + points=np.full((n_nodes, 2), point_val), exclude_complete=True + ) + + # All points should be updated IF predicted + is_predicted = isinstance(inst, PredictedInstance) + if is_predicted: + assert np.all(inst.get_points_array(invisible_as_nan=False) == point_val) + else: + assert np.all( + inst.get_points_array(invisible_as_nan=False) == old_point_val + ) + + # Test that we can update just a single point + old_point_val = point_val + point_val = 4 + inst = instance + points = inst._points[0] + points.visible = False + points.complete = False + inst.update_points(points=np.full((n_nodes, 2), point_val), exclude_complete=True) + # Only a single point should be updated + assert np.sum(inst.get_points_array(invisible_as_nan=False) == point_val) == 2 + + # Case 3. User instance with complete and invisible points. + point_val = 3 + for inst in [instance, pred_instance]: + for points in inst._points: + points.visible = False + points.complete = True + inst.update_points( + points=np.full((n_nodes, 2), point_val), exclude_complete=True + ) + + # All points should be updated + assert np.all(inst.get_points_array(invisible_as_nan=False) == point_val) diff --git a/tests/test_util.py b/tests/test_util.py index a7916d47f..cd4e01aad 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -154,3 +154,45 @@ def test_decode_preview_image(flies13_skeleton: Skeleton): img_b64 = skeleton.preview_image img = decode_preview_image(img_b64) assert img.mode == "RGBA" + + +def test_compute_oks(): + # Test compute_oks function with the cocoutils implementation + inst_gt = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr) + np.testing.assert_allclose(oks, 1) + + inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr) + np.testing.assert_allclose(oks, 2 / 3) + + inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr) + np.testing.assert_allclose(oks, 1) + + inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr) + np.testing.assert_allclose(oks, 1) + + # Test compute_oks function with the implementation from the paper + inst_gt = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr, False) + np.testing.assert_allclose(oks, 1) + + inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr, False) + np.testing.assert_allclose(oks, 2 / 3) + + inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr, False) + np.testing.assert_allclose(oks, 1) + + inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32") + oks = compute_oks(inst_gt, inst_pr, False) + np.testing.assert_allclose(oks, 1)