Skip to content

Commit 3f12ff4

Browse files
added fix for OMP-related warnings
1 parent aaac099 commit 3f12ff4

File tree

4 files changed

+24
-10
lines changed

4 files changed

+24
-10
lines changed

kilosort/gui/sorter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
setup_logger, initialize_ops, compute_preprocessing, compute_drift_correction,
1414
detect_spikes, cluster_spikes, save_sorting
1515
)
16-
1716
from kilosort.io import save_preprocessing
17+
from kilosort.utils import log_performance, log_cuda_details
1818

1919
#logger = setup_logger(__name__)
2020

@@ -128,7 +128,11 @@ def run(self):
128128
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
129129
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)
130130

131-
except:
131+
except Exception as e:
132+
if isinstance(e, torch.cuda.OutOfMemoryError):
133+
logger.exception('Out of memory error, printing performance...')
134+
log_performance(logger, level='info')
135+
log_cuda_details(logger)
132136
# This makes sure the full traceback is written to log file.
133137
logger.exception('Encountered error in `run_kilosort`:')
134138
# Annoyingly, this will print the error message twice for console

kilosort/hierarchical.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from scipy.sparse import csr_matrix
22
import numpy as np
3-
import faiss
4-
from sklearn.cluster import KMeans
53

64

75
def cluster_qr(M, iclust, iclust0):

kilosort/run_kilosort.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,12 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
230230
save_extra_vars=save_extra_vars,
231231
save_preprocessed_copy=save_preprocessed_copy
232232
)
233-
except:
233+
except Exception as e:
234+
if isinstance(e, torch.cuda.OutOfMemoryError):
235+
logger.exception('Out of memory error, printing performance...')
236+
log_performance(logger, level='info')
237+
log_cuda_details(logger)
238+
234239
# This makes sure the full traceback is written to log file.
235240
logger.exception('Encountered error in `run_kilosort`:')
236241
# Annoyingly, this will print the error message twice for console, but

kilosort/spikedetect.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from io import StringIO
2+
import os
23
import logging
4+
import warnings
35
logger = logging.getLogger(__name__)
46

57
from torch.nn.functional import max_pool2d, avg_pool2d, conv1d, max_pool1d
@@ -11,8 +13,6 @@
1113

1214
from kilosort.utils import template_path, log_performance
1315

14-
device = torch.device('cuda')
15-
1616

1717
def my_max2d(X, dt):
1818
Xmax = max_pool2d(
@@ -72,9 +72,16 @@ def extract_wPCA_wTEMP(ops, bfile, nt=61, twav_min=20, Th_single_ch=6, nskip=25,
7272
model = TruncatedSVD(n_components=ops['settings']['n_pcs']).fit(clips)
7373
wPCA = torch.from_numpy(model.components_).to(device).float()
7474

75-
model = KMeans(n_clusters=ops['settings']['n_templates'], n_init = 10).fit(clips)
76-
wTEMP = torch.from_numpy(model.cluster_centers_).to(device).float()
77-
wTEMP = wTEMP / (wTEMP**2).sum(1).unsqueeze(1)**.5
75+
with warnings.catch_warnings():
76+
warnings.filterwarnings("ignore", message="")
77+
# Prevents memory leak for KMeans when using MKL on Windows
78+
msg = 'KMeans is known to have a memory leak on Windows with MKL'
79+
nthread = os.environ.get('OMP_NUM_THREADS', msg)
80+
os.environ['OMP_NUM_THREADS'] = '7'
81+
model = KMeans(n_clusters=ops['settings']['n_templates'], n_init = 10).fit(clips)
82+
wTEMP = torch.from_numpy(model.cluster_centers_).to(device).float()
83+
wTEMP = wTEMP / (wTEMP**2).sum(1).unsqueeze(1)**.5
84+
os.environ['OMP_NUM_THREADS'] = nthread
7885

7986
return wPCA, wTEMP
8087

0 commit comments

Comments
 (0)