@@ -170,9 +170,26 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
170170 settings = {** DEFAULT_SETTINGS , ** settings }
171171 # NOTE: This modifies settings in-place
172172 filename , data_dir , results_dir , probe = \
173- set_files (settings , filename , probe , probe_name , data_dir , results_dir , bad_channels )
173+ set_files (settings , filename , probe , probe_name , data_dir ,
174+ results_dir , bad_channels )
174175 setup_logger (results_dir , verbose_console = verbose_console )
175176
177+ ops , st , clu , tF , Wall , similar_templates , \
178+ is_ref , est_contam_rate , kept_spikes = _sort (
179+ filename , results_dir , probe , settings , data_dtype , device , do_CAR ,
180+ clear_cache , invert_sign , save_preprocessed_copy , verbose_log ,
181+ save_extra_vars , file_object , progress_bar
182+ )
183+
184+ return ops , st , clu , tF , Wall , similar_templates , \
185+ is_ref , est_contam_rate , kept_spikes
186+
187+
188+ def _sort (filename , results_dir , probe , settings , data_dtype , device , do_CAR ,
189+ clear_cache , invert_sign , save_preprocessed_copy , verbose_log ,
190+ save_extra_vars , file_object , progress_bar , gui_sorter = None ):
191+ """Run sorting pipeline. See `run_kilosort` for documentation."""
192+
176193 try :
177194 logger .info (f"Kilosort version { kilosort .__version__ } " )
178195 logger .info (f"Python version { platform .python_version ()} " )
@@ -218,7 +235,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
218235
219236 tic0 = time .time ()
220237 ops = initialize_ops (settings , probe , data_dtype , do_CAR , invert_sign ,
221- device , save_preprocessed_copy )
238+ device , save_preprocessed_copy )
222239
223240 # Pretty-print ops and probe for log
224241 logger .debug (f"Initial ops:\n \n { ops_as_string (ops )} \n " )
@@ -242,24 +259,62 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
242259 b1 = bfile .padded_batch_to_torch (0 ).cpu ().numpy ()
243260 logger .debug (f"First batch min, max: { b1 .min (), b1 .max ()} " )
244261
262+ # Save preprocessing steps
245263 if save_preprocessed_copy :
246264 io .save_preprocessing (results_dir / 'temp_wh.dat' , ops , bfile )
247265
266+ # Generate drift plots
267+ # st0 will be None if nblocks = 0 (no drift correction)
268+ if st0 is not None :
269+ if gui_sorter is not None :
270+ gui_sorter .dshift = ops ['dshift' ]
271+ gui_sorter .st0 = st0
272+ gui_sorter .plotDataReady .emit ('drift' )
273+ else :
274+ # TODO: save non-GUI version of plot to results.
275+ pass
276+
248277 # Sort spikes and save results
249- st ,tF , _ , _ = detect_spikes (
278+ st ,tF , Wall0 , clu0 = detect_spikes (
250279 ops , device , bfile , tic0 = tic0 , progress_bar = progress_bar ,
251280 clear_cache = clear_cache , verbose = verbose_log
252281 )
253- clu , Wall = cluster_spikes (
282+
283+ # Generate diagnosic plots
284+ if gui_sorter is not None :
285+ gui_sorter .Wall0 = Wall0
286+ gui_sorter .wPCA = torch .clone (ops ['wPCA' ].cpu ()).numpy ()
287+ gui_sorter .clu0 = clu0
288+ gui_sorter .plotDataReady .emit ('diagnostics' )
289+ else :
290+ # TODO: save non-GUI version of plot to results.
291+ pass
292+
293+ clu , Wall , st , tF = cluster_spikes (
254294 st , tF , ops , device , bfile , tic0 = tic0 , progress_bar = progress_bar ,
255- clear_cache = clear_cache , verbose = verbose_log
295+ clear_cache = clear_cache , verbose = verbose_log ,
256296 )
257297 ops , similar_templates , is_ref , est_contam_rate , kept_spikes = \
258298 save_sorting (
259299 ops , results_dir , st , clu , tF , Wall , bfile .imin , tic0 ,
260300 save_extra_vars = save_extra_vars ,
261301 save_preprocessed_copy = save_preprocessed_copy
262302 )
303+
304+ # Generate spike positions plot
305+ if gui_sorter is not None :
306+ # TODO: re-use spike positions saved by `save_sorting` instead of
307+ # computing them again in `kilosort.gui.sanity_plots`.
308+ gui_sorter .ops = ops
309+ gui_sorter .st = st [kept_spikes ]
310+ gui_sorter .clu = clu [kept_spikes ]
311+ gui_sorter .tF = tF [kept_spikes ]
312+ gui_sorter .is_refractory = is_ref
313+ gui_sorter .plotDataReady .emit ('probe' )
314+ else :
315+ # TODO: save non-GUI version of plot to results.
316+ pass
317+
263318 except Exception as e :
264319 if isinstance (e , torch .cuda .OutOfMemoryError ):
265320 logger .exception ('Out of memory error, printing performance...' )
@@ -275,6 +330,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
275330 finally :
276331 close_logger ()
277332
333+
278334 return ops , st , clu , tF , Wall , similar_templates , \
279335 is_ref , est_contam_rate , kept_spikes
280336
@@ -676,7 +732,9 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
676732 ops , st0 , tF , mode = 'spikes' , device = device , progress_bar = progress_bar ,
677733 clear_cache = clear_cache , verbose = verbose
678734 )
679- Wall3 = template_matching .postprocess_templates (Wall , ops , clu , st0 , device = device )
735+ Wall3 = template_matching .postprocess_templates (
736+ Wall , ops , clu , st0 , tF , device = device
737+ )
680738 logger .info (f'{ clu .max ()+ 1 } clusters found, in { time .time ()- tic : .2f} s; ' +
681739 f'total { time .time ()- tic0 : .2f} s' )
682740 logger .debug (f'clu shape: { clu .shape } ' )
@@ -686,8 +744,9 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
686744 logger .info (' ' )
687745 logger .info ('Extracting spikes using cluster waveforms' )
688746 logger .info ('-' * 40 )
689- st , tF , ops = template_matching .extract (ops , bfile , Wall3 , device = device ,
690- progress_bar = progress_bar )
747+ st , tF , ops = template_matching .extract (
748+ ops , bfile , Wall3 , device = device , progress_bar = progress_bar
749+ )
691750 logger .info (f'{ len (st )} spikes extracted in { time .time ()- tic : .2f} s; ' +
692751 f'total { time .time ()- tic0 : .2f} s' )
693752 logger .debug (f'st shape: { st .shape } ' )
@@ -756,8 +815,9 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None,
756815 logger .info (' ' )
757816 logger .info ('Merging clusters' )
758817 logger .info ('-' * 40 )
759- Wall , clu , is_ref = template_matching .merging_function (ops , Wall , clu , st [:,0 ],
760- device = device )
818+ Wall , clu , is_ref , st , tF = template_matching .merging_function (
819+ ops , Wall , clu , st , tF , device = device , check_dt = True
820+ )
761821 clu = clu .astype ('int32' )
762822 logger .info (f'{ clu .max ()+ 1 } units found, in { time .time ()- tic : .2f} s; ' +
763823 f'total { time .time ()- tic0 : .2f} s' )
@@ -767,7 +827,7 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None,
767827 log_performance (logger , 'info' , 'Resource usage after clustering' )
768828 log_cuda_details (logger )
769829
770- return clu , Wall
830+ return clu , Wall , st , tF
771831
772832
773833def save_sorting (ops , results_dir , st , clu , tF , Wall , imin , tic0 = np .nan ,
0 commit comments