-
Notifications
You must be signed in to change notification settings - Fork 26
Tracking feature : CoTracker integration for automated label generation #155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 22 commits
5bde176
cf2f0e7
f05549c
0bbc3a6
79c0e63
7c28dc4
28ed829
67c1d0d
9eb9ecf
fd010e0
89f894c
d29154e
50ff568
3e43c47
f458ddc
d0e3636
1794a7c
8549847
a14509d
c551951
33eeacb
016d611
4e10ab8
a8b913d
bf5a036
e04a6b1
b816980
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,8 @@ | |||||
| from skimage.io import imsave | ||||||
|
|
||||||
| from napari_deeplabcut import _writer, keypoints | ||||||
| from napari_deeplabcut.tracking._data import TrackingModelInputs, TrackingWorkerData, TrackingWorkerOutput | ||||||
| from napari_deeplabcut.tracking._models import AVAILABLE_TRACKERS, RawModelOutputs, TrackingModel | ||||||
|
|
||||||
| # os.environ["NAPARI_DLC_HIDE_TUTORIAL"] = "True" # no longer on by default | ||||||
|
|
||||||
|
|
@@ -32,6 +34,10 @@ def viewer(make_napari_viewer_proxy): | |||||
| "napari-deeplabcut", | ||||||
| "Keypoint controls", | ||||||
| ) | ||||||
| tracking_dock_widget, tracking_plugin_widget = viewer.window.add_plugin_dock_widget( | ||||||
| "napari-deeplabcut", | ||||||
| "Tracking controls", | ||||||
| ) | ||||||
|
|
||||||
| try: | ||||||
| yield viewer | ||||||
|
|
@@ -151,3 +157,127 @@ def video_path(tmp_path_factory): | |||||
| writer.write(frame) | ||||||
| writer.release() | ||||||
| return output_path | ||||||
|
|
||||||
|
|
||||||
| # --- Tracking fixtures --- | ||||||
| DUMMY_TRACKER_NAME = "TestTracker" | ||||||
|
|
||||||
|
|
||||||
| class DummyTracker(TrackingModel): | ||||||
| """ | ||||||
| Minimal tracker that: | ||||||
| - echoes inputs to outputs with a tiny deterministic transform, | ||||||
| - emits progress via the callback, | ||||||
| - honors stop_callback. | ||||||
| """ | ||||||
|
|
||||||
| name = DUMMY_TRACKER_NAME | ||||||
| info_text = "Dummy tracker for unit testing." | ||||||
|
|
||||||
| def load_model(self, device: str): | ||||||
| # No-op model; keep a simple config to emulate 'step' like CoTracker. | ||||||
| class _NoOpModel: | ||||||
| step = 3 | ||||||
|
|
||||||
| return _NoOpModel() | ||||||
|
|
||||||
| def prepare_inputs(self, cfg: "TrackingWorkerData", **kwargs) -> TrackingModelInputs: | ||||||
| # Ensure video is (T, H, W, C) and keypoints is (K, 3) where columns: [frame_idx, x, y] or [id, x, y] | ||||||
| video = np.asarray(cfg.video) | ||||||
| queries = np.asarray(cfg.keypoints).copy() | ||||||
| metadata = { | ||||||
| "keypoint_range": cfg.keypoint_range, | ||||||
| "backward_tracking": getattr(cfg, "backward_tracking", False), | ||||||
| } | ||||||
| return TrackingModelInputs(video=video, keypoints=queries, metadata=metadata) | ||||||
|
|
||||||
| def run(self, inputs: TrackingModelInputs, progress_callback, stop_callback, **kwargs) -> RawModelOutputs: | ||||||
| # Fake progression per frame; stop if requested. | ||||||
| T = inputs.video.shape[0] | ||||||
| K = inputs.keypoints.shape[0] | ||||||
|
|
||||||
| # Produce tracks of shape (T, K, 2) with a deterministic offset (e.g., +1 pixel) | ||||||
| tracks = np.zeros((T, K, 2), dtype=float) | ||||||
| for t in range(T): | ||||||
| progress_callback(t, T) | ||||||
| if stop_callback(): | ||||||
| # Return partial result up to t | ||||||
| tracks = tracks[: t + 1] | ||||||
| vis = np.ones_like(tracks[..., 0], dtype=bool) # visibility dummy | ||||||
| return RawModelOutputs(keypoints=tracks, keypoint_features={"visibility": vis}) | ||||||
| # Use the input (x, y) for all K points and add a tiny drift proportional to t | ||||||
| tracks[t, :, 0] = inputs.keypoints[:, 1] + 0.1 * t # x | ||||||
| tracks[t, :, 1] = inputs.keypoints[:, 2] + 0.1 * t # y | ||||||
|
|
||||||
| vis = np.ones_like(tracks[..., 0], dtype=bool) | ||||||
| return RawModelOutputs(keypoints=tracks, keypoint_features={"visibility": vis}) | ||||||
|
|
||||||
| def prepare_outputs( | ||||||
| self, model_outputs: RawModelOutputs, worker_inputs: "TrackingWorkerData" = None, **kwargs | ||||||
| ) -> "TrackingWorkerOutput": | ||||||
| # Flatten (T, K, 2) -> (N, 3) with [frame_idx, x, y] | ||||||
| tracks = model_outputs.keypoints | ||||||
| T = tracks.shape[0] | ||||||
| K = tracks.shape[1] | ||||||
|
|
||||||
| T1, T2 = worker_inputs.keypoint_range | ||||||
| frame_ids = np.repeat(np.arange(T1, T1 + T), K) | ||||||
| flat = tracks.reshape(-1, 2) | ||||||
| keypoints = np.column_stack((frame_ids, flat)) # (N, 3) | ||||||
|
|
||||||
| # Minimal features: concat original per-keypoint features replicated per frame | ||||||
| keypoints_features = pd.concat( | ||||||
| [worker_inputs.keypoint_features] * T, | ||||||
| ignore_index=True, | ||||||
| ) | ||||||
|
|
||||||
| return TrackingWorkerOutput( | ||||||
| keypoints=keypoints, | ||||||
| keypoint_features=keypoints_features, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| @pytest.fixture(autouse=True) | ||||||
| def register_dummy_tracker(): | ||||||
| """ | ||||||
| Auto-register DummyTracker for all tests and restore registry afterwards. | ||||||
| """ | ||||||
| prev = dict(AVAILABLE_TRACKERS) | ||||||
| AVAILABLE_TRACKERS[DUMMY_TRACKER_NAME] = {"class": DummyTracker} | ||||||
| try: | ||||||
| yield | ||||||
| finally: | ||||||
| AVAILABLE_TRACKERS.clear() | ||||||
| AVAILABLE_TRACKERS.update(prev) | ||||||
|
|
||||||
|
|
||||||
| @pytest.fixture | ||||||
| def track_worker_inputs(): | ||||||
| """ | ||||||
| Provide minimal valid TrackingWorkerData with: | ||||||
| - 5-frame RGB video of 4x4 pixels, | ||||||
| - 2 keypoints, | ||||||
| - keypoint_range covering all frames, | ||||||
| - simple features DataFrame. | ||||||
| """ | ||||||
| video = np.zeros((5, 4, 4, 3), dtype=np.uint8) | ||||||
|
|
||||||
| keypoints = np.array( | ||||||
| [ | ||||||
| [0, 10.0, 20.0], | ||||||
| [1, 30.0, 40.0], | ||||||
|
||||||
| [1, 30.0, 40.0], | |
| [0, 30.0, 40.0], |
Uh oh!
There was an error while loading. Please reload this page.