Skip to content

Commit 7ddd8a2

Browse files
Merge pull request #919 from MouseLand/jacob/wip
Add WIP changes to main
2 parents f32209e + 903d3be commit 7ddd8a2

File tree

3 files changed

+133
-122
lines changed

3 files changed

+133
-122
lines changed

kilosort/gui/sorter.py

Lines changed: 14 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
import kilosort
1010
from kilosort.run_kilosort import (
11-
setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction,
12-
detect_spikes, cluster_spikes, save_sorting, close_logger
11+
# setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction,
12+
# detect_spikes, cluster_spikes, save_sorting, close_logger
13+
setup_logger, _sort
1314
)
1415
from kilosort.io import save_preprocessing
1516
from kilosort.utils import (
@@ -46,105 +47,17 @@ def run(self):
4647
results_dir.mkdir(parents=True)
4748

4849
setup_logger(results_dir)
49-
verbose = settings['verbose_log']
5050

51-
try:
52-
logger.info(f"Kilosort version {kilosort.__version__}")
53-
logger.info(f"Sorting {self.data_path}")
54-
clear_cache = settings['clear_cache']
55-
if clear_cache:
56-
logger.info('clear_cache=True')
57-
logger.info('-'*40)
58-
59-
tic0 = time.time()
60-
61-
if probe['chanMap'].max() >= settings['n_chan_bin']:
62-
raise ValueError(
63-
f'Largest value of chanMap exceeds channel count of data, '
64-
'make sure chanMap is 0-indexed.'
65-
)
66-
67-
if settings['nt0min'] is None:
68-
settings['nt0min'] = int(20 * settings['nt']/61)
69-
data_dtype = settings['data_dtype']
70-
device = self.device
71-
save_preprocessed_copy = settings['save_preprocessed_copy']
72-
do_CAR = settings['do_CAR']
73-
invert_sign = settings['invert_sign']
74-
if not do_CAR:
75-
logger.info("Skipping common average reference.")
76-
77-
ops = initialize_ops(settings, probe, data_dtype, do_CAR,
78-
invert_sign, device, save_preprocessed_copy)
79-
80-
# Pretty-print ops and probe for log
81-
logger.debug(f"Initial ops:\n\n{ops_as_string(ops)}\n")
82-
logger.debug(f"Probe dictionary:\n\n{probe_as_string(ops['probe'])}\n")
83-
84-
# TODO: add support for file object through data conversion
85-
# Set preprocessing and drift correction parameters
86-
ops = compute_preprocessing(ops, self.device, tic0=tic0,
87-
file_object=self.file_object)
88-
np.random.seed(1)
89-
torch.cuda.manual_seed_all(1)
90-
torch.random.manual_seed(1)
91-
ops, bfile, st0 = compute_drift_correction(
92-
ops, self.device, tic0=tic0, progress_bar=self.progress_bar,
93-
file_object=self.file_object, clear_cache=clear_cache
94-
)
95-
96-
# Check scale of data for log file
97-
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
98-
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")
99-
100-
if save_preprocessed_copy:
101-
save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
102-
103-
# Will be None if nblocks = 0 (no drift correction)
104-
if st0 is not None:
105-
self.dshift = ops['dshift']
106-
self.st0 = st0
107-
self.plotDataReady.emit('drift')
108-
109-
# Sort spikes and save results
110-
st, tF, Wall0, clu0 = detect_spikes(
111-
ops, self.device, bfile, tic0=tic0,
112-
progress_bar=self.progress_bar, clear_cache=clear_cache,
113-
verbose=verbose
114-
)
115-
116-
self.Wall0 = Wall0
117-
self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
118-
self.clu0 = clu0
119-
self.plotDataReady.emit('diagnostics')
120-
121-
clu, Wall = cluster_spikes(
122-
st, tF, ops, self.device, bfile, tic0=tic0,
123-
progress_bar=self.progress_bar, clear_cache=clear_cache,
124-
verbose=verbose
125-
)
126-
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
127-
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)
128-
129-
except Exception as e:
130-
if isinstance(e, torch.cuda.OutOfMemoryError):
131-
logger.exception('Out of memory error, printing performance...')
132-
log_performance(logger, level='info')
133-
log_cuda_details(logger)
134-
# This makes sure the full traceback is written to log file.
135-
logger.exception('Encountered error in `run_kilosort`:')
136-
# Annoyingly, this will print the error message twice for console
137-
# but I haven't found a good way around that.
138-
raise
139-
140-
finally:
141-
close_logger()
142-
143-
self.ops = ops
144-
self.st = st[kept_spikes]
145-
self.clu = clu[kept_spikes]
146-
self.tF = tF[kept_spikes]
147-
self.is_refractory = is_ref
148-
self.plotDataReady.emit('probe')
51+
# NOTE: All but `gui_sorter` are positional args,
52+
# don't move these around.
53+
_ = _sort(
54+
settings['filename'], results_dir, probe, settings,
55+
settings['data_dtype'], self.device, settings['do_CAR'],
56+
settings['clear_cache'], settings['invert_sign'],
57+
settings['save_preprocessed_copy'], settings['verbose_log'],
58+
False, self.file_object, self.progress_bar, gui_sorter=self
59+
)
60+
# Hard-coded `False` is for "save_extra_vars", which isn't an option
61+
# in the GUI right now (and isn't likely to be added).
14962

