-
Notifications
You must be signed in to change notification settings - Fork 124
(3->2) Add method to update Instances across views in RecordingSession
#1279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: liezl/ars-add-sessions-to-cache
Are you sure you want to change the base?
Changes from all commits
190b82c
4aa285b
7754a66
4dca8be
f5d41b3
619d584
96df7f4
280d48e
9559da7
45b8475
ba2df88
35c1521
8098ae1
addc8be
74a8f83
4f65679
243ad3c
bb0fda1
f203953
2bc5c47
e0a57f2
a1965e9
62f4f7f
9b67128
99ef17e
38599f4
df1bc72
e85e4fb
7105656
bffeca8
c293d73
001c849
06dff9f
4706fb2
4994f04
7cc7a5a
a2326d3
2f56a1c
712fc70
b1d2372
0d26728
8d3e232
507eefa
55c6fd4
07ea17b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1231,8 +1231,17 @@ def plotFrame(self, *args, **kwargs): | |
|
|
||
| def _after_plot_update(self, frame_idx): | ||
| """Run after plot is updated, but stay on same frame.""" | ||
|
|
||
| video = self.state["video"] | ||
|
|
||
| # Redraw trails | ||
| overlay: TrackTrailOverlay = self.overlays["trails"] | ||
| overlay.redraw(self.state["video"], frame_idx) | ||
| overlay.redraw(video, frame_idx) | ||
|
|
||
| # Replot connected views for multi-camera projects | ||
| # TODO(LM): Use context.state["session"] in command instead (when implemented) | ||
| session = self.labels.get_session(video) | ||
| self.commands.triangulateSession(session=session) | ||
|
Comment on lines
+1241
to
+1244
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider addressing the TODO regarding the use of Would you like me to help implement this or should I open a GitHub issue to track this enhancement? |
||
|
|
||
| def _after_plot_change(self, player, frame_idx, selected_inst): | ||
| """Called each time a new frame is drawn.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -35,14 +35,15 @@ class which inherits from `AppCommand` (or a more specialized class such as | |||||
| import traceback | ||||||
| from enum import Enum | ||||||
| from glob import glob | ||||||
| from itertools import permutations, product | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused imports to clean up the code. - from itertools import permutations, productCommittable suggestion
Suggested change
|
||||||
| from pathlib import Path, PurePath | ||||||
| from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union | ||||||
| from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused import to clean up the code. - from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
+ from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, UnionCommittable suggestion
Suggested change
|
||||||
|
|
||||||
| import attr | ||||||
| import cv2 | ||||||
| import numpy as np | ||||||
| from qtpy import QtCore, QtWidgets, QtGui | ||||||
| from qtpy.QtWidgets import QMessageBox, QProgressDialog | ||||||
| from sleap_anipose import triangulate, reproject | ||||||
| from qtpy import QtCore, QtGui, QtWidgets | ||||||
|
|
||||||
| from sleap.gui.dialogs.delete import DeleteDialog | ||||||
| from sleap.gui.dialogs.filedialog import FileDialog | ||||||
|
|
@@ -53,7 +54,7 @@ class which inherits from `AppCommand` (or a more specialized class such as | |||||
| from sleap.gui.state import GuiState | ||||||
| from sleap.gui.suggestions import VideoFrameSuggestions | ||||||
| from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track | ||||||
| from sleap.io.cameras import RecordingSession | ||||||
| from sleap.io.cameras import Camcorder, InstanceGroup, FrameGroup, RecordingSession | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused import to clean up the code. - from sleap.io.cameras import Camcorder, InstanceGroup, FrameGroup, RecordingSession
+ from sleap.io.cameras import InstanceGroup, FrameGroup, RecordingSessionCommittable suggestion
Suggested change
|
||||||
| from sleap.io.convert import default_analysis_filename | ||||||
| from sleap.io.dataset import Labels | ||||||
| from sleap.io.format.adaptor import Adaptor | ||||||
|
|
@@ -613,6 +614,20 @@ def generateSuggestions(self, params: Dict): | |||||
| """Generates suggestions using given params dictionary.""" | ||||||
| self.execute(GenerateSuggestions, **params) | ||||||
|
|
||||||
| def triangulateSession( | ||||||
| self, | ||||||
| session: Optional[RecordingSession] = None, | ||||||
| frame_idx: Optional[int] = None, | ||||||
| instance: Optional[Instance] = None, | ||||||
| ): | ||||||
| """Triangulates `Instance`s for selected views in a `RecordingSession`.""" | ||||||
| self.execute( | ||||||
| TriangulateSession, | ||||||
| session=session, | ||||||
| frame_idx=frame_idx, | ||||||
| instance=instance, | ||||||
| ) | ||||||
|
|
||||||
|
Comment on lines
+617
to
+630
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure proper documentation for the The method |
||||||
| def openWebsite(self, url): | ||||||
| """Open a website from URL using the native system browser.""" | ||||||
| self.execute(OpenWebsite, url=url) | ||||||
|
|
@@ -1929,7 +1944,6 @@ class AddSession(EditCommand): | |||||
|
|
||||||
| @staticmethod | ||||||
| def do_action(context: CommandContext, params: dict): | ||||||
|
|
||||||
| camera_calibration = params["camera_calibration"] | ||||||
| session = RecordingSession.load(filename=camera_calibration) | ||||||
|
|
||||||
|
|
@@ -2914,6 +2928,26 @@ def do_action(cls, context: CommandContext, params: dict): | |||||
| if context.state["labeled_frame"] not in context.labels.labels: | ||||||
| context.labels.append(context.state["labeled_frame"]) | ||||||
|
|
||||||
| # Also add the instance to the frame group if it exists | ||||||
| video = context.state["video"] | ||||||
| session = context.labels.get_session(video=video) | ||||||
| if session is None: | ||||||
| return | ||||||
|
|
||||||
| frame_idx = context.state["frame_idx"] | ||||||
| frame_group = session.frame_groups.get(frame_idx, None) | ||||||
| if frame_group is None: | ||||||
| return | ||||||
|
|
||||||
| instance_group = frame_group.get_instance_group(instance=from_predicted) | ||||||
| if instance_group is None: | ||||||
| return | ||||||
|
|
||||||
| camera = session.get_camera(video=video) | ||||||
| frame_group.add_instance( | ||||||
| instance=new_instance, camera=camera, instance_group=instance_group | ||||||
| ) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def create_new_instance( | ||||||
| context: CommandContext, | ||||||
|
|
@@ -3369,6 +3403,143 @@ def do_action(cls, context: CommandContext, params: dict): | |||||
| context.labels.append(current_frame) | ||||||
|
|
||||||
|
|
||||||
| class TriangulateSession(EditCommand): | ||||||
| topics = [UpdateTopic.frame, UpdateTopic.project_instances] | ||||||
|
|
||||||
| @classmethod | ||||||
| def do_action(cls, context: CommandContext, params: dict): | ||||||
| """Triangulate, reproject, and update instances in a session at a frame index. | ||||||
|
|
||||||
| Args: | ||||||
| context: The command context. | ||||||
| params: The command parameters. | ||||||
| session: The `RecordingSession` object to use. Default is current | ||||||
| video's session. | ||||||
| frame_idx: The frame index to use. Default is current frame index. | ||||||
| instance: The `Instance` object to use. Default is current instance. | ||||||
| show_dialog: If True, then show a warning dialog. Default is True. | ||||||
| ask_again: If True, then ask for views/instances again. Default is False. | ||||||
| """ | ||||||
|
|
||||||
| session: RecordingSession = ( | ||||||
| params.get("session", None) or context.state["session"] | ||||||
| ) | ||||||
| if session is None: | ||||||
| return | ||||||
|
|
||||||
| # Get `FrameGroup` for the current frame index | ||||||
| frame_idx: int = params.get("frame_idx", None) | ||||||
| if frame_idx is None: | ||||||
| frame_idx = context.state["frame_idx"] | ||||||
| frame_group: FrameGroup = session.frame_groups.get(frame_idx, None) | ||||||
| if frame_group is None: | ||||||
| return | ||||||
|
|
||||||
| # Get the `InstanceGroup` from `Instance` if any | ||||||
| instance = params.get("instance", None) or context.state["instance"] | ||||||
| instance_group = frame_group.get_instance_group(instance) | ||||||
|
|
||||||
| # If instance_group is None, then we will try to triangulate entire frame_group | ||||||
| instance_groups = ( | ||||||
| [instance_group] | ||||||
| if instance_group is not None | ||||||
| else frame_group.instance_groups | ||||||
| ) | ||||||
|
|
||||||
| # Retain instance groups that have enough views/instances for triangulation | ||||||
| instance_groups = TriangulateSession.has_enough_instances( | ||||||
| frame_group=frame_group, | ||||||
| instance_groups=instance_groups, | ||||||
| frame_idx=frame_idx, | ||||||
| instance=instance, | ||||||
| ) | ||||||
| if instance_groups is None or len(instance_groups) == 0: | ||||||
| return # Not enough instances for triangulation | ||||||
|
|
||||||
| # Get the `FrameGroup` of shape M=include x T x N x 2 | ||||||
| fg_tensor = frame_group.numpy(instance_groups=instance_groups, pred_as_nan=True) | ||||||
|
|
||||||
| # Add extra dimension for number of frames | ||||||
| frame_group_tensor = np.expand_dims(fg_tensor, axis=1) # M=include x F=1 xTxNx2 | ||||||
|
|
||||||
| # Triangulate to one 3D pose per instance | ||||||
| points_3d = triangulate( | ||||||
| p2d=frame_group_tensor, | ||||||
| calib=session.camera_cluster, | ||||||
| excluded_views=frame_group.excluded_views, | ||||||
| ) # F x T x N x 3 | ||||||
|
|
||||||
| # Reproject onto all views | ||||||
| pts_reprojected = reproject( | ||||||
| points_3d, | ||||||
| calib=session.camera_cluster, | ||||||
| excluded_views=frame_group.excluded_views, | ||||||
| ) # M=include x F=1 x T x N x 2 | ||||||
|
|
||||||
| # Sqeeze back to the original shape | ||||||
| points_reprojected = np.squeeze(pts_reprojected, axis=1) # M=include x TxNx2 | ||||||
|
|
||||||
| # Update or create/insert ("upsert") instance points | ||||||
| frame_group.upsert_points( | ||||||
| points=points_reprojected, | ||||||
| instance_groups=instance_groups, | ||||||
| exclude_complete=True, | ||||||
| ) | ||||||
|
|
||||||
| @classmethod | ||||||
| def has_enough_instances( | ||||||
| cls, | ||||||
| frame_group: FrameGroup, | ||||||
| instance_groups: Optional[List[InstanceGroup]] = None, | ||||||
| frame_idx: Optional[int] = None, | ||||||
| instance: Optional[Instance] = None, | ||||||
| ) -> Optional[List[InstanceGroup]]: | ||||||
| """Filters out instance groups without enough instances for triangulation. | ||||||
|
|
||||||
| Args: | ||||||
| frame_group: The `FrameGroup` object to use. | ||||||
| instance_groups: A list of `InstanceGroup` objects to use. Default is None. | ||||||
| frame_idx: The frame index to use (only used in logging). Default is None. | ||||||
| instance: The `Instance` object to use (only used in logging). Default None. | ||||||
|
|
||||||
| Returns: | ||||||
| A list of `InstanceGroup` objects with enough instances for triangulation. | ||||||
| """ | ||||||
|
|
||||||
| if instance is None: | ||||||
| instance = "" # Just used for logging | ||||||
|
|
||||||
| if frame_idx is None: | ||||||
| frame_idx = "" # Just used for logging | ||||||
|
|
||||||
| if instance_groups is None: | ||||||
| instance_groups = frame_group.instance_groups | ||||||
|
|
||||||
| if len(instance_groups) < 1: | ||||||
| logger.warning( | ||||||
| f"Require at least 1 instance group, but found " | ||||||
| f"{len(frame_group.instance_groups)} for frame group {frame_group} at " | ||||||
| f"frame {frame_idx}." | ||||||
| f"\nSkipping triangulation." | ||||||
| ) | ||||||
| return None # No instance groups found | ||||||
|
|
||||||
| # Assert that there are enough views and instances | ||||||
| instance_groups_to_tri = [] | ||||||
| for instance_group in instance_groups: | ||||||
| instances = instance_group.get_instances(frame_group.cams_to_include) | ||||||
| if len(instances) < 2: | ||||||
| # Not enough instances | ||||||
| logger.warning( | ||||||
| f"Not enough instances in {instance_group} for triangulation." | ||||||
| f"\nSkipping instance group." | ||||||
| ) | ||||||
| continue | ||||||
| instance_groups_to_tri.append(instance_group) | ||||||
|
|
||||||
| return instance_groups_to_tri # `InstanceGroup`s with enough instances | ||||||
|
|
||||||
|
|
||||||
|
Comment on lines
+3406
to
+3542
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Review the implementation of
|
||||||
| def open_website(url: str): | ||||||
| """Open website in default browser. | ||||||
|
|
||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.