Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions spikeinterface_gui/metricsview.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def _qt_creat_grid(self):

def _qt_refresh(self):
import pyqtgraph as pg
import pandas as pd
from .myqt import QT


Expand All @@ -111,27 +112,25 @@ def _qt_refresh(self):

scatter.setData(x=values2, y=values1)

visible_unit_ids = self.controller.get_visible_unit_ids()
visible_unit_ids = self.controller.get_visible_unit_indices()

for unit_ind, unit_id in self.controller.iter_visible_units():
color = self.get_unit_color(unit_id)
scatter.addPoints(x=[values2[unit_ind]], y=[values1[unit_ind]], pen=pg.mkPen(None), brush=color)
if (not pd.isna(values2[unit_ind])) and (not pd.isna(values1[unit_ind])):
scatter.addPoints(x=[values2[unit_ind]], y=[values1[unit_ind]], pen=pg.mkPen(None), brush=color)

# self.scatter.addPoints(x=scatter_x[unit_id], y=scatter_y[unit_id], pen=pg.mkPen(None), brush=color)
# self.scatter_select.setData(selected_scatter_x, selected_scatter_y)
elif c == r:
values1 = units_table[visible_metrics[r]].values
values1_no_nans = values1[~np.isnan(values1)]

count, bins = np.histogram(values1, bins=self.settings['num_bins'])
count, bins = np.histogram(values1_no_nans, bins=self.settings['num_bins'])
curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=white_brush, pen=white_brush)
plot.addItem(curve)

for unit_ind, unit_id in self.controller.iter_visible_units():
x = values1[unit_ind]
color = self.get_unit_color(unit_id)
line = pg.InfiniteLine(pos=x, angle=90, movable=False, pen=color)
plot.addItem(line)
if not pd.isna(x):
line = pg.InfiniteLine(pos=x, angle=90, movable=False, pen=color)
plot.addItem(line)

def _qt_select_metrics(self):
if not self.tree_visible_metrics.isVisible():
Expand Down Expand Up @@ -187,6 +186,7 @@ def _panel_on_metrics_changed(self, event):
self.refresh()

def _panel_refresh(self):
import pandas as pd
import panel as pn
import bokeh.plotting as bpl
from bokeh.layouts import gridplot
Expand All @@ -212,6 +212,8 @@ def _panel_refresh(self):
col2 = visible_metrics[c]
values1 = units_table[col1].values
values2 = units_table[col2].values
values1_no_nans = values1[~np.isnan(values1)]
values2_no_nans = values2[~np.isnan(values2)]

plot = bpl.figure(
width=plot_size, height=plot_size,
Expand All @@ -227,7 +229,7 @@ def _panel_refresh(self):
plot.xaxis.axis_label = col1
plot.yaxis.axis_label = "Count"
# Create histogram
hist, edges = np.histogram(values1, bins=self.settings['num_bins'])
hist, edges = np.histogram(values1_no_nans, bins=self.settings['num_bins'])
if len(hist) > 0 and max(hist) > 0:
plot.quad(
top=hist, bottom=0, left=edges[:-1], right=edges[1:],
Expand All @@ -238,8 +240,9 @@ def _panel_refresh(self):
max_hist = max(hist)
for unit_ind, unit_id in self.controller.iter_visible_units():
x = values1[unit_ind]
color = self.get_unit_color(unit_id)
plot.line([x, x], [0, max_hist], line_width=2, color=color, alpha=0.8)
if not pd.isna(x):
color = self.get_unit_color(unit_id)
plot.line([x, x], [0, max_hist], line_width=2, color=color, alpha=0.8)
else:
# Off-diagonal - scatter plot
plot.xaxis.axis_label = col2
Expand All @@ -251,8 +254,8 @@ def _panel_refresh(self):

# Plot all points in light color first
all_source = ColumnDataSource({
'x': values2,
'y': values1,
'x': values2_no_nans,
'y': values1_no_nans,
'color': colors
})
plot.scatter('x', 'y', source=all_source, size=8, color='color', alpha=0.5)
Expand Down