15063
self.finishedSpikesort.emit(self.context)

kilosort/run_kilosort.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,26 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
170170
settings = {**DEFAULT_SETTINGS, **settings}
171171
# NOTE: This modifies settings in-place
172172
filename, data_dir, results_dir, probe = \
173-
set_files(settings, filename, probe, probe_name, data_dir, results_dir, bad_channels)
173+
set_files(settings, filename, probe, probe_name, data_dir,
174+
results_dir, bad_channels)
174175
setup_logger(results_dir, verbose_console=verbose_console)
175176

177+
ops, st, clu, tF, Wall, similar_templates, \
178+
is_ref, est_contam_rate, kept_spikes = _sort(
179+
filename, results_dir, probe, settings, data_dtype, device, do_CAR,
180+
clear_cache, invert_sign, save_preprocessed_copy, verbose_log,
181+
save_extra_vars, file_object, progress_bar
182+
)
183+
184+
return ops, st, clu, tF, Wall, similar_templates, \
185+
is_ref, est_contam_rate, kept_spikes
186+
187+
188+
def _sort(filename, results_dir, probe, settings, data_dtype, device, do_CAR,
189+
clear_cache, invert_sign, save_preprocessed_copy, verbose_log,
190+
save_extra_vars, file_object, progress_bar, gui_sorter=None):
191+
"""Run sorting pipeline. See `run_kilosort` for documentation."""
192+
176193
try:
177194
logger.info(f"Kilosort version {kilosort.__version__}")
178195
logger.info(f"Python version {platform.python_version()}")
@@ -218,7 +235,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
218235

219236
tic0 = time.time()
220237
ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign,
221-
device, save_preprocessed_copy)
238+
device, save_preprocessed_copy)
222239

223240
# Pretty-print ops and probe for log
224241
logger.debug(f"Initial ops:\n\n{ops_as_string(ops)}\n")
@@ -242,24 +259,62 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
242259
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
243260
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")
244261

262+
# Save preprocessing steps
245263
if save_preprocessed_copy:
246264
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
247265

