2626def run_kilosort (settings , probe = None , probe_name = None , filename = None ,
2727 data_dir = None , file_object = None , results_dir = None ,
2828 data_dtype = None , do_CAR = True , invert_sign = False , device = None ,
29- progress_bar = None , save_extra_vars = False ,
29+ progress_bar = None , save_extra_vars = False , clear_cache = False ,
3030 save_preprocessed_copy = False , bad_channels = None ):
3131 """Run full spike sorting pipeline on specified data.
3232
@@ -82,6 +82,12 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
8282 not need to specify this.
8383 save_extra_vars : bool; default=False.
8484 If True, save tF and Wall to disk after sorting.
85+ clear_cache : bool; default=False.
86+ If True, force pytorch to free up memory reserved for its cache in
87+ between memory-intensive operations.
88+ Note that setting `clear_cache=True` is NOT recommended unless you
89+ encounter GPU out-of-memory errors, since this can result in slower
90+ sorting.
8591 save_preprocessed_copy : bool; default=False.
8692 If True, save a pre-processed copy of the data (including drift
8793 correction) to `temp_wh.dat` in the results directory and format Phy
@@ -150,6 +156,8 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
150156 try :
151157 logger .info (f"Kilosort version { kilosort .__version__ } " )
152158 logger .info (f"Sorting { filename } " )
159+ if clear_cache :
160+ logger .info ('clear_cache=True' )
153161 logger .info ('-' * 40 )
154162
155163 if data_dtype is None :
@@ -189,15 +197,14 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
189197 print_ops = pprint .pformat (ops_copy , indent = 4 , sort_dicts = False )
190198 logger .debug (f"Initial ops:\n { print_ops } \n " )
191199
192-
193200 # Set preprocessing and drift correction parameters
194201 ops = compute_preprocessing (ops , device , tic0 = tic0 , file_object = file_object )
195202 np .random .seed (1 )
196203 torch .cuda .manual_seed_all (1 )
197204 torch .random .manual_seed (1 )
198205 ops , bfile , st0 = compute_drift_correction (
199206 ops , device , tic0 = tic0 , progress_bar = progress_bar ,
200- file_object = file_object
207+ file_object = file_object , clear_cache = clear_cache ,
201208 )
202209
203210 # Check scale of data for log file
@@ -208,14 +215,20 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
208215 io .save_preprocessing (results_dir / 'temp_wh.dat' , ops , bfile )
209216
210217 # Sort spikes and save results
211- st ,tF , _ , _ = detect_spikes (ops , device , bfile , tic0 = tic0 ,
212- progress_bar = progress_bar )
213- clu , Wall = cluster_spikes (st , tF , ops , device , bfile , tic0 = tic0 ,
214- progress_bar = progress_bar )
218+ st ,tF , _ , _ = detect_spikes (
219+ ops , device , bfile , tic0 = tic0 , progress_bar = progress_bar ,
220+ clear_cache = clear_cache
221+ )
222+ clu , Wall = cluster_spikes (
223+ st , tF , ops , device , bfile , tic0 = tic0 , progress_bar = progress_bar ,
224+ clear_cache = clear_cache
225+ )
215226 ops , similar_templates , is_ref , est_contam_rate , kept_spikes = \
216- save_sorting (ops , results_dir , st , clu , tF , Wall , bfile .imin , tic0 ,
217- save_extra_vars = save_extra_vars ,
218- save_preprocessed_copy = save_preprocessed_copy )
227+ save_sorting (
228+ ops , results_dir , st , clu , tF , Wall , bfile .imin , tic0 ,
229+ save_extra_vars = save_extra_vars ,
230+ save_preprocessed_copy = save_preprocessed_copy
231+ )
219232 except :
220233 # This makes sure the full traceback is written to log file.
221234 logger .exception ('Encountered error in `run_kilosort`:' )
@@ -456,7 +469,7 @@ def compute_preprocessing(ops, device, tic0=np.nan, file_object=None):
456469
457470
458471def compute_drift_correction (ops , device , tic0 = np .nan , progress_bar = None ,
459- file_object = None ):
472+ file_object = None , clear_cache = False ):
460473 """Compute drift correction parameters and save them to `ops`.
461474
462475 Parameters
@@ -504,7 +517,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
504517 file_object = file_object
505518 )
506519
507- ops , st = datashift .run (ops , bfile , device = device , progress_bar = progress_bar )
520+ ops , st = datashift .run (ops , bfile , device = device , progress_bar = progress_bar ,
521+ clear_cache = clear_cache )
508522 bfile .close ()
509523 logger .info (f'drift computed in { time .time ()- tic : .2f} s; ' +
510524 f'total { time .time ()- tic0 : .2f} s' )
@@ -526,7 +540,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
526540 return ops , bfile , st
527541
528542
529- def detect_spikes (ops , device , bfile , tic0 = np .nan , progress_bar = None ):
543+ def detect_spikes (ops , device , bfile , tic0 = np .nan , progress_bar = None ,
544+ clear_cache = False ):
530545 """Detect spikes via template deconvolution.
531546
532547 Parameters
@@ -563,7 +578,10 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
563578 logger .info (' ' )
564579 logger .info (f'Extracting spikes using templates' )
565580 logger .info ('-' * 40 )
566- st0 , tF , ops = spikedetect .run (ops , bfile , device = device , progress_bar = progress_bar )
581+ st0 , tF , ops = spikedetect .run (
582+ ops , bfile , device = device , progress_bar = progress_bar ,
583+ clear_cache = clear_cache
584+ )
567585 tF = torch .from_numpy (tF )
568586 logger .info (f'{ len (st0 )} spikes extracted in { time .time ()- tic : .2f} s; ' +
569587 f'total { time .time ()- tic0 : .2f} s' )
@@ -576,8 +594,10 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
576594 logger .info (' ' )
577595 logger .info ('First clustering' )
578596 logger .info ('-' * 40 )
579- clu , Wall = clustering_qr .run (ops , st0 , tF , mode = 'spikes' , device = device ,
580- progress_bar = progress_bar )
597+ clu , Wall = clustering_qr .run (
598+ ops , st0 , tF , mode = 'spikes' , device = device , progress_bar = progress_bar ,
599+ clear_cache = clear_cache
600+ )
581601 Wall3 = template_matching .postprocess_templates (Wall , ops , clu , st0 , device = device )
582602 logger .info (f'{ clu .max ()+ 1 } clusters found, in { time .time ()- tic : .2f} s; ' +
583603 f'total { time .time ()- tic0 : .2f} s' )
@@ -600,7 +620,8 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
600620 return st , tF , Wall , clu
601621
602622
603- def cluster_spikes (st , tF , ops , device , bfile , tic0 = np .nan , progress_bar = None ):
623+ def cluster_spikes (st , tF , ops , device , bfile , tic0 = np .nan , progress_bar = None ,
624+ clear_cache = False ):
604625 """Cluster spikes using graph-based methods.
605626
606627 Parameters
@@ -636,8 +657,10 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None):
636657 logger .info (' ' )
637658 logger .info ('Final clustering' )
638659 logger .info ('-' * 40 )
639- clu , Wall = clustering_qr .run (ops , st , tF , mode = 'template' , device = device ,
640- progress_bar = progress_bar )
660+ clu , Wall = clustering_qr .run (
661+ ops , st , tF , mode = 'template' , device = device , progress_bar = progress_bar ,
662+ clear_cache = clear_cache
663+ )
641664 logger .info (f'{ clu .max ()+ 1 } clusters found, in { time .time ()- tic : .2f} s; ' +
642665 f'total { time .time ()- tic0 : .2f} s' )
643666 logger .debug (f'clu shape: { clu .shape } ' )
0 commit comments