1+ import numpy as np
2+ from joblib import Parallel , delayed
3+ from sklearn .metrics import accuracy_score
4+ from sklearn .metrics import roc_auc_score
5+ from tqdm import tqdm
6+
7+
8+ # Function to process a single time point (t)
9+ def process_window (
10+ t ,
11+ window ,
12+ train_A ,
13+ train_B ,
14+ test_A ,
15+ test_B ,
16+ classifier ,
17+ ):
18+ # (t, window,num_iter, avg, A, B, n_sensors, super_sample,
19+ # n_trials, split_idx,sampling_ratio,classifier) = args
20+ win = []
21+
22+ # Setup window size
23+ t2 = min (t + window , train_A .shape [2 ])
24+
25+ trial_length = t2 - t
26+ win = np .arange (t , t2 ) # adjust for pythons 0 NOT SURE IF IT WORKS
27+
28+ # concatenate them
29+ train_x = np .concatenate ((train_A [:, :, win ], train_B [:, :, win ]), axis = 0 )
30+ test_x = np .concatenate ((test_A [:, :, win ], test_B [:, :, win ]), axis = 0 )
31+ train_y = np .array ([1 ] * train_A .shape [0 ] + [2 ] * train_B .shape [0 ])
32+ test_y = np .array ([1 ] * test_A .shape [0 ] + [2 ] * test_B .shape [0 ])
33+
34+ max_size_train = max (train_A .shape [0 ],train_B .shape [0 ]) # get the minimum size of data, ideally not
35+ max_size_test = max (test_A .shape [0 ],test_B .shape [0 ]) # get the minimum size of data, ideally not
36+
37+ #print("Calculating Averages")
38+
39+ train_x , train_y = average_trials (train_x , train_y , average_trials = 10 , max_sampling = max_size_train )
40+ test_x , test_y = average_trials (test_x , test_y , average_trials = 10 , max_sampling = max_size_test )
41+
42+ if np .ndim (train_x ) > 2 :
43+ train_x = train_x .reshape (
44+ train_x .shape [0 ], train_x .shape [1 ] * train_x .shape [2 ]
45+ )
46+ if np .ndim (test_x ) > 2 :
47+ test_x = test_x .reshape (
48+ test_x .shape [0 ], test_x .shape [1 ] * test_x .shape [2 ]
49+ )
50+ if np .ndim (train_y ) > 2 :
51+ train_y = train_y .reshape (
52+ train_y .shape [0 ], train_y .shape [1 ] * train_y .shape [2 ]
53+ )
54+ if np .ndim (test_y ) > 2 :
55+ test_y = test_y .reshape (
56+ test_y .shape [0 ], test_y .shape [1 ] * test_y .shape [2 ]
57+ )
58+
59+ if np .any (np .abs (train_x )) != 0 :
60+ classifier .fit (train_x , train_y )
61+
62+ # test it on same time points
63+ pred_y = classifier .predict (test_x )
64+ acc = roc_auc_score (test_y , pred_y )
65+ return {"AUC" : acc , "time" : t }
66+ else :
67+ return {"AUC" : 0 , "time" : t }
68+
69+
70+ def run_LDA (
71+ train_A ,
72+ train_B ,
73+ test_A ,
74+ test_B ,
75+ classifier ,
76+ window = 1 ,
77+ step = 1 ,
78+ ):
79+
80+ all_results = []
81+
82+ # Add progress bar to joblib.Parallel
83+ all_results = Parallel (n_jobs = - 1 )(
84+ delayed (process_window )(
85+ t ,
86+ window ,
87+ train_A ,
88+ train_B ,
89+ test_A ,
90+ test_B ,
91+ classifier ,
92+ )
93+ for t in tqdm (range (0 , train_A .shape [2 ] - window - 1 , step ))
94+ )
95+
96+ return all_results
97+
98+
99+ def prep_decoding_data_hierarchical (
100+ merged_train ,
101+ merged_test ,
102+ cat_a_spec ,
103+ cat_b_spec ,
104+ train_df ,
105+ test_df ,
106+ category_hierarchy ,
107+ word_column = "category_name" ,
108+ ):
109+ """
110+ Prepares epoched data for decoding based on potentially hierarchical categories.
111+
112+ Args:
113+ epoched_data: The MNE Epochs object containing all data.
114+ cat_a_spec: A string or list of strings specifying top-level categories
115+ (keys in category_hierarchy) or specific words for category A.
116+ cat_b_spec: A string or list of strings specifying top-level categories
117+ (keys in category_hierarchy) or specific words for category B.
118+ stim_df: Pandas DataFrame with stimulus information. Must include
119+ a column with the name specified in `word_column`.
120+ category_hierarchy: A potentially nested dictionary where keys are category
121+ names and values are lists/sets of words or further
122+ nested dictionaries of subcategories.
123+ word_column (str): The name of the column in stim_df containing the
124+ individual stimulus words to match against the hierarchy.
125+ Defaults to 'word'.
126+
127+ Returns:
128+ tuple: (data_a, data_b) containing the selected MNE Epochs objects
129+ for category A and category B, or empty Epochs selections if no
130+ data is found for a category.
131+ """
132+
133+ # --- Get Words for Each Category Specification ---
134+ words_a = get_words_in_categories (cat_a_spec , category_hierarchy )
135+ words_b = get_words_in_categories (cat_b_spec , category_hierarchy )
136+
137+ print (
138+ f"Category A Specification '{ cat_a_spec } ' maps to"
139+ f" { len (words_a )} words."
140+ )
141+ if words_a :
142+ print (f" First few A words: { words_a [:10 ]} ..." )
143+ print (
144+ f"Category B Specification '{ cat_b_spec } ' maps to"
145+ f" { len (words_b )} words."
146+ )
147+ if words_b :
148+ print (f" First few B words: { words_b [:10 ]} ..." )
149+
150+ # --- Filter stim_df to Find Matching Trials ---
151+ # Use .copy() to avoid SettingWithCopyWarning if stim_df is modified later outside the function
152+ train_df_a = train_df [train_df [word_column ].isin (words_a )]
153+ train_df_b = train_df [train_df [word_column ].isin (words_b )]
154+ test_df_a = test_df [test_df [word_column ].isin (words_a )]
155+ test_df_b = test_df [test_df [word_column ].isin (words_b )]
156+ print (train_df_a )
157+ print (train_df_b )
158+ print (test_df_a )
159+ print (test_df_b )
160+
161+ # --- Extract Data Using Original Epoch Indices ---
162+ train_indices_a = train_df_a .index
163+ train_indices_b = train_df_b .index
164+ test_indices_a = test_df_a .index
165+ test_indices_b = test_df_b .index
166+
167+ print (
168+ f"Found { len (train_indices_a )} train epochs matching Category A spec."
169+ )
170+ if len (train_indices_a ):
171+ print (f" First few A indices: { train_indices_a [:10 ]} ..." )
172+ else :
173+ print (
174+ "Warning: No train epochs found matching Category A specification."
175+ )
176+ print (
177+ f"Found { len (train_indices_b )} train epochs matching Category B spec."
178+ )
179+ if len (train_indices_b ):
180+ print (f" First few B indices: { train_indices_b [:10 ]} ..." )
181+ else :
182+ print (
183+ "Warning: No train epochs found matching Category B specification."
184+ )
185+ print (f"Found { len (test_indices_a )} test epochs matching Category A spec." )
186+ if len (test_indices_a ):
187+ print (f" First few A indices: { test_indices_a [:10 ]} ..." )
188+ else :
189+ print (
190+ "Warning: No test epochs found matching Category A specification."
191+ )
192+ print (f"Found { len (test_indices_b )} test epochs matching Category B spec." )
193+ if len (test_indices_b ):
194+ print (f" First few B indices: { test_indices_b [:10 ]} ..." )
195+ else :
196+ print (
197+ "Warning: No test epochs found matching Category B specification."
198+ )
199+
200+ return (
201+ merged_train [train_indices_a ],
202+ merged_train [train_indices_b ],
203+ merged_test [test_indices_a ],
204+ merged_test [test_indices_b ],
205+ )
206+
207+
208+ def get_words_in_categories (categories_spec , hierarchy ):
209+ """
210+ Collects all unique words associated with the given category names or specific words,
211+ searching recursively/iteratively through the nested hierarchy starting from the specified items.
212+
213+ Args:
214+ categories_spec (list): A list of strings, where each string is either a
215+ top-level category key from the hierarchy or a
216+ specific word to include directly.
217+ hierarchy (dict): The potentially nested dictionary defining categories.
218+ Values can be lists/sets of words or nested dictionaries.
219+
220+ Returns:
221+ list: A list of unique words found under the specified categories or
222+ included directly from the spec.
223+ """
224+ final_words = set ()
225+ items_to_process = list (categories_spec ) # Start with user-provided spec
226+
227+ while items_to_process :
228+ item = items_to_process .pop (0 )
229+
230+ if not isinstance (item , str ):
231+ print (
232+ f"Warning: Skipping non-string item in categories_spec: { item } "
233+ )
234+ continue
235+
236+ # Check if the item is a key in the *top level* of the hierarchy
237+ if item in hierarchy :
238+ # It's a category key, start traversal from its value
239+ value_queue = [
240+ hierarchy [item ]
241+ ] # Queue for values within this category branch
242+
243+ while value_queue :
244+ current_val = value_queue .pop (0 )
245+
246+ if isinstance (current_val , dict ):
247+ # If it's a sub-dictionary, add its values to the queue for processing
248+ for sub_val in current_val .values ():
249+ value_queue .append (sub_val )
250+ elif isinstance (current_val , (list , set , tuple )):
251+ # If it's a list/set, assume it contains words
252+ final_words .update (
253+ w for w in current_val if isinstance (w , str )
254+ )
255+ elif isinstance (current_val , str ):
256+ # If a string is found directly as a value (less common structure)
257+ final_words .add (current_val )
258+ # Ignore other data types found within the hierarchy values
259+
260+ else :
261+ # Item is not a top-level category key, assume it's a specific word
262+ final_words .add (item )
263+
264+ return list (final_words )
265+
266+
267+
268+ def average_trials (data , labels , average_trials = 5 ,max_sampling = 1000 ):
269+
270+ #print(f'Start Averaging {average_trials} Trials with Sampling {max_sampling}')
271+ if average_trials < 2 :
272+ averaged_data = data
273+ averaged_labels = labels
274+ else :
275+
276+ averaged_data = []
277+ averaged_labels = []
278+
279+ # Separate data based on labels
280+ unique_labels = np .unique (labels )
281+ # PARALELLIZE
282+ for label in unique_labels :
283+ label_data = data [labels == label ]
284+
285+ # Loop over the data and collect averages with substitution
286+ for _ in range (int (max_sampling )):
287+ # Sample with replacement
288+ indices = np .random .choice (label_data .shape [0 ], 5 , replace = True )
289+ batch_data = label_data [indices ]
290+
291+ # Compute average and append to list
292+ averaged_trial = np .mean (batch_data , axis = 0 )
293+ averaged_data .append (averaged_trial )
294+ averaged_labels .append (label )
295+
296+ return np .array (averaged_data ), np .array (averaged_labels )
0 commit comments