-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpremerging.py
More file actions
221 lines (186 loc) · 11.5 KB
/
premerging.py
File metadata and controls
221 lines (186 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# get all the recordings on that day
# allocate destination folder and move the ephys folder on the server to Beast lab user
from pathlib import Path
import os
import shutil
import numpy as np
import spikeinterface.sorters
import spikeinterface.full as si
import scipy.signal
import spikeinterface.extractors as se
import spikeinterface.comparison
import spikeinterface.exporters
import spikeinterface.curation
import spikeinterface.widgets
import docker
from datetime import datetime
import itertools
import scipy.io as sio
startTime = datetime.now()
print('Start Time:' + startTime.strftime("%m/%d/%Y, %H:%M:%S"))
''' this section defines the animal and dates and fetch the recordings from the server to Beast'''
import sys
# The first command-line argument after the script name is the mouse identifier.
mouse = sys.argv[1]
# All command-line arguments after `mouse` and before `save_date` are considered dates.
dates = sys.argv[2].split(',') # This captures all dates as a list.
# The last command-line argument is `save_date`.
save_date = sys.argv[3]
local_folder = sys.argv[4]
no_probe = sys.argv[5]
print(mouse)
print('acquisition folder: ',dates)
print(save_date)
use_ks4 = sys.argv[6].lower() in ['true', '1', 't', 'y', 'yes']
use_ks3 = sys.argv[7].lower() in ['true', '1', 't', 'y', 'yes']
base_folder = '/mnt/rds01/ibn-vision/DATA/SUBJECTS/'
save_folder = local_folder + save_date +'/'
# get all the recordings on that day
import os
import subprocess
subprocess.run('ulimit -n 4096',shell=True)
def sorting_key(s):
return int(s.split('_g')[-1])
#grab recordings from the server to local machine (Beast)
job_kwargs = dict(n_jobs=32, chunk_duration='1s', progress_bar=True)
print(dates)
g_files_all = []
# iterate over all directories in source folder
date_count = 0
for date in dates:
print('acquisition folder:',date)
date_count = date_count + 1
ephys_folder = base_folder + mouse + '/ephys/' + date +'/'
dst_folder = local_folder + date + '/'
ephys_folder = base_folder + mouse + '/ephys/' + date +'/'
g_files = []
print('copying ephys data from:' + ephys_folder)
for dirname in os.listdir(ephys_folder):
# # check if '_g' is in the directory name
# #only grab recording folders - there might be some other existing folders for analysis or sorted data
if '_g' in dirname:
# # construct full directory path
g_files.append(dirname)
source = os.path.join(ephys_folder, dirname)
destination = os.path.join(dst_folder, dirname)
# copy the directory to the destination folder
shutil.copytree(source, destination)
print('Start to copying files to local folder: ')
print(datetime.now() - startTime)
''' read spikeglx recordings and preprocess them'''
# Define a custom sorting key that extracts the number after 'g'
# Sort the list using the custom sorting key
g_files = sorted(g_files, key=sorting_key)
g_files_all = g_files_all + g_files
print(g_files)
print('all g files:',g_files_all)
# stream_names, stream_ids = si.get_neo_streams('spikeglx',dst_folder)
# print(stream_names)
# print(stream_ids)
for probe in range(int(no_probe)):
date_count = 0
probe0_start_sample_fames = []
probe0_end_sample_frames = []
for date in dates:
date_count = date_count + 1
probe_name = 'imec' + str(probe) + '.ap'
dst_folder = local_folder + date + '/'
probe0_raw = si.read_spikeglx(dst_folder,stream_name=probe_name)
print(probe0_raw)
probe0_num_segments = [probe0_raw.get_num_frames(segment_index=i) for i in range(probe0_raw.get_num_segments())]
probe0_end_sample_frames_tmp = list(itertools.accumulate(probe0_num_segments))
if date_count == 1:
probe0_start_sample_frames = [1] + [probe0_end_sample_frames_tmp[i] + 1 for i in range(0, len(probe0_num_segments)-1)]
probe0_end_sample_frames = probe0_end_sample_frames + probe0_end_sample_frames_tmp
else:
probe0_start_sample_frames = probe0_start_sample_frames + [probe0_end_sample_frames[-1]+1] + \
[probe0_end_sample_frames[-1]+probe0_end_sample_frames_tmp[i] + 1 for i in range(0, len(probe0_num_segments)-1)]
probe0_end_sample_frames = probe0_end_sample_frames + [probe0_end_sample_frames_tmp[i] + probe0_end_sample_frames[-1] for i in range(0, len(probe0_num_segments))]
#several preprocessing steps and concatenation of the recordings
#highpass filter - threhsold at 300Hz
probe0_highpass = si.highpass_filter(probe0_raw,freq_min=300.)
#detect bad channels
#phase shift correction - equivalent to T-SHIFT in catGT
probe0_phase_shift = si.phase_shift(probe0_highpass)
probe0_common_reference = si.common_reference(probe0_phase_shift,operator='median',reference='global')
probe0_preprocessed = probe0_common_reference
probe0_cat = si.concatenate_recordings([probe0_preprocessed])
print('probe0_preprocessed',probe0_preprocessed)
print('probe0 concatenated',probe0_cat)
if date_count == 1:
probe0_cat_all = probe0_cat
else:
probe0_cat_all = si.concatenate_recordings([probe0_cat_all,probe0_cat])
bad_channel_ids, channel_labels = si.detect_bad_channels(probe0_cat_all)
probe0_cat_all = probe0_cat_all.remove_channels(bad_channel_ids)
print('probe0_bad_channel_ids',bad_channel_ids)
'''Motion Drift Correction'''
#motion correction if needed
#this is nonrigid correction - need to do parallel computing to speed up
#assign parallel processing as job_kwargs
#probe0_nonrigid_accurate = si.correct_motion(recording=probe0_cat_all, preset="kilosort_like",**job_kwargs)
print('Start to motion correction finished:')
print(datetime.now() - startTime)
#kilosort like to mimic kilosort - no need if you are just running kilosort
# probe0_kilosort_like = correct_motion(recording=probe0_cat, preset="kilosort_like")
# probe1_kilosort_like = correct_motion(recording=probe1_cat, preset="kilosort_like")
'''save preprocessed bin file before sorting'''
#after saving, sorters will read this preprocessed binary file instead
probe0_preprocessed_corrected = probe0_cat_all.save(folder=save_folder+'probe'+str(probe)+'_preprocessed', format='binary', **job_kwargs)
print('Start to prepocessed files saved:')
print(datetime.now() - startTime)
#probe0_preprocessed_corrected = si.load_extractor(save_folder+'/probe0_preprocessed')
#probe0_preprocessed_corrected = si.load_extractor(save_folder+'/probe1_preprocessed')
''' prepare sorters - currently using the default parameters and motion correction is turned off as it was corrected already above
you can check if the parameters using:
params = get_default_sorter_params('kilosort3')
print("Parameters:\n", params)
desc = get_sorter_params_description('kilosort3')
print("Descriptions:\n", desc)
Beware that moutainsort5 is commented out as the sorter somehow stops midway with no clue - currently raising this issue on their github page
'''
import pandas as pd
def save_spikes_to_csv(spikes,save_folder):
unit_index = spikes['unit_index']
segment_index = spikes['segment_index']
sample_index = spikes['sample_index']
spikes_df = pd.DataFrame({'unit_index':unit_index,'segment_index':segment_index,'sample_index':sample_index})
spikes_df.to_csv(save_folder + 'spikes.csv',index=False)
#probe0_sorting_ks2_5 = si.run_sorter(sorter_name= 'kilosort2_5',recording=probe0_preprocessed_corrected,output_folder=dst_folder+'probe0/sorters/kilosort2_5/',docker_image="spikeinterface/kilosort2_5-compiled-base:latest",do_correction=False)
#probe1_sorting_ks2_5 = si.run_sorter(sorter_name= 'kilosort2_5',recording=probe1_preprocessed_corrected,output_folder=dst_folder+'probe1/sorters/kilosort2_5/',docker_image="spikeinterface/kilosort2_5-compiled-base:latest",do_correction=False)
#probe0_sorting_ks3 = si.run_sorter(sorter_name= 'kilosort3',recording=probe0_preprocessed_corrected,output_folder=dst_folder+'probe0/sorters/kilosort3/',docker_image="spikeinterface/kilosort3-compiled-base:latest",do_correction=False)
#probe1_sorting_ks3 = si.run_sorter(sorter_name= 'kilosort3',recording=probe1_preprocessed_corrected,output_folder=dst_folder+'probe1/sorters/kilosort3/',docker_image="spikeinterface/kilosort3-compiled-base:latest",do_correction=False)
# probe0_sorting_ks3 = si.remove_duplicated_spikes(sorting = probe0_sorting_ks3, censored_period_ms=0.3,method='keep_first')
# probe1_sorting_ks2_5 = si.remove_duplicated_spikes(sorting = probe1_sorting_ks2_5, censored_period_ms=0.3,method='keep_first')
# probe1_sorting_ks3 = si.remove_duplicated_spikes(sorting = probe1_sorting_ks3, censored_period_ms=0.3,method='keep_first')
if use_ks4:
probe0_sorting_ks4 = si.run_sorter(sorter_name= 'kilosort4',recording=probe0_preprocessed_corrected,output_folder=save_folder+'probe'+str(probe)+'/sorters/kilosort4/',docker_image='spikeinterface/kilosort4-base:latest')
probe0_sorting_ks4 = si.remove_duplicated_spikes(sorting = probe0_sorting_ks4, censored_period_ms=0.3,method='keep_first')
probe0_we_ks4 = si.create_sorting_analyzer(probe0_sorting_ks4, probe0_preprocessed_corrected,
format = 'binary_folder',folder=save_folder +'probe'+str(probe)+'/waveform/kilosort4',
sparse = True,overwrite = True,
**job_kwargs)
probe0_we_ks4.compute('random_spikes')
probe0_we_ks4.compute('waveforms',ms_before=1.0, ms_after=2.0,**job_kwargs)
probe0_ks4_spikes = np.load(save_folder + 'probe'+str(probe)+'/sorters/kilosort4/in_container_sorting/spikes.npy')
save_spikes_to_csv(probe0_ks4_spikes,save_folder + 'probe'+str(probe)+'/sorters/kilosort4/in_container_sorting/')
if use_ks3:
probe0_sorting_ks3 = si.run_sorter(sorter_name= 'kilosort3',recording=probe0_preprocessed_corrected,output_folder=save_folder+'probe'+str(probe)+'/sorters/kilosort3/',docker_image='spikeinterface/kilosort3-compiled-base:latest')
probe0_sorting_ks3 = si.remove_duplicated_spikes(sorting = probe0_sorting_ks3, censored_period_ms=0.3,method='keep_first')
probe0_we_ks3 = si.create_sorting_analyzer(probe0_sorting_ks3, probe0_preprocessed_corrected,
format = 'binary_folder',folder=save_folder +'probe'+str(probe)+'/waveform/kilosort3',
sparse = True,overwrite = True,
**job_kwargs)
probe0_we_ks3.compute('random_spikes')
probe0_we_ks3.compute('waveforms',ms_before=1.0, ms_after=2.0,**job_kwargs)
probe0_ks3_spikes = np.load(save_folder + 'probe'+str(probe)+'/sorters/kilosort3/in_container_sorting/spikes.npy')
save_spikes_to_csv(probe0_ks3_spikes,save_folder + 'probe'+str(probe)+'/sorters/kilosort3/in_container_sorting/')
print('Start to all sorting done:')
print(datetime.now() - startTime)
import pandas as pd
probe0_segment_frames = pd.DataFrame({'segment_info':g_files_all,'segment start frame': probe0_start_sample_frames, 'segment end frame': probe0_end_sample_frames})
probe0_segment_frames.to_csv(save_folder+'probe'+str(probe)+'/sorters/segment_frames.csv', index=False)
''' read sorters directly from the output folder - so you dont need to worry if something went wrong and you can't access the temp variables
This section reads sorter outputs and extract waveforms
'''
sys.exit(0)