Skip to content

Commit 4f22615

Browse files
committed
2 parents 15eabc8 + 6f109e9 commit 4f22615

7 files changed

Lines changed: 1207 additions & 12 deletions

File tree

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
__pycache__
22
slurm_logs
3-
.env
3+
.env
4+
fig
5+
figs

all_categories.py

Lines changed: 457 additions & 0 deletions
Large diffs are not rendered by default.

generate_behav.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import pandas as pd
2+
import mne
3+
import numpy as np
4+
import os
5+
from preprocessing_utils import read_eeg_data
6+
from pathlib import Path
7+
from tqdm import tqdm
8+
import seaborn as sns
9+
import matplotlib.pyplot as plt
10+
11+
cnt = np.zeros((6, 20), dtype=int)
12+
data_dir = Path("/srv/eeg_reconstruction/shared/data/raw_eeg/Alljoined-1.6M")
13+
14+
for sub in tqdm(range(1, 21)):
15+
for sess in tqdm(range(1, 5), leave=False):
16+
try:
17+
for b in range(1, 20):
18+
raw = read_eeg_data(
19+
data_dir
20+
/ f"sub-{sub:02d}"
21+
/ f"session_{sess:02d}"
22+
/ f"block_{b:02d}"
23+
)
24+
events, event_id = mne.events_from_annotations(
25+
raw, regexp="behav.*", verbose=False
26+
)
27+
28+
for e in event_id.keys():
29+
behav_val = int(e.split(",")[1])
30+
cnt[behav_val, sub - 1] += 1
31+
except FileNotFoundError:
32+
continue
33+
34+
plt.figure(figsize=(20, 6))
35+
sns.heatmap(cnt, annot=True, fmt="d", cmap="Blues")
36+
plt.xlabel("Subject")
37+
plt.xticks(ticks=np.arange(20) + 0.5, labels=np.arange(1, 21))
38+
plt.ylabel("Behavioral Value")
39+
plt.title("Count of Behavioral Values per Subject")
40+
plt.savefig("behavioral_counts.png")
41+
42+
tot = cnt.sum(axis=0)
43+
pct = np.vstack(
44+
(
45+
(cnt[0] + cnt[3]).reshape(1, -1),
46+
(cnt[1] + cnt[2]).reshape(1, -1),
47+
(cnt[4] + cnt[5]).reshape(1, -1),
48+
)
49+
)
50+
pct = pct / tot
51+
plt.figure(figsize=(20, 6))
52+
sns.heatmap(pct, annot=True, fmt=".2f", cmap="Blues")
53+
plt.xlabel("Subject")
54+
plt.xticks(ticks=np.arange(20) + 0.5, labels=np.arange(1, 21))
55+
plt.ylabel("Accuracy percent")
56+
plt.yticks(
57+
ticks=np.arange(3) + 0.5, labels=["Correct", "Incorrect", "No Response"]
58+
)
59+
plt.title("Count of Behavioral Values per Subject")
60+
plt.savefig("behavioral_accuracy.png")

