Skip to content

Commit 98c415f

Browse files
committed
initial napari epoch playback
1 parent 2eb0ee1 commit 98c415f

File tree

1 file changed

+60
-1
lines changed

1 file changed

+60
-1
lines changed

src/spyglass/common/common_task.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
import datajoint as dj
22
import ndx_franklab_novela
3+
import numpy as np
34
import pynwb
5+
from nwb_data_viewer.time_series import NwbTimeSeriesViewer
6+
from nwb_data_viewer.video import NwbVideoViewer
7+
from nwb_data_viewer.viewer_mixin import MultiViewer
48

59
from spyglass.common.common_device import CameraDevice # noqa: F401
10+
from spyglass.common.common_ephys import Electrode, ElectrodeGroup, Raw
611
from spyglass.common.common_interval import IntervalList
712
from spyglass.common.common_nwbfile import Nwbfile
813
from spyglass.common.common_session import Session # noqa: F401
14+
from spyglass.settings import video_dir
915
from spyglass.utils import SpyglassMixin, logger
10-
from spyglass.utils.nwb_helper_fn import get_config, get_nwb_file
16+
from spyglass.utils.nwb_helper_fn import (
17+
get_config,
18+
get_electrode_indices,
19+
get_nwb_file,
20+
)
1121

1222
schema = dj.schema("common_task")
1323

@@ -306,3 +316,52 @@ def is_nwb_task_epoch(cls, task_table: pynwb.core.DynamicTable) -> bool:
306316
and hasattr(task_table, "camera_id")
307317
and hasattr(task_table, "task_epochs")
308318
)
319+
320+
def make_epoch_playback(self, key, max_channels=16):
321+
from spyglass.common.common_behav import VideoFile
322+
323+
key = (self & key).fetch1("KEY")
324+
# eseries info
325+
eseries_obj = (Raw & key).fetch_nwb()[0]["raw"]
326+
# epoch interval info
327+
interval_list_name = (self & key).fetch1("interval_list_name")
328+
key["interval_list_name"] = interval_list_name
329+
valid_times = (IntervalList & key).fetch1("valid_times")
330+
# channels groupings
331+
channel_groups = {}
332+
for group_key in (ElectrodeGroup & key).fetch("KEY"):
333+
334+
channel_groups[int(group_key["electrode_group_name"])] = (
335+
get_electrode_indices(
336+
eseries_obj, (Electrode & group_key).fetch("electrode_id")
337+
)
338+
)
339+
# get the video file obj
340+
video_obj = (VideoFile() & key).fetch_nwb()[0]["video_file"]
341+
342+
# Make the viewers
343+
interval_range = (valid_times[0][0], valid_times[-1][1])
344+
eseries_viewer = NwbTimeSeriesViewer(
345+
eseries_obj,
346+
interval_range=interval_range,
347+
n_plot=3000,
348+
max_channels=max_channels,
349+
channel_index=np.arange(max_channels),
350+
image_shape=(512, 512),
351+
# channel_options=np.arange,
352+
channel_groups=channel_groups,
353+
)
354+
video_viewer = NwbVideoViewer(
355+
video_obj,
356+
interval_range=interval_range,
357+
tmp_dir="/cumulus/sam/cache",
358+
video_dir=video_dir,
359+
)
360+
361+
multi_viewer = MultiViewer(
362+
viewer_list=[video_viewer, eseries_viewer],
363+
interval_range=interval_range,
364+
)
365+
multi_viewer.compile()
366+
multi_viewer.run()
367+
return multi_viewer

0 commit comments

Comments
 (0)