Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion src/spyglass/common/common_task.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import datajoint as dj
import ndx_franklab_novela
import numpy as np
import pynwb
from nwb_data_viewer.time_series import NwbTimeSeriesViewer
from nwb_data_viewer.video import NwbVideoViewer
from nwb_data_viewer.viewer_mixin import MultiViewer

from spyglass.common.common_device import CameraDevice # noqa: F401
from spyglass.common.common_ephys import Electrode, ElectrodeGroup, Raw
from spyglass.common.common_interval import IntervalList
from spyglass.common.common_nwbfile import Nwbfile
from spyglass.common.common_session import Session # noqa: F401
from spyglass.settings import video_dir
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.nwb_helper_fn import get_config, get_nwb_file
from spyglass.utils.nwb_helper_fn import (
get_config,
get_electrode_indices,
get_nwb_file,
)

schema = dj.schema("common_task")

Expand Down Expand Up @@ -306,3 +316,52 @@ def is_nwb_task_epoch(cls, task_table: pynwb.core.DynamicTable) -> bool:
and hasattr(task_table, "camera_id")
and hasattr(task_table, "task_epochs")
)

def make_epoch_playback(self, key, max_channels=16):
from spyglass.common.common_behav import VideoFile

key = (self & key).fetch1("KEY")
# eseries info
eseries_obj = (Raw & key).fetch_nwb()[0]["raw"]
# epoch interval info
interval_list_name = (self & key).fetch1("interval_list_name")
key["interval_list_name"] = interval_list_name
valid_times = (IntervalList & key).fetch1("valid_times")
# channels groupings
channel_groups = {}
for group_key in (ElectrodeGroup & key).fetch("KEY"):

channel_groups[int(group_key["electrode_group_name"])] = (
get_electrode_indices(
eseries_obj, (Electrode & group_key).fetch("electrode_id")
)
)
# get the video file obj
video_obj = (VideoFile() & key).fetch_nwb()[0]["video_file"]

# Make the viewers
interval_range = (valid_times[0][0], valid_times[-1][1])
eseries_viewer = NwbTimeSeriesViewer(
eseries_obj,
interval_range=interval_range,
n_plot=3000,
max_channels=max_channels,
channel_index=np.arange(max_channels),
image_shape=(512, 512),
# channel_options=np.arange,
channel_groups=channel_groups,
)
video_viewer = NwbVideoViewer(
video_obj,
interval_range=interval_range,
tmp_dir="/cumulus/sam/cache",
video_dir=video_dir,
)

multi_viewer = MultiViewer(
viewer_list=[video_viewer, eseries_viewer],
interval_range=interval_range,
)
multi_viewer.compile()
multi_viewer.run()
return multi_viewer
Loading