From 98c415f4aaafdb6393f8238d70ca92345b51e127 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Mon, 16 Jun 2025 12:55:12 -0700 Subject: [PATCH] initial napari epoch playback --- src/spyglass/common/common_task.py | 61 +++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py index 946727595..8b3371c19 100644 --- a/src/spyglass/common/common_task.py +++ b/src/spyglass/common/common_task.py @@ -1,13 +1,23 @@ import datajoint as dj import ndx_franklab_novela +import numpy as np import pynwb +from nwb_data_viewer.time_series import NwbTimeSeriesViewer +from nwb_data_viewer.video import NwbVideoViewer +from nwb_data_viewer.viewer_mixin import MultiViewer from spyglass.common.common_device import CameraDevice # noqa: F401 +from spyglass.common.common_ephys import Electrode, ElectrodeGroup, Raw from spyglass.common.common_interval import IntervalList from spyglass.common.common_nwbfile import Nwbfile from spyglass.common.common_session import Session # noqa: F401 +from spyglass.settings import video_dir from spyglass.utils import SpyglassMixin, logger -from spyglass.utils.nwb_helper_fn import get_config, get_nwb_file +from spyglass.utils.nwb_helper_fn import ( + get_config, + get_electrode_indices, + get_nwb_file, +) schema = dj.schema("common_task") @@ -306,3 +316,52 @@ def is_nwb_task_epoch(cls, task_table: pynwb.core.DynamicTable) -> bool: and hasattr(task_table, "camera_id") and hasattr(task_table, "task_epochs") ) + + def make_epoch_playback(self, key, max_channels=16): + from spyglass.common.common_behav import VideoFile + + key = (self & key).fetch1("KEY") + # eseries info + eseries_obj = (Raw & key).fetch_nwb()[0]["raw"] + # epoch interval info + interval_list_name = (self & key).fetch1("interval_list_name") + key["interval_list_name"] = interval_list_name + valid_times = (IntervalList & key).fetch1("valid_times") + # channels groupings + channel_groups = {} + for group_key in (ElectrodeGroup & key).fetch("KEY"): + + channel_groups[int(group_key["electrode_group_name"])] = ( + get_electrode_indices( + eseries_obj, (Electrode & group_key).fetch("electrode_id") + ) + ) + # get the video file obj + video_obj = (VideoFile() & key).fetch_nwb()[0]["video_file"] + + # Make the viewers + interval_range = (valid_times[0][0], valid_times[-1][1]) + eseries_viewer = NwbTimeSeriesViewer( + eseries_obj, + interval_range=interval_range, + n_plot=3000, + max_channels=max_channels, + channel_index=np.arange(max_channels), + image_shape=(512, 512), + # channel_options=np.arange, + channel_groups=channel_groups, + ) + video_viewer = NwbVideoViewer( + video_obj, + interval_range=interval_range, + tmp_dir="/cumulus/sam/cache", + video_dir=video_dir, + ) + + multi_viewer = MultiViewer( + viewer_list=[video_viewer, eseries_viewer], + interval_range=interval_range, + ) + multi_viewer.compile() + multi_viewer.run() + return multi_viewer