@@ -232,6 +232,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
232232 ops , bfile , st0 = compute_drift_correction (
233233 ops , device , tic0 = tic0 , progress_bar = progress_bar ,
234234 file_object = file_object , clear_cache = clear_cache ,
235+ verbose = verbose_log
235236 )
236237
237238 # Check scale of data for log file
@@ -531,7 +532,7 @@ def compute_preprocessing(ops, device, tic0=np.nan, file_object=None):
531532
532533
533534def compute_drift_correction (ops , device , tic0 = np .nan , progress_bar = None ,
534- file_object = None , clear_cache = False ):
535+ file_object = None , clear_cache = False , verbose = False ):
535536 """Compute drift correction parameters and save them to `ops`.
536537
537538 Parameters
@@ -548,6 +549,11 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
548549 Must have 'shape' and 'dtype' attributes and support array-like
549550 indexing (e.g. [:100,:], [5, 7:10], etc). For example, a numpy
550551 array or memmap.
552+ clear_cache : bool; False.
553+ If True, force pytorch to clear cached cuda memory after some
554+ memory-intensive steps in the pipeline.
555+ verbose : bool; False.
556+ If true, include additional debug-level logging statements.
551557
552558 Returns
553559 -------
@@ -580,7 +586,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
580586 )
581587
582588 ops , st = datashift .run (ops , bfile , device = device , progress_bar = progress_bar ,
583- clear_cache = clear_cache )
589+ clear_cache = clear_cache , verbose = verbose )
584590 bfile .close ()
585591 logger .info (f'drift computed in { time .time ()- tic : .2f} s; ' +
586592 f'total { time .time ()- tic0 : .2f} s' )
@@ -650,7 +656,7 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
650656 logger .info ('-' * 40 )
651657 st0 , tF , ops = spikedetect .run (
652658 ops , bfile , device = device , progress_bar = progress_bar ,
653- clear_cache = clear_cache
659+ clear_cache = clear_cache , verbose = verbose
654660 )
655661 tF = torch .from_numpy (tF )
656662 logger .info (f'{ len (st0 )} spikes extracted in { time .time ()- tic : .2f} s; ' +
0 commit comments