Skip to content

Commit c6cd856

Browse files
Fixed printing errors to log file
1 parent 3e010bb commit c6cd856

File tree

2 files changed

+159
-137
lines changed

2 files changed

+159
-137
lines changed

kilosort/gui/sorter.py

Lines changed: 86 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -48,79 +48,92 @@ def run(self):
4848
results_dir.mkdir(parents=True)
4949

5050
setup_logger(results_dir)
51-
logger.info(f"Kilosort version {kilosort.__version__}")
52-
logger.info(f"Sorting {self.data_path}")
53-
logger.info('-'*40)
54-
55-
tic0 = time.time()
56-
57-
# TODO: make these options in GUI
58-
do_CAR=True
59-
invert_sign=False
60-
61-
if not do_CAR:
62-
logger.info("Skipping common average reference.")
63-
64-
if probe['chanMap'].max() >= settings['n_chan_bin']:
65-
raise ValueError(
66-
f'Largest value of chanMap exceeds channel count of data, '
67-
'make sure chanMap is 0-indexed.'
68-
)
69-
70-
if settings['nt0min'] is None:
71-
settings['nt0min'] = int(20 * settings['nt']/61)
72-
data_dtype = settings['data_dtype']
73-
device = self.device
74-
save_preprocessed_copy = settings['save_preprocessed_copy']
75-
76-
ops = initialize_ops(settings, probe, data_dtype, do_CAR,
77-
invert_sign, device, save_preprocessed_copy)
78-
# Remove some stuff that doesn't need to be printed twice, then pretty-print
79-
# format for log file.
80-
ops_copy = ops.copy()
81-
_ = ops_copy.pop('settings')
82-
_ = ops_copy.pop('probe')
83-
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
84-
logger.debug(f"Initial ops:\n{print_ops}\n")
85-
86-
# TODO: add support for file object through data conversion
87-
# Set preprocessing and drift correction parameters
88-
ops = compute_preprocessing(ops, self.device, tic0=tic0,
89-
file_object=self.file_object)
90-
np.random.seed(1)
91-
torch.cuda.manual_seed_all(1)
92-
torch.random.manual_seed(1)
93-
ops, bfile, st0 = compute_drift_correction(
94-
ops, self.device, tic0=tic0, progress_bar=self.progress_bar,
95-
file_object=self.file_object
96-
)
97-
98-
# Check scale of data for log file
99-
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
100-
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")
101-
102-
if save_preprocessed_copy:
103-
save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
104-
105-
# Will be None if nblocks = 0 (no drift correction)
106-
if st0 is not None:
107-
self.dshift = ops['dshift']
108-
self.st0 = st0
109-
self.plotDataReady.emit('drift')
110-
111-
# Sort spikes and save results
112-
st, tF, Wall0, clu0 = detect_spikes(ops, self.device, bfile, tic0=tic0,
113-
progress_bar=self.progress_bar)
114-
115-
self.Wall0 = Wall0
116-
self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
117-
self.clu0 = clu0
118-
self.plotDataReady.emit('diagnostics')
119-
120-
clu, Wall = cluster_spikes(st, tF, ops, self.device, bfile, tic0=tic0,
121-
progress_bar=self.progress_bar)
122-
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
123-
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)
51+
52+
try:
53+
logger.info(f"Kilosort version {kilosort.__version__}")
54+
logger.info(f"Sorting {self.data_path}")
55+
logger.info('-'*40)
56+
57+
tic0 = time.time()
58+
59+
# TODO: make these options in GUI
60+
do_CAR=True
61+
invert_sign=False
62+
63+
if not do_CAR:
64+
logger.info("Skipping common average reference.")
65+
66+
if probe['chanMap'].max() >= settings['n_chan_bin']:
67+
raise ValueError(
68+
f'Largest value of chanMap exceeds channel count of data, '
69+
'make sure chanMap is 0-indexed.'
70+
)
71+
72+
if settings['nt0min'] is None:
73+
settings['nt0min'] = int(20 * settings['nt']/61)
74+
data_dtype = settings['data_dtype']
75+
device = self.device
76+
save_preprocessed_copy = settings['save_preprocessed_copy']
77+
78+
ops = initialize_ops(settings, probe, data_dtype, do_CAR,
79+
invert_sign, device, save_preprocessed_copy)
80+
# Remove some stuff that doesn't need to be printed twice,
81+
# then pretty-print format for log file.
82+
ops_copy = ops.copy()
83+
_ = ops_copy.pop('settings')
84+
_ = ops_copy.pop('probe')
85+
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
86+
logger.debug(f"Initial ops:\n{print_ops}\n")
87+
88+
# TODO: add support for file object through data conversion
89+
# Set preprocessing and drift correction parameters
90+
ops = compute_preprocessing(ops, self.device, tic0=tic0,
91+
file_object=self.file_object)
92+
np.random.seed(1)
93+
torch.cuda.manual_seed_all(1)
94+
torch.random.manual_seed(1)
95+
ops, bfile, st0 = compute_drift_correction(
96+
ops, self.device, tic0=tic0, progress_bar=self.progress_bar,
97+
file_object=self.file_object
98+
)
99+
100+
# Check scale of data for log file
101+
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
102+
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")
103+
104+
if save_preprocessed_copy:
105+
save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
106+
107+
# Will be None if nblocks = 0 (no drift correction)
108+
if st0 is not None:
109+
self.dshift = ops['dshift']
110+
self.st0 = st0
111+
self.plotDataReady.emit('drift')
112+
113+
# Sort spikes and save results
114+
st, tF, Wall0, clu0 = detect_spikes(
115+
ops, self.device, bfile, tic0=tic0,
116+
progress_bar=self.progress_bar
117+
)
118+
119+
self.Wall0 = Wall0
120+
self.wPCA = torch.clone(ops['wPCA'].cpu()).numpy()
121+
self.clu0 = clu0
122+
self.plotDataReady.emit('diagnostics')
123+
124+
clu, Wall = cluster_spikes(
125+
st, tF, ops, self.device, bfile, tic0=tic0,
126+
progress_bar=self.progress_bar
127+
)
128+
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
129+
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0)
130+
131+
except:
132+
# This makes sure the full traceback is written to log file.
133+
logger.exception('Encountered error in `run_kilosort`:')
134+
# Annoyingly, this will print the error message twice for console
135+
# but I haven't found a good way around that.
136+
raise
124137