266+
# Generate drift plots
267+
# st0 will be None if nblocks = 0 (no drift correction)
268+
if st0 is not None:
269+
if gui_sorter is not None:
270+
gui_sorter.dshift = ops['dshift']
271+
gui_sorter.st0 = st0
272+
gui_sorter.plotDataReady.emit('drift')
273+
else:
274+
# TODO: save non-GUI version of plot to results.
275+
pass
276+
248277
# Sort spikes and save results
249-
st,tF, _, _ = detect_spikes(
278+
st,tF, Wall0, clu0 = detect_spikes(
250279
ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
251280
clear_cache=clear_cache, verbose=verbose_log
252281
)
253-
clu, Wall = cluster_spikes(
282+
283+
# Generate diagnosic plots
284+
if gui_sorter is not None:
285+
gui_sorter.Wall0 = Wall0
286+
gui_sorter.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
287+
gui_sorter.clu0 = clu0
288+
gui_sorter.plotDataReady.emit('diagnostics')
289+
else:
290+
# TODO: save non-GUI version of plot to results.
291+
pass
292+
293+
clu, Wall, st, tF = cluster_spikes(
254294
st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
255-
clear_cache=clear_cache, verbose=verbose_log
295+
clear_cache=clear_cache, verbose=verbose_log,
256296
)
257297
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
258298
save_sorting(
259299
ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
260300
save_extra_vars=save_extra_vars,
261301
save_preprocessed_copy=save_preprocessed_copy
262302
)
303+
304+
# Generate spike positions plot
305+
if gui_sorter is not None:
306+
# TODO: re-use spike positions saved by `save_sorting` instead of
307+
# computing them again in `kilosort.gui.sanity_plots`.
308+
gui_sorter.ops = ops
309+
gui_sorter.st = st[kept_spikes]
310+
gui_sorter.clu = clu[kept_spikes]
311+
gui_sorter.tF = tF[kept_spikes]
312+
gui_sorter.is_refractory = is_ref
313+
gui_sorter.plotDataReady.emit('probe')
314+
else:
315+
# TODO: save non-GUI version of plot to results.
316+
pass
317+
263318
except Exception as e:
264319
if isinstance(e, torch.cuda.OutOfMemoryError):
265320
logger.exception('Out of memory error, printing performance...')
@@ -275,6 +330,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
275330
finally:
276331
close_logger()
277332

333+
278334
return ops, st, clu, tF, Wall, similar_templates, \
279335
is_ref, est_contam_rate, kept_spikes
280336

@@ -676,7 +732,9 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
676732
ops, st0, tF, mode='spikes', device=device, progress_bar=progress_bar,
677733
clear_cache=clear_cache, verbose=verbose
678734
)
679-
Wall3 = template_matching.postprocess_templates(Wall, ops, clu, st0, device=device)
735+
Wall3 = template_matching.postprocess_templates(
736+
Wall, ops, clu, st0, tF, device=device
737+
)
680738
logger.info(f'{clu.max()+1} clusters found, in {time.time()-tic : .2f}s; ' +
681739
f'total {time.time()-tic0 : .2f}s')
682740
logger.debug(f'clu shape: {clu.shape}')
@@ -686,8 +744,9 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
686744
logger.info(' ')
687745
logger.info('Extracting spikes using cluster waveforms')
688746
logger.info('-'*40)
689-
st, tF, ops = template_matching.extract(ops, bfile, Wall3, device=device,
690-
progress_bar=progress_bar)
747+
st, tF, ops = template_matching.extract(
748+
ops, bfile, Wall3, device=device, progress_bar=progress_bar
749+
)
691750
logger.info(f'{len(st)} spikes extracted in {time.time()-tic : .2f}s; ' +
692751
f'total {time.time()-tic0 : .2f}s')
693752
logger.debug(f'st shape: {st.shape}')
@@ -756,8 +815,9 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None,
756815
logger.info(' ')
757816
logger.info('Merging clusters')
758817
logger.info('-'*40)
759-
Wall, clu, is_ref = template_matching.merging_function(ops, Wall, clu, st[:,0],
760-
device=device)
818+
Wall, clu, is_ref, st, tF = template_matching.merging_function(
819+
ops, Wall, clu, st, tF, device=device, check_dt=True
820+
)
761821
clu = clu.astype('int32')
762822
logger.info(f'{clu.max()+1} units found, in {time.time()-tic : .2f}s; ' +
763823
f'total {time.time()-tic0 : .2f}s')
@@ -767,7 +827,7 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None,
767827
log_performance(logger, 'info', 'Resource usage after clustering')
768828
log_cuda_details(logger)
769829

770-
return clu, Wall
830+
return clu, Wall, st, tF
771831

772832

773833
def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan,

kilosort/template_matching.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,13 @@ def align_U(U, ops, device=torch.device('cuda')):
140140
return Unew, imax
141141

142142

143-
def postprocess_templates(Wall, ops, clu, st, device=torch.device('cuda')):
143+
def postprocess_templates(Wall, ops, clu, st, tF, device=torch.device('cuda')):
144144
Wall2, _ = align_U(Wall, ops, device=device)
145145
#Wall3, _= remove_duplicates(ops, Wall2)
146-
Wall3, _, _ = merging_function(ops, Wall2.transpose(1,2), clu, st[:,0],
147-
0.9, 'mu', device=device)
146+
Wall3, _, _, _, _ = merging_function(
147+
ops, Wall2.transpose(1,2), clu, st, tF,
148+
0.9, 'mu', check_dt=False, device=device
149+
)
148150
Wall3 = Wall3.transpose(1,2).to(device)
149151
return Wall3
150152

@@ -241,7 +243,8 @@ def run_matching(ops, X, U, ctc, device=torch.device('cuda')):
241243
return st, amps, th_amps, Xres
242244

243245

244-
def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.device('cuda')):
246+
def merging_function(ops, Wall, clu, st, tF, r_thresh=0.5, mode='ccg', check_dt=True,
247+
device=torch.device('cuda')):
245248
clu2 = clu.copy()
246249
clu_unq, ns = np.unique(clu2, return_counts = True)
247250

