Skip to content

Commit 3c97c43

Browse files
chore: linter 1/2
1 parent 68823cf commit 3c97c43

File tree

1 file changed

+99
-84
lines changed

1 file changed

+99
-84
lines changed

packages/ecog2vec/ecog2vec/data_generator.py

+99-84
Original file line numberDiff line numberDiff line change
@@ -43,64 +43,64 @@ class NeuralDataGenerator():
4343
def __init__(self, nwb_dir, patient):
4444

4545
self.patient = patient
46-
46+
4747
file_list = os.listdir(nwb_dir)
4848
self.nwb_dir = nwb_dir
49-
self.nwb_files = [file
50-
for file in file_list
49+
self.nwb_files = [file
50+
for file in file_list
5151
if file.startswith(f"{patient}")]
5252
self.target_sr = 100
53-
53+
5454
self.bad_electrodes = []
5555
self.good_electrodes = list(np.arange(256))
56-
56+
5757
self.high_gamma_min = 70
5858
self.high_gamma_max = 199
5959

6060
# Bad electrodes are 1-indexed!
61-
61+
6262
if patient == 'EFC400':
6363
self.electrode_name = 'R256GridElectrode electrodes'
6464
self.grid_size = np.array([16, 16])
65-
self.bad_electrodes = [x - 1 for x in [1, 2, 33, 50, 54, 64,
65+
self.bad_electrodes = [x - 1 for x in [1, 2, 33, 50, 54, 64,
6666
128, 129, 193, 194, 256]]
6767
self.blocks_ID_mocha = [3, 23, 72]
68-
68+
6969
elif patient == 'EFC401':
7070
self.electrode_name = 'L256GridElectrode electrodes'
7171
self.grid_size = np.array([16, 16])
72-
self.bad_electrodes = [x - 1 for x in [1, 2, 63, 64, 65, 127,
73-
143, 193, 194, 195, 196,
74-
235, 239, 243, 252, 254,
72+
self.bad_electrodes = [x - 1 for x in [1, 2, 63, 64, 65, 127,
73+
143, 193, 194, 195, 196,
74+
235, 239, 243, 252, 254,
7575
255, 256]]
7676
self.blocks_ID_mocha = [4, 41, 57, 61, 66, 69, 73, 77, 83, 87]
7777

7878
elif patient == "EFC402":
7979
self.electrode_name = 'InferiorGrid electrodes'
8080
self.grid_size = np.array([8, 16])
8181
self.bad_electrodes = [x - 1 for x in list(range(129, 257))]
82-
self.blocks_ID_demo2 = [4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17,
83-
18, 19, 25, 26, 27, 33, 34, 35, 44, 45,
82+
self.blocks_ID_demo2 = [4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17,
83+
18, 19, 25, 26, 27, 33, 34, 35, 44, 45,
8484
46, 47, 48, 49, 58, 59, 60]
85-
85+
8686
elif patient == 'EFC403':
8787
self.electrode_name = 'Grid electrodes'
8888
self.grid_size = np.array([16, 16])
89-
self.bad_electrodes = [x - 1 for x in [129, 130, 131, 132, 133,
89+
self.bad_electrodes = [x - 1 for x in [129, 130, 131, 132, 133,
9090
134, 135, 136, 137, 138,
91-
139, 140, 141, 142, 143,
91+
139, 140, 141, 142, 143,
9292
144, 145, 146, 147, 148,
93-
149, 161, 162, 163, 164,
93+
149, 161, 162, 163, 164,
9494
165, 166, 167, 168, 169,
95-
170, 171, 172, 173, 174,
95+
170, 171, 172, 173, 174,
9696
175, 176, 177, 178, 179,
9797
180, 181]]
98-
self.blocks_ID_demo2 = [4, 7, 10, 13, 19, 20, 21, 28, 35, 39,
99-
52, 53, 54, 55, 56, 59, 60, 61, 62, 63,
100-
64, 70, 73, 74, 75, 76, 77, 83, 92, 93,
101-
94, 95, 97, 98, 99, 100, 101, 108, 109,
98+
self.blocks_ID_demo2 = [4, 7, 10, 13, 19, 20, 21, 28, 35, 39,
99+
52, 53, 54, 55, 56, 59, 60, 61, 62, 63,
100+
64, 70, 73, 74, 75, 76, 77, 83, 92, 93,
101+
94, 95, 97, 98, 99, 100, 101, 108, 109,
102102
110, 111, 112, 113, 114, 115]
103-
103+
104104
else:
105105
self.electrode_name = None
106106
self.grid_size = None
@@ -110,8 +110,8 @@ def __init__(self, nwb_dir, patient):
110110

111111
self.config = None
112112

113-
114-
def make_data(self,
113+
114+
def make_data(self,
115115
chopped_sentence_dir=None,
116116
sentence_dir=None,
117117
chopped_recording_dir=None,
@@ -135,11 +135,12 @@ def make_data(self,
135135
(None)
136136
"""
137137
all_example_dict = [] # not maintained at the moment; stores ALL example dicts
138-
138+
139139
block_pattern = re.compile(r'B(\d+)')
140140

141141
if BPR is None:
142-
raise ValueError("Please specify whether to use common average reference or bipolar referencing")
142+
raise ValueError("Please specify whether to use \
143+
common average reference or bipolar referencing")
143144

144145
if self.config is None:
145146
self.config = {
@@ -148,23 +149,23 @@ def make_data(self,
148149
'target sampling rate': None,
149150
'grid size': self.grid_size
150151
}
151-
152+
152153
for file in self.nwb_files:
153154

154155
create_training_data = True
155-
156+
156157
match = block_pattern.search(file)
157158
block = int(match.group(1))
158-
159+
159160
if self.patient == 'EFC400' or self.patient == 'EFC401':
160161
if block in self.blocks_ID_mocha:
161162
create_training_data = False
162163
elif self.patient == 'EFC402' or self.patient == 'EFC403':
163164
if block in self.blocks_ID_demo2:
164165
create_training_data = False
165-
166+
166167
path = os.path.join(self.nwb_dir, file)
167-
168+
168169
io = NWBHDF5IO(path, load_namespaces=True, mode='r')
169170
nwbfile = io.read()
170171

@@ -173,13 +174,13 @@ def make_data(self,
173174
with NeuralDataProcessor(
174175
nwb_path=path, config=self.config, WRITE=False
175176
) as processor:
176-
177+
177178
# Grab the electrode table and sampling rate,
178-
# and then process the raw ECoG data.
179+
# and then process the raw ECoG data.
179180

180181
electrode_table = nwbfile.acquisition["ElectricalSeries"].\
181182
electrodes.table[:]
182-
183+
183184
self.nwb_sr = nwbfile.acquisition["ElectricalSeries"].\
184185
rate
185186

@@ -197,22 +198,23 @@ def make_data(self,
197198

198199
nwbfile_electrodes = processor.nwb_file.processing['ecephys'].\
199200
data_interfaces['LFP'].\
200-
electrical_series[f'high gamma ({list(self.config["referencing"])[0]})'].\
201+
electrical_series[f'high gamma \
202+
({list(self.config["referencing"])[0]})'].\
201203
data[()][:, self.good_electrodes]
202204

203-
print(f"Number of good electrodes in {file}: {nwbfile_electrodes.shape[1]}")
204-
205+
print(f"Number of good electrodes in {file}: {nwbfile_electrodes.shape[1]}")
206+
205207
# Begin building the WAVE files for wav2vec training
206208
# and evaluation.
207209

208210
# Starts/stops for each intrablock trial.
209-
starts = [int(start)
210-
for start
211+
starts = [int(start)
212+
for start
211213
in list(nwbfile.trials[:]["start_time"] * self.nwb_sr)]
212214
stops = [int(start)
213215
for start
214216
in list(nwbfile.trials[:]["stop_time"] * self.nwb_sr)]
215-
217+
216218
# Manage the speaking segments only... as an option .
217219
# Training data for wav2vec as speaking segments only
218220
# will be saved in the `chopped_sentence_dir` directory.
@@ -222,22 +224,23 @@ def make_data(self,
222224
for start, stop in zip(starts, stops):
223225
speaking_segment = nwbfile_electrodes[start:stop,:]
224226
all_speaking_segments.append(speaking_segment)
225-
227+
226228
if sentence_dir:
227229
file_name = f'{sentence_dir}/{file}_{i}.wav'
228-
sf.write(file_name,
230+
sf.write(file_name,
229231
speaking_segment, 16000, subtype='FLOAT')
230-
232+
231233
i = i + 1
232-
234+
233235
concatenated_speaking_segments = np.concatenate(all_speaking_segments, axis=0)
234-
236+
235237
# Training data: speaking segments only
236238
if create_training_data and chopped_sentence_dir:
237239
num_full_chunks = len(concatenated_speaking_segments) // chunk_length
238240
# last_chunk_size = len(nwbfile_electrodes) % chunk_size
239241

240-
full_chunks = np.split(concatenated_speaking_segments[:num_full_chunks * chunk_length], num_full_chunks)
242+
full_chunks = np.split(concatenated_speaking_segments[:num_full_chunks * chunk_length],
243+
num_full_chunks)
241244
last_chunk = concatenated_speaking_segments[num_full_chunks * chunk_length:]
242245

243246
chunks = full_chunks # + [last_chunk] omit the last non-100000 chunk
@@ -247,54 +250,61 @@ def make_data(self,
247250
file_name = f'{chopped_sentence_dir}/{file}_{i}.wav'
248251
sf.write(file_name, chunk, 16000, subtype='FLOAT')
249252

250-
print(f'Out of distribution block. Number of chopped chunks w/o intertrial silences of length {chunk_length} added to training data: {num_full_chunks}')
251-
252-
253+
print(f'Out of distribution block. \
254+
Number of chopped chunks w/o intertrial silences \
255+
of length {chunk_length} added to training data: {num_full_chunks}')
256+
257+
253258
# Training data: silences included
254259
if create_training_data and chopped_recording_dir:
255260

256-
_nwbfile_electrodes = nwbfile_electrodes # [starts[0]:stops[-1],:] # remove starting/end silences
261+
_nwbfile_electrodes = nwbfile_electrodes # [starts[0]:stops[-1],:]
257262
num_full_chunks = len(_nwbfile_electrodes) // chunk_length
258263
# last_chunk_size = len(_nwbfile_electrodes) % chunk_size
259-
264+
260265
if num_full_chunks != 0:
261266

262-
full_chunks = np.split(_nwbfile_electrodes[:num_full_chunks * chunk_length], num_full_chunks)
267+
full_chunks = np.split(_nwbfile_electrodes[:num_full_chunks * chunk_length],
268+
num_full_chunks)
263269
last_chunk = _nwbfile_electrodes[num_full_chunks * chunk_length:]
264270

265271
chunks = full_chunks # + [last_chunk] omit the last non-100000 chunk
266-
272+
267273
# Checking lengths here
268274
# for chunk in chunks:
269275
# print(chunk.shape)
270276
# print(last_chunk.shape)
271277

272278
# Loop through the chunks and save them as WAV files
273279
for i, chunk in enumerate(chunks):
274-
file_name = f'{chopped_recording_dir}/{file}_{i}.wav' # CHANGE FOR EACH SUBJECT
275-
sf.write(file_name, chunk, 16000, subtype='FLOAT') # adjust as needed
280+
file_name = f'{chopped_recording_dir}/{file}_{i}.wav'
281+
sf.write(file_name, chunk, 16000, subtype='FLOAT')
282+
283+
print(f'Out of distribution block. \
284+
Number of chopped chunks w/ intertrial silences \
285+
of length {chunk_length} added to training data: {num_full_chunks}')
276286

277-
print(f'Out of distribution block. Number of chopped chunks w/ intertrial silences of length {chunk_length} added to training data: {num_full_chunks}')
278-
279287
if full_recording_dir:
280288
file_name = f'{full_recording_dir}/{file}.wav'
281289
sf.write(file_name, nwbfile_electrodes, 16000, subtype='FLOAT')
282290

283291
print('Full recording saved as a WAVE file.')
284292

285-
if (ecog_tfrecords_dir and
293+
if (ecog_tfrecords_dir and
286294
((self.patient in ('EFC402', 'EFC403') and (block in self.blocks_ID_demo2) or
287295
(self.patient in ('EFC400', 'EFC401') and (block in self.blocks_ID_mocha))))):
288-
296+
289297
# Create TFRecords for the ECoG data
290298

291-
high_gamma = downsample(nwbfile_electrodes,
292-
self.nwb_sr,
293-
self.target_sr,
299+
high_gamma = downsample(nwbfile_electrodes,
300+
self.nwb_sr,
301+
self.target_sr,
294302
'NWB',
295303
ZSCORE=True)
296304

297-
phoneme_transcriptions = nwbfile.processing['behavior'].data_interfaces['BehavioralEpochs'].interval_series #['phoneme transcription'].timestamps[:]
305+
phoneme_transcriptions = nwbfile.processing['behavior'].\
306+
data_interfaces['BehavioralEpochs'].\
307+
interval_series
298308

299309
token_type = 'word_sequence'
300310

@@ -327,44 +337,44 @@ def make_data(self,
327337

328338
i0 = np.rint(self.target_sr * t0).astype(int)
329339
iF = np.rint(self.target_sr * tF).astype(int)
330-
340+
331341
# ECOG (C) SEQUENCE
332342
c = high_gamma[i0:iF,:]
333343
# print(c.shape)
334344
# plt.plot(c[:,0])
335345
# break
336-
346+
337347
nsamples = c.shape[0]
338-
348+
339349
# TEXT SEQUENCE
340350
speech_string = trial['transcription'].values[0]
341-
text_sequence = sentence_tokenize(speech_string.split(' ')) # , 'text_sequence')
342-
343-
# AUDIO SEQUENCE
351+
text_sequence = sentence_tokenize(speech_string.split(' '))
352+
353+
# AUDIO SEQUENCE
344354
audio_sequence = []
345-
355+
346356
# PHONEME SEQUENCE
347-
357+
348358
M = iF - i0
349-
350-
max_seconds = max_seconds_dict.get(token_type) # , 0.2) # i don't think this 0.2 default is necessary for the scope of this
359+
360+
max_seconds = max_seconds_dict.get(token_type)
351361
max_samples = int(np.floor(self.target_sr * max_seconds))
352362
max_length = min(M, max_samples)
353-
363+
354364
phoneme_array = transcription_to_array(
355365
t0, tF, phoneme_onset_times, phoneme_offset_times,
356-
phoneme_transcript, max_length, self.target_sr
366+
phoneme_transcript, max_length, self.target_sr
357367
)
358-
368+
359369
phoneme_sequence = [ph.encode('utf-8') for ph in phoneme_array]
360-
370+
361371
if len(phoneme_sequence) != nsamples:
362372
if len(phoneme_sequence) > nsamples:
363373
phoneme_sequence = [phoneme_sequence[i] for i in range(nsamples)]
364374
else:
365375
for i in range(nsamples - len(phoneme_sequence)):
366376
phoneme_sequence.append(phoneme_sequence[len(phoneme_sequence) - 1])
367-
377+
368378
print('\n------------------------')
369379
print(f'For sentence {index}: ')
370380
print(c[0:5,0:5])
@@ -374,17 +384,21 @@ def make_data(self,
374384
print(f'Length of phoneme sequence: {len(phoneme_sequence)}')
375385
print(phoneme_sequence)
376386
print('------------------------\n')
377-
378-
example_dicts.append({'ecog_sequence': c, 'text_sequence': text_sequence, 'audio_sequence': [], 'phoneme_sequence': phoneme_sequence,})
387+
388+
example_dicts.append({'ecog_sequence': c,
389+
'text_sequence': text_sequence,
390+
'audio_sequence': [],
391+
'phoneme_sequence': phoneme_sequence,})
379392

380393
# all_example_dict.extend(example_dicts)
381394
# print(len(example_dicts))
382395
# print(len(all_example_dict))
383-
write_to_Protobuf(f'{ecog_tfrecords_dir}/{self.patient}_B{block}.tfrecord', example_dicts)
396+
write_to_Protobuf(f'{ecog_tfrecords_dir}/{self.patient}_B{block}.tfrecord',
397+
example_dicts)
384398

385399
print('In distribution block. TFRecords created.')
386400

387-
except Exception as e:
401+
except Exception as e:
388402
print(f"An error occured and block {path} is not inluded in the wav2vec training data: {e}")
389403

390404
io.close()
@@ -435,7 +449,8 @@ def transcription_to_array(trial_t0, trial_tF, onset_times, offset_times, transc
435449
# print('exactly one phoneme:', np.all(np.sum(indices, 0) == 1))
436450
assert np.all(np.sum(indices, 0) < 2)
437451
except:
438-
pdb.set_trace()
452+
# pdb.set_trace()
453+
pass
439454

440455
# ...but there can be locations with *zero* phonemes; assume 'pau' here
441456
transcript = np.insert(transcript, 0, 'pau')

0 commit comments

Comments
 (0)