Skip to content

Commit d77d86d

Browse files
committed
new changes to fix API problems
1 parent 50c4566 commit d77d86d

File tree

6 files changed

+352
-28
lines changed

6 files changed

+352
-28
lines changed

code/data_processing/cc_qc.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,12 @@ def cc_qc(self, df, threshold, TS=False):
8686
CATEGORY = 2
8787
print(f"FOR TASK SWITCHING -> Average accuracy at or below 0.5 across conditions and CATEGORY set to 2")
8888

89-
problematic_conditions = QC_UTILS.cond_block_not_reported(raw, self.ACC_COLUMN_NAME, self.COND_COLUMN_NAME, self.INCORRECT_SYMBOL)
89+
problematic_conditions = QC_UTILS.cond_block_not_reported(
90+
raw,
91+
self.COND_COLUMN_NAME,
92+
self.ACC_COLUMN_NAME,
93+
self.INCORRECT_SYMBOL,
94+
)
9095

9196
if len(problematic_conditions) != 0:
9297
CATEGORY = 3

code/data_processing/plot_utils.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,25 @@ def af_nf_plot(self, df):
2727
tuple: A tuple containing two Axes objects (count_plot, response_time_plot).
2828
"""
2929
# Filter to drop practice data
30-
test = df[df['block'] == 'test'].copy()
30+
block_series = df['block'].astype(str).str.strip().str.lower()
31+
test_mask = block_series == 'test'
32+
test = df[test_mask].copy()
33+
34+
if test.empty:
35+
subject_series = df.get('subject_id')
36+
if subject_series is not None:
37+
subjects = sorted(
38+
{str(value).strip() for value in subject_series.dropna() if str(value).strip()}
39+
)
40+
else:
41+
subjects = []
42+
43+
unique_blocks = sorted({value for value in block_series.unique() if value and value != 'nan'})
44+
raise ValueError(
45+
"No 'test' block rows available for plotting. "
46+
f"Observed block labels: {unique_blocks or '<none>'}. "
47+
f"Subjects in frame: {subjects or '<unknown>'}"
48+
)
3149

3250
# Generate count plot
3351
plt.figure(figsize=(10, 6))
@@ -402,7 +420,27 @@ def fn_plot(self, df):
402420
Returns:
403421
tuple: The scatter/box plot and bar chart plot objects.
404422
"""
405-
test = df[df['block'] == 'test']
423+
block_series = df['block'].astype(str).str.strip().str.lower()
424+
test_mask = block_series == 'test'
425+
test = df[test_mask].copy()
426+
427+
if test.empty:
428+
subject_series = df.get('subject_id')
429+
if subject_series is not None:
430+
subjects = sorted(
431+
{str(value).strip() for value in subject_series.dropna() if str(value).strip()}
432+
)
433+
else:
434+
subjects = []
435+
436+
unique_blocks = sorted({value for value in block_series.unique() if value and value != 'nan'})
437+
raise ValueError(
438+
"No 'test' block rows available for MEM plotting. "
439+
f"Observed block labels: {unique_blocks or '<none>'}. "
440+
f"Subjects in frame: {subjects or '<unknown>'}"
441+
)
442+
443+
test['block'] = test['block'].astype(str).str.strip()
406444
test['correct_label'] = test['correct'].map({0: 'Incorrect', 1: 'Correct'})
407445

