Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions spikeinterface_gui/basescatterview.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _qt_make_layout(self):
tb = self.qt_widget.view_toolbar
self.combo_seg = QT.QComboBox()
tb.addWidget(self.combo_seg)
self.combo_seg.addItems([ f'Segment {segment_index}' for segment_index in range(self.controller.num_segments) ])
self.combo_seg.addItems([f'Segment {segment_index}' for segment_index in range(self.controller.num_segments)])
self.combo_seg.currentIndexChanged.connect(self._qt_change_segment)
add_stretch_to_qtoolbar(tb)
self.lasso_but = QT.QPushButton("select", checkable = True)
Expand Down Expand Up @@ -278,7 +278,7 @@ def _qt_refresh(self):
# make a copy of the color
color = QT.QColor(self.get_unit_color(unit_id))
color.setAlpha(int(self.settings['alpha']*255))
self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color)
self.scatter.addPoints(x=spike_times, y=spike_data, pen=pg.mkPen(None), brush=color)

color = self.get_unit_color(unit_id)
curve = pg.PlotCurveItem(hist_count, hist_bins[:-1], fillLevel=None, fillOutline=True, brush=color, pen=color)
Expand Down
95 changes: 83 additions & 12 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@

from spikeinterface.widgets.utils import get_unit_colors
from spikeinterface import compute_sparsity
from spikeinterface.core import get_template_extremum_channel
import spikeinterface.postprocessing
import spikeinterface.qualitymetrics
from spikeinterface.core import get_template_extremum_channel, BaseEvent
from spikeinterface.core.sorting_tools import spike_vector_to_indices
from spikeinterface.core.core_tools import check_json
from spikeinterface.curation import validate_curation_dict
from spikeinterface.curation.curation_model import CurationModel
from spikeinterface.widgets.utils import make_units_table_from_analyzer
Expand All @@ -33,10 +30,23 @@


class Controller():
def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save_on_compute=False,
curation=False, curation_data=None, label_definitions=None, with_traces=True,
displayed_unit_properties=None,
extra_unit_properties=None, skip_extensions=None, disable_save_settings_button=False):
def __init__(
self,
analyzer=None,
backend="qt",
parent=None,
verbose=False,
save_on_compute=False,
curation=False,
curation_data=None,
label_definitions=None,
with_traces=True,
displayed_unit_properties=None,
extra_unit_properties=None,
skip_extensions=None,
disable_save_settings_button=False,
events=None
):
self.views = []
skip_extensions = skip_extensions if skip_extensions is not None else []

Expand Down Expand Up @@ -220,6 +230,62 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
self.pc_ext = pc_ext

self._potential_merges = None
# some direct attribute
self.num_segments = self.analyzer.get_num_segments()
self.sampling_frequency = self.analyzer.sampling_frequency

self.events = None
if events is not None:
if verbose:
print('\tLoading events')
self.events = {}
if isinstance(events, dict):
for key, val in events.items():
if not isinstance(val, dict):
if verbose:
print(f'\tSkipping event {key}: not a dict')
continue
if 'samples' not in val and 'times' not in val:
if verbose:
print(f'\tSkipping event {key}: missing samples or times')
continue
if 'times' in val:
samples_data = val['times']
convert_to_samples = True
else:
samples_data = val['samples']
convert_to_samples = False
if self.num_segments > 1:
if not len(samples_data) == self.num_segments:
if verbose:
print(f'\tSkipping event {key}: inconsistent number of samples')
continue
else:
# here we make sure samples is a list of list
if np.array(samples_data).ndim == 1:
samples_data = [samples_data]
if convert_to_samples:
self.events[key] = [np.array(self.time_to_sample_index(s)) for s in samples_data]
else:
self.events[key] = [np.array(s) for s in samples_data]
elif isinstance(events, BaseEvent):
event_names = events.channel_ids
self.events = {
event_name: [] for event_name in event_names
}
for event_name in event_names:
for segment_index in range(self.num_segments):
event_times_segment = events.get_event_times(
channel_id=event_name,
segment_index=segment_index
)
event_samples_segment = self.analyzer.time_to_sample_index(
event_times_segment
)
self.events[event_name].append(np.array(event_samples_segment))

if len(self.events) == 0:
self.events = None

t1 = time.perf_counter()
if verbose:
Expand All @@ -229,10 +295,6 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save

self._extremum_channel = get_template_extremum_channel(self.analyzer, peak_sign='neg', outputs='index')

# some direct attribute
self.num_segments = self.analyzer.get_num_segments()
self.sampling_frequency = self.analyzer.sampling_frequency

# spikeinterface handle colors in matplotlib style tuple values in range (0,1)
self.refresh_colors()

Expand Down Expand Up @@ -489,6 +551,13 @@ def time_to_sample_index(self, time):
else:
return int(time * self.sampling_frequency)

def get_events(self, event_name):
if self.events is None:
return None
if event_name not in self.events:
return None
return self.events[event_name][self.time_info['segment_index']]

def get_information_txt(self):
nseg = self.analyzer.get_num_segments()
nchan = self.analyzer.get_num_channels()
Expand Down Expand Up @@ -715,6 +784,8 @@ def set_channel_visibility(self, visible_channel_inds):
def has_extension(self, extension_name):
if extension_name == 'recording':
return self.analyzer.has_recording() or self.analyzer.has_temporary_recording()
elif extension_name == 'events':
return self.events is not None
else:
# extension needs to be loaded
if extension_name in self.skip_extensions:
Expand Down
Loading