Skip to content

Commit 1d11e34

Browse files
Added clear_cache option
1 parent 5a1e630 commit 1d11e34

File tree

4 files changed

+61
-25
lines changed

4 files changed

+61
-25
lines changed

kilosort/clustering_qr.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import gc
2+
13
import numpy as np
24
import torch
35
from torch import sparse_coo_tensor as coo
@@ -301,7 +303,8 @@ def y_centers(ops):
301303
return centers
302304

303305

304-
def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_bar=None):
306+
def run(ops, st, tF, mode = 'template', device=torch.device('cuda'),
307+
progress_bar=None, clear_cache=False):
305308

306309
if mode == 'template':
307310
xy, iC = xy_templates(ops)
@@ -362,11 +365,16 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_b
362365

363366
# find new clusters
364367
iclust, iclust0, M, iclust_init = cluster(Xd, nskip=nskip, lam=1,
365-
seed=5, device=device)
368+
seed=5, device=device)
369+
if clear_cache:
370+
gc.collect()
371+
torch.cuda.empty_cache()
366372

367373
xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0)
368374

369-
xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0)
375+
xtree, tstat = swarmsplitter.split(
376+
Xd.numpy(), xtree, tstat,iclust, my_clus, meta=st0
377+
)
370378

371379
iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat)
372380

kilosort/datashift.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def kernel2D(x, y, sig = 1):
183183
Kn = np.exp(-ds / (2*sig**2))
184184
return Kn
185185

186-
def run(ops, bfile, device=torch.device('cuda'), progress_bar=None):
186+
def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
187+
clear_cache=False):
187188
""" this step computes a drift correction model
188189
it returns vertical correction amplitudes for each batch, and for multiple blocks in a batch if nblocks > 1.
189190
"""
@@ -194,7 +195,10 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None):
194195
return ops, None
195196

196197
# the first step is to extract all spikes using the universal templates
197-
st, _, ops = spikedetect.run(ops, bfile, device=device, progress_bar=progress_bar)
198+
st, _, ops = spikedetect.run(
199+
ops, bfile, device=device, progress_bar=progress_bar,
200+
clear_cache=clear_cache
201+
)
198202

199203
# spikes are binned by amplitude and y-position to construct a "fingerprint" for each batch
200204
F, ysamp = bin_spikes(ops, st)

kilosort/run_kilosort.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
def run_kilosort(settings, probe=None, probe_name=None, filename=None,
2727
data_dir=None, file_object=None, results_dir=None,
2828
data_dtype=None, do_CAR=True, invert_sign=False, device=None,
29-
progress_bar=None, save_extra_vars=False,
29+
progress_bar=None, save_extra_vars=False, clear_cache=False,
3030
save_preprocessed_copy=False, bad_channels=None):
3131
"""Run full spike sorting pipeline on specified data.
3232
@@ -82,6 +82,12 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
8282
not need to specify this.
8383
save_extra_vars : bool; default=False.
8484
If True, save tF and Wall to disk after sorting.
85+
clear_cache : bool; default=False.
86+
If True, force pytorch to free up memory reserved for its cache in
87+
between memory-intensive operations.
88+
Note that setting `clear_cache=True` is NOT recommended unless you
89+
encounter GPU out-of-memory errors, since this can result in slower
90+
sorting.
8591
save_preprocessed_copy : bool; default=False.
8692
If True, save a pre-processed copy of the data (including drift
8793
correction) to `temp_wh.dat` in the results directory and format Phy
@@ -150,6 +156,8 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
150156
try:
151157
logger.info(f"Kilosort version {kilosort.__version__}")
152158
logger.info(f"Sorting {filename}")
159+
if clear_cache:
160+
logger.info('clear_cache=True')
153161
logger.info('-'*40)
154162

155163
if data_dtype is None:
@@ -189,15 +197,14 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
189197
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
190198
logger.debug(f"Initial ops:\n{print_ops}\n")
191199

192-
193200
# Set preprocessing and drift correction parameters
194201
ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object)
195202
np.random.seed(1)
196203
torch.cuda.manual_seed_all(1)
197204
torch.random.manual_seed(1)
198205
ops, bfile, st0 = compute_drift_correction(
199206
ops, device, tic0=tic0, progress_bar=progress_bar,
200-
file_object=file_object
207+
file_object=file_object, clear_cache=clear_cache,
201208
)
202209

203210
# Check scale of data for log file
@@ -208,14 +215,20 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
208215
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
209216

210217
# Sort spikes and save results
211-
st,tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0,
212-
progress_bar=progress_bar)
213-
clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0,
214-
progress_bar=progress_bar)
218+
st,tF, _, _ = detect_spikes(
219+
ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
220+
clear_cache=clear_cache
221+
)
222+
clu, Wall = cluster_spikes(
223+
st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar,
224+
clear_cache=clear_cache
225+
)
215226
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
216-
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
217-
save_extra_vars=save_extra_vars,
218-
save_preprocessed_copy=save_preprocessed_copy)
227+
save_sorting(
228+
ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
229+
save_extra_vars=save_extra_vars,
230+
save_preprocessed_copy=save_preprocessed_copy
231+
)
219232
except:
220233
# This makes sure the full traceback is written to log file.
221234
logger.exception('Encountered error in `run_kilosort`:')
@@ -456,7 +469,7 @@ def compute_preprocessing(ops, device, tic0=np.nan, file_object=None):
456469

457470

458471
def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
459-
file_object=None):
472+
file_object=None, clear_cache=False):
460473
"""Compute drift correction parameters and save them to `ops`.
461474
462475
Parameters
@@ -504,7 +517,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
504517
file_object=file_object
505518
)
506519

