Skip to content

Commit 00f0ca8

Browse files
Merge branch 'main' of github.com:MouseLand/Kilosort
2 parents b73f0fa + a5b43f0 commit 00f0ca8

File tree

4 files changed

+28
-13
lines changed

4 files changed

+28
-13
lines changed

docs/tutorials/plotting_example.ipynb

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
"# Example plots using kilosort.data_tools"
88
]
99
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"##### Note that `kilosort.data_tools` was added in `v4.0.21`, so you will need to update Kilosort4 to at least that version to use these examples. This can be done using `pip install kilosort --upgrade`."
15+
]
16+
},
1017
{
1118
"cell_type": "code",
1219
"execution_count": 5,
@@ -42,7 +49,7 @@
4249
"from kilosort.io import load_ops\n",
4350
"from kilosort.data_tools import (\n",
4451
" mean_waveform, cluster_templates, get_good_cluster, get_cluster_spikes,\n",
45-
" get_spike_waveforms, get_best_channel\n",
52+
" get_spike_waveforms, get_best_channels\n",
4653
" )\n",
4754
"\n",
4855
"\n",
@@ -101,7 +108,7 @@
101108
"# Time in s for spike time axis\n",
102109
"t2 = spike_times / ops['fs']\n",
103110
"# Get single-channel waveform for each spike\n",
104-
"chan = get_best_channel(cluster_id, results_dir)\n",
111+
"chan = get_best_channels(results_dir)[cluster_id]\n",
105112
"waves = get_spike_waveforms(spike_times, results_dir, chan=chan)\n",
106113
"\n",
107114
"# Plot each waveform, using spike time as 3rd dimension\n",

kilosort/data_tools.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def mean_waveform(cluster_id, results_dir, n_spikes=np.inf, bfile=None, best=Tru
3434
"""
3535
results_dir = Path(results_dir)
3636
if best:
37-
chan = get_best_channel(cluster_id, results_dir)
37+
chan = get_best_channels(results_dir)[cluster_id]
3838
else:
3939
chan = None
4040

@@ -45,12 +45,14 @@ def mean_waveform(cluster_id, results_dir, n_spikes=np.inf, bfile=None, best=Tru
4545
return mean_wave
4646

4747

48-
def get_best_channel(cluster_id, results_dir):
49-
"""Get channel number with largest template norm for this cluster."""
48+
def get_best_channels(results_dir):
49+
"""Get channel numbers with largest template norm for each cluster."""
5050
templates = np.load(results_dir / 'templates.npy')
51-
chan = (templates**2).sum(axis=1).argmax(axis=-1)[cluster_id]
52-
return chan
51+
best_chans = (templates**2).sum(axis=1).argmax(axis=-1)
52+
return best_chans
5353

54+
def get_best_channel(results_dir, cluster_id):
55+
return get_best_channels(results_dir)[cluster_id]
5456

5557
def get_cluster_spikes(cluster_id, results_dir, n_spikes=np.inf):
5658
"""Get `n_spikes` random spike times assigned to `cluster_id`."""

kilosort/gui/probe_view_box.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def update_spots_variables(self, probe, template_args):
122122
for ind, (xc, yc) in enumerate(zip(self.xc, self.yc)):
123123
self.channel_map_dict[(xc, yc)] = ind
124124

125-
def get_template_spots(self, nC, dmin, dminx, max_dist, x_centers, device):
125+
def get_template_spots(self, nC, dmin, dminx, max_dist, x_centers):
126126
ops = {
127127
'yc': self.yc, 'xc': self.xc, 'max_channel_distance': max_dist,
128128
'x_centers': x_centers, 'settings': {'dmin': dmin, 'dminx': dminx},
@@ -131,7 +131,9 @@ def get_template_spots(self, nC, dmin, dminx, max_dist, x_centers, device):
131131
ops = template_centers(ops)
132132
[ys, xs] = np.meshgrid(ops['yup'], ops['xup'])
133133
ys, xs = ys.flatten(), xs.flatten()
134-
iC, ds = nearest_chans(ys, self.yc, xs, self.xc, nC, device=device)
134+
iC, ds = nearest_chans(
135+
ys, self.yc, xs, self.xc, nC, device=self.gui.device
136+
)
135137

136138
igood = ds[0,:] <= max_dist**2
137139
iC = iC[:,igood]

kilosort/gui/settings_box.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
_DEFAULT_DTYPE = 'int16'
2121
_ALLOWED_FILE_TYPES = ['.bin', '.dat', '.bat', '.raw'] # For binary data
22+
_PROBE_SETTINGS = [
23+
'nearest_chans', 'dmin', 'dminx', 'max_channel_distance', 'x_centers'
24+
]
2225

2326
class SettingsBox(QtWidgets.QGroupBox):
2427
settingsUpdated = QtCore.Signal()
@@ -247,6 +250,8 @@ def setup(self):
247250
)
248251
inp = getattr(self, f'{k}_input')
249252
inp.editingFinished.connect(self.update_parameter)
253+
if k in _PROBE_SETTINGS:
254+
inp.editingFinished.connect(self.show_probe_layout())
250255

251256
row_count += rspan
252257
layout.addWidget(
@@ -550,10 +555,7 @@ def update_settings(self):
550555

551556
def get_probe_template_args(self):
552557
epw = self.extra_parameters_window
553-
template_args = [
554-
epw.nearest_chans, epw.dmin, epw.dminx,
555-
epw.max_channel_distance, epw.x_centers, self.gui.device
556-
]
558+
template_args = [getattr(epw, k) for k in _PROBE_SETTINGS]
557559
return template_args
558560

559561
@QtCore.Slot()
@@ -862,6 +864,8 @@ def __init__(self, parent):
862864
layout.addWidget(getattr(self, f'{k}_input'), row_count, col+3, 1, 2)
863865
inp = getattr(self, f'{k}_input')
864866
inp.editingFinished.connect(self.update_parameter)
867+
if k in _PROBE_SETTINGS:
868+
inp.editingFinished.connect(self.main_settings.show_probe_layout)
865869

866870
self.setLayout(layout)
867871

0 commit comments

Comments
 (0)