diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 850599e36..291e63e0b 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1105,6 +1105,12 @@ ``` """ +sync_eyelink: bool = False + +remove_blink_saccades: bool = True +sync_eventtype_regex: str = ".*" +sync_eventtype_regex_et: str = "" + # ### SSP, ICA, and artifact regression regress_artifact: dict[str, Any] | None = None diff --git a/mne_bids_pipeline/steps/preprocessing/_05b_sync_eyelink.py b/mne_bids_pipeline/steps/preprocessing/_05b_sync_eyelink.py new file mode 100644 index 000000000..136aecf2a --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_05b_sync_eyelink.py @@ -0,0 +1,261 @@ +from types import SimpleNamespace +import mne +import os.path +import re +import numpy as np +from mne_bids import BIDSPath + +from ..._config_utils import ( + _bids_kwargs, + get_eeg_reference, + get_runs, + get_sessions, + get_subjects, +) +from ..._import_data import annotations_to_events, make_epochs +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._reject import _get_reject +from ..._report import _open_report +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs + + +def get_input_fnames_sync_eyelink( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> dict: + bids_basename = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False, + extension=".fif", + ) + + et_bids_basename = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + datatype="beh", + root=cfg.bids_root, + suffix="et", + check=False, + extension=".asc", + ) + + + et_edf_bids_basename = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + datatype="beh", + root=cfg.bids_root, + suffix="et", + check=False, + extension=".edf", + ) + + in_files = dict() + for run in cfg.runs: + key = f"raw_run-{run}" + in_files[key] = bids_basename.copy().update( + run=run, processing=cfg.processing, suffix="raw" + ) + _update_for_splits(in_files, key, single=True) + + + key = f"et_run-{run}" + in_files[key] = et_bids_basename.copy().update( + run=run + ) + _update_for_splits(in_files, key, single=True) # TODO: Find out if we need to add this or not + + key = f"et_edf_run-{run}" + in_files[key] = et_edf_bids_basename.copy().update( + run=run + ) + _update_for_splits(in_files, key, single=True) # TODO: Find out if we need to add this or not + + return in_files + + + +@failsafe_run( + get_input_fnames=get_input_fnames_sync_eyelink, +) +def sync_eyelink( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: dict, +) -> dict: + """Run Sync for Eyelink.""" + import matplotlib.pyplot as plt + + raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] + et_fnames = [in_files.pop(f"et_run-{run}") for run in cfg.runs] + et_edf_fnames = [in_files.pop(f"et_edf_run-{run}") for run in cfg.runs] + + logger.info(**gen_log_kwargs(message=f"et_fnames {et_fnames}")) + out_files = dict() + bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) + out_files["eyelink"] = bids_basename.copy().update(processing="eyelink", suffix="raw") + del bids_basename + + + + for idx, (run, raw_fname,et_fname,et_edf_fname) in enumerate(zip(cfg.runs, raw_fnames,et_fnames,et_edf_fnames)): + msg = f"Syncing eyelink data (fake for now) {raw_fname.basename}" + logger.info(**gen_log_kwargs(message=msg)) + raw = mne.io.read_raw_fif(raw_fname, preload=True) + if not os.path.isfile(et_fname): + logger.info(**gen_log_kwargs(message=f"Couldn't find {et_fname} file, trying to call edf2asc.")) + if not os.path.isfile(et_edf_fname): + logger.error(**gen_log_kwargs(message=f"Also didn't find {et_edf_fname} file, one of both need to exist for ET sync.")) + import subprocess + subprocess.run(["edf2asc", et_edf_fname]) # TODO: Still needs to be tested + + raw_et = mne.io.read_raw_eyelink(et_fname,find_overlaps=True) + + # If the user did not specify a regular expression for the eye-tracking sync events, it is assumed that it's + # identical to the regex for the EEG sync events + if not cfg.sync_eventtype_regex_et: + cfg.sync_eventtype_regex_et = cfg.sync_eventtype_regex + + et_sync_times = [annotation["onset"] for annotation in raw_et.annotations if re.search(cfg.sync_eventtype_regex_et,annotation["description"])] + sync_times = [annotation["onset"] for annotation in raw.annotations if re.search(cfg.sync_eventtype_regex, annotation["description"])] + + assert len(et_sync_times) == len(sync_times),f"Detected eyetracking and EEG sync events were not of equal size ({len(et_sync_times)} vs {len(sync_times)}). Adjust your regular expressions via 'sync_eventtype_regex_et' and 'sync_eventtype_regex' accordingly" + #logger.info(**gen_log_kwargs(message=f"{et_sync_times}")) + #logger.info(**gen_log_kwargs(message=f"{sync_times}")) + + + # Check whether the eye-tracking data contains nan values. If yes replace them with zeros. + if np.isnan(raw_et.get_data()).any(): + + # Set all nan values in the eye-tracking data to 0 (to make resampling possible) + # TODO: Decide whether this is a good approach or whether interpolation (e.g. of blinks) is useful + # TODO: Decide about setting the values (e.g. for blinks) back to nan after synchronising the signals + np.nan_to_num(raw_et._data, copy=False, nan=0.0) + logger.info(**gen_log_kwargs(message=f"The eye-tracking data contained nan values. They were replaced with zeros.")) + + #mne.preprocessing.eyetracking.interpolate_blinks(raw_et, buffer=(0.05, 0.05), interpolate_gaze=True) + + + # Align the data + mne.preprocessing.realign_raw(raw, raw_et, sync_times, et_sync_times) + + + # Add ET data to EEG + raw.add_channels([raw_et], force_update_info=True) + raw._raw_extras.append(raw_et._raw_extras) + + # Also add ET annotations to EEG + raw.set_annotations(mne.annotations._combine_annotations(raw.annotations,raw_et.annotations,0,raw.first_samp,raw_et.first_samp,raw.info["sfreq"])) + + + msg = f"Saving synced data to disk." + logger.info(**gen_log_kwargs(message=msg)) + raw.save( + out_files["eyelink"], + overwrite=True, + split_naming="bids", # TODO: Find out if we need to add this or not + split_size=cfg._raw_split_size, # ??? + ) + # no idea what the split stuff is... + _update_for_splits(out_files, "eyelink") # TODO: Find out if we need to add this or not + + + + # Add to report + tags = ("sync", "eyelink") + title = "Synchronize Eyelink" + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + task=cfg.task, + ) as report: + + + caption = ( + f"The `realign_raw` function from MNE was used to align an Eyelink `asc` file to the M/EEG file." + f"The Eyelink-data was added as annotations and appended as new channels." + ) + fig = raw_et.plot(scalings=dict(eyegaze=1e3)) + report.add_figure( + fig=fig, + title="Eyelink data", + section=title, + caption=caption, + tags=tags[1], + replace=True, + ) + plt.close(fig) + del caption + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + + + + + +def get_config( + *, + config: SimpleNamespace, + subject: str, + session: str | None = None, +) -> SimpleNamespace: + #logger.info(**gen_log_kwargs(message=f"config {config}")) + + cfg = SimpleNamespace( + runs=get_runs(config=config, subject=subject), + remove_blink_saccades = config.remove_blink_saccades, + sync_eventtype_regex = config.sync_eventtype_regex, + sync_eventtype_regex_et = config.sync_eventtype_regex_et, + processing= "filt" if config.regress_artifact is None else "regress", + _raw_split_size=config._raw_split_size, + + **_bids_kwargs(config=config), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Sync Eyelink.""" + if not config.sync_eyelink: + msg = "Skipping, sync_eyelink is set to False …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + + with get_parallel_backend(config.exec_params): + parallel, run_func = parallel_func(sync_eyelink, exec_params=config.exec_params) + logs = parallel( + run_func( + cfg=get_config(config=config, subject=subject), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + ) + save_logs(config=config, logs=logs) + + + diff --git a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py index 598d2e308..cec972d77 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py @@ -256,7 +256,8 @@ def run_ica( fit_params=fit_params, max_iter=cfg.ica_max_iterations, ) - ica.fit(epochs, decim=cfg.ica_decim) + # TODO: This works for our pipeline (exclude eye-tracking data for ICA) but probably not in general + ica.fit(epochs.pick(picks="eeg"), decim=cfg.ica_decim) explained_var = ( ica.pca_explained_variance_[: ica.n_components_].sum() / ica.pca_explained_variance_.sum() @@ -349,7 +350,7 @@ def get_config( eog_channels=config.eog_channels, rest_epochs_duration=config.rest_epochs_duration, rest_epochs_overlap=config.rest_epochs_overlap, - processing="filt" if config.regress_artifact is None else "regress", + processing="eyelink" if config.sync_eyelink else "filt" if config.regress_artifact is None else "regress", _epochs_split_size=config._epochs_split_size, **_bids_kwargs(config=config), ) diff --git a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py index 43f88032a..63eff384f 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py +++ b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py @@ -365,7 +365,7 @@ def get_config( eog_channels=config.eog_channels, rest_epochs_duration=config.rest_epochs_duration, rest_epochs_overlap=config.rest_epochs_overlap, - processing="filt" if config.regress_artifact is None else "regress", + processing="eyelink" if config.sync_eyelink else "filt" if config.regress_artifact is None else "regress", **_bids_kwargs(config=config), ) return cfg diff --git a/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py b/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py index b17816a7e..f39b3ba98 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py +++ b/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py @@ -249,7 +249,7 @@ def get_config( epochs_decim=config.epochs_decim, use_maxwell_filter=config.use_maxwell_filter, runs=get_runs(config=config, subject=subject), - processing="filt" if config.regress_artifact is None else "regress", + processing="eyelink" if config.sync_eyelink else "filt" if config.regress_artifact is None else "regress", **_bids_kwargs(config=config), ) return cfg diff --git a/mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py b/mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py index 47f717959..a52b3b0fe 100644 --- a/mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py +++ b/mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py @@ -335,7 +335,7 @@ def get_config( rest_epochs_overlap=config.rest_epochs_overlap, _epochs_split_size=config._epochs_split_size, runs=get_runs(config=config, subject=subject), - processing="filt" if config.regress_artifact is None else "regress", + processing="eyelink" if config.sync_eyelink else "filt" if config.regress_artifact is None else "regress", **_bids_kwargs(config=config), ) return cfg diff --git a/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py b/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py index 430c4cdd3..2d4d5c78a 100644 --- a/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py +++ b/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py @@ -243,7 +243,7 @@ def get_config( cfg = SimpleNamespace( baseline=config.baseline, ica_reject=config.ica_reject, - processing="filt" if config.regress_artifact is None else "regress", + processing="eyelink" if config.sync_eyelink else "filt" if config.regress_artifact is None else "regress", _epochs_split_size=config._epochs_split_size, **_import_data_kwargs(config=config, subject=subject), ) diff --git a/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py b/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py index 6ab00dc12..b3139433c 100644 --- a/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py +++ b/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py @@ -149,7 +149,7 @@ def get_config( subject: str, ) -> SimpleNamespace: cfg = SimpleNamespace( - processing="filt" if config.regress_artifact is None else "regress", + processing="eyelink" if config.sync_eyelink else "filt" if config.regress_artifact is None else "regress", _epochs_split_size=config._epochs_split_size, **_import_data_kwargs(config=config, subject=subject), ) diff --git a/mne_bids_pipeline/steps/preprocessing/__init__.py b/mne_bids_pipeline/steps/preprocessing/__init__.py index f9072617c..33dc66f8c 100644 --- a/mne_bids_pipeline/steps/preprocessing/__init__.py +++ b/mne_bids_pipeline/steps/preprocessing/__init__.py @@ -6,6 +6,7 @@ _03_maxfilter, _04_frequency_filter, _05_regress_artifact, + _05b_sync_eyelink, _06a1_fit_ica, _06a2_find_ica_artifacts, _06b_run_ssp, @@ -21,6 +22,7 @@ _03_maxfilter, _04_frequency_filter, _05_regress_artifact, + _05b_sync_eyelink, _06a1_fit_ica, _06a2_find_ica_artifacts, _06b_run_ssp,