Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion spikeinterface_gui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@

from .version import version as __version__

from .main import run_mainwindow, run_launcher
from .main import run_mainwindow, run_launcher, run_compare_analyzer

155 changes: 155 additions & 0 deletions spikeinterface_gui/backend_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import markdown
import numpy as np
from copy import copy
import itertools

import weakref

Expand Down Expand Up @@ -420,3 +421,157 @@ def refresh(self):
'horizontal' : QT.Qt.Horizontal,
'vertical' : QT.Qt.Vertical,
}


class ControllerSynchronizer(QT.QWidget):
def __init__(self, sorting_comparison, controllers, windows, names, parent=None):
QT.QWidget.__init__(self, parent=parent)

self.comp = sorting_comparison
self.controllers = controllers
self.windows = windows
self.names = names

self.layout = QT.QVBoxLayout()
self.setLayout(self.layout)

self.label = QT.QLabel('')
self.layout.addWidget(self.label)


for i, window in enumerate(windows):

# this is not working ???!!!!!
# callback = lambda: self.on_unit_visibility_changed(win_ind=i)

# so uggly solution
callback = [self.on_unit_visibility_changed_0, self.on_unit_visibility_changed_1][i]

for view in window.views.values():
view.notifier.unit_visibility_changed.connect(callback)

settings = [
{'name': 'mode', 'type': 'list', 'limits' : ['best', 'all',] },
{'name': 'thresh', 'type': 'float', 'value' : 0.3, 'step': 0.01, 'limits': (0, 1.)},
]
self.settings = pg.parametertree.Parameter.create(name="settings", type='group', children=settings)

# not that the parent is not the view (not Qt anymore) itself but the widget
self.tree_settings = pg.parametertree.ParameterTree(parent=self)
self.tree_settings.header().hide()
self.tree_settings.setParameters(self.settings, showTop=True)
self.tree_settings.setWindowTitle('Settings')
self.layout.addWidget(self.tree_settings)

from .utils_qt import ViewBoxHandlingClickToPositionWithCtrl

self.graphicsview = pg.GraphicsView()
self.layout.addWidget(self.graphicsview)
self.viewBox = ViewBoxHandlingClickToPositionWithCtrl()
self.viewBox.clicked.connect(self._qt_select_pair)
self.viewBox.disableAutoRange()
self.plot = pg.PlotItem(viewBox=self.viewBox)
self.graphicsview.setCentralItem(self.plot)
self.plot.hideButtons()
self.image = pg.ImageItem()
self.plot.addItem(self.image)
self.plot.hideAxis('bottom')
self.plot.hideAxis('left')


import matplotlib
N = 512
cmap_name = 'viridis'
cmap = matplotlib.colormaps[cmap_name].resampled(N)
lut = []
for i in range(N):
r,g,b,_ = matplotlib.colors.ColorConverter().to_rgba(cmap(i))
lut.append([r*255,g*255,b*255])
self.lut = np.array(lut, dtype='uint8')

# agreement = self.comp.agreement_scores.values
self.agreement_ordered = self.comp.get_ordered_agreement_scores()
self.image.setImage(self.agreement_ordered.values , lut=self.lut, levels=[0, 1])
self.image.show()
self.plot.setXRange(0, self.agreement_ordered.shape[0])
self.plot.setLabel('bottom', names[0])
self.plot.setYRange(0, self.agreement_ordered.shape[1])
self.plot.setLabel('left', names[1])



def on_unit_visibility_changed_0(self):
self.on_unit_visibility_changed(0)

def on_unit_visibility_changed_1(self):
self.on_unit_visibility_changed(1)


def on_unit_visibility_changed(self, win_ind):
changed_controller = self.controllers[win_ind]
visible_unit_inds = changed_controller.get_visible_unit_indices()
visible_unit_ids = changed_controller.get_visible_unit_ids()
if len(visible_unit_inds) != 1:
# TODO handle several units at once
return

unit_ind = visible_unit_inds[0]

agreement = self.comp.agreement_scores.values
if win_ind == 1:
agreement = agreement.T

thresh = self.settings['thresh']
mode = self.settings['mode']

other_ind = (win_ind + 1) % 2
other_controller = self.controllers[other_ind]
other_window = self.windows[other_ind]

if mode == 'all':
other_visible_inds = agreement[unit_ind, :] > thresh
elif mode == 'best':
best_ind = np.argmax(agreement[unit_ind, :])
if agreement[unit_ind, best_ind] > thresh:
other_visible_inds = [best_ind]
else:
other_visible_inds = []