@@ -256,7 +259,7 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.
256259
acg_threshold = ops['settings']['acg_threshold']
257260
ccg_threshold = ops['settings']['ccg_threshold']
258261
if mode == 'ccg':
259-
is_ref, est_contam_rate = CCG.refract(clu, st/ops['fs'],
262+
is_ref, est_contam_rate = CCG.refract(clu, st[:,0]/ops['fs'],
260263
acg_threshold=acg_threshold,
261264
ccg_threshold=ccg_threshold)
262265

@@ -287,13 +290,13 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.
287290
UtU = torch.einsum('lk, jlm -> jkm', Wnorm[kk], Wnorm)
288291
ctc = torch.einsum('jkm, kml -> jl', UtU, WtW)
289292

290-
cmax = ctc.max(1)[0]
293+
cmax, imax = ctc.max(1)
291294
cmax[kk] = 0
292295

293296
jsort = np.argsort(cmax.cpu().numpy())[::-1]
294297

295298
if mode == 'ccg':
296-
st0 = st[clu2==kk] / ops['fs']
299+
st0 = st[:,0][clu2==kk] / ops['fs']
297300

298301
is_ccg = 0
299302
for j in range(NN):
@@ -302,7 +305,7 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.
302305
break
303306
# compare with CCG
304307
if mode == 'ccg':
305-
st1 = st[clu2==jj] / ops['fs']
308+
st1 = st[:,0][clu2==jj] / ops['fs']
306309
_, is_ccg, _ = CCG.check_CCG(st0, st1, acg_threshold=acg_threshold,
307310
ccg_threshold=ccg_threshold)
308311
else:
@@ -311,9 +314,17 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.
311314

312315
if is_ccg:
313316
is_merged[jj] = 1
317+
dt = (imax[kk] -imax[jj]).item()
318+
if dt != 0 and check_dt:
319+
# Get spike indices for cluster jj
320+
idx = (clu2 == jj)
321+
# Update tF and Wall with shifted features
322+
tF, Wall = roll_features(W, tF, Ww, idx, jj, dt)
323+
# Shift spike times
324+
st[idx,0] -= dt
325+
314326
Ww[kk] = ns[kk]/(ns[kk]+ns[jj]) * Ww[kk] + ns[jj]/(ns[kk]+ns[jj]) * Ww[jj]
315327
Ww[jj] = 0
316-
317328
ns[kk] += ns[jj]
318329
ns[jj] = 0
319330
clu2[clu2==jj] = kk
@@ -337,4 +348,31 @@ def merging_function(ops, Wall, clu, st, r_thresh=0.5, mode='ccg', device=torch.
337348
else:
338349
is_ref = None
339350

340-
return Ww.cpu(), clu2, is_ref
351+
sorted_idx = np.argsort(st[:,0])
352+
st = np.take_along_axis(st, sorted_idx[..., np.newaxis], axis=0)
353+
clu2 = clu2[sorted_idx]
354+
tensor_idx = torch.from_numpy(sorted_idx)
355+
tF = tF[tensor_idx]
356+
357+
return Ww.cpu(), clu2, is_ref, st, tF
358+
359+
360+
def roll_features(wPCA, tF, Wall, spike_idx, clust_idx, dt):
361+
W = wPCA.cpu()
362+
# Project from PC space back to sample time, shift by dt
363+
feats = torch.roll(tF[spike_idx] @ W, shifts=dt, dims=2)
364+
temps = torch.roll(Wall[clust_idx:clust_idx+1] @ wPCA, shifts=dt, dims=2)
365+
366+
# For values that "rolled over the edge," set equal to next closest bin
367+
if dt > 0:
368+
feats[:,:,:dt] = feats[:,:,dt].unsqueeze(-1)
369+
temps[:,:,:dt] = temps[:,:,dt].unsqueeze(-1)
370+
elif dt < 0:
371+
feats[:,:,dt:] = feats[:,:,dt-1].unsqueeze(-1)
372+
temps[:,:,dt:] = temps[:,:,dt-1].unsqueeze(-1)
373+
374+
# Project back to PC space and update tF
375+
tF[spike_idx] = feats @ W.T
376+
Wall[clust_idx] = temps @ wPCA.T
377+
378+
return tF, Wall

0 commit comments

Comments
 (0)