125138
self.ops = ops
126139
self.st = st[kept_spikes]

kilosort/run_kilosort.py

Lines changed: 73 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -112,74 +112,82 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
112112
filename, data_dir, results_dir, probe = \
113113
set_files(settings, filename, probe, probe_name, data_dir, results_dir)
114114
setup_logger(results_dir)
115-
logger.info(f"Kilosort version {kilosort.__version__}")
116-
logger.info(f"Sorting {filename}")
117-
logger.info('-'*40)
118115

119-
if data_dtype is None:
120-
logger.info(
121-
"Interpreting binary file as default dtype='int16'. If data was "
122-
"saved in a different format, specify `data_dtype`."
116+
try:
117+
logger.info(f"Kilosort version {kilosort.__version__}")
118+
logger.info(f"Sorting {filename}")
119+
logger.info('-'*40)
120+
121+
if data_dtype is None:
122+
logger.info(
123+
"Interpreting binary file as default dtype='int16'. If data was "
124+
"saved in a different format, specify `data_dtype`."
125+
)
126+
data_dtype = 'int16'
127+
128+
if not do_CAR:
129+
logger.info("Skipping common average reference.")
130+
131+
if device is None:
132+
if torch.cuda.is_available():
133+
logger.info('Using GPU for PyTorch computations. '
134+
'Specify `device` to change this.')
135+
device = torch.device('cuda')
136+
else:
137+
logger.info('Using CPU for PyTorch computations. '
138+
'Specify `device` to change this.')
139+
device = torch.device('cpu')
140+
141+
if probe['chanMap'].max() >= settings['n_chan_bin']:
142+
raise ValueError(
143+
f'Largest value of chanMap exceeds channel count of data, '
144+
'make sure chanMap is 0-indexed.'
123145
)
124-
data_dtype = 'int16'
125-
126-
if not do_CAR:
127-
logger.info("Skipping common average reference.")
128146

129-
if device is None:
130-
if torch.cuda.is_available():
131-
logger.info('Using GPU for PyTorch computations. '
132-
'Specify `device` to change this.')
133-
device = torch.device('cuda')
134-
else:
135-
logger.info('Using CPU for PyTorch computations. '
136-
'Specify `device` to change this.')
137-
device = torch.device('cpu')
138-
139-
if probe['chanMap'].max() >= settings['n_chan_bin']:
140-
raise ValueError(
141-
f'Largest value of chanMap exceeds channel count of data, '
142-
'make sure chanMap is 0-indexed.'
143-
)
144-
145-
tic0 = time.time()
146-
ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign,
147-
device, save_preprocessed_copy)
148-
# Remove some stuff that doesn't need to be printed twice, then pretty-print
149-
# format for log file.
150-
ops_copy = ops.copy()
151-
_ = ops_copy.pop('settings')
152-
_ = ops_copy.pop('probe')
153-
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
154-
logger.debug(f"Initial ops:\n{print_ops}\n")
155-
156-
157-
# Set preprocessing and drift correction parameters
158-
ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object)
159-
np.random.seed(1)
160-
torch.cuda.manual_seed_all(1)
161-
torch.random.manual_seed(1)
162-
ops, bfile, st0 = compute_drift_correction(
163-
ops, device, tic0=tic0, progress_bar=progress_bar,
164-
file_object=file_object
165-
)
166-
167-
# Check scale of data for log file
168-
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
169-
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")
170-
171-
if save_preprocessed_copy:
172-
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
147+
tic0 = time.time()
148+
ops = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign,
149+
device, save_preprocessed_copy)
150+
# Remove some stuff that doesn't need to be printed twice, then pretty-print
151+
# format for log file.
152+
ops_copy = ops.copy()
153+
_ = ops_copy.pop('settings')
154+
_ = ops_copy.pop('probe')
155+
print_ops = pprint.pformat(ops_copy, indent=4, sort_dicts=False)
156+
logger.debug(f"Initial ops:\n{print_ops}\n")
157+
158+
159+
# Set preprocessing and drift correction parameters
160+
ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object)
161+
np.random.seed(1)
162+
torch.cuda.manual_seed_all(1)
163+
torch.random.manual_seed(1)
164+
ops, bfile, st0 = compute_drift_correction(
165+
ops, device, tic0=tic0, progress_bar=progress_bar,
166+
file_object=file_object
167+
)
173168

