diff --git a/spikeinterface_gui/basescatterview.py b/spikeinterface_gui/basescatterview.py index a74fa67..7c8f253 100644 --- a/spikeinterface_gui/basescatterview.py +++ b/spikeinterface_gui/basescatterview.py @@ -185,7 +185,7 @@ def _qt_make_layout(self): tb = self.qt_widget.view_toolbar self.combo_seg = QT.QComboBox() tb.addWidget(self.combo_seg) - self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ]) + self.combo_seg.addItems([f'Segment {segment_index}' for segment_index in range(self.controller.num_segments)]) self.combo_seg.currentIndexChanged.connect(self._qt_change_segment) add_stretch_to_qtoolbar(tb) self.lasso_but = QT.QPushButton("select", checkable = True) @@ -278,7 +278,7 @@ def _qt_refresh(self): # make a copy of the color color = QT.QColor(self.get_unit_color(unit_id)) color.setAlpha(int(self.settings['alpha']*255)) - self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color) + self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color) color = self.get_unit_color(unit_id) curve = pg.PlotCurveItem(hist_count, hist_bins[:-1], fillLevel=None, fillOutline=True, brush=color, pen=color) diff --git a/spikeinterface_gui/controller.py b/spikeinterface_gui/controller.py index 9dc75eb..44e043b 100644 --- a/spikeinterface_gui/controller.py +++ b/spikeinterface_gui/controller.py @@ -7,11 +7,8 @@ from spikeinterface.widgets.utils import get_unit_colors from spikeinterface import compute_sparsity -from spikeinterface.core import get_template_extremum_channel -import spikeinterface.postprocessing -import spikeinterface.qualitymetrics +from spikeinterface.core import get_template_extremum_channel, BaseEvent from spikeinterface.core.sorting_tools import spike_vector_to_indices -from spikeinterface.core.core_tools import check_json from spikeinterface.curation import validate_curation_dict from spikeinterface.curation.curation_model import CurationModel from spikeinterface.widgets.utils import make_units_table_from_analyzer @@ -33,10 +30,23 @@ class Controller(): - def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save_on_compute=False, - curation=False, curation_data=None, label_definitions=None, with_traces=True, - displayed_unit_properties=None, - extra_unit_properties=None, skip_extensions=None, disable_save_settings_button=False): + def __init__( + self, + analyzer=None, + backend="qt", + parent=None, + verbose=False, + save_on_compute=False, + curation=False, + curation_data=None, + label_definitions=None, + with_traces=True, + displayed_unit_properties=None, + extra_unit_properties=None, + skip_extensions=None, + disable_save_settings_button=False, + events=None + ): self.views = [] skip_extensions = skip_extensions if skip_extensions is not None else [] @@ -220,6 +230,62 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self.pc_ext = pc_ext self._potential_merges = None + # some direct attribute + self.num_segments = self.analyzer.get_num_segments() + self.sampling_frequency = self.analyzer.sampling_frequency + + self.events = None + if events is not None: + if verbose: + print('\tLoading events') + self.events = {} + if isinstance(events, dict): + for key, val in events.items(): + if not isinstance(val, dict): + if verbose: + print(f'\tSkipping event {key}: not a dict') + continue + if 'samples' not in val and 'times' not in val: + if verbose: + print(f'\tSkipping event {key}: missing samples or times') + continue + if 'times' in val: + samples_data = val['times'] + convert_to_samples = True + else: + samples_data = val['samples'] + convert_to_samples = False + if self.num_segments > 1: + if not len(samples_data) == self.num_segments: + if verbose: + print(f'\tSkipping event {key}: inconsistent number of samples') + continue + else: + # here we make sure samples is a list of list + if np.array(samples_data).ndim == 1: + samples_data = [samples_data] + if convert_to_samples: + self.events[key] = [np.array(self.time_to_sample_index(s)) for s in samples_data] + else: + self.events[key] = [np.array(s) for s in samples_data] + elif isinstance(events, BaseEvent): + event_names = events.channel_ids + self.events = { + event_name: [] for event_name in event_names + } + for event_name in event_names: + for segment_index in range(self.num_segments): + event_times_segment = events.get_event_times( + channel_id=event_name, + segment_index=segment_index + ) + event_samples_segment = self.analyzer.time_to_sample_index( + event_times_segment + ) + self.events[event_name].append(np.array(event_samples_segment)) + + if len(self.events) == 0: + self.events = None t1 = time.perf_counter() if verbose: @@ -229,10 +295,6 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save self._extremum_channel = get_template_extremum_channel(self.analyzer, peak_sign='neg', outputs='index') - # some direct attribute - self.num_segments = self.analyzer.get_num_segments() - self.sampling_frequency = self.analyzer.sampling_frequency - # spikeinterface handle colors in matplotlib style tuple values in range (0,1) self.refresh_colors() @@ -489,6 +551,13 @@ def time_to_sample_index(self, time): else: return int(time * self.sampling_frequency) + def get_events(self, event_name): + if self.events is None: + return None + if event_name not in self.events: + return None + return self.events[event_name][self.time_info['segment_index']] + def get_information_txt(self): nseg = self.analyzer.get_num_segments() nchan = self.analyzer.get_num_channels() @@ -715,6 +784,8 @@ def set_channel_visibility(self, visible_channel_inds): def has_extension(self, extension_name): if extension_name == 'recording': return self.analyzer.has_recording() or self.analyzer.has_temporary_recording() + elif extension_name == 'events': + return self.events is not None else: # extension needs to be loaded if extension_name in self.skip_extensions: diff --git a/spikeinterface_gui/eventview.py b/spikeinterface_gui/eventview.py new file mode 100644 index 0000000..72204d1 --- /dev/null +++ b/spikeinterface_gui/eventview.py @@ -0,0 +1,289 @@ +import numpy as np +from .view_base import ViewBase + +class EventView(ViewBase): + _supported_backend = ['qt', 'panel'] + _depend_on = ["events"] + _settings = [ + {'name': 'max_trials', 'type': 'int', 'value' : 50 }, + {'name': 'window_start', 'type': 'float', 'value': -0.2}, + {'name': 'window_end', 'type': 'float', 'value': 0.5}, + {'name': 'alpha_psth', 'type': 'float', 'value': 0.5}, + {'name': 'num_bins', 'type': 'int', 'value': 50}, + ] + _need_compute = False + + def __init__(self, controller=None, parent=None, backend="qt"): + self.mode = 'rasters' # or 'psth' + self.selected_unit = None + self.selected_event_key = None + ViewBase.__init__(self, controller=controller, parent=parent, backend=backend) + + + def get_aligned_spikes(self, unit_ids): + event_times = np.array(self.controller.get_events(self.selected_event_key)) + window_s = [self.settings['window_start'], self.settings['window_end']] + window_samples = [int(w * self.controller.sampling_frequency) for w in window_s] + + if len(event_times) > self.settings['max_trials']: + event_times = event_times[np.random.choice(len(event_times), self.settings['max_trials'], replace=False)] + + aligned_spikes_dict = {} + for selected_unit in unit_ids: + aligned_spikes = [] + # TODO: deal with this!!! (at controller level) + segment_index = 0 + inds = self.controller.get_spike_indices(selected_unit, segment_index=segment_index) + spike_times = self.controller.spikes["sample_index"][inds] + + for et in event_times: + rel_spikes = spike_times - et + rel_spikes = rel_spikes[(rel_spikes >= window_samples[0]) & (rel_spikes <= window_samples[1])] + aligned_spikes.append(rel_spikes / self.controller.sampling_frequency) # convert to seconds + aligned_spikes_dict[selected_unit] = aligned_spikes + return aligned_spikes_dict + + def _qt_make_layout(self): + import pyqtgraph as pg + from .myqt import QT, QtWidgets + + layout = QtWidgets.QVBoxLayout() + # Mode selection + toolbar = QtWidgets.QHBoxLayout() + self.mode_combo = QtWidgets.QComboBox() + self.mode_combo.addItems(['Rasters', 'PSTH']) + self.mode_combo.currentIndexChanged.connect(self._qt_on_mode_changed) + toolbar.addWidget(self.mode_combo) + # Event key selection + event_keys = list(self.controller.events.keys()) + if len(event_keys) > 1: + self.event_combo = QtWidgets.QComboBox() + self.event_combo.addItems(event_keys) + self.event_combo.currentIndexChanged.connect(self._qt_on_event_changed) + toolbar.addWidget(self.event_combo) + self.selected_event_key = event_keys[0] if event_keys else None + layout.addLayout(toolbar) + # Pyqtgraph PlotWidget + self.pg_plot = pg.PlotWidget() + self.scatter = pg.ScatterPlotItem(size=10, pxMode=True) + self.pg_plot.addItem(self.scatter) + + # Create vertical line at x=0 once + self.zero_line = pg.InfiniteLine(pos=0, angle=90, pen=pg.mkPen('gray', width=2, style=QT.Qt.DashLine)) + self.pg_plot.addItem(self.zero_line) + + layout.addWidget(self.pg_plot) + self.layout = layout + + def _qt_on_mode_changed(self, idx): + self.mode = 'rasters' if idx == 0 else 'psth' + self._qt_refresh() + + def _qt_on_event_changed(self, idx): + self.selected_event_key = self.event_combo.currentText() + self._qt_refresh() + + def _qt_refresh(self): + from .myqt import QT + import pyqtgraph as pg + + self.scatter.clear() + # Clear everything including scatter + self.pg_plot.clear() + self.pg_plot.addItem(self.zero_line) + + if self.mode == 'rasters': + # Clear all plot items except scatter + self.pg_plot.addItem(self.scatter) + + # Get visible units from controller + visible_units = self.controller.get_visible_unit_ids() + if not visible_units or self.selected_event_key is None: + return + + aligned_spikes_by_unit = self.get_aligned_spikes(visible_units) + window_s = [self.settings['window_start'], self.settings['window_end']] + + for selected_unit in visible_units: + aligned_spikes = aligned_spikes_by_unit[selected_unit] + color = QT.QColor(self.get_unit_color(selected_unit)) + + if self.mode == 'rasters': + all_x = [] + all_y = [] + for i, trial in enumerate(aligned_spikes): + if len(trial) > 0: + all_x.extend(trial) + y = [i]*len(trial) + all_y.extend(y) + if all_x: + self.scatter.addPoints(x=np.array(all_x), y=np.array(all_y), pen=pg.mkPen(None), brush=color, symbol="|") + else: + from pyqtgraph import BarGraphItem + + all_spikes = np.concatenate(aligned_spikes) if aligned_spikes else np.array([]) + all_y_hists = [] + if len(all_spikes) > 0: + bins = np.linspace(window_s[0], window_s[1], 51) + y, x = np.histogram(all_spikes, bins=bins) + # Use bin centers for plotting + bin_centers = (x[:-1] + x[1:]) / 2 + # Create a bar graph item instead of using stepMode + width = (x[1] - x[0]) * 0.8 # 80% of bin width + color.setAlpha(int(self.settings['alpha_psth']*255)) + bg = BarGraphItem(x=bin_centers, height=y, width=width, brush=color, pen=pg.mkPen(color, width=2)) + self.pg_plot.addItem(bg) + all_y_hists.extend(y) + # Set ranges + if self.mode == 'rasters': + self.pg_plot.setYRange(-0.5, len(aligned_spikes)+0.5, padding=0) + self.pg_plot.setXRange(window_s[0], window_s[1], padding=0) + self.pg_plot.setLabel('left', 'Event #') + self.pg_plot.setLabel('bottom', 'Time (s)') + self.pg_plot.setTitle(f'Rasters aligned to {self.selected_event_key}') + else: + self.pg_plot.setXRange(window_s[0], window_s[1], padding=0.05) + if len(all_y_hists) > 0: + self.pg_plot.setYRange(0, max(all_y_hists)*1.1, padding=0) + self.pg_plot.setLabel('left', 'Spike count') + self.pg_plot.setLabel('bottom', 'Time (s)') + self.pg_plot.setTitle(f'PSTH aligned to {self.selected_event_key}') + + + def _panel_make_layout(self): + import panel as pn + import bokeh.plotting as bpl + from bokeh.models import ColumnDataSource, Span, Range1d + from .utils_panel import _bg_color + + top_items = [] + self.panel_mode_select = pn.widgets.Select(name="Mode", value="Rasters", options=["Rasters", "PSTH"]) + self.panel_mode_select.param.watch(self._panel_on_mode_changed, 'value') + top_items.append(self.panel_mode_select) + event_keys = list(self.controller.events.keys()) + if len(event_keys) > 1: + self.panel_event_select = pn.widgets.Select(name="Event", value=event_keys[0], options=event_keys) + self.panel_event_select.param.watch( self._panel_on_event_changed, 'value') + top_items.append(self.panel_event_select) + self.selected_event_key = event_keys[0] + + top_bar = pn.Row(*top_items, sizing_mode="stretch_width") + self.bins = np.linspace( + self.settings["window_start"], + self.settings["window_end"], + self.settings["num_bins"] + ) + self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2 + self.scatter_source = ColumnDataSource(data={"x": [], "y": [], "color": []}) + self.hist_source = ColumnDataSource(data={"center": [], "height": [], "color": []}) + self.x_range = Range1d(start=self.settings['window_start'], end=self.settings['window_end']) + self.panel_fig = bpl.figure( + sizing_mode="stretch_both", + tools="reset,wheel_zoom", + active_scroll="wheel_zoom", + background_fill_color=_bg_color, + border_fill_color=_bg_color, + outline_line_color="white", + x_range=self.x_range, + styles={"flex": "1"} + ) + self.scatter = self.panel_fig.scatter( + "x", + "y", + source=self.scatter_source, + color="color", + ) + self.bar = self.panel_fig.vbar( + x="center", + top="height", + width=self.bins[1] - self.bins[0], + color="color", + source=self.hist_source, + alpha=self.settings['alpha_psth'] + ) + self.vline = Span(location=0, dimension='height', line_color='white', line_width=2, line_dash='dashed') + + self.panel_fig.yaxis.axis_label = 'Event #' + self.panel_fig.xaxis.axis_label = 'Time (s)' + self.panel_fig.toolbar.logo = None + self.panel_fig.add_layout(self.vline) + self.panel_plot_pane = pn.pane.Bokeh(self.panel_fig, sizing_mode="stretch_both") + self.layout = pn.Column( + top_bar, + self.panel_plot_pane, + sizing_mode="stretch_both", + ) + + def _panel_on_mode_changed(self, event): + self.mode = 'rasters' if event.new == 'Rasters' else 'psth' + self._panel_refresh() + + def _panel_on_event_changed(self, event): + self.selected_event_key = event.new + self._panel_refresh() + + def _panel_refresh(self): + import numpy as np + import bokeh.plotting as bpl + + visible_units = self.controller.get_visible_unit_ids() + aligned_spikes_by_unit = self.get_aligned_spikes(visible_units) + if self.mode == 'rasters': + self.hist_source.data = {"center": [], "height": [], "color": []} # Clear histogram data + self.panel_fig.title.text = f'Rasters aligned to {self.selected_event_key}' + self.panel_fig.yaxis.axis_label = 'Event #' + all_x = [] + all_y = [] + all_colors = [] + for selected_unit in visible_units: + aligned_spikes = aligned_spikes_by_unit[selected_unit] + color = self.get_unit_color(selected_unit) + for i, trial in enumerate(aligned_spikes): + if len(trial) > 0: + all_x.extend(trial) + y = [i] * len(trial) + all_y.extend(y) + all_colors.extend([color] * len(trial)) + self.scatter_source.data = { + "x": np.array(all_x), + "y": np.array(all_y), + "color": all_colors + } + else: + self.scatter_source.data = {"x": [], "y": [], "color": []} # Clear scatter data + + all_centers = [] + all_heights = [] + all_colors = [] + for selected_unit in visible_units: + aligned_spikes = aligned_spikes_by_unit[selected_unit] + all_spikes = np.concatenate(aligned_spikes) if aligned_spikes else np.array([]) + hist, _ = np.histogram(all_spikes, bins=self.bins) + all_centers.extend(list(self.bin_centers)) + all_heights.extend(list(hist)) + all_colors.extend([self.get_unit_color(selected_unit)] * len(hist)) + self.hist_source.data = { + "center": all_centers, + "height": all_heights, + "color": all_colors + } + self.panel_fig.yaxis.axis_label = 'Spike count' + self.panel_fig.title.text = f'PSTH aligned to {self.selected_event_key}' + + # adjust x_range if needed + if self.settings["window_start"] != self.x_range.start: + self.x_range.start = self.settings["window_start"] + if self.settings["window_end"] != self.x_range.end: + self.x_range.end = self.settings["window_end"] + + def _panel_on_settings_changed(self): + self.bins = np.linspace( + self.settings["window_start"], + self.settings["window_end"], + self.settings["num_bins"] + ) + self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2 + self.x_range.start = self.settings['window_start'] + self.x_range.end = self.settings['window_end'] + self.bar.glyph.width = self.bins[1] - self.bins[0] + self._panel_refresh() diff --git a/spikeinterface_gui/layout_presets.py b/spikeinterface_gui/layout_presets.py index f77c127..133f13d 100644 --- a/spikeinterface_gui/layout_presets.py +++ b/spikeinterface_gui/layout_presets.py @@ -54,7 +54,7 @@ def get_layout_description(preset_name, layout=None): default_layout = dict( zone1=['curation', 'spikelist'], zone2=['unitlist', 'merge'], - zone3=['trace', 'tracemap', 'spikeamplitude', 'spikedepth', 'spikerate'], + zone3=['trace', 'tracemap', 'spikeamplitude', 'spikedepth', 'spikerate', 'event'], zone4=[], zone5=['probe'], zone6=['ndscatter', 'similarity'], diff --git a/spikeinterface_gui/main.py b/spikeinterface_gui/main.py index da41740..c07f8b6 100644 --- a/spikeinterface_gui/main.py +++ b/spikeinterface_gui/main.py @@ -6,7 +6,7 @@ import warnings from spikeinterface import load_sorting_analyzer, load -from spikeinterface.core.core_tools import is_path_remote +from spikeinterface.core import BaseRecording, SortingAnalyzer, BaseEvent from spikeinterface.core.sortinganalyzer import get_available_analyzer_extensions from .utils_global import get_config_folder from spikeinterface_gui.layout_presets import get_layout_description @@ -16,26 +16,27 @@ from spikeinterface_gui.viewlist import possible_class_views def run_mainwindow( - analyzer, - mode="desktop", - with_traces=True, - curation=False, - curation_dict=None, - label_definitions=None, - displayed_unit_properties=None, - extra_unit_properties=None, - skip_extensions=None, - recording=None, - start_app=True, - layout_preset=None, - layout=None, - address="localhost", - port=0, - panel_start_server_kwargs=None, - panel_window_servable=True, - verbose=False, - user_settings=None, - disable_save_settings_button=False, + analyzer: SortingAnalyzer, + mode: str = "desktop", + with_traces: bool = True, + curation: bool = False, + curation_dict: dict | None = None, + label_definitions: dict | None = None, + displayed_unit_properties: list | None=None, + extra_unit_properties: list | None=None, + skip_extensions: list | None = None, + recording: BaseRecording | None = None, + events: BaseEvent | dict | None = None, + start_app: bool = True, + layout_preset: str | None = None, + layout: dict | None = None, + address: str = "localhost", + port: int = 0, + panel_start_server_kwargs: dict | None = None, + panel_window_servable: bool = True, + verbose: bool = False, + user_settings: dict | None = None, + disable_save_settings_button: bool = False, ): """ Create the main window and start the QT app loop. @@ -65,6 +66,9 @@ def run_mainwindow( recording: RecordingExtractor | None, default: None The recording object to display traces. This can be used when the SortingAnalyzer is recordingless. + events: BaseEvent | dict | None, default: None + The events to display in the GUI. This can be a BaseEvent object or a dictionary + with keys as event names and another dictionary as values with "samples" or "times". start_qt_app: bool, default: True If True, the QT app loop is started layout_preset : str | None @@ -136,7 +140,8 @@ def run_mainwindow( displayed_unit_properties=displayed_unit_properties, extra_unit_properties=extra_unit_properties, skip_extensions=skip_extensions, - disable_save_settings_button=disable_save_settings_button + disable_save_settings_button=disable_save_settings_button, + events=events ) if verbose: t1 = time.perf_counter() diff --git a/spikeinterface_gui/tracemapview.py b/spikeinterface_gui/tracemapview.py index 97818d0..8a4ff9f 100644 --- a/spikeinterface_gui/tracemapview.py +++ b/spikeinterface_gui/tracemapview.py @@ -72,8 +72,7 @@ def _qt_make_layout(self, **kargs): self.layout = QT.QVBoxLayout() - self._qt_create_toolbar() - + self._qt_create_toolbars() # create graphic view and 2 scroll bar g = QT.QGridLayout() @@ -87,14 +86,7 @@ def _qt_make_layout(self, **kargs): self.scatter = pg.ScatterPlotItem(size=10, pxMode = True) self.plot.addItem(self.scatter) - - self.scroll_time = QT.QScrollBar(orientation=QT.Qt.Horizontal) - g.addWidget(self.scroll_time, 1,1) - self.scroll_time.valueChanged.connect(self._qt_on_scroll_time) - - - # self.on_params_changed(do_refresh=False) - #this do refresh + self.layout.addWidget(self.bottom_toolbar) self._qt_change_segment(0) def _qt_on_settings_changed(self, do_refresh=True): @@ -116,6 +108,7 @@ def _qt_on_spike_selection_changed(self): self._qt_seek_with_selected_spike() def _qt_refresh(self): + self._qt_remove_event_line() t, _ = self.controller.get_time() self._qt_seek(t) @@ -245,16 +238,21 @@ def _panel_make_layout(self): x="x", y="y", size=10, fill_color="color", fill_alpha=self.settings['alpha'], source=self.spike_source ) + self.event_source = ColumnDataSource({"x": [], "y": []}) + self.event_renderer = self.figure.line( + x="x", y="y", source=self.event_source, line_color="yellow", line_width=2, line_dash='dashed' + ) + # # Add hover tool for spikes # hover_spikes = HoverTool(renderers=[self.spike_renderer], tooltips=[("Unit", "@unit_id")]) # self.figure.add_tools(hover_spikes) - self._panel_create_toolbar() + self._panel_create_toolbars() self.layout = pn.Column( pn.Column( # Main content area self.toolbar, self.figure, - self.time_slider, + self.bottom_toolbar, styles={"flex": "1"}, sizing_mode="stretch_both" ), @@ -263,6 +261,7 @@ def _panel_make_layout(self): ) def _panel_refresh(self): + self._panel_remove_event_line() t, segment_index = self.controller.get_time() xsize = self.xsize t1, t2 = t - xsize / 3.0, t + xsize * 2 / 3.0 diff --git a/spikeinterface_gui/traceview.py b/spikeinterface_gui/traceview.py index 8e97d61..a4ca029 100644 --- a/spikeinterface_gui/traceview.py +++ b/spikeinterface_gui/traceview.py @@ -100,7 +100,7 @@ def get_data_in_chunk(self, t1, t2, segment_index): return times_chunk, data_curves, scatter_x, scatter_y, scatter_colors ## Qt ## - def _qt_create_toolbar(self): + def _qt_create_toolbars(self): from .myqt import QT import pyqtgraph as pg from .utils_qt import TimeSeeker, add_stretch_to_qtoolbar @@ -133,6 +133,44 @@ def _qt_create_toolbar(self): but.clicked.connect(self.auto_scale) tb.addWidget(but) + + self.scroll_time = QT.QScrollBar(orientation=QT.Qt.Horizontal) + self.scroll_time.valueChanged.connect(self._qt_on_scroll_time) + if self.controller.has_extension("events"): + bottom_layout = QT.QHBoxLayout() + bottom_layout.addWidget(self.scroll_time, stretch=8) + bottom_layout.addStretch() # Push button to the right + + event_layout = QT.QHBoxLayout() + event_keys = list(self.controller.events.keys()) + if len(event_keys) > 1: + self.event_type_combo = QT.QComboBox() + self.event_type_combo.addItems(event_keys) + self.event_type_combo.currentIndexChanged.connect(self._qt_on_event_type_changed) + event_layout.addWidget(QT.QLabel("Event:"), stretch=2) + event_layout.addWidget(self.event_type_combo, stretch=3) + else: + self.event_key = event_keys[0] + + self.prev_event_button = QT.QPushButton("◀") + self.next_event_button = QT.QPushButton("▶") + self.prev_event_button.setMaximumWidth(30) + self.next_event_button.setMaximumWidth(30) + self.next_event_button.clicked.connect(self._qt_on_next_event) + self.prev_event_button.clicked.connect(self._qt_on_prev_event) + event_layout.addWidget(self.prev_event_button, stretch=1) + event_layout.addWidget(self.next_event_button, stretch=1) + + # Wrap event_layout in a QWidget + event_widget = QT.QWidget() + event_widget.setLayout(event_layout) + bottom_layout.addWidget(event_widget) + bottom_widget = QT.QWidget() + bottom_widget.setLayout(bottom_layout) + else: + bottom_widget = self.scroll_time + + self.bottom_toolbar = bottom_widget def _qt_initialize_plot(self): from .myqt import QT @@ -245,9 +283,59 @@ def _qt_scatter_item_clicked(self, x, y): self.notify_spike_selection_changed() self._qt_seek_with_selected_spike() + # change selected unit + unit_id = self.controller.unit_ids[self.controller.spikes[ind_spike_nearest]["unit_index"]] + self.controller.set_visible_unit_ids([unit_id]) + self.notify_unit_visibility_changed() + + def _qt_on_event_type_changed(self): + self.event_key = self.event_type_combo.currentText() + self.refresh() + + def _qt_add_event_line(self): + import pyqtgraph as pg + from .myqt import QT + + # Add vertical line at event time + evt_time = self.controller.get_time()[0] + if hasattr(self, 'event_line'): + self.event_line.setValue(evt_time) + self.event_line.show() + else: + pen = pg.mkPen(color=(255, 255, 0, 180), width=2, style=QT.Qt.DotLine) + self.event_line = pg.InfiniteLine(pos=evt_time, angle=90, movable=False, pen=pen) + self.plot.addItem(self.event_line) + + def _qt_remove_event_line(self): + if hasattr(self, 'event_line'): + self.plot.removeItem(self.event_line) + del self.event_line + + def _qt_on_next_event(self): + + current_sample = self.controller.time_to_sample_index(self.controller.get_time()[0]) + event_samples = self.controller.get_events(self.event_key) + next_events = event_samples[event_samples > current_sample] + if next_events.size > 0: + next_evt_sample = next_events[0] + evt_time = self.controller.sample_index_to_time(next_evt_sample) + self.controller.set_time(time=evt_time) + self.timeseeker.seek(evt_time) + self._qt_add_event_line() + + def _qt_on_prev_event(self): + current_sample = self.controller.time_to_sample_index(self.controller.get_time()[0]) + event_samples = self.controller.get_events(self.event_key) + prev_events = event_samples[event_samples < current_sample] + if prev_events.size > 0: + prev_evt_sample = prev_events[-1] + evt_time = self.controller.sample_index_to_time(prev_evt_sample) + self.controller.set_time(time=evt_time) + self.timeseeker.seek(evt_time) + self._qt_add_event_line() ## panel ## - def _panel_create_toolbar(self): + def _panel_create_toolbars(self): import panel as pn segment_index = self.controller.get_time()[1] @@ -287,6 +375,25 @@ def _panel_create_toolbar(self): value_throttled=0, sizing_mode="stretch_width") self.time_slider.param.watch(self._panel_on_time_slider_changed, "value_throttled") + bottom_toolbar_items = [self.time_slider] + if self.controller.has_extension("events"): + event_keys = list(self.controller.events.keys()) + if len(event_keys) > 1: + self.event_type_selector = pn.widgets.Select(name="Event", options=event_keys, value=event_keys[0]) + self.event_type_selector.param.watch(self._panel_on_event_type_changed, "value") + bottom_toolbar_items.append(self.event_type_selector) + else: + self.event_key = event_keys[0] + + self.prev_event_button = pn.widgets.Button(name="◀", button_type="default", width=40) + self.next_event_button = pn.widgets.Button(name="▶", button_type="default", width=40) + self.prev_event_button.on_click(self._panel_on_prev_event) + self.next_event_button.on_click(self._panel_on_next_event) + + bottom_toolbar_items.append(self.prev_event_button) + bottom_toolbar_items.append(self.next_event_button) + self.bottom_toolbar = pn.Row(*bottom_toolbar_items, sizing_mode="stretch_width") + def _panel_on_segment_changed(self, event): segment_index = int(event.new.split()[-1]) self._panel_change_segment(segment_index) @@ -354,6 +461,49 @@ def _panel_on_double_tap(self, event): self.controller.set_indices_spike_selected([ind_spike_nearest]) self._panel_seek_with_selected_spike() self.notify_spike_selection_changed() + # change selected unit + unit_id = self.controller.unit_ids[self.controller.spikes[ind_spike_nearest]["unit_index"]] + self.controller.set_visible_unit_ids([unit_id]) + self.notify_unit_visibility_changed() + + def _panel_on_event_type_changed(self, event): + self.event_key = event.new + self.refresh() + + def _panel_on_next_event(self, event): + current_sample = self.controller.time_to_sample_index(self.controller.get_time()[0]) + event_samples = self.controller.get_events(self.event_key) + next_events = event_samples[event_samples > current_sample] + if next_events.size > 0: + next_evt_sample = next_events[0] + evt_time = self.controller.sample_index_to_time(next_evt_sample) + self.controller.set_time(time=evt_time) + self.time_slider.value = evt_time + self._panel_refresh() + self._panel_add_event_line() + + def _panel_on_prev_event(self, event): + current_sample = self.controller.time_to_sample_index(self.controller.get_time()[0]) + event_samples = self.controller.get_events(self.event_key) + prev_events = event_samples[event_samples < current_sample] + if prev_events.size > 0: + prev_evt_sample = prev_events[-1] + evt_time = self.controller.sample_index_to_time(prev_evt_sample) + self.controller.set_time(time=evt_time) + self.time_slider.value = evt_time + self._panel_refresh() + self._panel_add_event_line() + + def _panel_add_event_line(self): + # Add vertical line at event time + evt_time = self.controller.get_time()[0] + # get yspan from self.figure + fig = self.figure + yspan = [fig.y_range.start, fig.y_range.end] + self.event_source.data = dict(x=[evt_time, evt_time], y=yspan) + + def _panel_remove_event_line(self): + self.event_source.data = dict(x=[], y=[]) # TODO: pan behavior like Qt? # def _panel_on_pan_start(self, event): @@ -440,24 +590,20 @@ def _qt_make_layout(self): self.layout = QT.QVBoxLayout() # self.setLayout(self.layout) - self._qt_create_toolbar() - - + self._qt_create_toolbars() + # create graphic view and 2 scroll bar - g = QT.QGridLayout() - self.layout.addLayout(g) + # g = QT.QGridLayout() + # self.layout.addLayout(g) self.graphicsview = pg.GraphicsView() - g.addWidget(self.graphicsview, 0,1) + # g.addWidget(self.graphicsview, 0, 1) + self.layout.addWidget(self.graphicsview) MixinViewTrace._qt_initialize_plot(self) self.scatter = pg.ScatterPlotItem(size=10, pxMode = True) self.plot.addItem(self.scatter) - self.scroll_time = QT.QScrollBar(orientation=QT.Qt.Horizontal) - g.addWidget(self.scroll_time, 1,1) - self.scroll_time.valueChanged.connect(self._qt_on_scroll_time) - - + self.layout.addWidget(self.bottom_toolbar) self._qt_update_scroll_limits() def _qt_on_settings_changed(self): @@ -476,6 +622,7 @@ def _qt_on_spike_selection_changed(self): MixinViewTrace._qt_seek_with_selected_spike(self) def _qt_refresh(self): + self._qt_remove_event_line() t, _ = self.controller.get_time() self._qt_seek(t) @@ -600,20 +747,26 @@ def _panel_make_layout(self): x="x", y="y", size=10, fill_color="color", fill_alpha=self.settings['alpha'], source=self.spike_source ) + self.event_source = ColumnDataSource({"x": [], "y": []}) + self.event_renderer = self.figure.line( + x="x", y="y", source=self.event_source, line_color="yellow", line_width=2, line_dash='dashed' + ) + self.figure.on_event(DoubleTap, self._panel_on_double_tap) - self._panel_create_toolbar() + self._panel_create_toolbars() self.layout = pn.Column( self.toolbar, self.figure, - self.time_slider, + self.bottom_toolbar, styles={"display": "flex", "flex-direction": "column"}, sizing_mode="stretch_both" ) def _panel_refresh(self): + self._panel_remove_event_line() t, segment_index = self.controller.get_time() xsize = self.xsize t1, t2 = t - xsize / 3.0, t + xsize * 2 / 3.0 diff --git a/spikeinterface_gui/viewlist.py b/spikeinterface_gui/viewlist.py index c48e0ac..e95c632 100644 --- a/spikeinterface_gui/viewlist.py +++ b/spikeinterface_gui/viewlist.py @@ -16,6 +16,7 @@ from .mainsettingsview import MainSettingsView from .metricsview import MetricsView from .spikerateview import SpikeRateView +from .eventview import EventView # probe and mainsettings view are first, since they affect other views (e.g., time info) possible_class_views = dict( @@ -36,5 +37,6 @@ tracemap = TraceMapView, curation = CurationView, spikerate = SpikeRateView, - metrics = MetricsView, + metrics = MetricsView, + event = EventView, ) diff --git a/spikeinterface_gui/waveformview.py b/spikeinterface_gui/waveformview.py index ca17218..2f51a6b 100644 --- a/spikeinterface_gui/waveformview.py +++ b/spikeinterface_gui/waveformview.py @@ -1378,7 +1378,8 @@ def _panel_clear_data_sources(self): self.vlines_data_source_std.data = dict(xs=[], ys=[], colors=[]) def _panel_on_spike_selection_changed(self): - self._panel_refresh_one_spike() + if self.settings["plot_selected_spike"] and self.settings["overlap"]: + self._panel_refresh_one_spike() def _panel_on_channel_visibility_changed(self): keep_range = not self.settings["auto_move_on_unit_selection"]