other_visible_ids = other_controller.unit_ids[other_visible_inds]
other_controller.set_visible_unit_ids(other_visible_ids)

for view in other_window.views.values():
view.refresh()


self._refresh_label()

def _refresh_label(self):

txt = ''
unit_ids0 = self.controllers[0].get_visible_unit_ids()
unit_ids1 = self.controllers[1].get_visible_unit_ids()
for unit_id0, unit_id1 in itertools.product(unit_ids0, unit_ids1):
a = self.comp.agreement_scores.loc[unit_id0, unit_id1]
txt += f'{self.names[0]} unit {unit_id0} - {self.names[1]} unit {unit_id1} agreement={a}\n'
self.label.setText(txt)

def _qt_select_pair(self, x, y, reset):
c0 = self.controllers[0]
c1 = self.controllers[1]

# used
ordered_unit_ids1 = self.agreement_ordered.index
ordered_unit_ids2 = self.agreement_ordered.columns
unit_id0 = ordered_unit_ids1[int(np.floor(x))]
unit_id1 = ordered_unit_ids2[int(np.floor(y))]

c0.set_visible_unit_ids([unit_id0])
c1.set_visible_unit_ids([unit_id1])

for win in self.windows:
for view in win.views.values():
view.refresh()

self._refresh_label()

80 changes: 79 additions & 1 deletion spikeinterface_gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,4 +377,82 @@ def find_skippable_extensions(layout_dict):

skippable_extensions = list(all_extensions.difference(set(needed_extensions)))

return skippable_extensions
return skippable_extensions


def run_compare_analyzer(
analyzers,
mode="desktop",
with_traces=False,
# displayed_unit_properties=None,
skip_extensions=None,
layout_preset=None,
layout=None,
verbose=False,
names=None,
# user_settings=None,
# disable_save_settings_button=False,
):

assert isinstance(analyzers, list)
assert len(analyzers) == 2
assert mode == "desktop"

from spikeinterface_gui.myqt import QT, mkQApp
from spikeinterface_gui.backend_qt import QtMainWindow, ControllerSynchronizer
from spikeinterface.comparison import compare_two_sorters

app = mkQApp()


layout_dict = get_layout_description(layout_preset, layout)

if names is None:
names = [f'Analyzer {i}' for i in range(2)]

controllers = []
windows = []
for i, analyzer in enumerate(analyzers):

if verbose:
import time
t0 = time.perf_counter()

controller = Controller(
analyzer, backend="qt",
# verbose=verbose,
verbose=False,

curation=False,
with_traces=with_traces,
skip_extensions=skip_extensions,
)

if verbose:
t1 = time.perf_counter()
print('controller init time', t1 - t0)


# Suppress a known pyqtgraph warning
warnings.filterwarnings("ignore", category=RuntimeWarning, module="pyqtgraph")
warnings.filterwarnings('ignore', category=UserWarning, message=".*QObject::connect.*")



win = QtMainWindow(controller, layout_dict=layout_dict) #, user_settings=user_settings)
name = names[i]
win.setWindowTitle(name)
# Set window icon
icon_file = Path(__file__).absolute().parent / 'img' / 'si.png'
if icon_file.exists():
app.setWindowIcon(QT.QIcon(str(icon_file)))
win.show()
windows.append(win)
controllers.append(controller)

comp = compare_two_sorters(analyzers[0].sorting, analyzers[1].sorting)

synchronizer = ControllerSynchronizer(comp, controllers, windows, names)
synchronizer.show()

app.exec()
43 changes: 43 additions & 0 deletions spikeinterface_gui/tests/test_compare_qt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from argparse import ArgumentParser
from spikeinterface_gui import run_compare_analyzer

from spikeinterface_gui.tests.testingtools import clean_all, make_analyzer_folder, make_curation_dict

from spikeinterface import load_sorting_analyzer


from pathlib import Path

import numpy as np
import sys




def setup_module():
global test_folder
case = test_folder.stem.split('_')[-1]
make_analyzer_folder(test_folder, case=case)

def teardown_module():
clean_all(test_folder)


def test_run_compare_analyzer():
analyzer = load_sorting_analyzer(test_folder / "sorting_analyzer")
analyzers = [analyzer, analyzer]
run_compare_analyzer(
analyzers,
mode="desktop",
verbose=True,
)

if __name__ == '__main__':
global test_folder

dataset = "small"
test_folder = Path(dataset).parent / f"my_dataset_{dataset}"
if not test_folder.is_dir():
setup_module()

win = test_run_compare_analyzer()