@@ -112,74 +112,82 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
112112 filename , data_dir , results_dir , probe = \
113113 set_files (settings , filename , probe , probe_name , data_dir , results_dir )
114114 setup_logger (results_dir )
115- logger .info (f"Kilosort version { kilosort .__version__ } " )
116- logger .info (f"Sorting { filename } " )
117- logger .info ('-' * 40 )
118115
119- if data_dtype is None :
120- logger .info (
121- "Interpreting binary file as default dtype='int16'. If data was "
122- "saved in a different format, specify `data_dtype`."
116+ try :
117+ logger .info (f"Kilosort version { kilosort .__version__ } " )
118+ logger .info (f"Sorting { filename } " )
119+ logger .info ('-' * 40 )
120+
121+ if data_dtype is None :
122+ logger .info (
123+ "Interpreting binary file as default dtype='int16'. If data was "
124+ "saved in a different format, specify `data_dtype`."
125+ )
126+ data_dtype = 'int16'
127+
128+ if not do_CAR :
129+ logger .info ("Skipping common average reference." )
130+
131+ if device is None :
132+ if torch .cuda .is_available ():
133+ logger .info ('Using GPU for PyTorch computations. '
134+ 'Specify `device` to change this.' )
135+ device = torch .device ('cuda' )
136+ else :
137+ logger .info ('Using CPU for PyTorch computations. '
138+ 'Specify `device` to change this.' )
139+ device = torch .device ('cpu' )
140+
141+ if probe ['chanMap' ].max () >= settings ['n_chan_bin' ]:
142+ raise ValueError (
143+ f'Largest value of chanMap exceeds channel count of data, '
144+ 'make sure chanMap is 0-indexed.'
123145 )
124- data_dtype = 'int16'
125-
126- if not do_CAR :
127- logger .info ("Skipping common average reference." )
128146
129- if device is None :
130- if torch .cuda .is_available ():
131- logger .info ('Using GPU for PyTorch computations. '
132- 'Specify `device` to change this.' )
133- device = torch .device ('cuda' )
134- else :
135- logger .info ('Using CPU for PyTorch computations. '
136- 'Specify `device` to change this.' )
137- device = torch .device ('cpu' )
138-
139- if probe ['chanMap' ].max () >= settings ['n_chan_bin' ]:
140- raise ValueError (
141- f'Largest value of chanMap exceeds channel count of data, '
142- 'make sure chanMap is 0-indexed.'
143- )
144-
145- tic0 = time .time ()
146- ops = initialize_ops (settings , probe , data_dtype , do_CAR , invert_sign ,
147- device , save_preprocessed_copy )
148- # Remove some stuff that doesn't need to be printed twice, then pretty-print
149- # format for log file.
150- ops_copy = ops .copy ()
151- _ = ops_copy .pop ('settings' )
152- _ = ops_copy .pop ('probe' )
153- print_ops = pprint .pformat (ops_copy , indent = 4 , sort_dicts = False )
154- logger .debug (f"Initial ops:\n { print_ops } \n " )
155-
156-
157- # Set preprocessing and drift correction parameters
158- ops = compute_preprocessing (ops , device , tic0 = tic0 , file_object = file_object )
159- np .random .seed (1 )
160- torch .cuda .manual_seed_all (1 )
161- torch .random .manual_seed (1 )
162- ops , bfile , st0 = compute_drift_correction (
163- ops , device , tic0 = tic0 , progress_bar = progress_bar ,
164- file_object = file_object
165- )
166-
167- # Check scale of data for log file
168- b1 = bfile .padded_batch_to_torch (0 ).cpu ().numpy ()
169- logger .debug (f"First batch min, max: { b1 .min (), b1 .max ()} " )
170-
171- if save_preprocessed_copy :
172- io .save_preprocessing (results_dir / 'temp_wh.dat' , ops , bfile )
147+ tic0 = time .time ()
148+ ops = initialize_ops (settings , probe , data_dtype , do_CAR , invert_sign ,
149+ device , save_preprocessed_copy )
150+ # Remove some stuff that doesn't need to be printed twice, then pretty-print
151+ # format for log file.
152+ ops_copy = ops .copy ()
153+ _ = ops_copy .pop ('settings' )
154+ _ = ops_copy .pop ('probe' )
155+ print_ops = pprint .pformat (ops_copy , indent = 4 , sort_dicts = False )
156+ logger .debug (f"Initial ops:\n { print_ops } \n " )
157+
158+
159+ # Set preprocessing and drift correction parameters
160+ ops = compute_preprocessing (ops , device , tic0 = tic0 , file_object = file_object )
161+ np .random .seed (1 )
162+ torch .cuda .manual_seed_all (1 )
163+ torch .random .manual_seed (1 )
164+ ops , bfile , st0 = compute_drift_correction (
165+ ops , device , tic0 = tic0 , progress_bar = progress_bar ,
166+ file_object = file_object
167+ )
173168
174- # Sort spikes and save results
175- st ,tF , _ , _ = detect_spikes (ops , device , bfile , tic0 = tic0 ,
176- progress_bar = progress_bar )
177- clu , Wall = cluster_spikes (st , tF , ops , device , bfile , tic0 = tic0 ,
178- progress_bar = progress_bar )
179- ops , similar_templates , is_ref , est_contam_rate , kept_spikes = \
180- save_sorting (ops , results_dir , st , clu , tF , Wall , bfile .imin , tic0 ,
181- save_extra_vars = save_extra_vars ,
182- save_preprocessed_copy = save_preprocessed_copy )
169+ # Check scale of data for log file
170+ b1 = bfile .padded_batch_to_torch (0 ).cpu ().numpy ()
171+ logger .debug (f"First batch min, max: { b1 .min (), b1 .max ()} " )
172+
173+ if save_preprocessed_copy :
174+ io .save_preprocessing (results_dir / 'temp_wh.dat' , ops , bfile )
175+
176+ # Sort spikes and save results
177+ st ,tF , _ , _ = detect_spikes (ops , device , bfile , tic0 = tic0 ,
178+ progress_bar = progress_bar )
179+ clu , Wall = cluster_spikes (st , tF , ops , device , bfile , tic0 = tic0 ,
180+ progress_bar = progress_bar )
181+ ops , similar_templates , is_ref , est_contam_rate , kept_spikes = \
182+ save_sorting (ops , results_dir , st , clu , tF , Wall , bfile .imin , tic0 ,
183+ save_extra_vars = save_extra_vars ,
184+ save_preprocessed_copy = save_preprocessed_copy )
185+ except :
186+ # This makes sure the full traceback is written to log file.
187+ logger .exception ('Encountered error in `run_kilosort`:' )
188+ # Annoyingly, this will print the error message twice for console, but
189+ # I haven't found a good way around that.
190+ raise
183191
184192 return ops , st , clu , tF , Wall , similar_templates , \
185193 is_ref , est_contam_rate , kept_spikes
@@ -435,6 +443,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
435443 Wrapped file object for handling data.
436444
437445 """
446+
438447 tic = time .time ()
439448 logger .info (' ' )
440449 logger .info ('Computing drift correction.' )
0 commit comments