507-
ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar)
520+
ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar,
521+
clear_cache=clear_cache)
508522
bfile.close()
509523
logger.info(f'drift computed in {time.time()-tic : .2f}s; ' +
510524
f'total {time.time()-tic0 : .2f}s')
@@ -526,7 +540,8 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
526540
return ops, bfile, st
527541

528542

529-
def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
543+
def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None,
544+
clear_cache=False):
530545
"""Detect spikes via template deconvolution.
531546
532547
Parameters
@@ -563,7 +578,10 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
563578
logger.info(' ')
564579
logger.info(f'Extracting spikes using templates')
565580
logger.info('-'*40)
566-
st0, tF, ops = spikedetect.run(ops, bfile, device=device, progress_bar=progress_bar)
581+
st0, tF, ops = spikedetect.run(
582+
ops, bfile, device=device, progress_bar=progress_bar,
583+
clear_cache=clear_cache
584+
)
567585
tF = torch.from_numpy(tF)
568586
logger.info(f'{len(st0)} spikes extracted in {time.time()-tic : .2f}s; ' +
569587
f'total {time.time()-tic0 : .2f}s')
@@ -576,8 +594,10 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
576594
logger.info(' ')
577595
logger.info('First clustering')
578596
logger.info('-'*40)
579-
clu, Wall = clustering_qr.run(ops, st0, tF, mode='spikes', device=device,
580-
progress_bar=progress_bar)
597+
clu, Wall = clustering_qr.run(
598+
ops, st0, tF, mode='spikes', device=device, progress_bar=progress_bar,
599+
clear_cache=clear_cache
600+
)
581601
Wall3 = template_matching.postprocess_templates(Wall, ops, clu, st0, device=device)
582602
logger.info(f'{clu.max()+1} clusters found, in {time.time()-tic : .2f}s; ' +
583603
f'total {time.time()-tic0 : .2f}s')
@@ -600,7 +620,8 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
600620
return st, tF, Wall, clu
601621

602622

603-
def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None):
623+
def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None,
624+
clear_cache=False):
604625
"""Cluster spikes using graph-based methods.
605626
606627
Parameters
@@ -636,8 +657,10 @@ def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None):
636657
logger.info(' ')
637658
logger.info('Final clustering')
638659
logger.info('-'*40)
639-
clu, Wall = clustering_qr.run(ops, st, tF, mode = 'template', device=device,
640-
progress_bar=progress_bar)
660+
clu, Wall = clustering_qr.run(
661+
ops, st, tF, mode = 'template', device=device, progress_bar=progress_bar,
662+
clear_cache=clear_cache
663+
)
641664
logger.info(f'{clu.max()+1} clusters found, in {time.time()-tic : .2f}s; ' +
642665
f'total {time.time()-tic0 : .2f}s')
643666
logger.debug(f'clu shape: {clu.shape}')

kilosort/spikedetect.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ def yweighted(yc, iC, adist, xy, device=torch.device('cuda')):
194194
yct = (cF0 * yy[:,xy[:,0]]).sum(0)
195195
return yct
196196

197-
def run(ops, bfile, device=torch.device('cuda'), progress_bar=None):
197+
def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
198+
clear_cache=False):
198199
sig = ops['settings']['min_template_size']
199200
nsizes = ops['settings']['template_sizes']
200201

0 commit comments

Comments
 (0)