-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfhs_split_dataframe.py
126 lines (115 loc) · 4.11 KB
/
fhs_split_dataframe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
fhs_split_dataframe.py
module for creating folds, splitting datasets;
"""
import os
import random
import numpy as np
def get_fhs_ids(df_pts):
"""
from a dataframe, get idtype+id (forms a unique FHS ID);
return an array with all the unique FHS IDs;
"""
idtypes = df_pts.idtype.values.ravel('K')
ids = df_pts.id.values.ravel('K')
return np.unique([f'{idtypes[i]}-{str(ids[i]).zfill(4)}' for i, _ in enumerate(idtypes)])
def has_transcript(df_raw):
"""
pts that have at least one tscript
"""
return df_raw.loc[df_raw["pt_has_tscript"] == "1"]
def get_static_and_remaining_ids(df_raw, get_static, get_pt_ids):
"""
get rows for those that are in the static fold and the remaining IDs
"""
all_ids = get_pt_ids(df_raw)
static_ids = get_pt_ids(get_static(df_raw))
return static_ids, np.array([i for i in all_ids if i not in static_ids])
def create_folds(sample_ids, num_folds, seed):
"""
take datasamples, split them into a number of folds (num_folds), set random seed;
"""
random.seed(seed)
lst_idx = np.array(range(len(sample_ids)))
random.shuffle(lst_idx)
return [lst_idx[np.arange(len(sample_ids)) % num_folds == i] for i in range(num_folds)]
def get_fold(sample_ids, folds, vld_idx, tst_idx, mode):
"""
fld: numpy array containing the folds and the data indices for that fofld;
vld_idx: validation fold index;
tst_idx: test fold index;
mode: 'VLD', 'TST', 'TRN'
"""
assert mode in {'TRN', 'VLD', 'TST'}, f"{mode} is not TRN VLD OR TST"
if mode == 'VLD':
idx = folds[vld_idx]
elif mode == 'TST':
idx = folds[tst_idx]
elif mode == 'TRN':
all_fold_indices = np.arange(len(folds))
## if 5 folds, then all_fold_indices = [0, 1, 2, 3, 4]
all_fold_indices = all_fold_indices[all_fold_indices != vld_idx]
all_fold_indices = all_fold_indices[all_fold_indices != tst_idx]
## keep all fold indices except for the TRN and VLD indices;
idx = np.concatenate([folds[all_fold_indices[i]] for i in range(len(all_fold_indices))])
return sample_ids[idx]
def get_holdout_fold(sample_ids, folds, vld_idx, mode):
"""
get fold if holdout_test is True
"""
assert mode in {'TRN', 'VLD'}, f"{mode} is not TRN OR VLD"
if mode == 'VLD':
idx = folds[vld_idx]
else:
all_fold_indices = np.arange(len(folds))
all_fold_indices = all_fold_indices[all_fold_indices != vld_idx]
idx = np.concatenate([folds[all_fold_indices[i]] for i in range(len(all_fold_indices))])
return sample_ids[idx]
def yield_aud_and_mni(row):
"""
yield audio fn and mni vector;
"""
yield (row['seg_fp'], row['mni_brain'])
def yield_rand_seg_and_mni(row, **kwargs):
"""
yield <num_pt_segments> random 5 minute segments of pt speech and mni vector
"""
mni_vector = row['mni_brain']
num_pt_segments = kwargs.get('num_pt_segments')
seg_min = kwargs.get('seg_min', 5)
pt_only_fp = row['pt_npy']
pt_only_npy = np.load(pt_only_fp)
seg_dur = seg_min * 60
assert pt_only_npy.shape[0] % 100 == 0, row['pt_npy']
last_start = int(pt_only_npy.shape[0] / 100 - seg_dur)
## the latest start we can pick is the length of the audio minus length of segment
## length of audio in seconds is shape / 100 bc each MFCC unit represents 10 milliseconds
## length of segment is in seconds already;
all_pairs = [(s, s + seg_dur) for s in range(last_start + 1)]
chosen_pairs = set()
for _ in range(num_pt_segments):
pair = random.choice(all_pairs)
while pair in chosen_pairs:
pair = random.choice(all_pairs)
chosen_pairs.add(pair)
start, end = pair
start *= 100
end *= 100
## convert timestamp<seconds> to timestamp<10-milliseconds> for MFCC indexing
yield (pt_only_fp, mni_vector, start, end)
def create_pt_segment(row, pt_segment_root, pt_only_npy, start, end):
"""
create pt segment npy;
"""
id_date = row['id_date']
npy_dir = os.path.join(pt_segment_root, id_date.split('_')[0], id_date)
npy_fn = f'start_{start}_end_{end}_{id_date}.npy'
npy_fp = os.path.join(npy_dir, npy_fn)
if not os.path.isfile(npy_fp):
if not os.path.isdir(npy_dir):
os.makedirs(npy_dir)
pt_segment = pt_only_npy[start*100:end*100]
## convert timestamp<seconds> to timestamp<10-milliseconds> for MFCC indexing
np.save(npy_fp, pt_segment)
print(f'created {npy_fp};')
return npy_fp