Skip to content

Commit 8d2c1df

Browse files
author
Jacob Pennington
committed
Added option to sort single shank to GUI
1 parent 944f833 commit 8d2c1df

File tree

3 files changed

+55
-10
lines changed

3 files changed

+55
-10
lines changed

kilosort/gui/main.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
DataConversionBox
1212
)
1313
from kilosort.gui.logger import setup_logger
14-
from kilosort.io import BinaryFiltered, remove_bad_channels
14+
from kilosort.io import BinaryFiltered, remove_bad_channels, select_shank
1515
from kilosort.utils import DOWNLOADS_DIR, download_probes
1616
from qtpy import QtCore, QtGui, QtWidgets
1717

@@ -311,10 +311,14 @@ def load_data(self):
311311
def set_parameters(self):
312312
settings = self.settings_box.settings
313313
bad_channels = self.settings_box.bad_channels
314+
shank_idx = self.settings_box.shank_idx
314315

315316
self.data_path = settings["data_file_path"]
316317
self.results_directory = settings["results_dir"]
317-
self.probe_layout = remove_bad_channels(settings["probe"], bad_channels)
318+
probe = remove_bad_channels(settings["probe"], bad_channels)
319+
if shank_idx is not None:
320+
probe = select_shank(probe, shank_idx)
321+
self.probe_layout = probe
318322
self.probe_name = settings["probe_name"]
319323
self.num_channels = settings["n_chan_bin"]
320324

@@ -324,9 +328,6 @@ def set_parameters(self):
324328
params['do_CAR'] = self.run_box.do_CAR_check.isChecked()
325329
params['invert_sign'] = self.run_box.invert_sign_check.isChecked()
326330
params['verbose_log'] = self.run_box.verbose_check.isChecked()
327-
328-
assert params
329-
330331
self.params = params
331332

332333
def do_load(self):

kilosort/gui/probe_view_box.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,13 @@ def generate_spots_list(self):
173173
self.center_spots = []
174174
self.number_spots = []
175175
bad_channels = self.gui.settings_box.get_bad_channels()
176+
shank_idx = self.gui.settings_box.shank_idx
177+
chan_map = self.active_layout['chanMap']
178+
if shank_idx is None:
179+
shank_channels = chan_map
180+
else:
181+
shank_map = (self.active_layout['kcoords'] == shank_idx).nonzero()[0]
182+
shank_channels = chan_map[shank_map]
176183
channel_size = 10 * self.spot_scale.value()/4
177184
template_size = 5 * self.spot_scale.value()/4
178185
center_size = 20 * self.spot_scale.value()/4
@@ -182,7 +189,7 @@ def generate_spots_list(self):
182189
for x_pos, y_pos in zip(self.xc, self.yc):
183190
index = self.channel_map_dict[(x_pos, y_pos)]
184191
channel = self.channel_map[index]
185-
if channel in bad_channels:
192+
if (channel in bad_channels) or (channel not in shank_channels):
186193
color = "b"
187194
else:
188195
color = "g"

kilosort/gui/settings_box.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,21 @@ def __init__(self, parent):
8989
else:
9090
self.bad_channels = []
9191

92+
self.shank_idx_text = QtWidgets.QLabel("Shank Index:")
93+
self.shank_idx_input = QtWidgets.QLineEdit()
94+
if self.gui.qt_settings.contains('shank_idx'):
95+
idx = self.gui.qt_settings.value('shank_idx')
96+
if isinstance(idx, str):
97+
self.shank_idx = float(idx)
98+
elif isinstance(idx, list):
99+
self.shank_idx = [float(s) for s in idx]
100+
else:
101+
self.shank_idx = idx
102+
else:
103+
self.shank_idx = None
104+
self.shank_idx_input.setText(str(self.shank_idx))
105+
106+
92107
self.dtype_selector_text = QtWidgets.QLabel("Data dtype:")
93108
self.dtype_selector = QtWidgets.QComboBox()
94109
self.populate_dtype_selector()
@@ -227,6 +242,18 @@ def setup(self):
227242
"excluding them from the probe dictionary."
228243
)
229244

245+
row_count += rspan
246+
layout.addWidget(self.shank_idx_text, row_count, col1, rspan, cspan1)
247+
layout.addWidget(self.shank_idx_input, row_count, col2, rspan, cspan2)
248+
self.shank_idx_input.editingFinished.connect(self.update_shank_idx)
249+
self.shank_idx_text.setToolTip(
250+
"If not None, only channels from the specified shank index will be used. "
251+
"If a list is provided, each shank will be sorted sequentially and results "
252+
"will be saved in separate subfolders. Note that the shank_idx value(s) "
253+
"must match the actual value specified in `probe['kcoords']`. For example, "
254+
"`probe_idx=0` will not work if `probe['kcoords']` uses 1,2,3,4."
255+
)
256+
230257

231258
# Add small vertical space for visual grouping
232259
row_count += rspan
@@ -755,13 +782,23 @@ def update_bad_channels(self):
755782
# Remove brackets and white space if present, convert to list of ints.
756783
self.bad_channels = self.get_bad_channels()
757784
self.gui.qt_settings.setValue('bad_channels', self.bad_channels)
758-
759785
if not self.pause_checks:
760-
# Trigger update so that probe layout in main gets updated, then
761-
# refresh probe view.
762-
self.update_settings()
763786
self.previewProbe.emit()
764787

788+
@QtCore.Slot()
789+
def update_shank_idx(self):
790+
# Remove brackets and white space if present, convert to list of floats.
791+
text = self.shank_idx_input.text()
792+
text = text.replace(']','').replace('[','').replace(' ','')
793+
if len(text) > 0 and text.lower() != 'none':
794+
idx = float(text)
795+
else:
796+
idx = None
797+
self.shank_idx = idx
798+
self.gui.qt_settings.setValue('shank_idx', self.shank_idx)
799+
800+
if not self.pause_checks:
801+
self.previewProbe.emit()
765802

766803
def on_data_dtype_selected(self, data_dtype):
767804
self.data_dtype = data_dtype

0 commit comments

Comments
 (0)