408446
# Scatter and box plot
@@ -459,7 +497,27 @@ def sm_plot(self, df):
459497
Returns:
460498
The scatter and box plot object.
461499
"""
462-
test = df[df['block'] == 'test']
500+
block_series = df['block'].astype(str).str.strip().str.lower()
501+
test_mask = block_series == 'test'
502+
test = df[test_mask].copy()
503+
504+
if test.empty:
505+
subject_series = df.get('subject_id')
506+
if subject_series is not None:
507+
subjects = sorted(
508+
{str(value).strip() for value in subject_series.dropna() if str(value).strip()}
509+
)
510+
else:
511+
subjects = []
512+
513+
unique_blocks = sorted({value for value in block_series.unique() if value and value != 'nan'})
514+
raise ValueError(
515+
"No 'test' block rows available for MEM plotting. "
516+
f"Observed block labels: {unique_blocks or '<none>'}. "
517+
f"Subjects in frame: {subjects or '<unknown>'}"
518+
)
519+
520+
test['block'] = test['block'].astype(str).str.strip()
463521
mapping = {'no': 'Incongruent', 'yes': 'Congruent'}
464522
test['target_congruent'] = test['target_congruent'].map(mapping)
465523
test['correct_label'] = test['correct'].map({0: 'Incorrect', 1: 'Correct'})
@@ -595,9 +653,6 @@ def dwl_plot(self, df):
595653

596654

597655

598-
599-
600-
601656

602657

603658

code/data_processing/utils.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,15 @@ def convert_to_csv(self, txt_dfs):
2121
# If file is empty or has no valid records, skip
2222
continue
2323

24-
# Normalize nested dicts into flat columns so downstream QC stays unchanged
25-
flattened_df = pd.json_normalize(records)
24+
normalized_records = [
25+
self._flatten_record(record) for record in records if isinstance(record, dict)
26+
]
27+
if not normalized_records:
28+
continue
29+
30+
flattened_df = pd.json_normalize(normalized_records, sep="_")
31+
flattened_df = self._harmonize_columns(flattened_df)
32+
flattened_df = self._normalize_semantics(flattened_df)
2633
new_dfs.append(flattened_df)
2734

2835
return new_dfs
@@ -82,6 +89,117 @@ def _collect_records(self, payload):
8289

8390
return []
8491

92+
def _flatten_record(self, record):
93+
"""
94+
Recursively merge wrapper keys (``data``, ``trialdata``) into a flat dict so
95+
downstream QC modules see consistent column names like ``block`` and ``correct``.
96+
"""
97+
flattened = {}
98+
for key, value in record.items():
99+
if isinstance(value, dict) and key.lower() in {"data", "trialdata"}:
100+
flattened.update(self._flatten_record(value))
101+
elif isinstance(value, dict):
102+
nested = self._flatten_record(value)
103+
for nested_key, nested_value in nested.items():
104+
combined_key = f"{key}_{nested_key}"
105+
flattened[combined_key] = nested_value
106+
else:
107+
flattened[key] = value
108+
return flattened
109+
110+
def _harmonize_columns(self, df: pd.DataFrame) -> pd.DataFrame:
111+
"""
112+
Strip known wrapper prefixes to restore historical column names and drop duplicates.
113+
"""
114+
rename_map = {}
115+
for col in df.columns:
116+
new_col = col
117+
for prefix in ("trialdata_", "data_", "payload_", "TrialData_", "trialData_"):
118+
if new_col.startswith(prefix):
119+
new_col = new_col[len(prefix):]
120+
rename_map[col] = new_col
121+
122+
harmonized = df.rename(columns=rename_map)
123+
124+
canonical_map = {
125+
"Block": "block",
126+
"BlockName": "block",
127+
"blockName": "block",
128+
"Block_Type": "block",
129+
"block_type": "block",
130+
"Condition": "condition",
131+
"Cond": "condition",
132+
"stim_condition": "condition",
133+
"Correct": "correct",
134+
"isCorrect": "correct",
135+
"Session": "session_number",
136+
"session": "session_number",
137+
"SessionID": "session_number",
138+
"Subject": "subject_id",
139+
"subject": "subject_id",
140+
}
141+
142+
harmonized = harmonized.rename(columns=lambda col: canonical_map.get(col, col))
143+
144+
# If both original and harmonized columns exist, keep the first non-null values.
145+
if harmonized.columns.duplicated().any():
146+
deduped = {}
147+
for col in harmonized.columns.unique():
148+
dupes = [c for c in harmonized.columns if c == col]
149+
if len(dupes) == 1:
150+
deduped[col] = harmonized[dupes[0]]
151+
else:
152+
stacked = harmonized[dupes].bfill(axis=1)
153+
deduped[col] = stacked.iloc[:, 0]
154+
harmonized = pd.DataFrame(deduped)
155+
156+
return harmonized
157+
158+
def _normalize_semantics(self, df: pd.DataFrame) -> pd.DataFrame:
159+
"""
160+
Coerce critical columns (block/condition/correct/session/subject_id) into their
161+
historical dtypes and label space so downstream QC and persistence stay stable.
162+
"""
163+
normalized = df.copy()
164+
165+
if "block" in normalized.columns:
166+
normalized["block"] = normalized["block"].map(self._standardize_block)
167+
168+
if "condition" in normalized.columns:
169+
normalized["condition"] = normalized["condition"].apply(
170+
lambda val: val.strip() if isinstance(val, str) else val
171+
)
172+
173+
if "correct" in normalized.columns:
174+
normalized["correct"] = pd.to_numeric(normalized["correct"], errors="coerce")
175+
176+
if "session_number" in normalized.columns:
177+
normalized["session_number"] = pd.to_numeric(
178+
normalized["session_number"], errors="coerce"
179+
)
180+
181+
if "subject_id" in normalized.columns:
182+
normalized["subject_id"] = normalized["subject_id"].apply(
183+
lambda val: str(val).strip() if pd.notna(val) else val
184+
)
185+
186+
return normalized
187+
188+
@staticmethod
189+
def _standardize_block(value):
190+
if isinstance(value, str):
191+
cleaned = value.strip().lower()
192+
if cleaned.startswith("test"):
193+
return "test"
194+
if cleaned.startswith(("prac", "practice")):
195+
return "prac"
196+
if cleaned in {"training", "train"}:
197+
return "prac"
198+
if cleaned == "":
199+
return np.nan
200+
return cleaned
201+
return value
202+
85203
def save_csv(self):
86204
return None
87205

code/data_processing/wl_qc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def wl_qc(self, submission, version):
2323

2424
if self.CATEGORY == 3:
2525
print("One or more conditions are empty, status finalized at 3")
26-
return self.CATEGORY
26+
return df_all, self.CATEGORY
2727
# Assuming df_all is the DataFrame and self.CATEGORY exists in the class context
2828

2929
if (df_all['block'] == 'immediate').any():
@@ -48,7 +48,7 @@ def dwl_qc(self, submission, version):
4848

4949
if self.CATEGORY == 3:
5050
print("One or more conditions are empty, status finalized at 3")
51-
return self.CATEGORY
51+
return df_all, self.CATEGORY
5252
# Assuming df_all is the DataFrame and self.CATEGORY exists in the class context
5353

5454
if (df_all['block'] == 'delay').any():
@@ -90,4 +90,3 @@ def dwl_count_correct(df_all):
9090
.reindex(['delay'], fill_value=0)
9191
.to_frame().T) # one row: column 'delay'
9292
return out
93-

0 commit comments

Comments
 (0)