174-
# Sort spikes and save results
175-
st,tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0,
176-
progress_bar=progress_bar)
177-
clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0,
178-
progress_bar=progress_bar)
179-
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
180-
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
181-
save_extra_vars=save_extra_vars,
182-
save_preprocessed_copy=save_preprocessed_copy)
169+
# Check scale of data for log file
170+
b1 = bfile.padded_batch_to_torch(0).cpu().numpy()
171+
logger.debug(f"First batch min, max: {b1.min(), b1.max()}")
172+
173+
if save_preprocessed_copy:
174+
io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile)
175+
176+
# Sort spikes and save results
177+
st,tF, _, _ = detect_spikes(ops, device, bfile, tic0=tic0,
178+
progress_bar=progress_bar)
179+
clu, Wall = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0,
180+
progress_bar=progress_bar)
181+
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
182+
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
183+
save_extra_vars=save_extra_vars,
184+
save_preprocessed_copy=save_preprocessed_copy)
185+
except:
186+
# This makes sure the full traceback is written to log file.
187+
logger.exception('Encountered error in `run_kilosort`:')
188+
# Annoyingly, this will print the error message twice for console, but
189+
# I haven't found a good way around that.
190+
raise
183191

184192
return ops, st, clu, tF, Wall, similar_templates, \
185193
is_ref, est_contam_rate, kept_spikes
@@ -435,6 +443,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
435443
Wrapped file object for handling data.
436444
437445
"""
446+
438447
tic = time.time()
439448
logger.info(' ')
440449
logger.info('Computing drift correction.')

0 commit comments

Comments
 (0)