-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtranscript_preprocessing.py
More file actions
366 lines (306 loc) · 11.6 KB
/
transcript_preprocessing.py
File metadata and controls
366 lines (306 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import re
from typing import List, Tuple, Optional
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
# Regex patterns for transcript cleaning
SPEAKER_PATTERN = re.compile(r"^\s*(Pat|Oth)\s*:", re.IGNORECASE)
PAUSE_PATTERN = re.compile(r"\((\d+)\s*second[s]?\)", re.IGNORECASE)
# Patterns to keep and map to tokens
KEEP_NONVERBAL_PATTERNS: List[Tuple[str, str]] = [
# laughter
(r"\(laughs?\)", "<LAUGH>"),
(r"\(laughter\)", "<LAUGH>"),
(r"\(oth laughs\)", "<LAUGH>"),
# sighs
(r"\(sighs\)", "<SIGH>"),
# tongue-click / tut
(r"\(tuts\)", "<TUT>"),
(r"\(clicking tongue\)", "<CLICK_TONGUE>"),
(r"\(clicks tongue\)", "<CLICK_TONGUE>"),
]
# Patterns to delete (not useful for cognitive status)
REMOVE_NONVERBAL_PATTERNS: List[str] = [
# throat / cough / sniff
r"\(clears throat\)",
r"\(coughs\)",
r"\(oth coughs\)",
r"\(sniffs\)",
# whistling
r"\(whistling\)",
r"\(whistles through teeth\)",
# environmental / device noises
r"\(buzzer sounds\)",
r"\(mobile phone alert\)",
r"\(phone\s*ringing\)",
r"\(phone stops\s*ringing\)",
r"\(alert sound on computer\)",
r"\(mouse click\)",
r"\(dog barking\)",
r"\(noise\)",
# speech style / name information
r"\(whispering\)",
r"\(int2 name\)",
r"\(oth name\)",
# unclear
r"\(\?\)",
]
def extract_patient_speech(text: str, patient_tag: str = "Pat") -> str:
"""
Extract all speaker indication (i.e. "Pat:") from transcripts and remove any lines not spoken by the patient ("Oth: ...").
"Pat" indicates patient, "Oth" indicated other speaker.
Lines with no explicit speaker indication are attributed to the most recent speaker.
Args:
text (str): Transcript text.
patient_tag (str): Tag indicating patient speech.
Returns:
str: Cleaned transcript containing only patient speech.
"""
if not isinstance(text, str):
return ""
lines = text.splitlines()
current_speaker = None
patient_chunks: List[str] = []
patient_prefix = patient_tag.lower()[:3] # 'pat
for raw_line in lines:
line = raw_line.rstrip()
speaker = SPEAKER_PATTERN.match(line)
if speaker:
current_speaker = speaker.group(1).lower() # 'pat' or 'oth'
content = line[speaker.end():].strip()
else:
# No explicit speaker label, so attribute to recent speaker
content = line.strip()
if current_speaker and current_speaker.startswith(patient_prefix):
if content:
patient_chunks.append(content)
patient_text = " ".join(patient_chunks)
patient_text = re.sub(r"\s+", " ", patient_text).strip()
return patient_text
def replace_pauses(text: str) -> str:
"""
Replace occurrences of "(X second(s))" with explicit pause tokens.
Args:
text (str): Transcript text.
Returns:
str: Transcript text with pause tokens.
"""
if not isinstance(text, str):
return ""
text = text.replace("\n", " ")
def get_token(match: re.Match) -> str:
seconds = int(match.group(1))
if seconds <= 2:
token = "<PAUSE_SHORT>"
elif seconds <= 5:
token = "<PAUSE_MEDIUM>"
else:
token = "<PAUSE_LONG>"
return f" {token} "
text = PAUSE_PATTERN.sub(get_token, text)
text = re.sub(r"\s+", " ", text).strip()
return text
def normalize_nonverbals(text: str) -> str:
"""
Normalize nonverbal markers in the transcript text by replacing some as tokens and removing others.
Args:
text (str): Transcript text.
Returns:
str: Transcript text with normalized nonverbal markers.
"""
if not isinstance(text, str):
return text
for pattern, token in KEEP_NONVERBAL_PATTERNS:
text = re.sub(pattern, f" {token} ", text, flags=re.IGNORECASE)
for pattern in REMOVE_NONVERBAL_PATTERNS:
text = re.sub(pattern, " ", text, flags=re.IGNORECASE)
text = re.sub(r"\s+", " ", text).strip()
return text
def preprocess_transcript(raw_text: str) -> str:
"""
Full preprocessing pipeline for a single transcript field:
1. Extract patient-only speech (Pat: ...).
2. Replace (X seconds) with pause tokens.
3. Normalize nonverbal markers.
4. Lowercase and collapse whitespace.
Args:
raw_text (str): Raw transcript text.
Returns:
str: Fully preprocessed transcript text.
"""
text = extract_patient_speech(raw_text)
text = replace_pauses(text)
text = normalize_nonverbals(text)
text = text.lower()
text = re.sub(r"\s+", " ", text).strip()
return text
def preprocess_dataframe(
input_csv_path: str,
transcript_cols: List[str] = ["Transcript_PFT", "Transcript_CTD", "Transcript_SFT"],
output_csv_path: Optional[str] = None,
) -> pd.DataFrame:
"""
Load the original dataset csv, preprocess transcript columns, and
return a new dataframe.
This function:
- Loads the raw CSV.
- Applies the full preprocessing pipeline to each transcript column
(patient-only speech, pause tokens, nonverbal normalization, lowercase).
- Builds a smaller DataFrame containing only Record-ID, Class, and
the specified transcript columns.
- Optionally saves this modeling subset to a new CSV.
Args:
input_csv_path (str): Path to the raw CSV file.
transcript_cols (list[str]): Names of transcript columns
output_csv_path (str, optional): If provided, the cleaned modeling
subset will be written to this path as a CSV.
Returns:
pd.DataFrame: A DataFrame named transcript_df, which
contains only the columns:
["Record-ID", "Class"] + transcript_cols, with all transcript
fields fully preprocessed and ready for tokenization.
"""
df_full = pd.read_csv(input_csv_path)
# Apply preprocessing to each transcript column
for col in transcript_cols:
if col in df_full.columns:
df_full[col] = df_full[col].apply(preprocess_transcript)
else:
raise ValueError(f"Column '{col}' not found in CSV.")
required_cols = ["Record-ID", "Class"] + transcript_cols
transcript_df = df_full[required_cols].copy()
# add numeric label column for classes
label_map = {"HC": 0, "MCI": 1, "Dementia": 2}
transcript_df["Label"] = transcript_df["Class"].map(label_map)
# Optionally save as CSV
if output_csv_path is not None:
transcript_df.to_csv(output_csv_path, index=False)
return transcript_df
def get_stratified_kfold_splits(
transcript_df: pd.DataFrame,
transcript_col: str,
label_col: str = "Class",
n_splits: int = 5,
seed: int = 42,
):
"""
Generate stratified k-fold train/test splits for a transcript column, preserving
the class distribution of the dementia labels.
Args:
transcript_df (pd.DataFrame): DataFrame containing at least the transcript column and label column.
transcript_col (str): Name of the transcript column to use as the input features.
label_col (str, optional): Column containing class labels
(default: "Class"). Labels should already be numeric (0/1/2).
n_splits (int, optional): Number of folds for cross-validation.
Defaults to 5, which corresponds to ≈80/20 train-test splits.
seed (int, optional): Random seed used for shuffling and ensuring
reproducible splits. Defaults to 42.
Yields:
tuple: A tuple of the form:
(fold_idx, train_idx, test_idx)
where:
fold_idx (int): The 1-based fold number.
train_idx (np.ndarray): Indices of training samples for this fold.
test_idx (np.ndarray): Indices of test samples for this fold.
Raises:
ValueError: If transcript_col or label_col are not present in the
supplied DataFrame.
"""
if transcript_col not in transcript_df.columns:
raise ValueError(f"Transcript column '{transcript_col}' not found in transcript_df.")
if label_col not in transcript_df.columns:
raise ValueError(f"Label column '{label_col}' not found in transcript_df.")
X = transcript_df[transcript_col].values
y = transcript_df[label_col].values
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
for fold_idx, (train_idx, test_idx) in enumerate(skf.split(X, y), start=1):
yield fold_idx, train_idx, test_idx
def load_transcript_splits(
csv_path: str,
transcript_cols: List[str] = ["Transcript_PFT", "Transcript_CTD", "Transcript_SFT"]
) -> dict:
"""
Load cleaned transcripts CSV, drop NaNs per transcript type,
and return a dict mapping transcript column name to its cleaned DataFrame.
"""
transcript_df = pd.read_csv(csv_path)
df_by_transcript = {}
for col in transcript_cols:
# Drop NaNs only in this transcript type column
df_clean = transcript_df.dropna(subset=[col]).copy()
df_by_transcript[col] = df_clean
return df_by_transcript
def add_binary_label_column(
df: pd.DataFrame,
label_col: str = "Label",
new_col: str = "Binary_Label",
save_path: Optional[str] = None,
) -> pd.DataFrame:
"""
Add a binary label column for HC vs Impaired (MCI + Dementia).
The new column encodes:
0 -> Healthy Control (HC)
1 -> Impaired (MCI or Dementia)
Args:
df:
Input DataFrame containing either `class_col` or `label_col`.
label_col:
Name of the numeric label column with values {0, 1, 2}.
new_col:
Name of the binary label column to create.
save_path:
Optional path to save the updated DataFrame as a CSV file
Returns:
A new DataFrame with an added `new_col` column where:
0 = Healthy Control (HC)
1 = Impaired (MCI or Dementia)
"""
df_out = df.copy()
if label_col in df_out.columns:
mapping = {
0: 0, # HC
1: 1, # MCI
2: 1, # Dementia
}
df_out[new_col] = df_out[label_col].map(mapping)
else:
raise ValueError(
f"`{label_col}` found in DataFrame. Cannot construct {label_col}."
)
if save_path is not None:
df_out.to_csv(save_path, index=False)
return df_out
if __name__ == "__main__":
INPUT_CSV = "data/dementia_data.csv"
OUTPUT_CSV = "data/transcripts_cleaned.csv"
transcript_df = preprocess_dataframe(
input_csv_path=INPUT_CSV,
transcript_cols=["Transcript_PFT", "Transcript_CTD", "Transcript_SFT"],
output_csv_path=OUTPUT_CSV,
)
df = add_binary_label_column(
transcript_df,
label_col="Label",
new_col="Binary_Label",
save_path="data/transcripts_cleaned.csv",
)
print("df shape:", df.shape)
print(df.head())
# 2. Print out token sizes for cleaned transcripts
for col in ["Transcript_PFT", "Transcript_CTD", "Transcript_SFT"]:
if col in df.columns:
lengths = (
df[col]
.fillna("")
.str.split()
.apply(len)
)
print(f"{col}:")
print(f" max length (tokens) = {lengths.max()}")
print(f" mean length (tokens) = {lengths.mean():.2f}")
print(f" 95th percentile = {lengths.quantile(0.95):.0f}")
print()
for fold_idx, train_idx, test_idx in get_stratified_kfold_splits(
df, transcript_col="Transcript_PFT", label_col="Class", n_splits=5):
print(f"Fold {fold_idx}: ")
print(f"train size = {len(train_idx)}, test size = {len(test_idx)}")