lda_utils.py

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
import numpy as np
2+
from joblib import Parallel, delayed
3+
from sklearn.metrics import accuracy_score
4+
from sklearn.metrics import roc_auc_score
5+
from tqdm import tqdm
6+
7+
8+
# Function to process a single time point (t)
9+
def process_window(
10+
t,
11+
window,
12+
train_A,
13+
train_B,
14+
test_A,
15+
test_B,
16+
classifier,
17+
):
18+
# (t, window,num_iter, avg, A, B, n_sensors, super_sample,
19+
# n_trials, split_idx,sampling_ratio,classifier) = args
20+
win = []
21+
22+
# Setup window size
23+
t2 = min(t + window, train_A.shape[2])
24+
25+
trial_length = t2 - t
26+
win = np.arange(t, t2) # adjust for pythons 0 NOT SURE IF IT WORKS
27+
28+
# concatenate them
29+
train_x = np.concatenate((train_A[:, :, win], train_B[:, :, win]), axis=0)
30+
test_x = np.concatenate((test_A[:, :, win], test_B[:, :, win]), axis=0)
31+
train_y = np.array([1] * train_A.shape[0] + [2] * train_B.shape[0])
32+
test_y = np.array([1] * test_A.shape[0] + [2] * test_B.shape[0])
33+
34+
max_size_train = max(train_A.shape[0],train_B.shape[0]) # get the minimum size of data, ideally not
35+
max_size_test = max(test_A.shape[0],test_B.shape[0]) # get the minimum size of data, ideally not
36+
37+
#print("Calculating Averages")
38+
39+
train_x, train_y = average_trials(train_x, train_y, average_trials=10, max_sampling = max_size_train)
40+
test_x, test_y = average_trials(test_x, test_y, average_trials=10, max_sampling = max_size_test)
41+
42+
if np.ndim(train_x) > 2:
43+
train_x = train_x.reshape(
44+
train_x.shape[0], train_x.shape[1] * train_x.shape[2]
45+
)
46+
if np.ndim(test_x) > 2:
47+
test_x = test_x.reshape(
48+
test_x.shape[0], test_x.shape[1] * test_x.shape[2]
49+
)
50+
if np.ndim(train_y) > 2:
51+
train_y = train_y.reshape(
52+
train_y.shape[0], train_y.shape[1] * train_y.shape[2]
53+
)
54+
if np.ndim(test_y) > 2:
55+
test_y = test_y.reshape(
56+
test_y.shape[0], test_y.shape[1] * test_y.shape[2]
57+
)
58+
59+
if np.any(np.abs(train_x)) != 0:
60+
classifier.fit(train_x, train_y)
61+
62+
# test it on same time points
63+
pred_y = classifier.predict(test_x)
64+
acc = roc_auc_score(test_y, pred_y)
65+
return {"AUC": acc, "time": t}
66+
else:
67+
return {"AUC": 0, "time": t}
68+
69+
70+
def run_LDA(
71+
train_A,
72+
train_B,
73+
test_A,
74+
test_B,
75+
classifier,
76+
window=1,
77+
step=1,
78+
):
79+
80+
all_results = []
81+
82+
# Add progress bar to joblib.Parallel
83+
all_results = Parallel(n_jobs=-1)(
84+
delayed(process_window)(
85+
t,
86+
window,
87+
train_A,
88+
train_B,
89+
test_A,
90+
test_B,
91+
classifier,
92+
)
93+
for t in tqdm(range(0, train_A.shape[2] - window - 1, step))
94+
)
95+
96+
return all_results
97+
98+
99+
def prep_decoding_data_hierarchical(
100+
merged_train,
101+
merged_test,
102+
cat_a_spec,
103+
cat_b_spec,
104+
train_df,
105+
test_df,
106+
category_hierarchy,
107+
word_column="category_name",
108+
):
109+
"""
110+
Prepares epoched data for decoding based on potentially hierarchical categories.
111+
112+
Args:
113+
epoched_data: The MNE Epochs object containing all data.
114+
cat_a_spec: A string or list of strings specifying top-level categories
115+
(keys in category_hierarchy) or specific words for category A.
116+
cat_b_spec: A string or list of strings specifying top-level categories
117+
(keys in category_hierarchy) or specific words for category B.
118+
stim_df: Pandas DataFrame with stimulus information. Must include
119+
a column with the name specified in `word_column`.
120+
category_hierarchy: A potentially nested dictionary where keys are category
121+
names and values are lists/sets of words or further
122+
nested dictionaries of subcategories.
123+
word_column (str): The name of the column in stim_df containing the
124+
individual stimulus words to match against the hierarchy.
125+
Defaults to 'word'.
126+
127+
Returns:
128+
tuple: (data_a, data_b) containing the selected MNE Epochs objects
129+
for category A and category B, or empty Epochs selections if no
130+
data is found for a category.
131+
"""
132+
133+
# --- Get Words for Each Category Specification ---
134+
words_a = get_words_in_categories(cat_a_spec, category_hierarchy)
135+
words_b = get_words_in_categories(cat_b_spec, category_hierarchy)
136+
137+
print(
138+
f"Category A Specification '{cat_a_spec}' maps to"
139+
f" {len(words_a)} words."
140+
)
141+
if words_a:
142+
print(f" First few A words: {words_a[:10]}...")
143+
print(
144+
f"Category B Specification '{cat_b_spec}' maps to"
145+
f" {len(words_b)} words."
146+
)
147+
if words_b:
148+
print(f" First few B words: {words_b[:10]}...")
149+
150+
# --- Filter stim_df to Find Matching Trials ---
151+
# Use .copy() to avoid SettingWithCopyWarning if stim_df is modified later outside the function
152+
train_df_a = train_df[train_df[word_column].isin(words_a)]
153+
train_df_b = train_df[train_df[word_column].isin(words_b)]
154+
test_df_a = test_df[test_df[word_column].isin(words_a)]
155+
test_df_b = test_df[test_df[word_column].isin(words_b)]
156+
print(train_df_a)
157+
print(train_df_b)
158+
print(test_df_a)
159+
print(test_df_b)
160+
161+
# --- Extract Data Using Original Epoch Indices ---
162+
train_indices_a = train_df_a.index
163+
train_indices_b = train_df_b.index
164+
test_indices_a = test_df_a.index
165+
test_indices_b = test_df_b.index
166+
167+
print(
168+
f"Found {len(train_indices_a)} train epochs matching Category A spec."
169+
)
170+
if len(train_indices_a):
171+
print(f" First few A indices: {train_indices_a[:10]}...")
172+
else:
173+
print(
174+
"Warning: No train epochs found matching Category A specification."
175+
)
176+
print(
177+
f"Found {len(train_indices_b)} train epochs matching Category B spec."
178+
)
179+
if len(train_indices_b):
180+
print(f" First few B indices: {train_indices_b[:10]}...")
181+
else:
182+
print(
183+
"Warning: No train epochs found matching Category B specification."
184+
)
185+
print(f"Found {len(test_indices_a)} test epochs matching Category A spec.")
186+
if len(test_indices_a):
187+
print(f" First few A indices: {test_indices_a[:10]}...")
188+
else:
189+
print(
190+
"Warning: No test epochs found matching Category A specification."
191+
)
192+
print(f"Found {len(test_indices_b)} test epochs matching Category B spec.")
193+
if len(test_indices_b):
194+
print(f" First few B indices: {test_indices_b[:10]}...")
195+
else:
196+
print(
197+
"Warning: No test epochs found matching Category B specification."
198+
)
199+
200+
return (
201+
merged_train[train_indices_a],
202+
merged_train[train_indices_b],
203+
merged_test[test_indices_a],
204+
merged_test[test_indices_b],
205+
)
206+
207+
208+
def get_words_in_categories(categories_spec, hierarchy):
209+
"""
210+
Collects all unique words associated with the given category names or specific words,
211+
searching recursively/iteratively through the nested hierarchy starting from the specified items.
212+
213+
Args:
214+
categories_spec (list): A list of strings, where each string is either a
215+
top-level category key from the hierarchy or a
216+
specific word to include directly.
217+
hierarchy (dict): The potentially nested dictionary defining categories.
218+
Values can be lists/sets of words or nested dictionaries.
219+
220+
Returns:
221+
list: A list of unique words found under the specified categories or
222+
included directly from the spec.
223+
"""
224+
final_words = set()
225+
items_to_process = list(categories_spec) # Start with user-provided spec
226+
227+
while items_to_process:
228+
item = items_to_process.pop(0)
229+
230+
if not isinstance(item, str):
231+
print(
232+
f"Warning: Skipping non-string item in categories_spec: {item}"
233+
)
234+
continue
235+
236+
# Check if the item is a key in the *top level* of the hierarchy
237+
if item in hierarchy:
238+
# It's a category key, start traversal from its value
239+
value_queue = [
240+
hierarchy[item]
241+
] # Queue for values within this category branch
242+
243+
while value_queue:
244+
current_val = value_queue.pop(0)
245+
246+
if isinstance(current_val, dict):
247+
# If it's a sub-dictionary, add its values to the queue for processing
248+
for sub_val in current_val.values():
249+
value_queue.append(sub_val)
250+
elif isinstance(current_val, (list, set, tuple)):
251+
# If it's a list/set, assume it contains words
252+
final_words.update(
253+
w for w in current_val if isinstance(w, str)
254+
)
255+
elif isinstance(current_val, str):
256+
# If a string is found directly as a value (less common structure)
257+
final_words.add(current_val)
258+
# Ignore other data types found within the hierarchy values
259+
260+
else:
261+
# Item is not a top-level category key, assume it's a specific word
262+
final_words.add(item)
263+
264+
return list(final_words)
265+
266+
267+
268+
def average_trials(data, labels, average_trials=5,max_sampling=1000):
269+
270+
#print(f'Start Averaging {average_trials} Trials with Sampling {max_sampling}')
271+
if average_trials < 2:
272+
averaged_data = data
273+
averaged_labels = labels
274+
else:
275+
276+
averaged_data = []
277+
averaged_labels = []
278+
279+
# Separate data based on labels
280+
unique_labels = np.unique(labels)
281+
# PARALELLIZE
282+
for label in unique_labels:
283+
label_data = data[labels == label]
284+
285+
# Loop over the data and collect averages with substitution
286+
for _ in range(int(max_sampling)):
287+
# Sample with replacement
288+
indices = np.random.choice(label_data.shape[0], 5, replace=True)
289+
batch_data = label_data[indices]
290+
291+
# Compute average and append to list
292+
averaged_trial = np.mean(batch_data, axis=0)
293+
averaged_data.append(averaged_trial)
294+
averaged_labels.append(label)
295+
296+
return np.array(averaged_data), np.array(averaged_labels)

preprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _make_configs_from_args(args: argparse.Namespace) -> Configs:
166166
CONFIGS = _make_configs_from_args(ARGS)
167167

168168
OUTPUT_DIR = (
169-
PROJECT_DIR / "preprocessed_data" / "Alljoined-1.7M" / f"sub-{SUB:02d}"
169+
PROJECT_DIR / "preprocessed_data" / "Alljoined-1.6M" / f"sub-{SUB:02d}"
170170
)
171171
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
172172

0 commit comments

Comments
 (0)