diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 41d696f0c..dc2d91b93 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -474,6 +474,14 @@ def add_submenu_choices(menu, title, options, key): fileMenu, "replace videos", "Replace Videos...", self.commands.replaceVideo ) + fileMenu.addSeparator() + add_menu_item( + fileMenu, + "add session", + "Add Recording Session...", + self.commands.addSession, + ) + fileMenu.addSeparator() add_menu_item(fileMenu, "save", "Save", self.commands.saveProject) add_menu_item(fileMenu, "save as", "Save As...", self.commands.saveProjectAs) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index ef6055a45..ee0b2edd8 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -41,7 +41,8 @@ class which inherits from `AppCommand` (or a more specialized class such as import attr import cv2 import numpy as np -from qtpy import QtCore, QtGui, QtWidgets +from qtpy import QtCore, QtWidgets, QtGui +from qtpy.QtWidgets import QMessageBox, QProgressDialog from sleap.gui.dialogs.delete import DeleteDialog from sleap.gui.dialogs.filedialog import FileDialog @@ -52,6 +53,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from sleap.gui.state import GuiState from sleap.gui.suggestions import VideoFrameSuggestions from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track +from sleap.io.cameras import RecordingSession from sleap.io.convert import default_analysis_filename from sleap.io.dataset import Labels from sleap.io.format.adaptor import Adaptor @@ -430,6 +432,10 @@ def removeVideo(self): """Removes selected video from project.""" self.execute(RemoveVideo) + def addSession(self): + """Shows gui for adding `RecordingSession`s to the project.""" + self.execute(AddSession) + def openSkeletonTemplate(self): """Shows gui for loading saved skeleton into project.""" self.execute(OpenSkeleton, template=True) @@ -1918,6 +1924,38 @@ def ask(context: CommandContext, params: dict) -> bool: return True +class AddSession(EditCommand): + # topics = [UpdateTopic.session] + + @staticmethod + def do_action(context: CommandContext, params: dict): + + camera_calibration = params["camera_calibration"] + session = RecordingSession.load(filename=camera_calibration) + + # Add session + context.labels.add_session(session) + + # Load if no video currently loaded + if context.state["session"] is None: + context.state["session"] = session + + @staticmethod + def ask(context: CommandContext, params: dict) -> bool: + """Shows gui for adding video to project.""" + filters = ["Camera calibration (*.toml)"] + filename, selected_filter = FileDialog.open( + context.app, + dir=None, + caption="Select camera calibration...", + filter=";;".join(filters), + ) + + params["camera_calibration"] = filename + + return len(filename) > 0 + + class OpenSkeleton(EditCommand): topics = [UpdateTopic.skeleton] diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 84f39762f..4d0630f23 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -1,11 +1,22 @@ """Module for storing information for camera groups.""" +import logging +from pathlib import Path +import tempfile +import cattr +import toml +from typing import List, Optional, Union, Iterator, Any, Dict, Tuple -from typing import List, Optional, Union, Iterator - -from attrs import define, field from aniposelib.cameras import Camera, FisheyeCamera, CameraGroup +from attrs import define, field +from attrs.validators import deep_iterable, instance_of import numpy as np +from sleap.util import deep_iterable_converter +from sleap.io.video import Video + + +logger = logging.getLogger(__name__) + @define class Camcorder: @@ -13,9 +24,37 @@ class Camcorder: Attributes: camera: `Camera` or `FishEyeCamera` object. + videos: List of `Video` objects. """ - camera: Optional[Union[Camera, FisheyeCamera]] = field(default=None) + camera: Union[Camera, FisheyeCamera] + camera_cluster: "CameraCluster" = None + _video_by_session: Dict["RecordingSession", Video] = field(factory=dict) + + @property + def videos(self) -> List[Video]: + return list(self.camera_cluster._session_by_video.keys()) + + @property + def sessions(self) -> List["RecordingSession"]: + return list(self._video_by_session.keys()) + + def get_video(self, session: "RecordingSession") -> Optional[Video]: + if session not in self._video_by_session: + logger.warning(f"{session} not found in {self}.") + return None + return self._video_by_session[session] + + def get_session(self, video: Video) -> Optional["RecordingSession"]: + if video not in self.camera_cluster._session_by_video: + logger.warning(f"{video} not found in {self}.") + return None + return self.camera_cluster._session_by_video[video] + + def __attrs_post_init__(self): + # Avoid overwriting `CameraCluster` if already set. + if not isinstance(self.camera_cluster, CameraCluster): + self.camera_cluster = CameraCluster() def __eq__(self, other): if not isinstance(other, Camcorder): @@ -41,6 +80,34 @@ def __getattr__(self, attr): ) return getattr(self.camera, attr) + def __getitem__( + self, key: Union[str, "RecordingSession", Video] + ) -> Union["RecordingSession", Video]: # Raises KeyError if key not found + """Return linked `Video` or `RecordingSession`. + + Args: + key: Key to use for lookup. Can be a `RecordingSession` or `Video` object. + + Returns: + `Video` or `RecordingSession` object. + + Raises: + KeyError: If key is not found. + """ + + # If key is a RecordingSession, return the Video + if isinstance(key, RecordingSession): + return self._video_by_session[key] + + # If key is a Video, return the RecordingSession + elif isinstance(key, Video): + return self.camera_cluster._session_by_video[key] + + raise KeyError(f"Key {key} not found in {self}.") + + def __hash__(self) -> int: + return hash(self.camera) + def __repr__(self): return f"{self.__class__.__name__}(name={self.name}, size={self.size})" @@ -60,6 +127,31 @@ def from_dict(cls, d) -> "Camcorder": cam = Camera.from_dict(d) return Camcorder(cam) + @classmethod + def from_camera( + cls, cam: Union[Camera, FisheyeCamera], *args, **kwargs + ) -> "Camcorder": + """Creates a `Camcorder` object from a `Camera` or `FishEyeCamera` object. + + Args: + cam: `Camera` or `FishEyeCamera` object. + + Returns: + `Camcorder` object. + """ + # Do not convert if already a Camcorder + if isinstance(cam, Camcorder): + return cam + + # Do not convert if not a `Camera` or `FishEyeCamera` + if not isinstance(cam, Camera): + raise TypeError( + f"Expected `Camera` or `FishEyeCamera` object, got {type(cam)}" + ) + + # Convert! + return Camcorder(cam) + @define class CameraCluster(CameraGroup): @@ -67,29 +159,176 @@ class CameraCluster(CameraGroup): Attributes: cameras: List of `Camcorder`s. - metadata: Set of metadata. + metadata: Dictionary of metadata. + sessions: List of `RecordingSession`s. + videos: List of `Video`s. """ - cameras: List[Camcorder] = field(factory=list) - metadata: set = field(factory=set) + cameras: List[Camcorder] = field( + factory=list, + validator=deep_iterable( + member_validator=instance_of(Camcorder), + iterable_validator=instance_of(list), + ), + converter=deep_iterable_converter( + member_converter=Camcorder.from_camera, + iterable_converter=list, + ), + ) + metadata: dict = field(factory=dict) + _videos_by_session: Dict["RecordingSession", List[Video]] = field(factory=dict) + _session_by_video: Dict[Video, "RecordingSession"] = field(factory=dict) + _camcorder_by_video: Dict[Video, Camcorder] = field(factory=dict) + + @property + def sessions(self) -> List["RecordingSession"]: + """List of `RecordingSession`s.""" + return list(self._videos_by_session.keys()) + + @property + def videos(self) -> List[Video]: + """List of `Video`s.""" + return list(self._session_by_video.keys()) + + def get_videos_from_session( + self, session: "RecordingSession" + ) -> Optional[List[Video]]: + """Get `Video`s from `RecordingSession` object. + + Args: + session: `RecordingSession` object. + + Returns: + List of `Video` objects or `None` if not found. + """ + if session not in self.sessions: + logger.warning( + f"RecordingSession not linked to {self}. " + "Use `self.add_session(session)` to add it." + ) + return None + return self._videos_by_session[session] + + def get_session_from_video(self, video: Video) -> Optional["RecordingSession"]: + """Get `RecordingSession` from `Video` object. + + Args: + video: `Video` object. + + Returns: + `RecordingSession` object or `None` if not found. + """ + if video not in self.videos: + logger.warning(f"Video not linked to any RecordingSession in {self}.") + return None + return self._session_by_video[video] + + def get_camcorder_from_video(self, video: Video) -> Optional[Camcorder]: + """Get `Camcorder` from `Video` object. + + Args: + video: `Video` object. + + Returns: + `Camcorder` object or `None` if not found. + """ + if video not in self.videos: + logger.warning(f"Video not linked to any Camcorders in {self}.") + return None + return self._camcorder_by_video[video] + + def get_videos_from_camcorder(self, camcorder: Camcorder) -> List[Video]: + """Get `Video`s from `Camcorder` object. + + Args: + camcorder: `Camcorder` object. + + Returns: + List of `Video` objects. + + Raises: + ValueError: If `camcorder` is not in `self.cameras`. + """ + if camcorder not in self.cameras: + raise ValueError(f"Camcorder not in {self}.") + return camcorder.videos + + def add_session(self, session: "RecordingSession"): + """Adds a `RecordingSession` to the `CameraCluster`.""" + self._videos_by_session[session] = [] + session.camera_cluster = self def __attrs_post_init__(self): + """Initialize `CameraCluster` object.""" super().__init__(cameras=self.cameras, metadata=self.metadata) + for cam in self.cameras: + cam.camera_cluster = self + + def __contains__(self, item): + return item in self.cameras + + def __iter__(self) -> Iterator[Camcorder]: + return iter(self.cameras) def __len__(self): return len(self.cameras) - def __getitem__(self, idx): - return self.cameras[idx] + def __getitem__( + self, idx_or_key: Union[int, Video, Camcorder, "RecordingSession", str] + ) -> Optional[ + Union[Camcorder, Tuple[Camcorder, Video], List[Video], "RecordingSession", Any] + ]: + """Get item from `CameraCluster`. - def __iter__(self) -> Iterator[List[Camcorder]]: - return iter(self.cameras) + Args: + idx_or_key: Index, `Video`, `Camcorder`, `RecordingSession`, or `str` name. - def __contains__(self, item): - return item in self.cameras + Returns: + `Camcorder`, (`Camcorder`, `Video`), `List[Video]`, `RecordingSession`, + metadata value, or None if not found. + + Raises: + ValueError: If `idx_or_key` used as a metadata key and not found or + `idx_or_key` is a `Camcorder` which is not in `self.cameras`. + """ + + # If key is int, index into cameras -> Camcorder + if isinstance(idx_or_key, int): + return self.cameras[idx_or_key] + + # If key is Video, return linked + # (Camcorder, RecordingSession) -> Optional[Tuple[Camcorder, Video]] + elif isinstance(idx_or_key, Video): + camcorder = self.get_camcorder_from_video(idx_or_key) + session = self.get_session_from_video(idx_or_key) + if camcorder is None or session is None: + return None + return (camcorder, session) + + # If key is Camcorder, return linked Videos -> Optional[List[Video]] + elif isinstance(idx_or_key, Camcorder): + return self.get_videos_from_camcorder(idx_or_key) + + # If key is RecordingSession, return linked Videos -> Optional[List[Video]] + elif isinstance(idx_or_key, RecordingSession): + return self.get_videos_from_session(idx_or_key) + + # Last resort: look in metadata for matching key -> Any + elif idx_or_key in self.metadata: + return self.metadata[idx_or_key] + + # Raise error if not found + else: + raise KeyError( + f"Key {idx_or_key} not found in {self.__class__.__name__} or " + "associated metadata." + ) def __repr__(self): - message = f"{self.__class__.__name__}(len={len(self)}: " + message = ( + f"{self.__class__.__name__}(sessions={len(self.sessions)}, " + f"cameras={len(self)}: " + ) for cam in self: message += f"{cam.name}, " return f"{message[:-2]})" @@ -104,13 +343,378 @@ def load(cls, filename) -> "CameraCluster": Returns: `CameraCluster` object. """ + cgroup: CameraGroup = super().load(filename) + return cls(cameras=cgroup.cameras, metadata=cgroup.metadata) + + @classmethod + def from_calibration_dict(cls, calibration_dict: Dict[str, str]) -> "CameraCluster": + """Structure a cluster dictionary to a `CameraCluster`. + + This method is intended to be used for restructuring a `CameraCluster` object + (that was previously unstructured to a serializable format). Note: this method + does not handle any mapping between `Video`s, `RecordingSession`s, and + `Camcorder`s. + + Args: + calibration_dict: A dictionary containing just the calibration info needed + to partially restructure a `CameraCluster` (no mapping between `Video`s, + `RecordingSession`s, and `Camcorder`s). + + Returns: + `CameraCluster` object. + """ + + # Save the calibration dictionary to a temp file and load as `CameraGroup` + with tempfile.TemporaryDirectory() as temp_dir: + temp_file = str(Path(temp_dir, "calibration.toml")) + with open(temp_file, "w") as f: + toml.dump(calibration_dict, f) + cgroup: CameraGroup = super().load(temp_file) + + return cls(cameras=cgroup.cameras, metadata=cgroup.metadata) + + def to_calibration_dict(self) -> Dict[str, str]: + """Unstructure the `CameraCluster` object to a dictionary. + + This method is intended to be used for unstructuring a `CameraCluster` object + to a serializable format. Note: this method does not save any mapping between + `Video`s, `RecordingSession`s, and `Camcorders`. + + Returns: + Dictionary of `CameraCluster` object. + """ + + # Use existing `CameraGroup.dump` method to get the calibration dictionary + with tempfile.TemporaryDirectory() as temp_dir: + temp_file = str(Path(temp_dir, "calibration.toml")) + self.dump(fname=temp_file) + calibration_dict = toml.load(temp_file) + + return calibration_dict + + +@define(eq=False) +class RecordingSession: + """Class for storing information for a recording session. + + Attributes: + camera_cluster: `CameraCluster` object. + metadata: Dictionary of metadata. + videos: List of `Video`s that have been linked to a `Camcorder` in the + `self.camera_cluster`. + linked_cameras: List of `Camcorder`s in the `self.camera_cluster` that are + linked to a `Video`. + unlinked_cameras: List of `Camcorder`s in the `self.camera_cluster` that are + not linked to a `Video`. + """ + + camera_cluster: CameraCluster = field(factory=CameraCluster) + metadata: dict = field(factory=dict) + _video_by_camcorder: Dict[Camcorder, Video] = field(factory=dict) + + @property + def videos(self) -> List[Video]: + """List of `Video`s.""" + + return self.camera_cluster._videos_by_session[self] + + @property + def linked_cameras(self) -> List[Camcorder]: + """List of `Camcorder`s in `self.camera_cluster` that are linked to a video.""" + + return list(self._video_by_camcorder.keys()) + + @property + def unlinked_cameras(self) -> List[Camcorder]: + """List of `Camcorder`s in `self.camera_cluster` that are not linked to a video.""" + + return list(set(self.camera_cluster.cameras) - set(self.linked_cameras)) + + def get_video(self, camcorder: Camcorder) -> Optional[Video]: + """Retrieve `Video` linked to `Camcorder`. + + Args: + camcorder: `Camcorder` object. + + Returns: + If `Camcorder` in `self.camera_cluster`, then `Video` object if found, else + `None` if `Camcorder` has no linked `Video`. + + Raises: + ValueError: If `Camcorder` is not in `self.camera_cluster`. + """ + + if camcorder not in self.camera_cluster: + raise ValueError( + f"Camcorder {camcorder.name} is not in this RecordingSession's " + f"{self.camera_cluster}." + ) + + if camcorder not in self._video_by_camcorder: + logger.warning( + f"Camcorder {camcorder.name} is not linked to a video in this " + f"RecordingSession." + ) + return None + + return self._video_by_camcorder[camcorder] + + def get_camera(self, video: Video) -> Optional[Camcorder]: + """Retrieve `Camcorder` linked to `Video`. + + Args: + video: `Video` object. + + Returns: + `Camcorder` object if found, else `None`. + """ + + if video not in self.camera_cluster._camcorder_by_video: + logger.warning( + f"{video} is not linked to a Camcorder in this " + f"RecordingSession's {self.camera_cluster}." + ) + return None + return self.camera_cluster._camcorder_by_video[video] + + def add_video(self, video: Video, camcorder: Camcorder): + """Adds a `Video` to the `RecordingSession`. + + Args: + video: `Video` object. + camcorder: `Camcorder` object. + """ + + # Ensure the `Camcorder` is in this `RecordingSession`'s `CameraCluster` try: - cam_group: CameraGroup = super().load(filename) - except FileNotFoundError as e: - raise FileNotFoundError( - f"Could not find calibration file at {filename}." - ) from e - - cameras = [Camcorder(cam) for cam in cam_group.cameras] - return cls(cameras=cameras, metadata=cam_group.metadata) + assert camcorder in self.camera_cluster + except AssertionError: + raise ValueError( + f"Camcorder {camcorder.name} is not in this RecordingSession's " + f"{self.camera_cluster}." + ) + + # Add session-to-videos (1-to-many) map to `CameraCluster` + if self not in self.camera_cluster._videos_by_session: + self.camera_cluster.add_session(self) + if video not in self.camera_cluster._videos_by_session[self]: + self.camera_cluster._videos_by_session[self].append(video) + + # Add session-to-video (1-to-1) map to `Camcorder` + if video not in camcorder._video_by_session: + camcorder._video_by_session[self] = video + + # Add video-to-session (1-to-1) map to `CameraCluster` + self.camera_cluster._session_by_video[video] = self + + # Add video-to-camcorder (1-to-1) map to `CameraCluster` + if video not in self.camera_cluster._camcorder_by_video: + self.camera_cluster._camcorder_by_video[video] = [] + self.camera_cluster._camcorder_by_video[video] = camcorder + + # Add camcorder-to-video (1-to-1) map to `RecordingSession` + self._video_by_camcorder[camcorder] = video + + def remove_video(self, video: Video): + """Removes a `Video` from the `RecordingSession`. + + Args: + video: `Video` object. + """ + + # Remove video-to-camcorder map from `CameraCluster` + camcorder = self.camera_cluster._camcorder_by_video.pop(video) + + # Remove video-to-session map from `CameraCluster` + self.camera_cluster._session_by_video.pop(video) + + # Remove session-to-video(s) maps from related `CameraCluster` and `Camcorder` + self.camera_cluster._videos_by_session[self].remove(video) + camcorder._video_by_session.pop(self) + + # Remove camcorder-to-video map from `RecordingSession` + self._video_by_camcorder.pop(camcorder) + + def __attrs_post_init__(self): + self.camera_cluster.add_session(self) + + def __iter__(self) -> Iterator[List[Camcorder]]: + return iter(self.camera_cluster) + + def __len__(self): + return len(self.videos) + + def __getattr__(self, attr: str) -> Any: + + """Try to find the attribute in the camera_cluster next.""" + return getattr(self.camera_cluster, attr) + + def __getitem__( + self, idx_or_key: Union[int, Video, Camcorder, str] + ) -> Union[Camcorder, Video, Any]: + """Grab a `Camcorder`, `Video`, or metadata from the `RecordingSession`. + + Try to index into `camera_cluster.cameras` first, then check + video-to-camera map and camera-to-video map. Lastly check in the `metadata`s. + """ + + # Try to find in `self.camera_cluster.cameras` + if isinstance(idx_or_key, int): + try: + return self.camera_cluster[idx_or_key] + except IndexError: + pass # Try to find in metadata + + # Return a `Camcorder` if `idx_or_key` is a `Video + if isinstance(idx_or_key, Video): + return self.get_camera(idx_or_key) + + # Return a `Video` if `idx_or_key` is a `Camcorder` + elif isinstance(idx_or_key, Camcorder): + return self.get_video(idx_or_key) + + # Try to find in `self.metadata` + elif idx_or_key in self.metadata: + return self.metadata[idx_or_key] + + # Try to find in `self.camera_cluster.metadata` + elif idx_or_key in self.camera_cluster.metadata: + return self.camera_cluster.metadata[idx_or_key] + + # Raise error if not found + else: + raise KeyError( + f"Key {idx_or_key} not found in {self.__class__.__name__} or " + "associated metadata." + ) + + def __repr__(self): + return f"{self.__class__.__name__}(camera_cluster={self.camera_cluster})" + + @classmethod + def load( + cls, + filename, + metadata: Optional[dict] = None, + ) -> "RecordingSession": + """Loads cameras as `Camcorder`s from a calibration.toml file. + + Args: + filename: Path to calibration.toml file. + metadata: Dictionary of metadata. + + Returns: + `RecordingSession` object. + """ + + camera_cluster: CameraCluster = CameraCluster.load(filename) + return cls( + camera_cluster=camera_cluster, + metadata=(metadata or {}), + ) + + @classmethod + def from_calibration_dict(cls, calibration_dict: dict) -> "RecordingSession": + """Loads cameras as `Camcorder`s from a calibration dictionary. + + Args: + calibration_dict: Dictionary of calibration data. + + Returns: + `RecordingSession` object. + """ + + camera_cluster: CameraCluster = CameraCluster.from_calibration_dict( + calibration_dict + ) + return cls(camera_cluster=camera_cluster) + + def to_session_dict(self, video_to_idx: Dict[Video, int]) -> dict: + """Unstructure `RecordingSession` to an invertible dictionary. + + Returns: + Dictionary of "calibration" and "camcorder_to_video_idx_map" needed to + restructure a `RecordingSession`. + """ + + # Unstructure `CameraCluster` and `metadata` + calibration_dict = self.camera_cluster.to_calibration_dict() + + # Store camcorder-to-video indices map where key is camcorder index + # and value is video index from `Labels.videos` + camcorder_to_video_idx_map = {} + for cam_idx, camcorder in enumerate(self.camera_cluster): + + # Skip if Camcorder is not linked to any Video + if camcorder not in self._video_by_camcorder: + continue + + # Get video index from `Labels.videos` + video = self._video_by_camcorder[camcorder] + video_idx = video_to_idx.get(video, None) + + if video_idx is not None: + camcorder_to_video_idx_map[cam_idx] = video_idx + else: + logger.warning( + f"Video {video} not found in `Labels.videos`. " + "Not saving to `RecordingSession` serialization." + ) + + return { + "calibration": calibration_dict, + "camcorder_to_video_idx_map": camcorder_to_video_idx_map, + } + + @classmethod + def from_session_dict( + cls, session_dict, videos_list: List[Video] + ) -> "RecordingSession": + """Restructure `RecordingSession` from an invertible dictionary. + + Args: + session_dict: Dictionary of "calibration" and "camcorder_to_video_idx_map" + needed to fully restructure a `RecordingSession`. + videos_list: List containing `Video` objects (expected `Labels.videos`). + + Returns: + `RecordingSession` object. + """ + + # Restructure `RecordingSession` without `Video` to `Camcorder` mapping + calibration_dict = session_dict["calibration"] + session: RecordingSession = RecordingSession.from_calibration_dict( + calibration_dict + ) + + # Retrieve all `Camcorder` and `Video` objects, then add to `RecordingSession` + camcorder_to_video_idx_map = session_dict["camcorder_to_video_idx_map"] + for cam_idx, video_idx in camcorder_to_video_idx_map.items(): + camcorder = session.camera_cluster.cameras[cam_idx] + video = videos_list[video_idx] + session.add_video(video, camcorder) + + return session + + @staticmethod + def make_cattr(videos_list: List[Video]): + """Make a `cattr.Converter` for `RecordingSession` serialization. + + Args: + videos_list: List containing `Video` objects (expected `Labels.videos`). + + Returns: + `cattr.Converter` object. + """ + sessions_cattr = cattr.Converter() + sessions_cattr.register_structure_hook( + RecordingSession, + lambda x, cls: RecordingSession.from_session_dict(x, videos_list), + ) + + video_to_idx = {video: i for i, video in enumerate(videos_list)} + sessions_cattr.register_unstructure_hook( + RecordingSession, lambda x: x.to_session_dict(video_to_idx) + ) + return sessions_cattr diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 45280cc54..a9782bad7 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -62,6 +62,7 @@ import numpy as np import datetime from sklearn.model_selection import train_test_split +from sleap.io.cameras import RecordingSession try: from typing import ForwardRef @@ -419,6 +420,7 @@ class Labels(MutableSequence): nodes: List[Node] = attr.ib(default=attr.Factory(list)) tracks: List[Track] = attr.ib(default=attr.Factory(list)) suggestions: List[SuggestionFrame] = attr.ib(default=attr.Factory(list)) + sessions: List[RecordingSession] = attr.ib(default=attr.Factory(list)) negative_anchors: Dict[Video, list] = attr.ib(default=attr.Factory(dict)) provenance: Dict[Text, Union[str, int, float, bool]] = attr.ib( default=attr.Factory(dict) @@ -1584,6 +1586,19 @@ def remove_video(self, video: Video): self.videos.remove(video) self._cache.remove_video(video) + def add_session(self, session: RecordingSession): + """Add a recording session to the labels. + Args: + session: `RecordingSession` instance + """ + if not isinstance(session, RecordingSession): + raise TypeError( + f"Expected a RecordingSession instance. Received type: {type(session)}" + ) + + if session not in self.sessions: + self.sessions.append(session) + @classmethod def from_json(cls, *args, **kwargs): from sleap.io.format.labels_json import LabelsJsonAdaptor @@ -1918,16 +1933,21 @@ def to_dict(self, skip_labels: bool = False) -> Dict[str, Any]: label_cattr.register_unstructure_hook( Track, lambda x: str(self.tracks.index(x)) ) + label_cattr.register_unstructure_hook( + RecordingSession, lambda x: str(self.sessions.index(x)) + ) # Make a converter for the top level skeletons list. idx_to_node = {i: self.nodes[i] for i in range(len(self.nodes))} - skeleton_cattr = Skeleton.make_cattr(idx_to_node) # Make attr for tracks so that we save as tuples rather than dicts; # this can save a lot of space when there are lots of tracks. track_cattr = cattr.Converter(unstruct_strat=cattr.UnstructureStrategy.AS_TUPLE) + # Make converter for recording sessions + sessions_cattr = RecordingSession.make_cattr(videos_list=self.videos) + # Serialize the skeletons, videos, and labels dicts = { "version": LABELS_JSON_FILE_VERSION, @@ -1936,6 +1956,7 @@ def to_dict(self, skip_labels: bool = False) -> Dict[str, Any]: "videos": Video.cattr().unstructure(self.videos), "tracks": track_cattr.unstructure(self.tracks), "suggestions": label_cattr.unstructure(self.suggestions), + "sessions": sessions_cattr.unstructure(self.sessions), "negative_anchors": label_cattr.unstructure(self.negative_anchors), "provenance": label_cattr.unstructure(self.provenance), } diff --git a/sleap/io/format/labels_json.py b/sleap/io/format/labels_json.py index f284731a6..765b7d7cd 100644 --- a/sleap/io/format/labels_json.py +++ b/sleap/io/format/labels_json.py @@ -7,6 +7,7 @@ also the videos/frames as HDF5 datasets. """ import atexit +import logging import os import re import shutil @@ -16,6 +17,8 @@ import cattr +from sleap.io.cameras import RecordingSession + from .adaptor import Adaptor, SleapObjectType from .filehandle import FileHandle @@ -30,6 +33,8 @@ from sleap.skeleton import Node, Skeleton from sleap.util import json_loads, json_dumps, weak_filename_match +logger = logging.getLogger(__name__) + class LabelsJsonAdaptor(Adaptor): FORMAT_ID = 1 @@ -498,6 +503,16 @@ def from_json_data( print(e) pass + try: + sessions_cattr = RecordingSession.make_cattr(videos_list=videos) + sessions = sessions_cattr.structure( + dicts["sessions"], List[RecordingSession] + ) + except Exception as e: + logger.warning("Error while loading `RecordingSession`s:") + logger.warning(e) + sessions = [] + if "negative_anchors" in dicts: negative_anchors_cattr = cattr.Converter() negative_anchors_cattr.register_structure_hook( @@ -527,6 +542,9 @@ def from_json_data( label_cattr.register_structure_hook( Track, lambda x, type: None if x is None else tracks[int(x)] ) + label_cattr.register_structure_hook( + RecordingSession, lambda x, type: sessions[int(x)] + ) labels = label_cattr.structure(dicts["labels"], List[LabeledFrame]) else: @@ -538,6 +556,7 @@ def from_json_data( skeletons=skeletons, nodes=nodes, suggestions=suggestions, + sessions=sessions, negative_anchors=negative_anchors, tracks=tracks, provenance=provenance, diff --git a/sleap/util.py b/sleap/util.py index 5edbf164b..c27cb6c09 100644 --- a/sleap/util.py +++ b/sleap/util.py @@ -11,11 +11,13 @@ from collections import defaultdict from io import BytesIO from pathlib import Path -from typing import Any, Dict, Hashable, Iterable, List, Optional +from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional from urllib.parse import unquote, urlparse from urllib.request import url2pathname import attr +from attrs import field +from attrs.validators import is_callable, optional, and_ import h5py as h5 import numpy as np import psutil @@ -30,6 +32,55 @@ import sleap.version as sleap_version +# TODO(LM): Open a PR to attrs to add this to the library, then remove once we upgrade. +@attr.s(repr=False, slots=True, hash=True) +class _DeepIterableConverter: + member_converter: Callable = field(validator=is_callable()) + iterable_converter: Optional[Callable] = field( + default=None, validator=optional(is_callable()) + ) + + def __call__(self, value): + """We use a callable class to be able to change the ``__repr__``.""" + + new_value = [] + for member in value: + new_value.append(self.member_converter(member)) + + if self.iterable_converter is not None: + return self.iterable_converter(new_value) + else: + return type(value)(new_value) + + def __repr__(self): + iterable_identifier = ( + "" if self.iterable_converter is None else f" {self.iterable_converter!r}" + ) + return ( + "" + ).format( + iterable_identifier=iterable_identifier, + member=self.member_converter, + ) + + +# TODO(LM): Open a PR to attrs to add this to the library, then remove once we upgrade. +def deep_iterable_converter(member_converter, iterable_converter=None): + """A converter that performs deep conversion of an iterable. + + :param member_converter: Converter(s) to apply to iterable members + :param iterable_converter: Converter to apply to iterable itself + (optional) + + .. versionadded:: not added to attrs yet + + :raises TypeError: if any sub-converters fail + """ + if isinstance(member_converter, (list, tuple)): + member_converter = and_(*member_converter) + return _DeepIterableConverter(member_converter, iterable_converter) + def json_loads(json_str: str) -> Dict: """A simple wrapper around the JSON decoder we are using. diff --git a/tests/fixtures/cameras.py b/tests/fixtures/cameras.py index 720dd1b11..6f30511e8 100644 --- a/tests/fixtures/cameras.py +++ b/tests/fixtures/cameras.py @@ -2,7 +2,19 @@ import pytest +from sleap.io.cameras import CameraCluster, RecordingSession + @pytest.fixture def min_session_calibration_toml_path(): return "tests/data/cameras/minimal_session/calibration.toml" + + +@pytest.fixture +def min_session_camera_cluster(min_session_calibration_toml_path): + return CameraCluster.load(min_session_calibration_toml_path) + + +@pytest.fixture +def min_session_session(min_session_calibration_toml_path): + return RecordingSession.load(min_session_calibration_toml_path) diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 13aa60e6b..6048e13ef 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -1,20 +1,21 @@ -import pytest import shutil import sys import time - -from pathlib import PurePath, Path +from pathlib import Path, PurePath from typing import List -from sleap import Skeleton, Track, PredictedInstance +import pytest + +from sleap import PredictedInstance, Skeleton, Track from sleap.gui.commands import ( + AddSession, CommandContext, ExportAnalysisFile, ExportDatasetWithImages, ImportDeepLabCutFolder, + OpenSkeleton, RemoveVideo, ReplaceVideo, - OpenSkeleton, SaveProjectAs, get_new_version_filename, ) @@ -30,8 +31,8 @@ # These imports cause trouble when running `pytest.main()` from within the file # Comment out to debug tests file via VSCode's "Debug Python File" from tests.info.test_h5 import extract_meta_hdf5 -from tests.io.test_video import assert_video_params from tests.io.test_formats import read_nix_meta +from tests.io.test_video import assert_video_params def test_delete_user_dialog(centered_pair_predictions): @@ -922,3 +923,32 @@ def no_gui_ask(cls, context, params): # Case 3: Export all frames and suggested frames with image data. context.exportFullPackage() assert_loaded_package_similar(path_to_pkg, sugg=True, pred=True) + + +def test_AddSession( + min_tracks_2node_labels: Labels, + min_session_calibration_toml_path: str, +): + """Test that adding a session works.""" + labels = min_tracks_2node_labels + camera_calibration = min_session_calibration_toml_path + + # Set-up CommandContext + context: CommandContext = CommandContext.from_labels(labels) + + # Case 1: No session selected + assert context.state["session"] is None + assert labels.sessions == [] + + params = {"camera_calibration": camera_calibration} + AddSession.do_action(context, params) + assert len(labels.sessions) == 1 + session = labels.sessions[0] + assert context.state["session"] is session + + # Case 2: Session selected + params = {"camera_calibration": camera_calibration} + AddSession.do_action(context, params) + assert len(labels.sessions) == 2 + assert context.state["session"] is session + assert labels.sessions[1] is not session diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index 51a7953f6..1201a6a68 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -2,14 +2,19 @@ import numpy as np import pytest -from sleap.io.cameras import Camcorder, CameraCluster +from sleap.io.cameras import Camcorder, CameraCluster, RecordingSession +from sleap.io.video import Video -def test_camcorder(min_session_calibration_toml_path): + +def test_camcorder( + min_session_session: RecordingSession, + centered_pair_vid: Video, +): """Test `Camcorder` data structure.""" - calibration = min_session_calibration_toml_path - cameras = CameraCluster.load(calibration) - cam: Camcorder = cameras[0] + session: RecordingSession = min_session_session + cam: Camcorder = session.cameras[0] + video: Video = centered_pair_vid # Test from_dict cam_dict = cam.get_dict() @@ -30,20 +35,210 @@ def test_camcorder(min_session_calibration_toml_path): # Test __eq__ assert cam == cam2 + # Test videos property + assert cam.videos == [] + session.add_video(video, cam) + assert cam.videos == [video] + + # Test sessions property + assert cam.sessions == [session] + + # Test __getitem__ + assert cam[session] == video + assert cam[video] == session + with pytest.raises(KeyError): + cam["foo"] + -def test_camera_cluster(min_session_calibration_toml_path): +def test_camera_cluster( + min_session_calibration_toml_path: str, + min_session_session: RecordingSession, + centered_pair_vid: Video, +): """Test `CameraCluster` data structure.""" + # Test load calibration = min_session_calibration_toml_path - cameras = CameraCluster.load(calibration) + camera_cluster = CameraCluster.load(calibration) # Test __len__ - assert len(cameras) == len(cameras.cameras) - assert len(cameras) == 4 + assert len(camera_cluster) == len(camera_cluster.cameras) + assert len(camera_cluster) == 4 # Test __getitem__, __iter__, and __contains__ - for idx, cam in enumerate(cameras): - assert cam == cameras[idx] - assert cam in cameras + for idx, cam in enumerate(camera_cluster): + assert cam == camera_cluster[idx] + assert cam in camera_cluster + + # Test __repr__ + assert f"{camera_cluster.__class__.__name__}(" in repr(camera_cluster) + + # Test validator + with pytest.raises(TypeError): + camera_cluster.cameras = [1, 2, 3] + + # Test converter + assert isinstance(camera_cluster.cameras[0], Camcorder) + + # Test sessions property and add_session + assert camera_cluster.sessions == [] + camera_cluster.add_session(min_session_session) + assert camera_cluster.sessions == [min_session_session] + + # Test videos property + camera = camera_cluster.cameras[0] + min_session_session.add_video(centered_pair_vid, camera) + assert camera_cluster.videos == [centered_pair_vid] + + # Test __getitem__ + assert camera_cluster[centered_pair_vid] == (camera, min_session_session) + assert camera_cluster[camera] == [centered_pair_vid] + assert camera_cluster[min_session_session] == [centered_pair_vid] + min_session_session.remove_video(centered_pair_vid) + assert camera_cluster[centered_pair_vid] is None + assert camera_cluster[camera] == [] + assert camera_cluster[min_session_session] == [] + + # Test to_calibration_dict + calibration_dict = camera_cluster.to_calibration_dict() + assert isinstance(calibration_dict, dict) + for cam_idx, cam in enumerate(camera_cluster): + cam_key = f"cam_{cam_idx}" + cam_value = calibration_dict[cam_key] + + assert calibration_dict[cam_key]["name"] == cam.name + assert np.array_equal(cam_value["matrix"], cam.matrix) + assert np.array_equal(cam_value["distortions"], cam.dist) + assert np.array_equal(cam_value["size"], cam.size) + assert np.array_equal(cam_value["rotation"], cam.rvec) + assert np.array_equal(cam_value["translation"], cam.tvec) + + # Test from_calibration_dict + camera_cluster2 = CameraCluster.from_calibration_dict(calibration_dict) + assert isinstance(camera_cluster2, CameraCluster) + assert len(camera_cluster2) == len(camera_cluster) + for cam_1, cam_2 in zip(camera_cluster, camera_cluster2): + assert cam_1 == cam_2 + assert camera_cluster2.sessions == [] + + +def test_recording_session( + min_session_calibration_toml_path: str, + min_session_camera_cluster: CameraCluster, + centered_pair_vid: Video, + hdf5_vid: Video, +): + """Test `RecordingSession` data structure.""" + calibration: str = min_session_calibration_toml_path + camera_cluster: CameraCluster = min_session_camera_cluster + + # Test load + session = RecordingSession.load(calibration) + session.metadata = {"test": "we can access this information!"} + session.camera_cluster.metadata = { + "another_test": "we can even access this information!" + } + + # Test __attrs_post_init__ + assert session in session.camera_cluster.sessions + + # Test __iter__, __contains__, and __getitem__ (with int key) + for idx, cam in enumerate(session): + assert isinstance(cam, Camcorder) + assert cam in camera_cluster + assert cam == camera_cluster[idx] + + # Test __getattr__ + assert session.cameras == camera_cluster.cameras + + # Test __getitem__ with string key + assert session["test"] == "we can access this information!" + assert session["another_test"] == "we can even access this information!" + + # Test __len__ + assert len(session) == len(session.videos) # Test __repr__ - assert f"{cameras.__class__.__name__}(" in repr(cameras) + assert f"{session.__class__.__name__}(" in repr(session) + + # Test add_video + camcorder = session.camera_cluster.cameras[0] + session.add_video(centered_pair_vid, camcorder) + assert centered_pair_vid is session.camera_cluster._videos_by_session[session][0] + assert centered_pair_vid is camcorder._video_by_session[session] + assert session is session.camera_cluster._session_by_video[centered_pair_vid] + assert camcorder is session.camera_cluster._camcorder_by_video[centered_pair_vid] + assert centered_pair_vid is session._video_by_camcorder[camcorder] + + # Test videos property + assert centered_pair_vid in session.videos + + # Test linked_cameras property + assert camcorder in session.linked_cameras + assert camcorder not in session.unlinked_cameras + + # Test __getitem__ with `Video` key + assert session[centered_pair_vid] is camcorder + + # Test __getitem__ with `Camcorder` key + assert session[camcorder] is centered_pair_vid + + # Test from_calibration_dict + def compare_cameras(session_1: RecordingSession, session_2: RecordingSession): + assert len(session_2.camera_cluster) == len(session.camera_cluster) + for cam_1, cam_2 in zip(session, session_2): + assert cam_1 == cam_2 + + calibration_dict = session.camera_cluster.to_calibration_dict() + session_2 = RecordingSession.from_calibration_dict(calibration_dict) + assert isinstance(session_2, RecordingSession) + assert len(session_2.videos) == 0 + compare_cameras(session, session_2) + + # Test to_session_dict + camcorder_2 = session.camera_cluster.cameras[2] + session.add_video(hdf5_vid, camcorder_2) + videos_list = [centered_pair_vid, hdf5_vid] + video_to_idx = {video: idx for idx, video in enumerate(videos_list)} + session_dict = session.to_session_dict(video_to_idx) + assert isinstance(session_dict, dict) + assert session_dict["calibration"] == calibration_dict + assert session_dict["camcorder_to_video_idx_map"] == { + 0: video_to_idx[centered_pair_vid], + 2: video_to_idx[hdf5_vid], + } + + # Test from_session_dict + def compare_sessions(session_1: RecordingSession, session_2: RecordingSession): + assert isinstance(session_2, RecordingSession) + assert not (session_2 == session) # Not the same object in memory + assert len(session_2.camera_cluster) == len(session_1.camera_cluster) + compare_cameras(session_1, session_2) + assert len(session_2.videos) == len(session_1.videos) + assert np.array_equal(session_2.videos, session_1.videos) + + session_2 = RecordingSession.from_session_dict(session_dict, videos_list) + compare_sessions(session, session_2) + + # Test remove_video + session.remove_video(centered_pair_vid) + assert centered_pair_vid not in session.videos + assert camcorder not in session.linked_cameras + assert camcorder in session.unlinked_cameras + assert centered_pair_vid not in session.camera_cluster._videos_by_session[session] + assert session not in camcorder._video_by_session + assert centered_pair_vid not in session.camera_cluster._session_by_video + assert centered_pair_vid not in session.camera_cluster._camcorder_by_video + assert camcorder not in session._video_by_camcorder + + # Test __getitem__ with `Video` key + assert session[centered_pair_vid] is None + + # Test __getitem__ with `Camcorder` key + assert session[camcorder] is None + + # Test make_cattr + sessions_cattr = RecordingSession.make_cattr(videos_list) + session_dict_2 = sessions_cattr.unstructure(session_2) + assert session_dict_2 == session_dict + session_3 = sessions_cattr.structure(session_dict_2, RecordingSession) + compare_sessions(session_2, session_3) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 5592ae437..cb8842ddc 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -4,6 +4,7 @@ from pathlib import Path, PurePath import sleap +from sleap.io.cameras import RecordingSession from sleap.skeleton import Skeleton from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track from sleap.io.video import Video, MediaVideo @@ -967,6 +968,45 @@ def test_save_labels_with_images(min_labels_slp, tmpdir): assert Labels.load_file(fn).video.embedded_frame_inds == [0, 1, 2] +def test_save_labels_with_sessions( + min_labels_slp: Labels, min_session_session: RecordingSession, tmpdir +): + """Test that we can save labels with sessions attribute.""" + + labels = min_labels_slp + session = min_session_session + + assert labels.sessions == [] + labels.add_session(session) + assert len(labels.sessions) == 1 + + new_path = str(Path(tmpdir, "test.slp")) + labels.save(new_path) + + loaded_labels: Labels = Labels.load_file(new_path) + loaded_session = loaded_labels.sessions[0] + + assert len(loaded_labels.sessions) == 1 + assert isinstance(loaded_session, RecordingSession) + assert not (loaded_session == session) # Not the same object in memory + assert len(loaded_session.camera_cluster) == len(session.camera_cluster) + assert len(loaded_session.videos) == len(session.videos) + assert np.array_equal(loaded_session.videos, session.videos) + assert len(loaded_session.camera_cluster) == len(session.camera_cluster) + for cam_1, cam_2 in zip(session, loaded_session): + assert cam_1 == cam_2 + + +def test_add_session(min_labels_slp: Labels, min_session_session: RecordingSession): + """Test that we can add a `RecordingSession` to a `Labels` object.""" + + labels = min_labels_slp + session = min_session_session + + labels.add_session(session) + assert labels.sessions == [session] + + def test_labels_hdf5(multi_skel_vid_labels, tmpdir): labels = multi_skel_vid_labels filename = os.path.join(tmpdir, "test.h5")