Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions lda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,21 @@ def process_window(
train_y = np.array([1] * train_A.shape[0] + [2] * train_B.shape[0])
test_y = np.array([1] * test_A.shape[0] + [2] * test_B.shape[0])

max_size_train = max(train_A.shape[0],train_B.shape[0]) # get the minimum size of data, ideally not
max_size_test = max(test_A.shape[0],test_B.shape[0]) # get the minimum size of data, ideally not
max_size_train = max(
train_A.shape[0], train_B.shape[0]
) # get the minimum size of data, ideally not
max_size_test = max(
test_A.shape[0], test_B.shape[0]
) # get the minimum size of data, ideally not

#print("Calculating Averages")
# print("Calculating Averages")

train_x, train_y = average_trials(train_x, train_y, average_trials=10, max_sampling = max_size_train)
test_x, test_y = average_trials(test_x, test_y, average_trials=10, max_sampling = max_size_test)
train_x, train_y = average_trials(
train_x, train_y, average_trials=10, max_sampling=max_size_train
)
test_x, test_y = average_trials(
test_x, test_y, average_trials=10, max_sampling=max_size_test
)

if np.ndim(train_x) > 2:
train_x = train_x.reshape(
Expand Down Expand Up @@ -153,10 +161,6 @@ def prep_decoding_data_hierarchical(
train_df_b = train_df[train_df[word_column].isin(words_b)]
test_df_a = test_df[test_df[word_column].isin(words_a)]
test_df_b = test_df[test_df[word_column].isin(words_b)]
print(train_df_a)
print(train_df_b)
print(test_df_a)
print(test_df_b)

# --- Extract Data Using Original Epoch Indices ---
train_indices_a = train_df_a.index
Expand Down Expand Up @@ -264,10 +268,8 @@ def get_words_in_categories(categories_spec, hierarchy):
return list(final_words)


def average_trials(data, labels, average_trials=5, max_sampling=1000):

def average_trials(data, labels, average_trials=5,max_sampling=1000):

#print(f'Start Averaging {average_trials} Trials with Sampling {max_sampling}')
if average_trials < 2:
averaged_data = data
averaged_labels = labels
Expand All @@ -285,12 +287,14 @@ def average_trials(data, labels, average_trials=5,max_sampling=1000):
# Loop over the data and collect averages with substitution
for _ in range(int(max_sampling)):
# Sample with replacement
indices = np.random.choice(label_data.shape[0], 5, replace=True)
indices = np.random.choice(
label_data.shape[0], 5, replace=True
)
batch_data = label_data[indices]

# Compute average and append to list
averaged_trial = np.mean(batch_data, axis=0)
averaged_data.append(averaged_trial)
averaged_labels.append(label)

return np.array(averaged_data), np.array(averaged_labels)
return np.array(averaged_data), np.array(averaged_labels)
175 changes: 16 additions & 159 deletions preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,169 +1,27 @@
import argparse
import json
import os
import pickle
from pathlib import Path
from typing import Tuple, Dict

import mne
import pandas as pd

from preprocessing_utils import (
Configs,
compute_dropped_trials,
compute_whitening_matrix,
epoching,
get_preprocessing_parser,
make_configs_from_args,
save_data,
whiten,
)

# --------------------------------------------------------------------------
# helpers ------------------------------------------------------------------
# --------------------------------------------------------------------------


def _parse_float_tuple(text: str) -> Tuple[float | None, float]:
"""Parse 'None,0' or '-0.2,0' → (None, 0.0) or (-0.2, 0.0)."""
a, b = [x.strip() for x in text.split(",")]
return (None if a.lower() == "none" else float(a), float(b))


def _parse_reject(text: str) -> Dict[str, float]:
"""
Parse JSON or 'EEG001:1e-4,EEG002:1.2e-4' → {'EEG001': 1e-4, 'EEG002': 1.2e-4}
"""
if text.lstrip().startswith("{"):
return json.loads(text)
out: Dict[str, float] = {}
for pair in text.split(","):
ch, thr = pair.split(":")
out[ch.strip()] = float(thr)
return out


def _mvnn_arg(text: str) -> str | None:
"""Return 'epochs', 'time', or None (for the string 'none')."""
t = text.lower()
if t == "none":
return None
if t in {"epochs", "time"}:
return t
raise argparse.ArgumentTypeError(
"mvnn_dim must be 'epochs', 'time', or 'None'"
)


def _make_configs_from_args(args: argparse.Namespace) -> Configs:
"""Instantiate a Configs object from parsed CLI args."""
return Configs(
baseline=args.baseline,
tmin=args.tmin,
tmax=args.tmax,
sfreq=args.sfreq,
l_freq=args.l_freq,
h_freq=args.h_freq,
notch_freqs=args.notch_freqs,
mvnn_dim=args.mvnn_dim,
reject=args.reject,
)


# --------------------------------------------------------------------------
# CLI ----------------------------------------------------------------------
# --------------------------------------------------------------------------

parser = argparse.ArgumentParser(
prog="preprocessing.py",
description=(
"Preprocess EEG data for a specific subject: epoching, filtering, "
"MVNN, and saving."
),
)

# required
parser.add_argument(
"-s",
"--sub",
required=True,
type=int,
help="Subject number (e.g. 1 for sub-01)",
)

# Configs-related ----------------------------------------------------------
parser.add_argument(
"--tmin",
type=float,
default=-0.2,
help="Epoch start (seconds, default -0.2)",
)
parser.add_argument(
"--tmax", type=float, default=1.0, help="Epoch end (seconds, default 1.0)"
)
parser.add_argument(
"--baseline",
type=_parse_float_tuple,
default="None,0",
help=(
'Baseline tuple "None,0" or "-0.2,0" (use None for '
"no pre-stim baseline)"
),
)
parser.add_argument(
"--sfreq",
type=int,
default=250,
help="Target sampling rate after downsampling (Hz)",
)
parser.add_argument(
"--l_freq", type=float, help="Low cutoff for band-pass filter (Hz)"
)
parser.add_argument(
"--h_freq", type=float, help="High cutoff for band-pass filter (Hz)"
)
parser.add_argument(
"--notch_freqs",
nargs="+",
type=float,
help="One or more notch filter frequencies (Hz)",
)
parser.add_argument(
"--reject",
type=_parse_reject,
help=(
'Artifact-rejection dict; JSON or "CH1:thr,CH2:thr" '
'(mV, e.g. "EEG001:1e-4").'
),
)
parser.add_argument(
"--mvnn_dim",
type=_mvnn_arg,
default="epochs",
help="MVNN mode (off to skip whitening)",
)

# misc ---------------------------------------------------------------------
parser.add_argument(
"--project_dir",
default="/srv/eeg_reconstruction/shared/data/",
help="Root of the project directory tree",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose output during preprocessing",
)

parser = get_preprocessing_parser()
ARGS = parser.parse_args()

# --------------------------------------------------------------------------
# main ---------------------------------------------------------------------
# --------------------------------------------------------------------------

mne.set_log_level("WARNING" if not ARGS.verbose else "INFO")

PROJECT_DIR = Path(ARGS.project_dir)
SUB = ARGS.sub
CONFIGS = _make_configs_from_args(ARGS)
CONFIGS = make_configs_from_args(ARGS)

OUTPUT_DIR = (
PROJECT_DIR / "preprocessed_data" / "Alljoined-1.6M" / f"sub-{SUB:02d}"
Expand Down Expand Up @@ -197,18 +55,20 @@ def _make_configs_from_args(args: argparse.Namespace) -> Configs:
)

# dropped-trial bookkeeping -------------------------------------------------
test_df = stim_order.query("partition == 'stim_test'")
train_df = stim_order.query("partition == 'stim_train'")

test_keep = compute_dropped_trials(epoched_test, test_df, verbose=ARGS.verbose)
train_keep = compute_dropped_trials(
epoched_train, train_df, verbose=ARGS.verbose
test_df = compute_dropped_trials(
epoched_test,
stim_order.query("partition == 'stim_test'"),
verbose=ARGS.verbose,
)
train_df = compute_dropped_trials(
epoched_train,
stim_order.query("partition == 'stim_train'"),
verbose=ARGS.verbose,
)


stim_order["dropped"] = True
stim_order.loc[test_keep, "dropped"] = False
stim_order.loc[train_keep, "dropped"] = False
stim_order.loc[test_df.index, "dropped"] = False
stim_order.loc[train_df.index, "dropped"] = False
stim_order.to_parquet(OUTPUT_DIR / "experiment_metadata.parquet")

# --------------------------------------------------------------------------
Expand All @@ -217,15 +77,12 @@ def _make_configs_from_args(args: argparse.Namespace) -> Configs:
whitening_mats = compute_whitening_matrix(
CONFIGS.mvnn_dim,
epoched_train,
stim_order.query("partition == 'stim_train'"),
train_df,
verbose=ARGS.verbose,
)
epoched_train = whiten(epoched_train, whitening_mats)
epoched_test = whiten(epoched_test, whitening_mats)

with open(OUTPUT_DIR / "mvnn_whitening_matrices.pkl", "wb") as f:
pickle.dump(whitening_mats, f)

# --------------------------------------------------------------------------
# save ---------------------------------------------------------------------
save_data(
Expand Down
Loading