Skip to content

Commit eac3c3a

Browse files
Added verbose arg stump to datashift and spikedetect
1 parent 4c8a1e7 commit eac3c3a

File tree

4 files changed

+15
-6
lines changed

4 files changed

+15
-6
lines changed

kilosort/clustering_qr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,9 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),
502502

503503
logger.debug(f'Center {ii} | Xd shape: {Xd.shape} | ntemp: {ntemp}')
504504
if verbose and Xd.nelement() > 10**8:
505+
logger.info(f'Resetting cuda memory stats for Center {ii}')
506+
if device == torch.device('cuda'):
507+
torch.cuda.reset_peak_memory_stats(device)
505508
v = True
506509

507510
if Xd is None:

kilosort/datashift.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def kernel2D(x, y, sig = 1):
184184
return Kn
185185

186186
def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
187-
clear_cache=False):
187+
clear_cache=False, verbose=False):
188188
""" this step computes a drift correction model
189189
it returns vertical correction amplitudes for each batch, and for multiple blocks in a batch if nblocks > 1.
190190
"""
@@ -197,7 +197,7 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
197197
# the first step is to extract all spikes using the universal templates
198198
st, _, ops = spikedetect.run(
199199
ops, bfile, device=device, progress_bar=progress_bar,
200-
clear_cache=clear_cache
200+
clear_cache=clear_cache, verbose=verbose
201201
)
202202

203203
# spikes are binned by amplitude and y-position to construct a "fingerprint" for each batch

kilosort/run_kilosort.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
232232
ops, bfile, st0 = compute_drift_correction(
233233
ops, device, tic0=tic0, progress_bar=progress_bar,
234234
file_object=file_object, clear_cache=clear_cache,
235+
verbose=verbose_log
235236
)
236237

237238
# Check scale of data for log file
@@ -531,7 +532,7 @@ def compute_preprocessing(ops, device, tic0=np.nan, file_object=None):
531532

532533

533534
def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
534-
file_object=None, clear_cache=False):
535+
file_object=None, clear_cache=False, verbose=False):
535536
"""Compute drift correction parameters and save them to `ops`.
536537
537538
Parameters
@@ -548,6 +549,11 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
548549
Must have 'shape' and 'dtype' attributes and support array-like
549550
indexing (e.g. [:100,:], [5, 7:10], etc). For example, a numpy
550551
array or memmap.
552+
clear_cache : bool; False.
553+
If True, force pytorch to clear cached cuda memory after some
554+
memory-intensive steps in the pipeline.
555+
verbose : bool; False.
556+
If true, include additional debug-level logging statements.
551557
552558
Returns
553559
-------
@@ -580,7 +586,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
580586
)
581587

582588
ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar,
583-
clear_cache=clear_cache)
589+
clear_cache=clear_cache, verbose=verbose)
584590
bfile.close()
585591
logger.info(f'drift computed in {time.time()-tic : .2f}s; ' +
586592
f'total {time.time()-tic0 : .2f}s')
@@ -650,7 +656,7 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
650656
logger.info('-'*40)
651657
st0, tF, ops = spikedetect.run(
652658
ops, bfile, device=device, progress_bar=progress_bar,
653-
clear_cache=clear_cache
659+
clear_cache=clear_cache, verbose=verbose
654660
)
655661
tF = torch.from_numpy(tF)
656662
logger.info(f'{len(st0)} spikes extracted in {time.time()-tic : .2f}s; ' +

kilosort/spikedetect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def yweighted(yc, iC, adist, xy, device=torch.device('cuda')):
190190
return yct
191191

192192
def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
193-
clear_cache=False):
193+
clear_cache=False, verbose=False):
194194
sig = ops['settings']['min_template_size']
195195
nsizes = ops['settings']['template_sizes']
196196

0 commit comments

Comments
 (0)