diff --git a/.circleci/config.yml b/.circleci/config.yml index 62f559ab4..54f0c1541 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -47,10 +47,10 @@ jobs: - run: name: Get Python running command: | - pip install --upgrade --progress-bar off pip setuptools - pip install --upgrade --progress-bar off "autoreject @ https://api.github.com/repos/autoreject/autoreject/zipball/master" "mne[hdf5] @ git+https://github.com/mne-tools/mne-python" "mne-bids[full] @ https://api.github.com/repos/mne-tools/mne-bids/zipball/main" numba + pip install --upgrade --progress-bar off pip + pip install --upgrade --progress-bar off "autoreject @ https://api.github.com/repos/autoreject/autoreject/zipball/master" "mne[hdf5] @ git+https://github.com/mne-tools/mne-python@main" "mne-bids[full] @ https://api.github.com/repos/mne-tools/mne-bids/zipball/main" numba pip install -ve .[tests] - pip install PyQt6 + pip install "PyQt6!=6.6.1" "PyQt6-Qt6!=6.6.1,!=6.6.2,!=6.6.3,!=6.7.0" - run: name: Check Qt command: | @@ -76,14 +76,15 @@ jobs: at: ~/ - restore_cache: keys: - - data-cache-ds000117-2 + - data-cache-ds000117-3 - bash_env - run: name: Get ds000117 command: | $DOWNLOAD_DATA ds000117 + - codecov/upload - save_cache: - key: data-cache-ds000117-2 + key: data-cache-ds000117-3 paths: - ~/mne_data/ds000117 @@ -118,6 +119,7 @@ jobs: name: Get ds001971 command: | $DOWNLOAD_DATA ds001971 + - codecov/upload - save_cache: key: data-cache-ds001971-2 paths: @@ -136,6 +138,7 @@ jobs: name: Get ds004107 command: | $DOWNLOAD_DATA ds004107 + - codecov/upload - save_cache: key: data-cache-ds004107-2 paths: @@ -148,14 +151,15 @@ jobs: at: ~/ - restore_cache: keys: - - data-cache-ds000246-2 + - data-cache-ds000246-3 - bash_env - run: name: Get ds000246 command: | $DOWNLOAD_DATA ds000246 + - codecov/upload - save_cache: - key: data-cache-ds000246-2 + key: data-cache-ds000246-3 paths: - ~/mne_data/ds000246 @@ -166,14 +170,15 @@ jobs: at: ~/ - restore_cache: keys: - - data-cache-ds000247-2 + - data-cache-ds000247-3 - bash_env - run: name: Get ds000247 command: | $DOWNLOAD_DATA ds000247 + - codecov/upload - save_cache: - key: data-cache-ds000247-2 + key: data-cache-ds000247-3 paths: - ~/mne_data/ds000247 @@ -190,6 +195,7 @@ jobs: name: Get ds000248 command: | $DOWNLOAD_DATA ds000248 + - codecov/upload - save_cache: key: data-cache-ds000248-4 paths: @@ -208,6 +214,7 @@ jobs: name: Get ds001810 command: | $DOWNLOAD_DATA ds001810 + - codecov/upload - save_cache: key: data-cache-ds001810-2 paths: @@ -226,6 +233,7 @@ jobs: name: Get ds003104 command: | $DOWNLOAD_DATA ds003104 + - codecov/upload - save_cache: key: data-cache-ds003104-2 paths: @@ -244,6 +252,7 @@ jobs: name: Get ds003392 command: | $DOWNLOAD_DATA ds003392 + - codecov/upload - save_cache: key: data-cache-ds003392-2 paths: @@ -256,14 +265,15 @@ jobs: at: ~/ - restore_cache: keys: - - data-cache-ds004229-2 + - data-cache-ds004229-103 - bash_env - run: name: Get ds004229 command: | $DOWNLOAD_DATA ds004229 + - codecov/upload - save_cache: - key: data-cache-ds004229-2 + key: data-cache-ds004229-103 paths: - ~/mne_data/ds004229 @@ -276,16 +286,35 @@ jobs: keys: - data-cache-eeg_matchingpennies-1 - bash_env - - gitconfig # email address is needed for datalad - run: name: Get eeg_matchingpennies command: | $DOWNLOAD_DATA eeg_matchingpennies + - codecov/upload - save_cache: key: data-cache-eeg_matchingpennies-1 paths: - ~/mne_data/eeg_matchingpennies + cache_MNE-phantom-KIT-data: + <<: *imageconfig + steps: + - attach_workspace: + at: ~/ + - restore_cache: + keys: + - data-cache-MNE-phantom-KIT-data-1 + - bash_env + - run: + name: Get MNE-phantom-KIT-data + command: | + $DOWNLOAD_DATA MNE-phantom-KIT-data + - codecov/upload + - save_cache: + key: data-cache-MNE-phantom-KIT-data-1 + paths: + - ~/mne_data/MNE-phantom-KIT-data + cache_ERP_CORE: <<: *imageconfig steps: @@ -299,6 +328,7 @@ jobs: name: Get ERP_CORE command: | $DOWNLOAD_DATA ERP_CORE + - codecov/upload - save_cache: key: data-cache-ERP_CORE-1 paths: @@ -312,16 +342,10 @@ jobs: - bash_env - restore_cache: keys: - - data-cache-ds000117-2 + - data-cache-ds000117-3 - run: name: test ds000117 - command: | - DS=ds000117 - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds000117 - codecov/upload - store_test_results: path: ./test-results @@ -347,13 +371,7 @@ jobs: - data-cache-ds003775-2 - run: name: test ds003775 - command: | - DS=ds003775 - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds003775 - codecov/upload - store_test_results: path: ./test-results @@ -379,13 +397,7 @@ jobs: - data-cache-ds001971-2 - run: name: test ds001971 - command: | - DS=ds001971 - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds001971 - codecov/upload - store_test_results: path: ./test-results @@ -412,13 +424,7 @@ jobs: - data-cache-ds004107-2 - run: name: test ds004107 - command: | - DS=ds004107 - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds004107 - codecov/upload - store_test_results: path: ./test-results @@ -441,18 +447,11 @@ jobs: - bash_env - restore_cache: keys: - - data-cache-ds000246-2 + - data-cache-ds000246-3 - run: name: test ds000246 no_output_timeout: 15m - command: | - DS=ds000246 - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.tsv ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds000246 - codecov/upload - store_test_results: path: ./test-results @@ -476,16 +475,10 @@ jobs: - bash_env - restore_cache: keys: - - data-cache-ds000247-2 + - data-cache-ds000247-3 - run: name: test ds000247 - command: | - DS=ds000247 - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds000247 - codecov/upload - store_test_results: path: ./test-results @@ -511,15 +504,8 @@ jobs: - data-cache-ds000248-4 - run: name: test ds000248_base - command: | - DS=ds000248_base - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.json ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.tsv ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + # Forces rerunning (cov and FLASH BEM) so don't check + command: $RUN_TESTS -r ds000248_base - codecov/upload - store_test_results: path: ./test-results @@ -547,14 +533,7 @@ jobs: - data-cache-ds000248-4 - run: name: test ds000248_ica - command: | - DS=ds000248_ica - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.tsv ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds000248_ica - codecov/upload - store_test_results: path: ./test-results @@ -581,15 +560,20 @@ jobs: - data-cache-ds000248-4 - run: name: test BEM from FLASH - command: | - $RUN_TESTS ds000248_FLASH_BEM - ls -al test-results/*.xml + command: $RUN_TESTS -r ds000248_FLASH_BEM - codecov/upload - store_test_results: path: ./test-results - store_artifacts: path: ./test-results destination: test-results + - store_artifacts: + path: /home/circleci/reports/ds000248_FLASH_BEM + destination: reports/ds000248_FLASH_BEM + - persist_to_workspace: + root: ~/ + paths: + - mne_data/derivatives/mne-bids-pipeline/ds000248_FLASH_BEM/*/*/*.html test_ds000248_T1_BEM: <<: *imageconfig @@ -603,15 +587,20 @@ jobs: - run: name: test BEM from T1 (watershed) no_output_timeout: 20m - command: | - $RUN_TESTS ds000248_T1_BEM - ls -al test-results/*.xml + command: $RUN_TESTS -r ds000248_T1_BEM - codecov/upload - store_test_results: path: ./test-results - store_artifacts: path: ./test-results destination: test-results + - store_artifacts: + path: /home/circleci/reports/ds000248_T1_BEM + destination: reports/ds000248_T1_BEM + - persist_to_workspace: + root: ~/ + paths: + - mne_data/derivatives/mne-bids-pipeline/ds000248_T1_BEM/*/*/*.html test_ds000248_coreg_surfaces: <<: *imageconfig @@ -624,9 +613,7 @@ jobs: - data-cache-ds000248-4 - run: name: test head surface creation for MNE coregistration - command: | - $RUN_TESTS ds000248_coreg_surfaces - ls -al test-results/*.xml + command: $RUN_TESTS -c -r ds000248_coreg_surfaces - codecov/upload - store_test_results: path: ./test-results @@ -645,13 +632,7 @@ jobs: - data-cache-ds000248-4 - run: name: test ds000248_no_mri - command: | - DS=ds000248_no_mri - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.html ~/reports/${DS}/ - cp ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds000248_no_mri - codecov/upload - store_test_results: path: ./test-results @@ -678,13 +659,7 @@ jobs: - data-cache-ds001810-2 - run: name: test ds001810 - command: | - DS=ds001810 - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds001810 - codecov/upload - store_test_results: path: ./test-results @@ -710,13 +685,7 @@ jobs: - data-cache-ds003104-2 - run: name: test ds003104 - command: | - DS=ds003104 - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds003104 - codecov/upload - store_test_results: path: ./test-results @@ -742,15 +711,7 @@ jobs: - data-cache-ds003392-2 - run: name: test ds003392 - command: | - DS=ds003392 - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.json ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.tsv ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds003392 - codecov/upload - store_test_results: path: ./test-results @@ -776,18 +737,10 @@ jobs: - bash_env - restore_cache: keys: - - data-cache-ds004229-2 + - data-cache-ds004229-103 - run: name: test ds004229 - command: | - DS=ds004229 - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.json ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/*/*.tsv ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ds004229 - codecov/upload - store_test_results: path: ./test-results @@ -815,13 +768,7 @@ jobs: - data-cache-eeg_matchingpennies-1 - run: name: test eeg_matchingpennies - command: | - DS=eeg_matchingpennies - $RUN_TESTS ${DS} - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS eeg_matchingpennies - codecov/upload - store_test_results: path: ./test-results @@ -836,6 +783,32 @@ jobs: paths: - mne_data/derivatives/mne-bids-pipeline/eeg_matchingpennies/*/*/*.html + test_MNE-phantom-KIT-data: + <<: *imageconfig + steps: + - attach_workspace: + at: ~/ + - bash_env + - restore_cache: + keys: + - data-cache-MNE-phantom-KIT-data-1 + - run: + name: test MNE-phantom-KIT-data + command: $RUN_TESTS MNE-phantom-KIT-data + - codecov/upload + - store_test_results: + path: ./test-results + - store_artifacts: + path: ./test-results + destination: test-results + - store_artifacts: + path: /home/circleci/reports/MNE-phantom-KIT-data + destination: reports/MNE-phantom-KIT-data + - persist_to_workspace: + root: ~/ + paths: + - mne_data/derivatives/mne-bids-pipeline/MNE-phantom-KIT-data/*/*/*.html + test_ERP_CORE_N400: <<: *imageconfig resource_class: large @@ -856,13 +829,7 @@ jobs: google-chrome --version - run: name: test ERP CORE N400 - command: | - DS=ERP_CORE - $RUN_TESTS ${DS}_N400 - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*N400*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*N400*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ERP_CORE_N400 - codecov/upload - store_test_results: path: ./test-results @@ -892,13 +859,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE ERN - command: | - DS=ERP_CORE - $RUN_TESTS ${DS}_ERN - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*ERN*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*ERN*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ERP_CORE_ERN - codecov/upload - store_test_results: path: ./test-results @@ -928,13 +889,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE LRP - command: | - DS=ERP_CORE - $RUN_TESTS ${DS}_LRP - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*LRP*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*LRP*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ERP_CORE_LRP - codecov/upload - store_test_results: path: ./test-results @@ -964,13 +919,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE MMN - command: | - DS=ERP_CORE - $RUN_TESTS ${DS}_MMN - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*MMN*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*MMN*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ERP_CORE_MMN - codecov/upload - store_test_results: path: ./test-results @@ -1000,13 +949,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE N2pc - command: | - DS=ERP_CORE - $RUN_TESTS ${DS}_N2pc - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*N2pc*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*N2pc*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ERP_CORE_N2pc - codecov/upload - store_test_results: path: ./test-results @@ -1036,13 +979,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE N170 - command: | - DS=ERP_CORE - $RUN_TESTS ${DS}_N170 - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*N170*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*N170*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ERP_CORE_N170 - codecov/upload - store_test_results: path: ./test-results @@ -1072,13 +1009,7 @@ jobs: command: mkdir -p /home/circleci/.local/share/pyvista - run: name: test ERP CORE P3 - command: | - DS=ERP_CORE - $RUN_TESTS ${DS}_P3 - mkdir -p ~/reports/${DS} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*P3*.html ~/reports/${DS}/ - cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*P3*.xlsx ~/reports/${DS}/ - ls -al test-results/*.xml + command: $RUN_TESTS ERP_CORE_P3 - codecov/upload - store_test_results: path: ./test-results @@ -1099,9 +1030,15 @@ jobs: - attach_workspace: at: ~/ - bash_env + - run: + name: Install dependencies + command: | + pip install -ve .[docs] - run: name: Build documentation command: | + set -eo pipefail + ls ~/mne_data/derivatives/mne-bids-pipeline/*/*/*/*.html make doc - store_artifacts: path: docs/site @@ -1125,6 +1062,10 @@ jobs: at: ~/ - bash_env - gitconfig + - run: + name: Install dependencies + command: | + pip install -ve .[docs] - run: # This is a bit computationally inefficient, but it should be much # faster to "cp" directly on the machine rather than persist @@ -1302,6 +1243,15 @@ workflows: - cache_eeg_matchingpennies <<: *filter_tags + - cache_MNE-phantom-KIT-data: + requires: + - setup_env + <<: *filter_tags + - test_MNE-phantom-KIT-data: + requires: + - cache_MNE-phantom-KIT-data + <<: *filter_tags + - cache_ERP_CORE: requires: - setup_env @@ -1346,11 +1296,14 @@ workflows: - test_ds000248_base - test_ds000248_ica - test_ds000248_no_mri + - test_ds000248_T1_BEM + - test_ds000248_FLASH_BEM - test_ds001810 - test_ds003104 - test_ds003392 - test_ds004229 - test_eeg_matchingpennies + - test_MNE-phantom-KIT-data - test_ERP_CORE_N400 - test_ERP_CORE_ERN - test_ERP_CORE_LRP diff --git a/.circleci/remove_examples.sh b/.circleci/remove_examples.sh new file mode 100755 index 000000000..ee4004442 --- /dev/null +++ b/.circleci/remove_examples.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +set -eo pipefail + +VER=$1 +if [ -z "$VER" ]; then + echo "Usage: $0 " + exit 1 +fi +ROOT="$PWD/$VER/examples/" +if [ ! -d ${ROOT} ]; then + echo "Version directory does not exist or appears incorrect:" + echo + echo "$ROOT" + echo + echo "Are you on the gh-pages branch and is the ds000117 directory present?" + exit 1 +fi +if [ ! -d ${ROOT}ds000117 ]; then + echo "Directory does not exist:" + echo + echo "${ROOT}ds000117" + echo + echo "Assuming already pruned and exiting." + exit 0 +fi +echo "Pruning examples in ${ROOT} ..." + +find $ROOT -type d -name "*" | tail -n +2 | xargs rm -Rf +find $ROOT -name "*.html" -exec sed -i /^\Generated/,/^\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ $/{d} {} \; +find $ROOT -name "*.html" -exec sed -i '/^ $/{d}' {} \; + +echo "Done" diff --git a/.circleci/run_dataset_and_copy_files.sh b/.circleci/run_dataset_and_copy_files.sh new file mode 100755 index 000000000..f8ae2d31f --- /dev/null +++ b/.circleci/run_dataset_and_copy_files.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +set -eo pipefail + +COPY_FILES="true" +RERUN_TEST="true" +while getopts "cr" option; do + echo $option + case $option in + c) + COPY_FILES="false";; + r) + RERUN_TEST="false";; + esac +done +shift "$(($OPTIND -1))" + +DS_RUN=$1 +if [[ -z $1 ]]; then + echo "Missing dataset argument" + exit 1 +fi +if [[ "$DS_RUN" == "ERP_CORE_"* ]]; then + DS="ERP_CORE" +else + DS="$1" +fi + +SECONDS=0 +pytest mne_bids_pipeline --junit-xml=test-results/junit-results.xml -k ${DS_RUN} +echo "Runtime: ${SECONDS} seconds" + +# rerun test (check caching)! +SECONDS=0 +RERUN_LIMIT=30 +if [[ "$RERUN_TEST" == "false" ]]; then + echo "Skipping rerun test" + RUN_TIME=0 +else + pytest mne_bids_pipeline --cov-append -k $DS_RUN + RUN_TIME=$SECONDS + echo "Runtime: ${RUN_TIME} seconds (should be <= $RERUN_LIMIT)" +fi +test $RUN_TIME -le $RERUN_LIMIT + +if [[ "$COPY_FILES" == "false" ]]; then + echo "Not copying files" + exit 0 +fi +mkdir -p ~/reports/${DS} +# these should always exist +cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.html ~/reports/${DS}/ +cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*.xlsx ~/reports/${DS}/ +# these are allowed to be optional +cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.json ~/reports/${DS}/ || : +cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.tsv ~/reports/${DS}/ || : +ls -al test-results/*.xml diff --git a/.circleci/setup_bash.sh b/.circleci/setup_bash.sh index 6bd99e4a0..dd21463e7 100755 --- a/.circleci/setup_bash.sh +++ b/.circleci/setup_bash.sh @@ -33,15 +33,13 @@ fi # Set up image sudo ln -s /usr/lib/x86_64-linux-gnu/libxcb-util.so.0 /usr/lib/x86_64-linux-gnu/libxcb-util.so.1 -wget -q -O- http://neuro.debian.net/lists/focal.us-tn.libre | sudo tee /etc/apt/sources.list.d/neurodebian.sources.list -sudo apt-key adv --recv-keys --keyserver hkps://keyserver.ubuntu.com 0xA5D32F012649A5A9 -echo "export RUN_TESTS=\"pytest mne_bids_pipeline --junit-xml=test-results/junit-results.xml -k\"" >> "$BASH_ENV" -echo "export DOWNLOAD_DATA=\"python -m mne_bids_pipeline._download\"" >> "$BASH_ENV" +echo "export RUN_TESTS=\".circleci/run_dataset_and_copy_files.sh\"" >> "$BASH_ENV" +echo "export DOWNLOAD_DATA=\"coverage run -m mne_bids_pipeline._download\"" >> "$BASH_ENV" # Similar CircleCI setup to mne-python (Xvfb, venv, minimal commands, env vars) wget -q https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/setup_xvfb.sh bash setup_xvfb.sh -sudo apt install -qq tcsh git-annex-standalone python3.10-venv python3-venv libxft2 +sudo apt install -qq tcsh python3.10-venv python3-venv libxft2 python3.10 -m venv ~/python_env wget -q https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/get_minimal_commands.sh source get_minimal_commands.sh @@ -57,6 +55,7 @@ echo 'export MPLBACKEND=Agg' >> "$BASH_ENV" echo "source ~/python_env/bin/activate" >> "$BASH_ENV" echo "export MNE_3D_OPTION_MULTI_SAMPLES=1" >> "$BASH_ENV" echo "export MNE_BIDS_PIPELINE_FORCE_TERMINAL=true" >> "$BASH_ENV" +echo "export FORCE_COLOR=1" >> "$BASH_ENV" # for rich to use color in logs mkdir -p ~/.local/bin if [[ ! -f ~/.local/bin/python ]]; then ln -s ~/python_env/bin/python ~/.local/bin/python diff --git a/.git_archival.txt b/.git_archival.txt new file mode 100644 index 000000000..7c5100942 --- /dev/null +++ b/.git_archival.txt @@ -0,0 +1,3 @@ +node: $Format:%H$ +node-date: $Format:%cI$ +describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..00a7b00c9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +.git_archival.txt export-subst diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..4a47c7a99 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,15 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + labels: + - "dependabot" + commit-message: + prefix: "[dependabot]" diff --git a/.github/workflows/circleci-redirector.yml b/.github/workflows/circleci-redirector.yml index 87b6a38cb..812e116ce 100644 --- a/.github/workflows/circleci-redirector.yml +++ b/.github/workflows/circleci-redirector.yml @@ -10,5 +10,6 @@ jobs: uses: larsoner/circleci-artifacts-redirector-action@master with: repo-token: ${{ secrets.GITHUB_TOKEN }} + api-token: ${{ secrets.CIRCLECI_TOKEN }} artifact-path: 0/site/index.html circleci-jobs: build_docs diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1d5a786d7..ddd6fae1a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -18,9 +18,9 @@ jobs: package: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' - name: Install dependencies @@ -34,7 +34,7 @@ jobs: - name: Check env vars run: | echo "Triggered by: ${{ github.event_name }}" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: dist path: dist @@ -44,13 +44,12 @@ jobs: needs: package runs-on: ubuntu-latest if: github.event_name == 'release' + permissions: + id-token: write steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: dist path: dist - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 7f7ae75ea..879eab4f9 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -6,32 +6,70 @@ concurrency: on: [push, pull_request] jobs: - check-style: - name: Style - runs-on: "ubuntu-latest" - defaults: - run: - shell: bash -l {0} - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - - name: Install ruff and codespell - run: pip install ruff codespell tomli - - run: make ruff - - run: make codespell-error - - uses: psf/black@stable check-doc: - name: Doc consistency + name: Doc consistency and codespell runs-on: ubuntu-latest defaults: run: shell: bash -l {0} steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - - run: pip install --upgrade setuptools[toml] pip - - run: pip install --no-build-isolation -ve .[tests] + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - run: pip install --upgrade pip + - run: pip install -ve .[tests] codespell tomli + - run: make codespell-error - run: pytest mne_bids_pipeline -m "not dataset_test" - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 if: success() name: 'Upload coverage to CodeCov' + caching: + name: 'Caching on ${{ matrix.os }}' + timeout-minutes: 30 + continue-on-error: true + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -el {0} + strategy: + matrix: + include: + - os: ubuntu-latest + - os: macos-latest + - os: windows-latest + env: + MNE_BIDS_PIPELINE_LEGACY_WINDOWS: "false" + PYTHONIOENCODING: 'utf8' # for Windows + steps: + - uses: actions/checkout@v4 + - uses: pyvista/setup-headless-display-action@main + with: + qt: true + pyvista: false + - uses: actions/setup-python@v5 + with: + python-version: "3.11" # no "multidict" wheels on 3.12 yet + - run: pip install -ve .[tests] + - uses: actions/cache@v4 + with: + key: ds001971 + path: ~/mne_data/ds001971 + id: ds001971-cache + - run: python -m mne_bids_pipeline._download ds001971 + if: steps.ds001971-cache.outputs.cache-hit != 'true' + - run: pytest --cov-append -k ds001971 mne_bids_pipeline/ + - run: pytest --cov-append -k ds001971 mne_bids_pipeline/ # uses "hash" method + timeout-minutes: 1 + - uses: actions/cache@v4 + with: + key: ds003392 + path: ~/mne_data/ds003392 + id: ds003392-cache + - run: python -m mne_bids_pipeline._download ds003392 + if: steps.ds003392-cache.outputs.cache-hit != 'true' + - run: pytest --cov-append -k ds003392 mne_bids_pipeline/ + - run: pytest --cov-append -k ds003392 mne_bids_pipeline/ # uses "mtime" method + timeout-minutes: 1 + - uses: codecov/codecov-action@v4 + if: success() diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 66dfbf03c..beebf2d67 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,22 +1,24 @@ --- -# Eventually we should use yamllint, too files: ^(.*\.(py|yaml))$ -exclude: ^(\.[^/]*cache/.*)$ +# We need to match the exclude list in pyproject.toml because pre-commit.ci +# passes filenames and these do not get passed through the tool.black filter +# for example +exclude: ^(\.[^/]*cache/.*|.*/freesurfer/contrib/.*)$ repos: - - repo: https://github.com/psf/black - rev: 22.10.0 - hooks: - - id: black - args: - - --safe - - --quiet - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.178 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.10 hooks: - id: ruff + args: ["--fix"] + - id: ruff-format - repo: https://github.com/codespell-project/codespell - rev: v2.2.2 + rev: v2.3.0 hooks: - id: codespell additional_dependencies: - tomli + - repo: https://github.com/adrienverge/yamllint.git + rev: v1.35.1 + hooks: + - id: yamllint + args: [--strict] diff --git a/BUILDING.md b/BUILDING.md deleted file mode 100644 index 41f1e327d..000000000 --- a/BUILDING.md +++ /dev/null @@ -1,9 +0,0 @@ -# Building a release - -* Tag a new release with `git` if necessary. -* Create `sdist` distribution: - - ```shell - pip install -q build - python -m build # will build sdist and wheel - ``` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a403bc8f2..cea32e59d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,9 +1,9 @@ -# Overview +# Contributing to MNE-BIDS-Pipeline Contributors to MNE-BIDS-Pipeline are expected to follow our [Code of Conduct](https://github.com/mne-tools/.github/blob/main/CODE_OF_CONDUCT.md). -# Installation +## Installation First, you need to make sure you have MNE-Python installed and working on your system. See the [installation instructions](http://martinos.org/mne/stable/install_mne_python.html). @@ -11,17 +11,11 @@ Once this is done, you should be able to run this in a terminal: `$ python -c "import mne; mne.sys_info()"` -You can then install the following additional packages via `pip`. Note that +You can then install the following additional package via `pip`. Note that the URL points to the bleeding edge version of `mne_bids`: -`$ pip install datalad` `$ pip install https://github.com/mne-tools/mne-bids/zipball/main` -To get the test data, you need to install `git-annex` on your system. If you -installed MNE-Python via `conda`, you can simply call: - -`conda install -c conda-forge git-annex` - Now, get the pipeline through git: `$ git clone https://github.com/mne-tools/mne-bids-pipeline.git` @@ -32,9 +26,9 @@ If you do not know how to use git, download the pipeline as a zip file Finally, for source analysis you'll also need `FreeSurfer`, follow the instructions on [their website](https://surfer.nmr.mgh.harvard.edu/). -# Testing +## Testing -## Running the tests, and continuous integration +### Running the tests, and continuous integration The tests are run using `pytest`. You can run them by calling `pytest mne_bids_pipeline` to run @@ -48,7 +42,7 @@ For every pull request or merge into the `main` branch of the [CircleCI](https://circleci.com/gh/brainthemind/CogBrainDyn_MEG_Pipeline) will run tests as defined in `./circleci/config.yml`. -## Debugging +### Debugging To run the test in debugging mode, just pass `--pdb` to the `pytest` call as usual. This will place you in debugging mode on failure. @@ -56,9 +50,9 @@ See the [pdb help](https://docs.python.org/3/library/pdb.html#debugger-commands) for more commands. -## Config files +### Config files -Nested in the `/tests` directory is a `/configs` directory, which contains +Nested in the `tests` directory is a `configs` directory, which contains config files for specific test datasets. For example, the `config_ds001810.py` file specifies parameters only for the `ds001810` data, which should overwrite the more general parameters in the main `_config.py` file. diff --git a/Makefile b/Makefile index 8af267201..3199402d6 100644 --- a/Makefile +++ b/Makefile @@ -23,8 +23,6 @@ doc: check: which python - git-annex version - datalad --version openneuro-py --version mri_convert --version mne_bids --version @@ -33,10 +31,6 @@ check: trailing-spaces: find . -name "*.py" | xargs perl -pi -e 's/[ \t]*$$//' -ruff: - ruff . - @echo "ruff passed" - codespell: # running manually; auto-fix spelling mistakes @codespell --write-changes $(CODESPELL_DIRS) diff --git a/README.md b/README.md index a18948e22..075e1e2b8 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Structure (BIDS)](https://bids.neuroimaging.io/). * 👣 Data processing as a sequence of standard processing steps. * ⏩ Steps are cached to avoid unnecessary recomputation. * ⏏️ Data can be "ejected" from the pipeline at any stage. No lock-in! -* ☁️ Runs on your laptop, on a powerful server, or on a high-performance cluster via Dash. +* ☁️ Runs on your laptop, on a powerful server, or on a high-performance cluster via Dask. @@ -44,7 +44,7 @@ developed for this publication: > M. Jas, E. Larson, D. A. Engemann, J. Leppäkangas, S. Taulu, M. Hämäläinen, > A. Gramfort (2018). A reproducible MEG/EEG group study with the MNE software: > recommendations, quality assessments, and good practices. Frontiers in -> neuroscience, 12. https://doi.org/10.3389/fnins.2018.00530 +> neuroscience, 12. The current iteration is based on BIDS and relies on the extensions to BIDS for EEG and MEG. See the following two references: @@ -52,10 +52,10 @@ for EEG and MEG. See the following two references: > Pernet, C. R., Appelhoff, S., Gorgolewski, K. J., Flandin, G., > Phillips, C., Delorme, A., Oostenveld, R. (2019). EEG-BIDS, an extension > to the brain imaging data structure for electroencephalography. Scientific -> Data, 6, 103. https://doi.org/10.1038/s41597-019-0104-8 +> Data, 6, 103. > Niso, G., Gorgolewski, K. J., Bock, E., Brooks, T. L., Flandin, G., Gramfort, A., > Henson, R. N., Jas, M., Litvak, V., Moreau, J., Oostenveld, R., Schoffelen, J., > Tadel, F., Wexler, J., Baillet, S. (2018). MEG-BIDS, the brain imaging data > structure extended to magnetoencephalography. Scientific Data, 5, 180110. -> https://doi.org/10.1038/sdata.2018.110 +> diff --git a/docs/build-docs.sh b/docs/build-docs.sh index cac6ba96b..4a94ab206 100755 --- a/docs/build-docs.sh +++ b/docs/build-docs.sh @@ -1,4 +1,6 @@ -#!/bin/bash -e +#!/bin/bash + +set -eo pipefail STEP_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) @@ -8,6 +10,9 @@ python $STEP_DIR/source/examples/gen_examples.py echo "Generating pipeline table …" python $STEP_DIR/source/features/gen_steps.py +echo "Generating config docs …" +python $STEP_DIR/source/settings/gen_settings.py + echo "Building the documentation …" cd $STEP_DIR mkdocs build diff --git a/docs/hooks.py b/docs/hooks.py index ab7192f6e..41d73d3d4 100644 --- a/docs/hooks.py +++ b/docs/hooks.py @@ -1,19 +1,31 @@ +"""Custom hooks for MkDocs-Material.""" + import logging -from typing import Dict, Any +from typing import Any from mkdocs.config.defaults import MkDocsConfig -from mkdocs.structure.pages import Page from mkdocs.structure.files import Files +from mkdocs.structure.pages import Page + +from mne_bids_pipeline._docs import _ParseConfigSteps logger = logging.getLogger("mkdocs") config_updated = False +# This hack can be cleaned up once this is resolved: +# https://github.com/mkdocstrings/mkdocstrings/issues/615#issuecomment-1971568301 +def on_pre_build(config: MkDocsConfig) -> None: + """Monkey patch mkdocstrings-python jinja template to have global vars.""" + python_handler = config["plugins"]["mkdocstrings"].get_handler("python") + python_handler.env.globals["pipeline_steps"] = _ParseConfigSteps() + + # Ideally there would be a better hook, but it's unclear if context can # be obtained any earlier def on_template_context( - context: Dict[str, Any], + context: dict[str, Any], template_name: str, config: MkDocsConfig, ) -> None: @@ -46,6 +58,7 @@ def on_page_markdown( config: MkDocsConfig, files: Files, ) -> str: + """Replace emojis.""" if page.file.name == "index" and page.title == "Home": for rd, md in _EMOJI_MAP.items(): markdown = markdown.replace(rd, md) diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index b30f8df2f..2cba8080e 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -6,6 +6,7 @@ site_description: The MNE-BIDS-Pipeline is a full-flegded processing pipeline fo site_author: MNE-BIDS-Pipeline authors docs_dir: ./source site_dir: ./site +strict: true use_directory_urls: false # For easier navigation on CircleCI watch: # Additional directories to watch for changes during `mkdocs serve` - ../mne_bids_pipeline @@ -64,6 +65,10 @@ copyright: Copyright © MNE-BIDS-Pipeline authors extra_css: - css/extra.css +# https://squidfunk.github.io/mkdocs-material/reference/data-tables/ +extra_javascript: + - https://unpkg.com/tablesort@5.3.0/dist/tablesort.min.js + - javascripts/tablesort.js nav: - Home: index.md - Getting started: @@ -72,7 +77,7 @@ nav: - Preparations for source-level analyses: getting_started/freesurfer.md - Processing steps: - Overview: features/overview.md - - Detailed list of processing steps: features/steps.md + - List of processing steps: features/steps.md - Configuration options: - General settings: settings/general.md - Preprocessing: @@ -85,7 +90,7 @@ nav: - Epoching: settings/preprocessing/epochs.md - Artifact removal: - Stimulation artifact: settings/preprocessing/stim_artifact.md - - SSP & ICA: settings/preprocessing/ssp_ica.md + - SSP, ICA, and artifact regression: settings/preprocessing/ssp_ica.md - Amplitude-based artifact rejection: settings/preprocessing/artifacts.md - Sensor-level analysis: - Condition contrasts: settings/sensor/contrasts.md @@ -98,6 +103,11 @@ nav: - Source space & forward solution: settings/source/forward.md - Inverse solution: settings/source/inverse.md - Report generation: settings/reports/report_generation.md + - Caching: settings/caching.md + - Parallelization: settings/parallelization.md + - Logging: settings/logging.md + - Error handling: settings/error_handling.md + - Examples: - Examples Gallery: examples/examples.md - examples/ds003392.md @@ -106,11 +116,12 @@ nav: - examples/ds000247.md - examples/ds000248_base.md - examples/ds000248_ica.md - # - examples/ds000248_T1_BEM.md - # - examples/ds000248_FLASH_BEM.md + - examples/ds000248_T1_BEM.md + - examples/ds000248_FLASH_BEM.md - examples/ds000248_no_mri.md - examples/ds003104.md - examples/eeg_matchingpennies.md + - examples/MNE-phantom-KIT-data.md - examples/ds001810.md - examples/ds000117.md - examples/ds003775.md @@ -126,8 +137,14 @@ plugins: - tags: tags_file: tags.md - include-markdown + - exclude: + glob: + - "*.py" # Python scripts + - "*.inc" # includes - mkdocstrings: default_handler: python + enable_inventory: true + custom_templates: templates handlers: python: paths: # Where to find the packages and modules to import @@ -139,6 +156,9 @@ plugins: show_root_toc_entry: false show_root_full_path: false separate_signature: true + show_signature_annotations: true + unwrap_annotated: true + signature_crossrefs: true line_length: 88 # Black's default show_bases: false docstring_style: numpy @@ -150,8 +170,8 @@ markdown_extensions: - attr_list # Allows to turn any element into a button - pymdownx.details - pymdownx.emoji: - emoji_index: !!python/name:materialx.emoji.twemoji - emoji_generator: !!python/name:materialx.emoji.to_svg + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg - pymdownx.superfences: custom_fences: - name: mermaid @@ -167,6 +187,7 @@ markdown_extensions: repo_url_shorthand: true repo: mne-bids-pipeline user: mne-tools + - tables - toc: permalink: true # Add paragraph symbol to link to current headline - pymdownx.tabbed: diff --git a/docs/source/.gitignore b/docs/source/.gitignore index 77afb012b..b909cc5d3 100644 --- a/docs/source/.gitignore +++ b/docs/source/.gitignore @@ -1 +1,4 @@ features/steps.md +features/overview.md +settings/**/ +settings/*.md diff --git a/docs/source/changes.md b/docs/source/changes.md index 8d0af3b72..8101723ca 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -1,3 +1,13 @@ +{% include-markdown "./v1.9.md.inc" %} + +{% include-markdown "./v1.8.md.inc" %} + +{% include-markdown "./v1.7.md.inc" %} + +{% include-markdown "./v1.6.md.inc" %} + +{% include-markdown "./v1.5.md.inc" %} + {% include-markdown "./v1.4.md.inc" %} {% include-markdown "./v1.3.md.inc" %} diff --git a/docs/source/css/extra.css b/docs/source/css/extra.css index 7ca920b15..d66d71ec9 100644 --- a/docs/source/css/extra.css +++ b/docs/source/css/extra.css @@ -101,3 +101,9 @@ td p { .md-button { margin-top: 1rem !important; } + +/* Ensure the link in the "You're not viewing the latest stable version" banner is + readable in both dark and light theme. */ +:root { + --md-typeset-a-color: var(--md-default-fg-color); +} diff --git a/docs/source/doc-config.py b/docs/source/doc-config.py deleted file mode 100644 index e3846c3a7..000000000 --- a/docs/source/doc-config.py +++ /dev/null @@ -1,2 +0,0 @@ -bids_root = "/tmp" -ch_types = ["meg"] diff --git a/docs/source/examples/gen_examples.py b/docs/source/examples/gen_examples.py index 08aa7d94f..07bf58b4a 100755 --- a/docs/source/examples/gen_examples.py +++ b/docs/source/examples/gen_examples.py @@ -1,26 +1,29 @@ #!/usr/bin/env python -from collections import defaultdict +"""Generate documentation pages for our examples gallery.""" + import contextlib import logging import shutil -from pathlib import Path import sys -from typing import Union, Iterable +from collections import defaultdict +from collections.abc import Iterable +from pathlib import Path + +from tqdm import tqdm import mne_bids_pipeline -from mne_bids_pipeline._config_import import _import_config import mne_bids_pipeline.tests.datasets -from mne_bids_pipeline.tests.test_run import TEST_SUITE +from mne_bids_pipeline._config_import import _import_config from mne_bids_pipeline.tests.datasets import DATASET_OPTIONS -from tqdm import tqdm +from mne_bids_pipeline.tests.test_run import TEST_SUITE this_dir = Path(__file__).parent root = Path(mne_bids_pipeline.__file__).parent.resolve(strict=True) logger = logging.getLogger() -def _bool_to_icon(x: Union[bool, Iterable]) -> str: +def _bool_to_icon(x: bool | Iterable) -> str: if x: return "✅" else: @@ -61,6 +64,8 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: key = "Maxwell filter" funcs[key] = funcs[key] or config.use_maxwell_filter funcs["Frequency filter"] = config.l_freq or config.h_freq + key = "Artifact regression" + funcs[key] = funcs[key] or (config.regress_artifact is not None) key = "SSP" funcs[key] = funcs[key] or (config.spatial_filter == "ssp") key = "ICA" @@ -104,6 +109,18 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: datasets_without_html.append(dataset_name) continue + # For ERP_CORE, cut down on what we show otherwise our website is huge + if "ERP_CORE" in test_name: + show = ["015", "average"] + orig_fnames = html_report_fnames + html_report_fnames = [ + f + for f in html_report_fnames + if any(f"sub-{s}" in f.parts and f"sub-{s}" in f.name for s in show) + ] + assert len(html_report_fnames), orig_fnames + del orig_fnames + fname_iter = tqdm( html_report_fnames, desc=f" {test_name}", @@ -133,7 +150,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: ) if dataset_name in all_demonstrated: logger.warning( - f"Duplicate dataset name {test_dataset_name} -> {dataset_name}, " "skipping" + f"Duplicate dataset name {test_dataset_name} -> {dataset_name}, skipping" ) continue del test_dataset_options, test_dataset_name @@ -142,7 +159,8 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: logger.warning(f"Dataset {dataset_name} has no HTML report.") continue - options = DATASET_OPTIONS[dataset_options_key] + assert dataset_options_key in DATASET_OPTIONS, dataset_options_key + options = DATASET_OPTIONS[dataset_options_key].copy() # we modify locally report_str = "\n## Generated output\n\n" example_target_dir = this_dir / dataset_name @@ -198,20 +216,22 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: f"{fname.name} :fontawesome-solid-square-poll-vertical:\n\n" ) - if options["openneuro"]: + assert sum(key in options for key in ("openneuro", "web", "mne")) == 1 + if "openneuro" in options: url = f'https://openneuro.org/datasets/{options["openneuro"]}' - elif options["git"]: - url = options["git"] - elif options["web"]: + elif "web" in options: url = options["web"] else: - url = "" + assert "mne" in options + url = f"https://mne.tools/dev/generated/mne.datasets.{options['mne']}.data_path.html" # noqa: E501 source_str = ( f"## Dataset source\n\nThis dataset was acquired from " f"[{url}]({url})\n" ) - if options["openneuro"]: + if "openneuro" in options: + for key in ("include", "exclude"): + options[key] = options.get(key, []) download_str = ( f'\n??? example "How to download this dataset"\n' f" Run in your terminal:\n" @@ -240,7 +260,9 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: # TODO: For things like ERP_CORE_ERN, decoding_csp are not populated # properly by the root config - config_path = root / "tests" / "configs" / f"config_{dataset_name}.py" + config_path = ( + root / "tests" / "configs" / f"config_{dataset_name.replace('-', '_')}.py" + ) config = config_path.read_text(encoding="utf-8-sig").strip() descr_end_idx = config[2:].find('"""') config_descr = "# " + config[: descr_end_idx + 1].replace('"""', "").strip() @@ -257,6 +279,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: demonstrated_funcs = _gen_demonstrated_funcs(config_path) all_demonstrated[dataset_name] = demonstrated_funcs del config, config_options + # Add the subsection and table header funcs = [ "## Demonstrated features\n", "Feature | This example", @@ -275,7 +298,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: f.write(config_str) f.write(report_str) -# Finally, write our examples.html file +# Finally, write our examples.html file with a table of examples _example_header = """\ # Examples diff --git a/docs/source/features/gen_steps.py b/docs/source/features/gen_steps.py index fffc61ddf..ad4d7ae1c 100755 --- a/docs/source/features/gen_steps.py +++ b/docs/source/features/gen_steps.py @@ -3,22 +3,115 @@ import importlib from pathlib import Path + from mne_bids_pipeline._config_utils import _get_step_modules -pre = """\ -# Detailed lis of processing steps +autogen_header = f"""\ +[//]: # (AUTO-GENERATED, TO CHANGE EDIT {'/'.join(Path(__file__).parts[-4:])}) +""" + +steps_pre = f"""\ +{autogen_header} + +# List of processing steps The following table provides a concise summary of each processing step. The step names can be used to run individual steps or entire groups of steps by -passing their name(s) to `mne_bids_pipeline` via the `steps=...` argument. +passing their name(s) to `mne_bids_pipeline` via the `steps=...` argument. However, +we recommend using `mne_bids_pipeline config.py` to run the entire pipeline +instead to ensure that all steps affected by a given change are re-run. +""" # noqa: E501 + +overview_pre = f"""\ +{autogen_header} + +# MNE-BIDS-Pipeline overview + +MNE-BIDS-Pipeline processes your data in a sequential manner, i.e., one step +at a time. The next step is only run after the previous steps have been +successfully completed. There are, of course, exceptions; for example, if you +chose not to apply ICA or SSP, the spatial filtering steps will simply be omitted and +we'll directly move to the subsequent steps. See [the flowchart below](#flowchart) for +a visualization of the steps, or check out the +[list of processing steps](steps.md) for more information. + +All intermediate results are saved to disk for later inspection, and an +**extensive report** is generated. Analyses are conducted on individual (per-subject) +as well as group level. + +## Caching + +MNE-BIDS-Pipeline offers automated caching of intermediate results. This means that +running `mne_bids_pipeline config.py` once will generate all outputs, and running it +again will only re-run the steps that need rerunning based on: + +1. Changes to files on disk (e.g., updates to `bids_root` files), and +2. Changes to `config.py` + +This is particularly useful when you are developing your pipeline, as you can quickly +iterate over changes to your pipeline without having to re-run the entire pipeline +every time -- only the steps that need to be re-run will be executed. + +## Flowchart + +For more detailed information on each step, please refer to the [detailed list +of processing steps](steps.md). """ +icon_map = { + "Filesystem initialization and dataset inspection": ":open_file_folder:", + "Preprocessing": ":broom:", + "Sensor-space analysis": ":satellite:", + "Source-space analysis": ":brain:", + "FreeSurfer-related processing": ":person_surfing:", +} +out_dir = Path(__file__).parent + print("Generating steps …") step_modules = _get_step_modules() +char_start = ord("A") + +# In principle we could try to sort this out based on naming, but for now let's just +# set our hierarchy manually and update it when we move files around since that's easy +# (and rare) enough to do. +manual_order = { + "Preprocessing": ( + ("01", "02"), + ("02", "03"), + ("03", "04"), + ("04", "05"), + ("05", "06a1"), + ("06a1", "06a2"), + ("05", "06b"), + ("05", "07"), + # technically we could have the raw data flow here, but it doesn't really help + # ("05", "08a"), + # ("05", "08b"), + ("06a2", "08a"), + # Force the artifact-fitting and epoching steps on the same level, in this order + """\ + subgraph Z[" "] + direction LR + B06a1 + B07 + B06b + end + style Z fill:#0000,stroke-width:0px +""", + ("06b", "08b"), + ("07", "08a"), + ("07", "08b"), + ("08a", "09"), + ("08b", "09"), + ), +} # Construct the lines of steps.md -lines = [pre] +lines = [steps_pre] +overview_lines = [overview_pre] +used_titles = set() for di, (dir_, modules) in enumerate(step_modules.items(), 1): + # Steps if dir_ == "all": continue # this is an alias dir_module = importlib.import_module(f"mne_bids_pipeline.steps.{dir_}") @@ -28,7 +121,9 @@ dir_body = dir_body[1].strip() else: dir_body = "" - lines.append(f"## {di}. {dir_header}\n") + icon = icon_map[dir_header] + module_header = f"{di}. {icon} {dir_header}" + lines.append(f"## {module_header}\n") if dir_body: lines.append(f"{dir_body}\n") lines.append("| Step name | Description |") @@ -41,5 +136,67 @@ step_title = module.__doc__.split("\n")[0] lines.append(f"`{step_name}` | {step_title} |") lines.append("") -with open(Path(__file__).parent / "steps.md", "w") as fid: - fid.write("\n".join(lines)) + + # Overview + overview_lines.append( + f"""\ +### {module_header} + +
+Click to expand + +```mermaid +flowchart TD""" + ) + chr_pre = chr(char_start + di - 1) # A, B, C, ... + start = None + prev_idx = None + title_map = {} + for mi, module in enumerate(modules, 1): + step_title = module.__doc__.split("\n")[0].rstrip(".") + idx = module.__name__.split(".")[-1].split("_")[1] # 01, 05a, etc. + # Need to quote the title to deal with parens, and sanitize quotes + step_title = step_title.replace('"', "'") + assert step_title not in used_titles, f"Redundant title: {step_title}" + used_titles.add(step_title) + this_block = f'{chr_pre}{idx}["{step_title}"]' + # special case: manual order + title_map[idx] = step_title + if dir_header in manual_order: + continue + if mi == 1: + start = this_block + assert prev_idx is None + continue + if start is not None: + assert mi == 2, mi + overview_lines.append(f" {start} --> {this_block}") + start = None + else: + overview_lines.append(f" {chr_pre}{prev_idx} --> {this_block}") + prev_idx = idx + if dir_header in manual_order: + mapped = set() + for a_b in manual_order[dir_header]: + if isinstance(a_b, str): # insert directly + overview_lines.append(a_b) + continue + assert isinstance(a_b, tuple), type(a_b) + a_b = list(a_b) # allow modification + for ii, idx in enumerate(a_b): + assert idx in title_map, (dir_header, idx, sorted(title_map)) + if idx not in mapped: + mapped.add(idx) + a_b[ii] = f'{idx}["{title_map[idx]}"]' + overview_lines.append(f" {chr_pre}{a_b[0]} --> {chr_pre}{a_b[1]}") + all_steps = set( + sum( + [a_b for a_b in manual_order[dir_header] if not isinstance(a_b, str)], + (), + ) + ) + assert mapped == all_steps, all_steps.symmetric_difference(mapped) + overview_lines.append("```\n\n
\n") + +(out_dir / "steps.md").write_text("\n".join(lines), encoding="utf8") +(out_dir / "overview.md").write_text("\n".join(overview_lines), encoding="utf8") diff --git a/docs/source/features/overview.md b/docs/source/features/overview.md deleted file mode 100644 index 9fe044038..000000000 --- a/docs/source/features/overview.md +++ /dev/null @@ -1,53 +0,0 @@ -MNE-BIDS-Pipeline processes your data in a sequential manner, i.e., one step -at a time. The next step is only run after the previous steps have been -successfully completed. There are, of course, exceptions; for example, if you -chose not to apply ICA, the respective steps will simply be omitted and we'll -directly move to the subsequent steps. The following flow chart aims to give -you a brief overview of which steps are included in the pipeline, in which -order they are run, and how we group them together. - -!!! info - All intermediate results are saved to disk for later - inspection, and an **extensive report** is generated. - -!!! info - Analyses are conducted on individual (per-subject) as well as group level. - - -## :open_file_folder: Filesystem initialization and dataset inspection -```mermaid -flowchart TD - A1[initialize the target directories] --> A2[locate empty-room recordings] -``` - -## :broom: Preprocessing -```mermaid - flowchart TD - B1[Noisy & flat channel detection] --> B2[Maxwell filter] - B2 --> B3[Frequency filter] - B3 --> B4[Epoch creation] - B4 --> B5[SSP or ICA fitting] - B5 --> B6[Artifact removal via SSP or ICA] - B6 --> B7[Amplitude-based epoch rejection] -``` - -## :satellite: Sensor-space processing -```mermaid - flowchart TD - C1[ERP / ERF calculation] --> C2[MVPA: full epochs] - C2 --> C3[MVPA: time-by-time decoding] - C3 --> C4[Time-frequency decomposition] - C4 --> C5[MVPA: CSP] - C5 --> C6[Noise covariance estimation] - C6 --> C7[Grand average] -``` - -## :brain: Source-space processing -```mermaid - flowchart TD - D1[BEM surface creation] --> D2[BEM solution] - D2 --> D3[Source space creation] - D3 --> D4[Forward model creation] - D4 --> D5[Inverse solution] - D5 --> D6[Grand average] -``` diff --git a/docs/source/javascripts/tablesort.js b/docs/source/javascripts/tablesort.js new file mode 100644 index 000000000..2e9fd4e51 --- /dev/null +++ b/docs/source/javascripts/tablesort.js @@ -0,0 +1,6 @@ +document$.subscribe(function() { + var tables = document.querySelectorAll("article table:not([class])") + tables.forEach(function(table) { + new Tablesort(table) + }) + }) diff --git a/docs/source/settings/gen_settings.py b/docs/source/settings/gen_settings.py new file mode 100755 index 000000000..7dc56d32f --- /dev/null +++ b/docs/source/settings/gen_settings.py @@ -0,0 +1,201 @@ +"""Generate settings .md files.""" + +# Any changes to the overall structure need to be reflected in mkdocs.yml nav section. + +import re +from pathlib import Path + +from tqdm import tqdm + +import mne_bids_pipeline._config + +config_path = Path(mne_bids_pipeline._config.__file__) +settings_dir = Path(__file__).parent + +# Mapping between first two lower-case words in the section name and the desired +# file or folder name +section_to_file = { # .md will be added to the files + # root file + "general settings": "general", + # folder + "preprocessing": "preprocessing", + "break detection": "breaks", + "bad channel": "autobads", + "maxwell filter": "maxfilter", + "filtering": "filter", + "resampling": "resample", + "epoching": "epochs", + "filtering &": None, # just a header + "artifact removal": None, + "stimulation artifact": "stim_artifact", + "ssp, ica,": "ssp_ica", + "amplitude-based artifact": "artifacts", + # folder + "sensor-level analysis": "sensor", + "condition contrasts": "contrasts", + "decoding /": "mvpa", + "time-frequency analysis": "time_freq", + "group-level analysis": "group_level", + # folder + "source-level analysis": "source", + "general source": "general", + "bem surface": "bem", + "source space": "forward", + "inverse solution": "inverse", + # folder + "reports": "reports", + "report generation": "report_generation", + # root file + "caching": "caching", + # root file + "parallelization": "parallelization", + # root file + "logging": "logging", + # root file + "error handling": "error_handling", +} +# TODO: Make sure these are consistent, autogenerate some based on section names, +# and/or autogenerate based on inputs/outputs of actual functions. +section_tags = { + "general settings": (), + "preprocessing": (), + "filtering &": (), + "artifact removal": (), + "break detection": ("preprocessing", "artifact-removal", "raw", "events"), + "bad channel": ("preprocessing", "raw", "bad-channels"), + "maxwell filter": ("preprocessing", "maxwell-filter", "raw"), + "filtering": ("preprocessing", "frequency-filter", "raw"), + "resampling": ("preprocessing", "resampling", "decimation", "raw", "epochs"), + "epoching": ("preprocessing", "epochs", "events", "metadata", "resting-state"), + "stimulation artifact": ("preprocessing", "artifact-removal", "raw", "epochs"), + "ssp, ica,": ("preprocessing", "artifact-removal", "raw", "epochs", "ssp", "ica"), + "amplitude-based artifact": ("preprocessing", "artifact-removal", "epochs"), + "sensor-level analysis": (), + "condition contrasts": ("epochs", "evoked", "contrast"), + "decoding /": ("epochs", "evoked", "contrast", "decoding", "mvpa"), + "time-frequency analysis": ("epochs", "evoked", "time-frequency"), + "group-level analysis": ("evoked", "group-level"), + "source-level analysis": (), + "general source": ("inverse-solution",), + "bem surface": ("inverse-solution", "bem", "freesurfer"), + "source space": ("inverse-solution", "forward-model"), + "inverse solution": ("inverse-solution",), + "reports": (), + "report generation": ("report",), + "caching": ("cache",), + "parallelization": ("paralleliation", "dask", "out-of-core"), + "logging": ("logging", "error-handling"), + "error handling": ("error-handling",), +} + +extra_headers = { + "general settings": """\ +!!! info + Many settings in this section control the pipeline behavior very early in the + pipeline. Therefore, for most of them (e.g., `bids_root`) we do not list the + steps that directly depend on the setting. The options with drop-down step + lists (e.g., `random_state`) have more localized effects. +""" +} + +option_header = """\ +::: mne_bids_pipeline._config + options: + members:""" +prefix = """\ + - """ + +# We cannot use ast for this because it doesn't preserve comments. We could use +# something like redbaron, but our code is hopefully simple enough! +assign_re = re.compile( + "^" # The line starts, then is followed by + r"(\w+): " # annotation syntax (name captured by the first group), + "(?:" # then the rest of the line can be (in a non-capturing group): + ".+ = .+" # 1. a standard assignment + "|" # 2. or + r"Literal\[" # 3. the start of a multiline type annotation like "a: Literal[" + "|" # 4. or + r"\(" # 5. the start of a multiline 3.9+ type annotation like "a: (" + ")" # Then the end of our group + "$", # and immediately the end of the line. + re.MULTILINE, +) + + +def main(): + """Parse the configuration and generate the markdown documentation.""" + print(f"Parsing {config_path} to generate settings .md files.") + # max file-level depth is 2 even though we have 3 subsection levels + levels = [None, None] + current_path, current_lines = None, list() + text = config_path.read_text("utf-8") + lines = text.splitlines() + lines += ["# #"] # add a dummy line to trigger the last write + in_header = False + have_params = False + for li, line in enumerate(tqdm(lines)): + line = line.rstrip() + if line.startswith("# #"): # a new (sub)section / file + this_def = line[2:] + this_level = this_def.split()[0] + assert this_level.count("#") == len(this_level), this_level + this_level = this_level.count("#") - 1 + if this_level == 2: + # flatten preprocessing/filtering/filter to preprocessing/filter + # for example + this_level = 1 + assert this_level in (0, 1), (this_level, this_def) + this_def = this_def[this_level + 2 :] + levels[this_level] = this_def + # Write current lines and reset + if have_params: # more than just the header + assert current_path is not None, levels + if current_lines[0] == "": # this happens with tags + current_lines = current_lines[1:] + current_path.write_text("\n".join(current_lines + [""]), "utf-8") + have_params = False + if this_level == 0: + this_root = settings_dir + else: + this_root = settings_dir / f"{section_to_file[levels[0].lower()]}" + this_root.mkdir(exist_ok=True) + key = " ".join(this_def.split()[:2]).lower() + if key == "": + assert li == len(lines) - 1, (li, line) + continue # our dummy line + fname = section_to_file[key] + if fname is None: + current_path = None + else: + current_path = this_root / f"{fname}.md" + in_header = True + current_lines = list() + if len(section_tags[key]): + current_lines += ["---", "tags:"] + current_lines += [f" - {tag}" for tag in section_tags[key]] + current_lines += ["---"] + if key in extra_headers: + current_lines.extend(["", extra_headers[key]]) + continue + + if in_header: + if line == "": + in_header = False + if current_lines: + current_lines.append("") + current_lines.append(option_header) + else: + assert line == "#" or line.startswith("# "), (li, line) # a comment + current_lines.append(line[2:]) + continue + + # Could be an option + match = assign_re.match(line) + if match is not None: + have_params = True + current_lines.append(f"{prefix}{match.groups()[0]}") + continue + + +if __name__ == "__main__": + main() diff --git a/docs/source/settings/general.md b/docs/source/settings/general.md deleted file mode 100644 index d0814cce6..000000000 --- a/docs/source/settings/general.md +++ /dev/null @@ -1,47 +0,0 @@ -::: mne_bids_pipeline._config - options: - members: - - study_name - - bids_root - - deriv_root - - subjects_dir - - interactive - - sessions - - task - - task_is_rest - - runs - - exclude_runs - - crop_runs - - acq - - proc - - rec - - space - - subjects - - exclude_subjects - - process_empty_room - - process_rest - - ch_types - - data_type - - eog_channels - - eeg_bipolar_channels - - eeg_reference - - eeg_template_montage - - drop_channels - - reader_extra_params - - read_raw_bids_verbose - - analyze_channels - - plot_psd_for_runs - - n_jobs - - parallel_backend - - dask_open_dashboard - - dask_temp_dir - - dask_worker_memory_limit - - random_state - - shortest_event - - memory_location - - memory_file_method - - memory_verbose - - config_validation - - log_level - - mne_log_level - - on_error diff --git a/docs/source/settings/preprocessing/artifacts.md b/docs/source/settings/preprocessing/artifacts.md deleted file mode 100644 index 9a81b4d15..000000000 --- a/docs/source/settings/preprocessing/artifacts.md +++ /dev/null @@ -1,19 +0,0 @@ ---- -tags: - - preprocessing - - artifact-removal - - epochs ---- - -???+ info "Good Practice / Advice" - Have a look at your raw data and train yourself to detect a blink, a heart - beat and an eye movement. - You can do a quick average of blink data and check what the amplitude looks - like. - -::: mne_bids_pipeline._config - options: - members: - - reject - - reject_tmin - - reject_tmax diff --git a/docs/source/settings/preprocessing/autobads.md b/docs/source/settings/preprocessing/autobads.md deleted file mode 100644 index a118917a1..000000000 --- a/docs/source/settings/preprocessing/autobads.md +++ /dev/null @@ -1,27 +0,0 @@ ---- -tags: - - preprocessing - - raw - - bad-channels ---- - -!!! warning - This functionality will soon be removed from the pipeline, and - will be integrated into MNE-BIDS. - -"Bad", i.e. flat and overly noisy channels, can be automatically detected -using a procedure inspired by the commercial MaxFilter by Elekta. First, -a copy of the data is low-pass filtered at 40 Hz. Then, channels with -unusually low variability are flagged as "flat", while channels with -excessively high variability are flagged as "noisy". Flat and noisy channels -are marked as "bad" and excluded from subsequent analysis. See -:func:`mne.preprocssessing.find_bad_channels_maxwell` for more information -on this procedure. The list of bad channels detected through this procedure -will be merged with the list of bad channels already present in the dataset, -if any. - -::: mne_bids_pipeline._config - options: - members: - - find_flat_channels_meg - - find_noisy_channels_meg diff --git a/docs/source/settings/preprocessing/breaks.md b/docs/source/settings/preprocessing/breaks.md deleted file mode 100644 index 01e3159eb..000000000 --- a/docs/source/settings/preprocessing/breaks.md +++ /dev/null @@ -1,15 +0,0 @@ ---- -tags: - - preprocessing - - artifact-removal - - raw - - events ---- - -::: mne_bids_pipeline._config - options: - members: - - find_breaks - - min_break_duration - - t_break_annot_start_after_previous_event - - t_break_annot_stop_before_next_event diff --git a/docs/source/settings/preprocessing/epochs.md b/docs/source/settings/preprocessing/epochs.md deleted file mode 100644 index 02dd1f71d..000000000 --- a/docs/source/settings/preprocessing/epochs.md +++ /dev/null @@ -1,26 +0,0 @@ ---- -tags: - - preprocessing - - epochs - - events - - metadata - - resting-state ---- - -::: mne_bids_pipeline._config - options: - members: - - rename_events - - on_rename_missing_events - - event_repeated - - conditions - - epochs_tmin - - epochs_tmax - - baseline - - epochs_metadata_tmin - - epochs_metadata_tmax - - epochs_metadata_keep_first - - epochs_metadata_keep_last - - epochs_metadata_query - - rest_epochs_duration - - rest_epochs_overlap diff --git a/docs/source/settings/preprocessing/filter.md b/docs/source/settings/preprocessing/filter.md deleted file mode 100644 index 9d1301412..000000000 --- a/docs/source/settings/preprocessing/filter.md +++ /dev/null @@ -1,37 +0,0 @@ ---- -tags: - - preprocessing - - frequency-filter - - raw ---- - -It is typically better to set your filtering properties on the raw data so -as to avoid what we call border (or edge) effects. - -If you use this pipeline for evoked responses, you could consider -a low-pass filter cut-off of h_freq = 40 Hz -and possibly a high-pass filter cut-off of l_freq = 1 Hz -so you would preserve only the power in the 1Hz to 40 Hz band. -Note that highpass filtering is not necessarily recommended as it can -distort waveforms of evoked components, or simply wash out any low -frequency that can may contain brain signal. It can also act as -a replacement for baseline correction in Epochs. See below. - -If you use this pipeline for time-frequency analysis, a default filtering -could be a high-pass filter cut-off of l_freq = 1 Hz -a low-pass filter cut-off of h_freq = 120 Hz -so you would preserve only the power in the 1Hz to 120 Hz band. - -If you need more fancy analysis, you are already likely past this kind -of tips! 😇 - -::: mne_bids_pipeline._config - options: - members: - - l_freq - - h_freq - - l_trans_bandwidth - - h_trans_bandwidth - - notch_freq - - notch_trans_bandwidth - - notch_widths diff --git a/docs/source/settings/preprocessing/maxfilter.md b/docs/source/settings/preprocessing/maxfilter.md deleted file mode 100644 index 2bdd10b26..000000000 --- a/docs/source/settings/preprocessing/maxfilter.md +++ /dev/null @@ -1,25 +0,0 @@ ---- -tags: - - preprocessing - - maxwell-filter - - raw ---- - -::: mne_bids_pipeline._config - options: - members: - - use_maxwell_filter - - mf_st_duration - - mf_st_correlation - - mf_head_origin - - mf_destination - - mf_int_order - - mf_reference_run - - mf_cal_fname - - mf_ctc_fname - - mf_mc - - mf_mc_t_step_min - - mf_mc_t_window - - mf_mc_gof_limit - - mf_mc_dist_limit - - mf_filter_chpi diff --git a/docs/source/settings/preprocessing/resample.md b/docs/source/settings/preprocessing/resample.md deleted file mode 100644 index 6aa824a6c..000000000 --- a/docs/source/settings/preprocessing/resample.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -tags: - - preprocessing - - resampling - - decimation - - raw - - epochs ---- - -If you have acquired data with a very high sampling frequency (e.g. 2 kHz) -you will likely want to downsample to lighten up the size of the files you -are working with (pragmatics) -If you are interested in typical analysis (up to 120 Hz) you can typically -resample your data down to 500 Hz without preventing reliable time-frequency -exploration of your data. - -::: mne_bids_pipeline._config - options: - members: - - raw_resample_sfreq - - epochs_decim diff --git a/docs/source/settings/preprocessing/ssp_ica.md b/docs/source/settings/preprocessing/ssp_ica.md deleted file mode 100644 index b132ef4bf..000000000 --- a/docs/source/settings/preprocessing/ssp_ica.md +++ /dev/null @@ -1,32 +0,0 @@ ---- -tags: - - preprocessing - - artifact-removal - - raw - - epochs - - ssp - - ica ---- - -::: mne_bids_pipeline._config - options: - members: - - spatial_filter - - min_ecg_epochs - - min_eog_epochs - - n_proj_eog - - n_proj_ecg - - ssp_meg - - ecg_proj_from_average - - eog_proj_from_average - - ssp_reject_eog - - ssp_reject_ecg - - ssp_ecg_channel - - ica_reject - - ica_algorithm - - ica_l_freq - - ica_max_iterations - - ica_n_components - - ica_decim - - ica_ctps_ecg_threshold - - ica_eog_threshold diff --git a/docs/source/settings/preprocessing/stim_artifact.md b/docs/source/settings/preprocessing/stim_artifact.md deleted file mode 100644 index cbc142550..000000000 --- a/docs/source/settings/preprocessing/stim_artifact.md +++ /dev/null @@ -1,19 +0,0 @@ ---- -tags: - - preprocessing - - artifact-removal - - raw - - epochs ---- - -When using electric stimulation systems, e.g. for median nerve or index -stimulation, it is frequent to have a stimulation artifact. This option -allows to fix it by linear interpolation early in the pipeline on the raw -data. - -::: mne_bids_pipeline._config - options: - members: - - fix_stim_artifact - - stim_artifact_tmin - - stim_artifact_tmax diff --git a/docs/source/settings/reports/report_generation.md b/docs/source/settings/reports/report_generation.md deleted file mode 100644 index 2ccf520ed..000000000 --- a/docs/source/settings/reports/report_generation.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -tags: - - report ---- - -::: mne_bids_pipeline._config - options: - members: - - report_evoked_n_time_points - - report_stc_n_time_points diff --git a/docs/source/settings/sensor/contrasts.md b/docs/source/settings/sensor/contrasts.md deleted file mode 100644 index 576e45ee3..000000000 --- a/docs/source/settings/sensor/contrasts.md +++ /dev/null @@ -1,11 +0,0 @@ ---- -tags: - - epochs - - evoked - - contrast ---- - -::: mne_bids_pipeline._config - options: - members: - - contrasts diff --git a/docs/source/settings/sensor/group_level.md b/docs/source/settings/sensor/group_level.md deleted file mode 100644 index a330a9703..000000000 --- a/docs/source/settings/sensor/group_level.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -tags: - - evoked - - group-level ---- - -::: mne_bids_pipeline._config - options: - members: - - interpolate_bads_grand_average diff --git a/docs/source/settings/sensor/mvpa.md b/docs/source/settings/sensor/mvpa.md deleted file mode 100644 index 8131cd428..000000000 --- a/docs/source/settings/sensor/mvpa.md +++ /dev/null @@ -1,25 +0,0 @@ ---- -tags: - - epochs - - evoked - - contrast - - decoding ---- - -::: mne_bids_pipeline._config - options: - members: - - decode - - decoding_epochs_tmin - - decoding_epochs_tmax - - decoding_metric - - decoding_n_splits - - decoding_time_generalization - - decoding_time_generalization_decim - - decoding_csp - - decoding_csp_times - - decoding_csp_freqs - - n_boot - - cluster_forming_t_threshold - - cluster_n_permutations - - cluster_permutation_p_threshold diff --git a/docs/source/settings/sensor/time_freq.md b/docs/source/settings/sensor/time_freq.md deleted file mode 100644 index 492296dc0..000000000 --- a/docs/source/settings/sensor/time_freq.md +++ /dev/null @@ -1,18 +0,0 @@ ---- -tags: - - epochs - - evoked - - time-frequency ---- - -::: mne_bids_pipeline._config - options: - members: - - time_frequency_conditions - - time_frequency_freq_min - - time_frequency_freq_max - - time_frequency_cycles - - time_frequency_subtract_evoked - - time_frequency_baseline - - time_frequency_baseline_mode - - time_frequency_crop diff --git a/docs/source/settings/source/bem.md b/docs/source/settings/source/bem.md deleted file mode 100644 index f55972baf..000000000 --- a/docs/source/settings/source/bem.md +++ /dev/null @@ -1,16 +0,0 @@ ---- -tags: - - inverse-solution - - bem - - freesurfer ---- - -::: mne_bids_pipeline._config - options: - members: - - use_template_mri - - adjust_coreg - - bem_mri_images - - recreate_bem - - recreate_scalp_surface - - freesurfer_verbose diff --git a/docs/source/settings/source/forward.md b/docs/source/settings/source/forward.md deleted file mode 100644 index 8ce5c87ad..000000000 --- a/docs/source/settings/source/forward.md +++ /dev/null @@ -1,14 +0,0 @@ ---- -tags: - - inverse-solution - - forward-model ---- - -::: mne_bids_pipeline._config - options: - members: - - mri_t1_path_generator - - mri_landmarks_kind - - spacing - - mindist - - source_info_path_update diff --git a/docs/source/settings/source/general.md b/docs/source/settings/source/general.md deleted file mode 100644 index 09eac741f..000000000 --- a/docs/source/settings/source/general.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -tags: - - inverse-solution ---- - -::: mne_bids_pipeline._config - options: - members: - - run_source_estimation diff --git a/docs/source/settings/source/inverse.md b/docs/source/settings/source/inverse.md deleted file mode 100644 index 4a10f1aef..000000000 --- a/docs/source/settings/source/inverse.md +++ /dev/null @@ -1,14 +0,0 @@ ---- -tags: - - inverse-solution ---- - -::: mne_bids_pipeline._config - options: - members: - - loose - - depth - - inverse_method - - noise_cov - - source_info_path_update - - inverse_targets diff --git a/docs/source/v1.0.md.inc b/docs/source/v1.0.md.inc index c845f9f6a..9c4d8dd18 100644 --- a/docs/source/v1.0.md.inc +++ b/docs/source/v1.0.md.inc @@ -7,7 +7,7 @@ the release, so we are bumping the patch version number rather than the major version number. - The `N_JOBS` parameter has been renamed to - [`n_jobs`](mne_bids_pipeline._config.n_jobs) for consistency + [`n_jobs`][mne_bids_pipeline._config.n_jobs] for consistency (#694 by @larsoner) ### :bug: Bug fixes @@ -109,7 +109,7 @@ Changes were only tracked starting April 15, 2021. [`ssp_reject_ecg`][mne_bids_pipeline._config.ssp_reject_ecg]. (#392 by @agramfort, @dengemann, @apmellot and @hoechenberger) -- You can now use autoreject for exclusing artifacts before SSP estimation via +- You can now use autoreject for excluding artifacts before SSP estimation via the `autoreject_global` option in [`ssp_reject_eog`][mne_bids_pipeline._config.ssp_reject_eog] and [`ssp_reject_ecg`][mne_bids_pipeline._config.ssp_reject_ecg]. (#396 by @agramfort, @dengemann, diff --git a/docs/source/v1.4.md.inc b/docs/source/v1.4.md.inc index dfc79ed60..e2a207fac 100644 --- a/docs/source/v1.4.md.inc +++ b/docs/source/v1.4.md.inc @@ -1,10 +1,10 @@ -## v1.4.0 (unreleased) +## v1.4.0 (2023-07-04) ### :new: New features & enhancements - Add movement compensation and cHPI filtering to the Maxwell filtering step, along with additional configuration options (#747 by @larsoner) -- Add option to specify [`ssp_ecg_channel`]([mne_bids_pipeline._config.ssp_ecg_channel) to override the default value (#747 by @larsoner) -- Add option [`read_raw_bids_verbose`]([mne_bids_pipeline._config.read_raw_bids_verbose) to set the verbosity level when using `read_raw_bids` to suppress known warnings (#749 by @larsoner) +- Add option to specify [`ssp_ecg_channel`][mne_bids_pipeline._config.ssp_ecg_channel] to override the default value (#747 by @larsoner) +- Add option [`read_raw_bids_verbose`][mne_bids_pipeline._config.read_raw_bids_verbose] to set the verbosity level when using `read_raw_bids` to suppress known warnings (#749 by @larsoner) [//]: # (### :warning: Behavior changes) @@ -16,8 +16,8 @@ ### :bug: Bug fixes -- Fix bug when [`mf_reference_run != runs[0]`]([mne_bids_pipeline._config.mf_reference_run) (#742 by @larsoner) -- Fix bug with too many JSON files found during empty room matching (#743 by @allermat) -- Fix bug with outdated info on ch_types config option (#745 by @allermat) +- Fix bug when [`mf_reference_run != runs[0]`][mne_bids_pipeline._config.mf_reference_run] (#742 by @larsoner) +- Fix bug with too many JSON files found during empty-room discovery (#743 by @allermat) - Fix bug where SSP projectors were not added to the report (#747 by @larsoner) -- Fix bug with documentation issue on data_type config option (#751 by @allermat) \ No newline at end of file +- Fix documentation of `data_type` configuration option (#751 by @allermat) +- Fix documentation of `ch_types` configuration option (#745 by @allermat) diff --git a/docs/source/v1.5.md.inc b/docs/source/v1.5.md.inc new file mode 100644 index 000000000..6ef152c1e --- /dev/null +++ b/docs/source/v1.5.md.inc @@ -0,0 +1,58 @@ +## v1.5.0 (2023-11-30) + +This release contains a number of very important bug fixes that address problems related to decoding, time-frequency analysis, and inverse modeling. +All users are encouraged to update. + +### :new: New features & enhancements + +- Added `deriv_root` argument to CLI (#773 by @vferat) +- Added support for annotating bad segments based on head movement velocity (#757 by @larsoner) +- Added examples of T1 and FLASH BEM to website (#758 by @larsoner) +- Added support for extended SSS (eSSS) in Maxwell filtering (#762 by @larsoner) +- Output logging spacing improved (#764 by @larsoner) +- Added caching of sensor and source average steps (#765 by @larsoner) +- Improved logging of coregistration distances (#769 by @larsoner) +- Input validation has been improved by leveraging [pydantic](https://docs.pydantic.dev) (#779 by @larsoner) +- Reduced logging when reports are created and saved (#799 by @hoechenberger) +- Added [`"picard-extended_infomax"`][mne_bids_pipeline._config.ica_algorithm] ICA algorithm to perform "extended Infomax"-like ICA decomposition using Picard (#801 by @hoechenberger) +- Added support for using "local" [`autoreject`](https://autoreject.github.io) to find (and repair) bad channels on a + per-epoch basis as the last preprocessing step; this can be enabled by setting [`reject`][mne_bids_pipeline._config.reject] + to `"autoreject_local"`. The behavior can further be controlled via the new setting + [`autoreject_n_interpolate`][mne_bids_pipeline._config.autoreject_n_interpolate]. (#807 by @hoechenberger) +- Added support for "local" [`autoreject`](https://autoreject.github.io) to remove bad epochs + before submitting the data to ICA fitting. This can be enabled by setting [`ica_reject`][mne_bids_pipeline._config.ica_reject] + to `"autoreject_local"`. (#810, #816 by @hoechenberger) +- The new setting [`decoding_which_epochs`][mne_bids_pipeline._config.decoding_which_epochs] controls which epochs (e.g., uncleaned, after ICA/SSP, cleaned) shall be used for decoding. (#819 by @hoechenberger) +- Website documentation tables can now be sorted (e.g., to find examples that use a specific feature) (#808 by @larsoner) + +### :warning: Behavior changes + +- The default cache directory is now `_cache` within the derivatives folder when using `memory_location=True`, set [`memory_subdir="joblib"`][mne_bids_pipeline._config.memory_subdir] to get the behavior from v1.4 (#778 by @larsoner) +- Before cleaning epochs via ICA, we used to reject any epochs execeeding the [`ica_reject`][mne_bids_pipeline._config.ica_reject] + criteria. However, this may lead to the unnecessary exclusion of epochs that could have been salvaged through ICA cleaning. Now, + we only apply `ica_reject` to the epochs used for ICA fitting. After the experimental epochs have been cleaned with ICA + (`preprocessing/apply_ica` step), any remaining large-amplitude artifacts can be removed via + [`reject`][mne_bids_pipeline._config.reject], which is used in the last preprocessing step, `preprocessing/ptp_reject`. (#806 by @hoechenberger) +- MVPA / decoding used to be performed on un-cleaned epochs in the past. Now, cleaned epochs will be used by default (please also see the "Bug fixes" section below). (#796 by @hoechenberger) + +[//]: # (- Whatever (#000 by @whoever)) + +### :medical_symbol: Code health + +- Fixed doc build errors and dependency specifications (#755 by @larsoner) +- Ensure `memory_file_method = "hash"` is tested (#768 by @larsoner) +- Enable [pre-commit.ci](https://pre-commit.ci) (#774 by @larsoner) +- Use `pooch` for web downloads (#775 by @larsoner) +- Ensure compatibility with MNE-Python 1.6 (#800 by @hoechenberger) +- Updated testing dataset for ds004229 v1.0.3 (#808 by @larsoner) + +### :bug: Bug fixes + +- Fixed bug where cache would not invalidate properly based on output file changes and steps could be incorrectly skipped. All steps will automatically rerun to accommodate the new, safer caching scheme (#756 by @larsoner) +- Fixed bug with parallelization across runs for Maxwell filtering (#761 by @larsoner) +- Fixed bug where head position files were not written with a proper suffix and extension (#761 by @larsoner) +- Fixed bug where default values for `decoding_csp_times` and `decoding_csp_freqs` were not set dynamically based on the config parameters (#779 by @larsoner) +- Fixed bug where the MNE logger verbosity was not respected inside parallel jobs (#813 by @larsoner) +- A number of processing steps erroneously **always** operated on un-cleaned epochs (`sensor/decoding_full_epochs`, `sensor/decoding_time_by_time`, `sensor/decoding_csp`); or operated on un-cleaned epochs (without PTP rejection) if no ICA or SSP was requested (`sensor/ime_frequency`, `sensor/make_cov`) The bug in `sensor/make_cov` could propagate to the source level, as the covariance matrix is used for inverse modeling. (#796 by @hoechenberger) +- Bad channels may have been submitted to MVPA (full epochs decoding, time-by-time decoding, CSP-based decoding) when not using Maxwell filtering + (i.e., usually only EEG data was affected). This has now been fixed and data from bad channels is omitted from decoding. (#817 by @hoechenberger) diff --git a/docs/source/v1.6.md.inc b/docs/source/v1.6.md.inc new file mode 100644 index 000000000..b29871c11 --- /dev/null +++ b/docs/source/v1.6.md.inc @@ -0,0 +1,44 @@ +## v1.6.0 (2024-03-01) + +:new: New features & enhancements + +- Added [`regress_artifact`][mne_bids_pipeline._config.regress_artifact] to allow artifact regression (e.g., of MEG reference sensors in KIT systems) (#837 by @larsoner) +- Chosen `reject` parameters are now saved in the generated HTML reports (#839 by @larsoner) +- Added saving of clean raw data in addition to epochs (#840 by @larsoner) +- Added saving of detected blink and cardiac events used to calculate SSP projectors (#840 by @larsoner) +- Added [`noise_cov_method`][mne_bids_pipeline._config.noise_cov_method] to allow for the use of methods other than `"shrunk"` for noise covariance estimation (#854 by @larsoner) +- Added option to pass `image_kwargs` to [`mne.Report.add_epochs`] to allow adjusting e.g. `"vmin"` and `"vmax"` of the epochs image in the report via [`report_add_epochs_image_kwargs`][mne_bids_pipeline._config.report_add_epochs_image_kwargs]. This feature requires MNE-Python 1.7 or newer. (#848 by @SophieHerbst) +- Split ICA fitting and artifact detection into separate steps. This means that now, ICA is split into a total of three consecutive steps: fitting, artifact detection, and the actual data cleaning step ("applying ICA"). This makes it easier to experiment with different settings for artifact detection without needing to re-fit ICA. (#865 by @larsoner) +- The configuration used for the pipeline is now saved in a separate spreadsheet in the `.xlsx` log file (#869 by @larsoner) + +[//]: # (### :warning: Behavior changes) + +[//]: # (- Whatever (#000 by @whoever)) + +### :package: Requirements + +- MNE-BIDS-Pipeline now requires Python 3.9 or newer. (#825 by @hoechenberger) + +### :bug: Bug fixes + +- Fixed minor issues with path handling for cross-talk and calibration files (#834 by @larsoner) +- Fixed EEG `reject` use for `ch_types = ["meg", "eeg"]` in epoch cleaning (#839 by @larsoner) +- Fixed bug where implicit `mf_reference_run` could change across invocations of `mne_bids_pipeline`, breaking caching (#839 by @larsoner) +- Fixed `--no-cache` behavior having no effect (#839 by @larsoner) +- Fixed Maxwell filtering failures when [`find_noisy_channels_meg = False`][mne_bids_pipeline._config.find_noisy_channels_meg]` is used (#847 by @larsoner) +- Fixed raw, empty-room, and custom noise covariances calculation, previously they could errantly be calculated on data without ICA or SSP applied (#840 by @larsoner) +- Fixed multiple channel type handling (e.g., MEG and EEG) in decoding (#853 by @larsoner) +- Changed the default for [`ica_n_components`][mne_bids_pipeline._config.ica_n_components] from `0.8` (too conservative) to `None` to match MNE-Python's default (#853 by @larsoner) +- Prevent events table for the average subject overflowing in reports (#854 by @larsoner) +- Fixed split file behavior for Epochs when using ICA (#855 by @larsoner) +- Fixed a bug where users could not set `_components.tsv` as it would be detected as a cache miss and overwritten on next pipeline run (#865 by @larsoner) + +### :medical_symbol: Code health + +- The package build backend has been switched from `setuptools` to `hatchling`. (#825 by @hoechenberger) +- Removed dependencies on `datalad` and `git-annex` for testing (#867 by @larsoner) +- Code formatting now uses `ruff format` instead of `black` (#834, #838 by @larsoner) +- Code caching is now tested using GitHub Actions (#836 by @larsoner) +- Steps in the documentation are now automatically parsed into flowcharts (#859 by @larsoner) +- New configuration options are now automatically added to the docs (#863 by @larsoner) +- Configuration options now have relevant steps listed in the docs (#866 by @larsoner) diff --git a/docs/source/v1.7.md.inc b/docs/source/v1.7.md.inc new file mode 100644 index 000000000..3db0f1dfd --- /dev/null +++ b/docs/source/v1.7.md.inc @@ -0,0 +1,29 @@ +## v1.7.0 (2024-03-13) + +### :new: New features & enhancements + +- Improved logging message during cache invalidation: We now print the selected + [`memory_file_method`][mne_bids_pipeline._config.memory_file_method] ("hash" or "mtime"). + Previously, we'd always print "hash". (#876 by @hoechenberger) + +[//]: # (- Whatever (#000 by @whoever)) + +[//]: # (### :warning: Behavior changes) + +[//]: # (- Whatever (#000 by @whoever)) + +[//]: # (### :package: Requirements) + +[//]: # (- Whatever (#000 by @whoever)) + +### :bug: Bug fixes + +- Fixed an error when using [`analyze_channels`][mne_bids_pipeline._config.analyze_channels] with EEG data, where e.g. ERP creation didn't work. (#883 by @hoechenberger) + +[//]: # (- Whatever (#000 by @whoever)) + +### :medical_symbol: Code health + +- We enabled stricter linting to guarantee a consistently high code quality! (#872 by @hoechenberger) + +[//]: # (- Whatever (#000 by @whoever)) diff --git a/docs/source/v1.8.md.inc b/docs/source/v1.8.md.inc new file mode 100644 index 000000000..d4bb7f867 --- /dev/null +++ b/docs/source/v1.8.md.inc @@ -0,0 +1,27 @@ +## v1.8.0 (2024-03-20) + +### :new: New features & enhancements + +- Disabling CSP time-frequency mode is now supported by passing an empty list to [`decoding_csp_times`][mne_bids_pipeline._config.decoding_csp_times] (#890 by @whoever) + +[//]: # (### :warning: Behavior changes) + +[//]: # (- Whatever (#000 by @whoever)) + +### :package: Requirements + +- MNE-BIDS-Pipeline now explicitly depends on `annotated-types` (#886 by @hoechenberger) + +[//]: # (- Whatever (#000 by @whoever)) + +### :bug: Bug fixes + +- Fix handling of Maxwell-filtered data in CSP (#890 by @larsoner) +- Avoid recomputation / cache miss when the same empty-room file is matched to multiple subjects (#890 by @larsoner) + +### :medical_symbol: Code health + +- We removed the unused settings `shortest_event` and `study_name`. They were relics of early days of the pipeline + and haven't been in use for a long time. (#888, #889 by @hoechenberger and @larsoner) + +[//]: # (- Whatever (#000 by @whoever)) diff --git a/docs/source/v1.9.md.inc b/docs/source/v1.9.md.inc new file mode 100644 index 000000000..55e9f5ac8 --- /dev/null +++ b/docs/source/v1.9.md.inc @@ -0,0 +1,59 @@ +## v1.9.0 + +### :new: New features & enhancements + +- Added number of subject to `sub-average` report (#902, #910 by @SophieHerbst) +- The type annotations in the default configuration file are now easier to read: We + replaced `Union[X, Y]` with `X | Y` and `Optional[X]` with `X | None`. (#908, #911 by @hoechenberger) +- Epochs metadata creation now supports variable time windows by specifying the names of events via + [`epochs_metadata_tmin`][mne_bids_pipeline._config.epochs_metadata_tmin] and + [`epochs_metadata_tmax`][mne_bids_pipeline._config.epochs_metadata_tmax]. (#873 by @hoechenberger) +- If you requested processing of non-existing subjects, we will now provide a more helpful error message. (#928 by @hoechenberger) +- We improved the logging output for automnated epochs rejection and cleaning via ICA and SSP. (#936, #937 by @hoechenberger) +- ECG and EOG signals created during ICA artifact detection are now saved to disk. (#938 by @hoechenberger) + +### :warning: Behavior changes + +- All ICA HTML reports have been consolidated in the standard subject `*_report.html` + file instead of producing separate files. (#899 by @larsoner) +- Changed default for `source_info_path_update` to `None`. In `_04_make_forward.py` + and `_05_make_inverse.py`, we retrieve the info from the file from which + the `noise_cov` is computed. (#919 by @SophieHerbst) +- The [`depth`][mne_bids_pipeline._config.depth] parameter doesn't accept `None` + anymore. Please use `0` instead. (#915 by @hoechenberger) +- When using automated bad channel detection, now indicate the generated `*_bads.tsv` files whether a channel + had previously already been marked as bad in the dataset. Resulting entries in the TSV file may now look like: + `"pre-existing (before MNE-BIDS-pipeline was run) & auto-noisy"` (previously: only `"auto-noisy"`). (#930 by @hoechenberger) +- The `ica_ctps_ecg_threshold` has been renamed to [`ica_ecg_threshold`][mne_bids_pipeline._config.ica_ecg_threshold]. (#935 by @hoechenberger) +- We changed the behavior when setting an EEG montage: + - When applying the montage, we now also check for channel aliases (e.g. `M1 -> TP9`). + - If the data contains a channel that is not present in the montage, we now abort with an exception (previously, we emitted a warning). + This is to prevent silent errors. To proceed in this situation, select a different montage, or drop the respective channels via + the [`drop_channels`][mne_bids_pipeline._config.drop_channels] configuration option. (#960 by @hoechenberger) + +### :package: Requirements + +- The minimum required version of MNE-Python is now 1.7.0. +- We dropped support for Python 3.9. You now need Python 3.10 or newer. (#908 by @hoechenberger) + +### :book: Documentation + +- We removed the `Execution` section from configuration options documentation and + replaced it with new, more explicit sections (namely, Caching, Parallelization, + Logging, and Error handling), and enhanced documentation. (#914 by @hoechenberger, #916 by @SophieHerbst) + +### :bug: Bug fixes + +- When running the pipeline with [`find_bad_channels_meg`][mne_bids_pipeline._config. find_bad_channels_meg] enabled, + then disabling it and running the pipeline again, the pipeline would incorrectly still use automatically detected + bad channels from the first pipeline run. Now, we ensure that the original bad channels would be used and the + related section is removed from the report in this case. (#902 by @larsoner) +- Fixed group-average decoding statistics were not updated in some cases, even if relevant configuration options had been changed. (#902 by @larsoner) +- Fixed a compatibility bug with joblib 1.4.0. (#899 by @larsoner) +- Fixed how "original" raw data is included in the report. Previously, bad channels, subject, and experimenter name would not + be displayed correctly. (#930 by @hoechenberger) +- In the report's table of contents, don't put the run numbers in quotation marks. (#933 by @hoechenberger) + +### :medical_symbol: Code health and infrastructure + +- Use GitHub's `dependabot` service to automatically keep GitHub Actions up-to-date. (#893 by @hoechenberger) diff --git a/docs/source/vX.Y.md.inc b/docs/source/vX.Y.md.inc new file mode 100644 index 000000000..36bf65f57 --- /dev/null +++ b/docs/source/vX.Y.md.inc @@ -0,0 +1,23 @@ +[//]: # (Don't forget to add this to changes.md as an include!) + +## vX.Y.0 (unreleased) + +[//]: # (### :new: New features & enhancements) + +[//]: # (- Whatever (#000 by @whoever)) + +[//]: # (### :warning: Behavior changes) + +[//]: # (- Whatever (#000 by @whoever)) + +[//]: # (### :package: Requirements) + +[//]: # (- Whatever (#000 by @whoever)) + +[//]: # (### :bug: Bug fixes) + +[//]: # (- Whatever (#000 by @whoever)) + +[//]: # (### :medical_symbol: Code health) + +[//]: # (- Whatever (#000 by @whoever)) diff --git a/docs/templates/python/material/attribute.html.jinja b/docs/templates/python/material/attribute.html.jinja new file mode 100644 index 000000000..caad07cbf --- /dev/null +++ b/docs/templates/python/material/attribute.html.jinja @@ -0,0 +1,101 @@ +{# Modified from https://github.com/mkdocstrings/python/blob/master/src/mkdocstrings_handlers/python/templates/material/_base/attribute.html #} +{# Updated 2024/05/20. See "START NEW CODE" for block that is new. #} + +{{ log.debug("Rendering " + attribute.path) }} + +
+{% with html_id = attribute.path %} + + {% if root %} + {% set show_full_path = config.show_root_full_path %} + {% set root_members = True %} + {% elif root_members %} + {% set show_full_path = config.show_root_members_full_path or config.show_object_full_path %} + {% set root_members = False %} + {% else %} + {% set show_full_path = config.show_object_full_path %} + {% endif %} + + {% if not root or config.show_root_heading %} + + {% filter heading(heading_level, + role="data" if attribute.parent.kind.value == "module" else "attr", + id=html_id, + class="doc doc-heading", + toc_label=attribute.name) %} + + {% block heading scoped %} + {% if config.separate_signature %} + {% if show_full_path %}{{ attribute.path }}{% else %}{{ attribute.name }}{% endif %} + {% else %} + {% filter highlight(language="python", inline=True) %} + {% if show_full_path %}{{ attribute.path }}{% else %}{{ attribute.name }}{% endif %} + {% if attribute.annotation %}: {{ attribute.annotation }}{% endif %} + {% if attribute.value %} = {{ attribute.value }}{% endif %} + {% endfilter %} + {% endif %} + {% endblock heading %} + + {% block labels scoped %} + {% with labels = attribute.labels %} + {% include "labels.html" with context %} + {% endwith %} + {% endblock labels %} + + {% endfilter %} + + {% block signature scoped %} + {% if config.separate_signature %} + {% filter highlight(language="python", inline=False) %} + {% filter format_code(config.line_length) %} + {% if show_full_path %}{{ attribute.path }}{% else %}{{ attribute.name }}{% endif %} + {% if attribute.annotation %}: {{ attribute.annotation|safe }}{% endif %} + {% if attribute.value %} = {{ attribute.value|safe }}{% endif %} + {% endfilter %} + {% endfilter %} + {% endif %} + {% endblock signature %} + + {% else %} + {% if config.show_root_toc_entry %} + {% filter heading(heading_level, + role="data" if attribute.parent.kind.value == "module" else "attr", + id=html_id, + toc_label=attribute.path if config.show_root_full_path else attribute.name, + hidden=True) %} + {% endfilter %} + {% endif %} + {% set heading_level = heading_level - 1 %} + {% endif %} + +
+ {% block contents scoped %} + {% block docstring scoped %} + {% with docstring_sections = attribute.docstring.parsed %} + {% include "docstring.html" with context %} + {% endwith %} + {% endblock docstring %} + {% endblock contents %} + + {# START NEW CODE #} + {% if pipeline_steps(attribute.name) %} + {# https://squidfunk.github.io/mkdocs-material/reference/admonitions/#collapsible-blocks #} +
+ Pipeline steps using this setting +

+ The following steps are directly affected by changes to + {{ attribute.name }}: +

+
    + {% for step in pipeline_steps(attribute.name) %} +
  • {{ step }}
  • + {% endfor %} +
+
+ {% endif %} + {# END NEW CODE #} + +
+ +{% endwith %} +
diff --git a/ignore_words.txt b/ignore_words.txt index e69de29bb..1f7391f92 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -0,0 +1 @@ +master diff --git a/mne_bids_pipeline/__init__.py b/mne_bids_pipeline/__init__.py index 39d9f4177..2474edb8a 100644 --- a/mne_bids_pipeline/__init__.py +++ b/mne_bids_pipeline/__init__.py @@ -1,7 +1,7 @@ -from importlib.metadata import version, PackageNotFoundError +from importlib.metadata import PackageNotFoundError, version try: __version__ = version("mne_bids_pipeline") -except PackageNotFoundError: +except PackageNotFoundError: # pragma: no cover # package is not installed __version__ = "0.0.0" diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 9250489ea..aba6fd2ce 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1,32 +1,23 @@ # Default settings for data processing and analysis. -from typing import Optional, Union, Iterable, List, Tuple, Dict, Callable, Literal +from collections.abc import Callable, Sequence +from typing import Annotated, Any, Literal -from numpy.typing import ArrayLike - -import mne +from annotated_types import Ge, Interval, Len, MinLen +from mne import Covariance from mne_bids import BIDSPath -import numpy as np - -from mne_bids_pipeline.typing import PathLike, ArbitraryContrast - - -############################################################################### -# Config parameters -# ----------------- -study_name: str = "" -""" -Specify the name of your study. It will be used to populate filenames for -saving the analysis results. +from mne_bids_pipeline.typing import ( + ArbitraryContrast, + DigMontageType, + FloatArrayLike, + PathLike, +) -???+ example "Example" - ```python - study_name = 'my-study' - ``` -""" +# %% +# # General settings -bids_root: Optional[PathLike] = None +bids_root: PathLike | None = None """ Specify the BIDS root directory. Pass an empty string or ```None` to use the value specified in the `BIDS_ROOT` environment variable instead. @@ -39,7 +30,7 @@ ``` """ -deriv_root: Optional[PathLike] = None +deriv_root: PathLike | None = None """ The root of the derivatives directory in which the pipeline will store the processing results. If `None`, this will be @@ -50,7 +41,7 @@ set [`subjects_dir`][mne_bids_pipeline._config.subjects_dir] as well. """ -subjects_dir: Optional[PathLike] = None +subjects_dir: PathLike | None = None """ Path to the directory that contains the FreeSurfer reconstructions of all subjects. Specifically, this defines the `SUBJECTS_DIR` that is used by @@ -82,7 +73,7 @@ Enabling interactive mode deactivates parallel processing. """ -sessions: Union[List, Literal["all"]] = "all" +sessions: list | Literal["all"] = "all" """ The sessions to process. If `'all'`, will process all sessions found in the BIDS dataset. @@ -93,13 +84,18 @@ The task to process. """ -runs: Union[Iterable, Literal["all"]] = "all" +task_is_rest: bool = False +""" +Whether the task should be treated as resting-state data. +""" + +runs: Sequence | Literal["all"] = "all" """ The runs to process. If `'all'`, will process all runs found in the BIDS dataset. """ -exclude_runs: Optional[Dict[str, List[str]]] = None +exclude_runs: dict[str, list[str]] | None = None """ Specify runs to exclude from analysis, for each participant individually. @@ -115,42 +111,34 @@ did not understand the instructions, etc.). """ -crop_runs: Optional[Tuple[float, float]] = None +crop_runs: tuple[float, float] | None = None """ Crop the raw data of each run to the specified time interval `[tmin, tmax]`, in seconds. The runs will be cropped before Maxwell or frequency filtering is applied. If `None`, do not crop the data. """ -acq: Optional[str] = None +acq: str | None = None """ The BIDS `acquisition` entity. """ -proc: Optional[str] = None +proc: str | None = None """ The BIDS `processing` entity. """ -rec: Optional[str] = None +rec: str | None = None """ The BIDS `recording` entity. """ -space: Optional[str] = None +space: str | None = None """ The BIDS `space` entity. """ -plot_psd_for_runs: Union[Literal["all"], Iterable[str]] = "all" -""" -For which runs to add a power spectral density (PSD) plot to the generated -report. This can take a considerable amount of time if you have many long -runs. In this case, specify the runs, or pass an empty list to disable raw PSD -plotting. -""" - -subjects: Union[Iterable[str], Literal["all"]] = "all" +subjects: Sequence[str] | Literal["all"] = "all" """ Subjects to analyze. If `'all'`, include all subjects. To only include a subset of subjects, pass a list of their identifiers. Even @@ -170,7 +158,7 @@ ``` """ -exclude_subjects: Iterable[str] = [] +exclude_subjects: Sequence[str] = [] """ Specify subjects to exclude from analysis. The MEG empty-room mock-subject is automatically excluded from regular analysis. @@ -200,7 +188,7 @@ covariance (via `noise_cov='rest'`). """ -ch_types: Iterable[Literal["meg", "mag", "grad", "eeg"]] = [] +ch_types: Annotated[Sequence[Literal["meg", "mag", "grad", "eeg"]], Len(1, 4)] = [] """ The channel types to consider. @@ -217,7 +205,7 @@ ``` """ -data_type: Optional[Literal["meg", "eeg"]] = None +data_type: Literal["meg", "eeg"] | None = None """ The BIDS data type. @@ -251,7 +239,7 @@ ``` """ -eog_channels: Optional[Iterable[str]] = None +eog_channels: Sequence[str] | None = None """ Specify EOG channels to use, or create virtual EOG channels. @@ -286,7 +274,7 @@ ``` """ -eeg_bipolar_channels: Optional[Dict[str, Tuple[str, str]]] = None +eeg_bipolar_channels: dict[str, tuple[str, str]] | None = None """ Combine two channels into a bipolar channel, whose signal is the **difference** between the two combined channels, and add it to the data. @@ -319,7 +307,7 @@ ``` """ -eeg_reference: Union[Literal["average"], str, Iterable["str"]] = "average" +eeg_reference: Literal["average"] | str | Sequence["str"] = "average" """ The EEG reference to use. If `average`, will use the average reference, i.e. the average across all channels. If a string, must be the name of a single @@ -342,7 +330,7 @@ ``` """ -eeg_template_montage: Optional[str] = None +eeg_template_montage: str | DigMontageType | None = None """ In situations where you wish to process EEG data and no individual digitization points (measured channel locations) are available, you can apply @@ -358,6 +346,13 @@ You can find an overview of supported template montages at https://mne.tools/stable/generated/mne.channels.make_standard_montage.html +!!! warning + If the data contains channel names that are not part of the template montage, the + pipeline run will fail with an error message. You must either pick a different + montage or remove those channels via + [`drop_channels`][mne_bids_pipeline._config.drop_channels] to continue. + + ???+ example "Example" Do not apply template montage: ```python @@ -370,11 +365,12 @@ ``` """ -drop_channels: Iterable[str] = [] +drop_channels: Sequence[str] = [] """ Names of channels to remove from the data. This can be useful, for example, if you have added a new bipolar channel via `eeg_bipolar_channels` and now wish -to remove the anode, cathode, or both. +to remove the anode, cathode, or both; or if your selected EEG template montage +doesn't contain coordinates for some channels. ???+ example "Example" Exclude channels `Fp1` and `Cz` from processing: @@ -383,9 +379,9 @@ ``` """ -analyze_channels: Union[ - Literal["all"], Literal["ch_types"], Iterable["str"] -] = "ch_types" +analyze_channels: Literal["all", "ch_types"] | Annotated[Sequence["str"], MinLen(1)] = ( + "ch_types" +) """ The names of the channels to analyze during ERP/ERF and time-frequency analysis steps. For certain paradigms, e.g. EEG ERP research, it is common to constrain @@ -416,7 +412,7 @@ ``` """ -read_raw_bids_verbose: Optional[Literal["error"]] = None +read_raw_bids_verbose: Literal["error"] | None = None """ Verbosity level to pass to `read_raw_bids(..., verbose=read_raw_bids_verbose)`. If you know your dataset will contain files that are not perfectly BIDS @@ -424,9 +420,26 @@ `'error'` to suppress warnings emitted by read_raw_bids. """ -############################################################################### -# BREAK DETECTION -# --------------- +plot_psd_for_runs: Literal["all"] | Sequence[str] = "all" +""" +For which runs to add a power spectral density (PSD) plot to the generated +report. This can take a considerable amount of time if you have many long +runs. In this case, specify the runs, or pass an empty list to disable raw PSD +plotting. +""" + +random_state: int | None = 42 +""" +You can specify the seed of the random number generator (RNG). +This setting is passed to the ICA algorithm and to the decoding function, +ensuring reproducible results. Set to `None` to avoid setting the RNG +to a defined state. +""" + +# %% +# # Preprocessing + +# ## Break detection find_breaks: bool = False """ @@ -525,10 +538,23 @@ ``` """ -############################################################################### -# MAXWELL FILTER PARAMETERS -# ------------------------- -# done in 01-import_and_maxfilter.py +# %% +# ## Bad channel detection +# +# !!! warning +# This functionality will soon be removed from the pipeline, and +# will be integrated into MNE-BIDS. +# +# "Bad", i.e. flat and overly noisy channels, can be automatically detected +# using a procedure inspired by the commercial MaxFilter by Elekta. First, +# a copy of the data is low-pass filtered at 40 Hz. Then, channels with +# unusually low variability are flagged as "flat", while channels with +# excessively high variability are flagged as "noisy". Flat and noisy channels +# are marked as "bad" and excluded from subsequent analysis. See +# :func:`mne.preprocssessing.find_bad_channels_maxwell` for more information +# on this procedure. The list of bad channels detected through this procedure +# will be merged with the list of bad channels already present in the dataset, +# if any. find_flat_channels_meg: bool = False """ @@ -541,6 +567,9 @@ Auto-detect "noisy" channels and mark them as bad. """ +# %% +# ## Maxwell filter + use_maxwell_filter: bool = False """ Whether or not to use Maxwell filtering to preprocess the data. @@ -553,7 +582,7 @@ before applying Maxwell filter. """ -mf_st_duration: Optional[float] = None +mf_st_duration: float | None = None """ There are two kinds of Maxwell filtering: SSS (signal space separation) and tSSS (temporal signal space separation) @@ -594,7 +623,7 @@ ``` """ -mf_head_origin: Union[Literal["auto"], ArrayLike] = "auto" +mf_head_origin: Literal["auto"] | FloatArrayLike = "auto" """ `mf_head_origin` : array-like, shape (3,) | 'auto' Origin of internal and external multipolar moment space in meters. @@ -609,7 +638,7 @@ ``` """ -mf_destination: Union[Literal["reference_run"], ArrayLike] = "reference_run" +mf_destination: Literal["reference_run"] | FloatArrayLike = "reference_run" """ Despite all possible care to avoid movements in the MEG, the participant will likely slowly drift down from the Dewar or slightly shift the head @@ -618,7 +647,7 @@ 1. Choose a reference run. Often one from the middle of the recording session is a good choice. Set `mf_destination = "reference_run" and then set - [`config.mf_reference_run`](mne_bids_pipeline._config.mf_reference_run). + [`config.mf_reference_run`][mne_bids_pipeline._config.mf_reference_run]. This will result in a device-to-head transformation that differs between subjects. 2. Choose a standard position in the MEG coordinate frame. For this, pass @@ -640,10 +669,10 @@ is expected. """ -mf_reference_run: Optional[str] = None +mf_reference_run: str | None = None """ Which run to take as the reference for adjusting the head position of all -runs when [`mf_destination="reference_run"`](mne_bids_pipeline._config.mf_destination). +runs when [`mf_destination="reference_run"`][mne_bids_pipeline._config.mf_destination]. If `None`, pick the first run. ???+ example "Example" @@ -652,7 +681,7 @@ ``` """ -mf_cal_fname: Optional[str] = None +mf_cal_fname: str | None = None """ !!! warning This parameter should only be used for BIDS datasets that don't store @@ -666,7 +695,7 @@ ``` """ # noqa : E501 -mf_ctc_fname: Optional[str] = None +mf_ctc_fname: str | None = None """ Path to the Maxwell Filter cross-talk file. If `None`, the recommended location is used. @@ -681,6 +710,16 @@ ``` """ # noqa : E501 +mf_esss: int = 0 +""" +Number of extended SSS (eSSS) basis projectors to use from empty-room data. +""" + +mf_esss_reject: dict[str, float] | None = None +""" +Rejection parameters to use when computing the extended SSS (eSSS) basis. +""" + mf_mc: bool = False """ If True, perform movement compensation on the data. @@ -691,7 +730,7 @@ Minimum time step to use during cHPI coil amplitude estimation. """ -mf_mc_t_window: Union[float, Literal["auto"]] = "auto" +mf_mc_t_window: float | Literal["auto"] = "auto" """ The window to use during cHPI coil amplitude estimation and in cHPI filtering. Can be "auto" to autodetect a reasonable value or a float (in seconds). @@ -707,66 +746,76 @@ Minimum distance (m) to accept for cHPI position fitting. """ -mf_filter_chpi: Optional[bool] = None -""" -Use mne.chpi.filter_chpi after Maxwell filtering. Can be None to use -the same value as [`mf_mc`][mne_bids_pipeline._config.mf_mc]. -Only used when [`use_maxwell_filter=True`][mne_bids_pipeline._config.use_maxwell_filter] -""" # noqa: E501 - -############################################################################### -# STIMULATION ARTIFACT -# -------------------- -# used in 01-import_and_maxfilter.py - -fix_stim_artifact: bool = False +mf_mc_rotation_velocity_limit: float | None = None """ -Apply interpolation to fix stimulation artifact. - -???+ example "Example" - ```python - fix_stim_artifact = False - ``` +The rotation velocity limit (degrees/second) to use when annotating +movement-compensated data. If `None`, no annotations will be added. """ -stim_artifact_tmin: float = 0.0 +mf_mc_translation_velocity_limit: float | None = None """ -Start time of the interpolation window in seconds. - -???+ example "Example" - ```python - stim_artifact_tmin = 0. # on stim onset - ``` +The translation velocity limit (meters/second) to use when annotating +movement-compensated data. If `None`, no annotations will be added. """ -stim_artifact_tmax: float = 0.01 +mf_filter_chpi: bool | None = None """ -End time of the interpolation window in seconds. +Use mne.chpi.filter_chpi after Maxwell filtering. Can be None to use +the same value as [`mf_mc`][mne_bids_pipeline._config.mf_mc]. +Only used when [`use_maxwell_filter=True`][mne_bids_pipeline._config.use_maxwell_filter] +""" # noqa: E501 -???+ example "Example" - ```python - stim_artifact_tmax = 0.01 # up to 10ms post-stimulation - ``` -""" +# ## Filtering & resampling -############################################################################### -# FREQUENCY FILTERING & RESAMPLING -# -------------------------------- -# done in 02-frequency_filter.py +# ### Filtering +# +# It is typically better to set your filtering properties on the raw data so +# as to avoid what we call border (or edge) effects. +# +# If you use this pipeline for evoked responses, you could consider +# a low-pass filter cut-off of h_freq = 40 Hz +# and possibly a high-pass filter cut-off of l_freq = 1 Hz +# so you would preserve only the power in the 1Hz to 40 Hz band. +# Note that highpass filtering is not necessarily recommended as it can +# distort waveforms of evoked components, or simply wash out any low +# frequency that can may contain brain signal. It can also act as +# a replacement for baseline correction in Epochs. See below. +# +# If you use this pipeline for time-frequency analysis, a default filtering +# could be a high-pass filter cut-off of l_freq = 1 Hz +# a low-pass filter cut-off of h_freq = 120 Hz +# so you would preserve only the power in the 1Hz to 120 Hz band. +# +# If you need more fancy analysis, you are already likely past this kind +# of tips! 😇 -l_freq: Optional[float] = None +l_freq: float | None = None """ The low-frequency cut-off in the highpass filtering step. Keep it `None` if no highpass filtering should be applied. """ -h_freq: Optional[float] = 40.0 +h_freq: float | None = 40.0 """ The high-frequency cut-off in the lowpass filtering step. Keep it `None` if no lowpass filtering should be applied. """ -notch_freq: Optional[Union[float, Iterable[float]]] = None +l_trans_bandwidth: float | Literal["auto"] = "auto" +""" +Specifies the transition bandwidth of the +highpass filter. By default it's `'auto'` and uses default MNE +parameters. +""" + +h_trans_bandwidth: float | Literal["auto"] = "auto" +""" +Specifies the transition bandwidth of the +lowpass filter. By default it's `'auto'` and uses default MNE +parameters. +""" + +notch_freq: float | Sequence[float] | None = None """ Notch filter frequency. More than one frequency can be supplied, e.g. to remove harmonics. Keep it `None` if no notch filter should be applied. @@ -785,31 +834,26 @@ ``` """ -l_trans_bandwidth: Union[float, Literal["auto"]] = "auto" -""" -Specifies the transition bandwidth of the -highpass filter. By default it's `'auto'` and uses default MNE -parameters. -""" - -h_trans_bandwidth: Union[float, Literal["auto"]] = "auto" -""" -Specifies the transition bandwidth of the -lowpass filter. By default it's `'auto'` and uses default MNE -parameters. -""" - notch_trans_bandwidth: float = 1.0 """ Specifies the transition bandwidth of the notch filter. The default is `1.`. """ -notch_widths: Optional[Union[float, Iterable[float]]] = None +notch_widths: float | Sequence[float] | None = None """ Specifies the width of each stop band. `None` uses the MNE default. """ -raw_resample_sfreq: Optional[float] = None +# ### Resampling +# +# If you have acquired data with a very high sampling frequency (e.g. 2 kHz) +# you will likely want to downsample to lighten up the size of the files you +# are working with (pragmatics) +# If you are interested in typical analysis (up to 120 Hz) you can typically +# resample your data down to 500 Hz without preventing reliable time-frequency +# exploration of your data. + +raw_resample_sfreq: float | None = None """ Specifies at which sampling frequency the data should be resampled. If `None`, then no resampling will be done. @@ -821,10 +865,6 @@ ``` """ -############################################################################### -# DECIMATION -# ---------- - epochs_decim: int = 1 """ Says how much to decimate data at the epochs level. @@ -843,9 +883,7 @@ """ -############################################################################### -# RENAME EXPERIMENTAL EVENTS -# -------------------------- +# ## Epoching rename_events: dict = dict() """ @@ -871,10 +909,6 @@ to only get a warning instead, or `'ignore'` to ignore it completely. """ -############################################################################### -# HANDLING OF REPEATED EVENTS -# --------------------------- - event_repeated: Literal["error", "drop", "merge"] = "error" """ How to handle repeated events. We call events "repeated" if more than one event @@ -890,25 +924,35 @@ April 1st, 2021. """ -############################################################################### -# EPOCHING -# -------- - -epochs_metadata_tmin: Optional[float] = None +epochs_metadata_tmin: float | str | list[str] | None = None """ -The beginning of the time window for metadata generation, in seconds, -relative to the time-locked event of the respective epoch. This may be less -than or larger than the epoch's first time point. If `None`, use the first -time point of the epoch. +The beginning of the time window used for epochs metadata generation. This setting +controls the `tmin` value passed to +[`mne.epochs.make_metadata`](https://mne.tools/stable/generated/mne.epochs.make_metadata.html). + +If a float, the time in seconds relative to the time-locked event of the respective +epoch. Negative indicate times before, positive values indicate times after the +time-locked event. + +If a string or a list of strings, the name(s) of events marking the start of time +window. + +If `None`, use the first time point of the epoch. + +???+ info + Note that `None` here behaves differently than `tmin=None` in + `mne.epochs.make_metadata`. To achieve the same behavior, pass the name(s) of the + time-locked events instead. + """ -epochs_metadata_tmax: Optional[float] = None +epochs_metadata_tmax: float | str | list[str] | None = None """ Same as `epochs_metadata_tmin`, but specifying the **end** of the time window for metadata generation. """ -epochs_metadata_keep_first: Optional[Iterable[str]] = None +epochs_metadata_keep_first: Sequence[str] | None = None """ Event groupings using hierarchical event descriptors (HEDs) for which to store the time of the **first** occurrence of any event of this group in a new column @@ -936,14 +980,14 @@ and `first_stimulus`. """ -epochs_metadata_keep_last: Optional[Iterable[str]] = None +epochs_metadata_keep_last: Sequence[str] | None = None """ Same as `epochs_metadata_keep_first`, but for keeping the **last** occurrence of matching event types. The columns indicating the event types will be named with a `last_` instead of a `first_` prefix. """ -epochs_metadata_query: Optional[str] = None +epochs_metadata_query: str | None = None """ A [metadata query][https://mne.tools/stable/auto_tutorials/epochs/30_epochs_metadata.html] specifying which epochs to keep. If the query fails because it refers to an @@ -956,7 +1000,7 @@ ``` """ # noqa: E501 -conditions: Optional[Union[Iterable[str], Dict[str, str]]] = None +conditions: Sequence[str] | dict[str, str] | None = None """ The time-locked events based on which to create evoked responses. This can either be name of the experimental condition as specified in the @@ -1008,23 +1052,18 @@ ``` """ -task_is_rest: bool = False -""" -Whether the task should be treated as resting-state data. -""" - -rest_epochs_duration: Optional[float] = None +rest_epochs_duration: float | None = None """ Duration of epochs in seconds. """ -rest_epochs_overlap: Optional[float] = None +rest_epochs_overlap: float | None = None """ Overlap between epochs in seconds. This is used if the task is `'rest'` and when the annotations do not contain any stimulation or behavior events. """ -baseline: Optional[Tuple[Optional[float], Optional[float]]] = (None, 0) +baseline: tuple[float | None, float | None] | None = (None, 0) """ Specifies which time interval to use for baseline correction of epochs; if `None`, no baseline correction is applied. @@ -1035,74 +1074,66 @@ ``` """ -contrasts: Iterable[Union[Tuple[str, str], ArbitraryContrast]] = [] -""" -The conditions to contrast via a subtraction of ERPs / ERFs. The list elements -can either be tuples or dictionaries (or a mix of both). Each element in the -list corresponds to a single contrast. - -A tuple specifies a one-vs-one contrast, where the second condition is -subtracted from the first. - -If a dictionary, must contain the following keys: - -- `name`: a custom name of the contrast -- `conditions`: the conditions to contrast -- `weights`: the weights associated with each condition. +# ## Artifact removal -Pass an empty list to avoid calculation of any contrasts. +# ### Stimulation artifact +# +# When using electric stimulation systems, e.g. for median nerve or index +# stimulation, it is frequent to have a stimulation artifact. This option +# allows to fix it by linear interpolation early in the pipeline on the raw +# data. -For the contrasts to be computed, the appropriate conditions must have been -epoched, and therefore the conditions should either match or be subsets of -`conditions` above. +fix_stim_artifact: bool = False +""" +Apply interpolation to fix stimulation artifact. ???+ example "Example" - Contrast the "left" and the "right" conditions by calculating - `left - right` at every time point of the evoked responses: ```python - contrasts = [('left', 'right')] # Note we pass a tuple inside the list! + fix_stim_artifact = False ``` +""" - Contrast the "left" and the "right" conditions within the "auditory" and - the "visual" modality, and "auditory" vs "visual" regardless of side: +stim_artifact_tmin: float = 0.0 +""" +Start time of the interpolation window in seconds. + +???+ example "Example" ```python - contrasts = [('auditory/left', 'auditory/right'), - ('visual/left', 'visual/right'), - ('auditory', 'visual')] + stim_artifact_tmin = 0. # on stim onset ``` +""" - Contrast the "left" and the "right" regardless of side, and compute an - arbitrary contrast with a gradient of weights: +stim_artifact_tmax: float = 0.01 +""" +End time of the interpolation window in seconds. + +???+ example "Example" ```python - contrasts = [ - ('auditory/left', 'auditory/right'), - { - 'name': 'gradedContrast', - 'conditions': [ - 'auditory/left', - 'auditory/right', - 'visual/left', - 'visual/right' - ], - 'weights': [-1.5, -.5, .5, 1.5] - } - ] + stim_artifact_tmax = 0.01 # up to 10ms post-stimulation ``` """ -############################################################################### -# ARTIFACT REMOVAL -# ---------------- -# -# You can choose between ICA and SSP to remove eye and heart artifacts. -# SSP: https://mne-tools.github.io/stable/auto_tutorials/plot_artifacts_correction_ssp.html?highlight=ssp # noqa -# ICA: https://mne-tools.github.io/stable/auto_tutorials/plot_artifacts_correction_ica.html?highlight=ica # noqa -# if you choose ICA, run steps 5a and 6a -# if you choose SSP, run steps 5b and 6b -# -# Currently you cannot use both. +# ### SSP, ICA, and artifact regression + +regress_artifact: dict[str, Any] | None = None +""" +Keyword arguments to pass to the `mne.preprocessing.EOGRegression` model used +in `mne.preprocessing.regress_artifact`. If `None`, no time-domain regression will +be applied. Note that any channels picked in `regress_artifact["picks_artifact"]` will +have the same time-domain filters applied to them as the experimental data. + +Artifact regression is applied before SSP or ICA. + +???+ example "Example" + For example, if you have MEG reference channel data recorded in three + miscellaneous channels, you could do: -spatial_filter: Optional[Literal["ssp", "ica"]] = None + ```python + regress_artifact = {"picks": "meg", "picks_artifact": ["MISC 001", "MISC 002", "MISC 003"]} + ``` +""" # noqa: E501 + +spatial_filter: Literal["ssp", "ica"] | None = None """ Whether to use a spatial filter to detect and remove artifacts. The BIDS Pipeline offers the use of signal-space projection (SSP) and independent @@ -1117,27 +1148,34 @@ EOG and ECG activity will be omitted during the signal reconstruction step in order to remove the artifacts. The ICA procedure can be configured in various ways using the configuration options you can find below. + +!!! warning "ICA requires manual intervention!" + After the automatic ICA component detection step, review each subject's + `*_report.html` report file check if the set of ICA components to be removed + is correct. Adjustments should be made to the `*_proc-ica_components.tsv` + file, which will then be used in the step that is applied during ICA. + + ICA component order can be considered arbitrary, so any time the ICA is + re-fit – i.e., if you change any parameters that affect steps prior to + ICA fitting – this file will need to be updated! """ -min_ecg_epochs: int = 5 +min_ecg_epochs: Annotated[int, Ge(1)] = 5 """ -Minimal number of ECG epochs needed to compute SSP or ICA rejection. +Minimal number of ECG epochs needed to compute SSP projectors. """ -min_eog_epochs: int = 5 +min_eog_epochs: Annotated[int, Ge(1)] = 5 """ -Minimal number of EOG epochs needed to compute SSP or ICA rejection. +Minimal number of EOG epochs needed to compute SSP projectors. """ - -# Rejection based on SSP -# ~~~~~~~~~~~~~~~~~~~~~~ -n_proj_eog: Dict[str, float] = dict(n_mag=1, n_grad=1, n_eeg=1) +n_proj_eog: dict[str, float] = dict(n_mag=1, n_grad=1, n_eeg=1) """ Number of SSP vectors to create for EOG artifacts for each channel type. """ -n_proj_ecg: Dict[str, float] = dict(n_mag=1, n_grad=1, n_eeg=1) +n_proj_ecg: dict[str, float] = dict(n_mag=1, n_grad=1, n_eeg=1) """ Number of SSP vectors to create for ECG artifacts for each channel type. """ @@ -1165,7 +1203,7 @@ `'separate'` otherwise. """ -ssp_reject_ecg: Optional[Union[Dict[str, float], Literal["autoreject_global"]]] = None +ssp_reject_ecg: dict[str, float] | Literal["autoreject_global"] | None = None """ Peak-to-peak amplitude limits of the ECG epochs to exclude from SSP fitting. This allows you to remove strong transient artifacts, which could negatively @@ -1183,7 +1221,7 @@ ``` """ -ssp_reject_eog: Optional[Union[Dict[str, float], Literal["autoreject_global"]]] = None +ssp_reject_eog: dict[str, float] | Literal["autoreject_global"] | None = None """ Peak-to-peak amplitude limits of the EOG epochs to exclude from SSP fitting. This allows you to remove strong transient artifacts, which could negatively @@ -1201,45 +1239,71 @@ ``` """ -ssp_ecg_channel: Optional[str] = None +ssp_ecg_channel: str | None = None """ Channel to use for ECG SSP. Can be useful when the autodetected ECG channel is not reliable. """ -# Rejection based on ICA -# ~~~~~~~~~~~~~~~~~~~~~~ -ica_reject: Optional[Dict[str, float]] = None -""" -Peak-to-peak amplitude limits to exclude epochs from ICA fitting. - -This allows you to remove strong transient artifacts, which could negatively -affect ICA performance. - -This will also be applied to ECG and EOG epochs created during preprocessing. - -The BIDS Pipeline will automatically try to detect EOG and ECG artifacts in -your data, and remove them. For this to work properly, it is recommended -to **not** specify rejection thresholds for EOG and ECG channels here – -otherwise, ICA won't be able to "see" these artifacts. - -If `None` (default), do not apply artifact rejection. If a dictionary, -manually specify peak-to-peak rejection thresholds (see examples). +ica_reject: dict[str, float] | Literal["autoreject_local"] | None = None +""" +Peak-to-peak amplitude limits to exclude epochs from ICA fitting. This allows you to +remove strong transient artifacts from the epochs used for fitting ICA, which could +negatively affect ICA performance. + +The parameter values are the same as for [`reject`][mne_bids_pipeline._config.reject], +but `"autoreject_global"` is not supported. `"autoreject_local"` here behaves +differently, too: it is only used to exclude bad epochs from ICA fitting; we do not +perform any interpolation. + +???+ info + We don't support `"autoreject_global"` here (as opposed to + [`reject`][mne_bids_pipeline._config.reject]) because in the past, we found that + rejection thresholds were too strict before running ICA, i.e., too many epochs + got rejected. `"autoreject_local"`, however, usually performed nicely. + The `autoreject` documentation + [recommends][https://autoreject.github.io/stable/auto_examples/plot_autoreject_workflow.html] + running local `autoreject` before and after ICA, which can be achieved by setting + both, `ica_reject` and [`reject`][mne_bids_pipeline._config.reject], to + `"autoreject_local"`. + +If passing a dictionary, the rejection limits will also be applied to the ECG and EOG +epochs created to find heart beats and ocular artifacts. + +???+ info + MNE-BIDS-Pipeline will automatically try to detect EOG and ECG artifacts in + your data, and remove them. For this to work properly, it is recommended + to **not** specify rejection thresholds for EOG and ECG channels here – + otherwise, ICA won't be able to "see" these artifacts. + +???+ info + This setting is applied only to the epochs that are used for **fitting** ICA. The + goal is to make it easier for ICA to produce a good decomposition. After fitting, + ICA is applied to the epochs to be analyzed, usually with one or more components + removed (as to remove artifacts). But even after ICA cleaning, some epochs may still + contain large-amplitude artifacts. Those epochs can then be rejected by using + the [`reject`][mne_bids_pipeline._config.reject] parameter. ???+ example "Example" ```python ica_reject = {'grad': 10e-10, 'mag': 20e-12, 'eeg': 400e-6} ica_reject = {'grad': 15e-10} - ica_reject = None # no rejection + ica_reject = None # no rejection before fitting ICA + ica_reject = "autoreject_global" # find global (per channel type) PTP thresholds before fitting ICA + ica_reject = "autoreject_local" # find local (per channel) thresholds and repair epochs before fitting ICA ``` -""" +""" # noqa: E501 -ica_algorithm: Literal["picard", "fastica", "extended_infomax"] = "picard" +ica_algorithm: Literal[ + "picard", "fastica", "extended_infomax", "picard-extended_infomax" +] = "picard" """ -The ICA algorithm to use. +The ICA algorithm to use. `"picard-extended_infomax"` operates `picard` such that the +generated ICA decomposition is identical to the one generated by the extended Infomax +algorithm (but may converge in less time). """ -ica_l_freq: Optional[float] = 1.0 +ica_l_freq: float | None = 1.0 """ The cutoff frequency of the high-pass filter to apply before running ICA. Using a relatively high cutoff like 1 Hz will remove slow drifts from the @@ -1273,7 +1337,7 @@ limit may be too low to achieve convergence. """ -ica_n_components: Optional[Union[float, int]] = 0.8 +ica_n_components: float | int | None = None """ MNE conducts ICA as a sort of a two-step procedure: First, a PCA is run on the data (trying to exclude zero-valued components in rank-deficient @@ -1292,12 +1356,13 @@ explained variance less than the value specified here will be passed to ICA. -If `None`, **all** principal components will be used. +If `None` (default), `0.999999` will be used to avoid issues when working with +rank-deficient data. This setting may drastically alter the time required to compute ICA. """ -ica_decim: Optional[int] = None +ica_decim: int | None = None """ The decimation parameter to compute ICA. If 5 it means that 1 every 5 sample is used by ICA solver. The higher the faster @@ -1305,9 +1370,20 @@ `1` or `None` to not perform any decimation. """ -ica_ctps_ecg_threshold: float = 0.1 +ica_use_ecg_detection: bool = True +""" +Whether to use the MNE ECG detection on the ICA components. +""" + +ica_ecg_threshold: float = 0.1 +""" +The cross-trial phase statistics (CTPS) threshold parameter used for detecting +ECG-related ICs. +""" + +ica_use_eog_detection: bool = True """ -The threshold parameter passed to `find_bads_ecg` method. +Whether to use the MNE EOG detection on the ICA components. """ ica_eog_threshold: float = 3.0 @@ -1317,93 +1393,225 @@ false-alarm rate increases dramatically. """ -# Rejection based on peak-to-peak amplitude -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -reject: Optional[Union[Dict[str, float], Literal["autoreject_global"]]] = None + +# From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 +ica_use_icalabel: bool = False +""" +Whether to use MNE-ICALabel to automatically label ICA components. Only available for +EEG data. +!!! info + Using MNE-ICALabel mandates that you also set: + ```python + eeg_reference = "average" + ica_l_freq = 1 + h_freq = 100 + ``` +""" + +icalabel_include: Annotated[Sequence[Literal["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"]], Len(1, 7)] = ["brain","other"] +""" +Which independent components (ICs) to keep based on the labels given by ICLabel. +Possible labels are "brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other". +""" + +# ### Amplitude-based artifact rejection +# +# ???+ info "Good Practice / Advice" +# Have a look at your raw data and train yourself to detect a blink, a heart +# beat and an eye movement. +# You can do a quick average of blink data and check what the amplitude looks +# like. + +reject: dict[str, float] | Literal["autoreject_global", "autoreject_local"] | None = ( + None +) """ Peak-to-peak amplitude limits to mark epochs as bad. This allows you to remove epochs with strong transient artifacts. -If `None` (default), do not apply artifact rejection. If a dictionary, -manually specify rejection thresholds (see examples). If -`'autoreject_global'`, use [`autoreject`](https://autoreject.github.io) to find -suitable "global" rejection thresholds for each channel type, i.e. `autoreject` -will generate a dictionary with (hopefully!) optimal thresholds for each -channel type. +!!! info + The rejection is performed **after** SSP or ICA, if any of those methods + is used. To reject epochs **before** fitting ICA, see the + [`ica_reject`][mne_bids_pipeline._config.ica_reject] setting. + +If `None` (default), do not apply artifact rejection. +If a dictionary, manually specify rejection thresholds (see examples). The thresholds provided here must be at least as stringent as those in [`ica_reject`][mne_bids_pipeline._config.ica_reject] if using ICA. In case of `'autoreject_global'`, thresholds for any channel that do not meet this requirement will be automatically replaced with those used in `ica_reject`. -!!! info - The rejection is performed **after** SSP or ICA, if any of those methods - is used. To reject epochs **before** fitting ICA, see the - [`ica_reject`][mne_bids_pipeline._config.ica_reject] setting. - -If `None` (default), do not apply automated rejection. If a dictionary, -manually specify rejection thresholds (see examples). If `'auto'`, use -[`autoreject`](https://autoreject.github.io) to find suitable "global" -rejection thresholds for each channel type, i.e. `autoreject` will generate -a dictionary with (hopefully!) optimal thresholds for each channel type. Note -that using `autoreject` can be a time-consuming process. +If `"autoreject_global"`, use [`autoreject`](https://autoreject.github.io) to find +suitable "global" rejection thresholds for each channel type, i.e., `autoreject` +will generate a dictionary with (hopefully!) optimal thresholds for each +channel type. -!!! info - `autoreject` basically offers two modes of operation: "global" and - "local". In "global" mode, it will try to estimate one rejection - threshold **per channel type.** In "local" mode, it will generate - thresholds **for each individual channel.** Currently, the BIDS Pipeline - only supports the "global" mode. +If `"autoreject_local"`, use "local" `autoreject` to detect (and potentially repair) bad +channels in each epoch. Use [`autoreject_n_interpolate`][mne_bids_pipeline._config.autoreject_n_interpolate] +to control how many channels are allowed to be bad before an epoch gets dropped. ???+ example "Example" ```python - reject = {'grad': 4000e-13, 'mag': 4e-12, 'eog': 150e-6} - reject = {'eeg': 100e-6, 'eog': 250e-6} + reject = {"grad": 4000e-13, 'mag': 4e-12, 'eog': 150e-6} + reject = {"eeg": 100e-6, "eog": 250e-6} reject = None # no rejection based on PTP amplitude + reject = "autoreject_global" # find global (per channel type) PTP thresholds + reject = "autoreject_local" # find local (per channel) thresholds and repair epochs ``` -""" +""" # noqa: E501 -reject_tmin: Optional[float] = None +reject_tmin: float | None = None """ Start of the time window used to reject epochs. If `None`, the window will -start with the first time point. +start with the first time point. Has no effect if +[`reject`][mne_bids_pipeline._config.reject] has been set to `"autoreject_local"`. + ???+ example "Example" ```python reject_tmin = -0.1 # 100 ms before event onset. ``` """ -reject_tmax: Optional[float] = None +reject_tmax: float | None = None """ End of the time window used to reject epochs. If `None`, the window will end -with the last time point. +with the last time point. Has no effect if +[`reject`][mne_bids_pipeline._config.reject] has been set to `"autoreject_local"`. + ???+ example "Example" ```python reject_tmax = 0.3 # 300 ms after event onset. ``` """ -############################################################################### -# DECODING -# -------- - -decode: bool = True -""" -Whether to perform decoding (MVPA) on the specified -[`contrasts`][mne_bids_pipeline._config.contrasts]. Classifiers will be trained -on entire epochs ("full-epochs decoding"), and separately on each time point -("time-by-time decoding"), trying to learn how to distinguish the contrasting -conditions. +autoreject_n_interpolate: FloatArrayLike = [4, 8, 16] """ +The maximum number of bad channels in an epoch that `autoreject` local will try to +interpolate. The optimal number among this list will be estimated using a +cross-validation procedure; this means that the more elements are provided here, the +longer the `autoreject` run will take. If the number of bad channels in an epoch +exceeds this value, the channels won't be interpolated and the epoch will be dropped. -decoding_epochs_tmin: Optional[float] = 0.0 -""" -The first time sample to use for full epochs decoding. By default it starts -at 0. If `None`,, it starts at the beginning of the epoch. -""" +!!! info + This setting only takes effect if [`reject`][mne_bids_pipeline._config.reject] has + been set to `"autoreject_local"`. -decoding_epochs_tmax: Optional[float] = None +!!! info + Channels marked as globally bad in the BIDS dataset (in `*_channels.tsv)`) will not + be considered (i.e., will remain marked as bad and not analyzed by autoreject). +""" + +# %% +# # Sensor-level analysis + +# ## Condition contrasts + +contrasts: Sequence[tuple[str, str] | ArbitraryContrast] = [] +""" +The conditions to contrast via a subtraction of ERPs / ERFs. The list elements +can either be tuples or dictionaries (or a mix of both). Each element in the +list corresponds to a single contrast. + +A tuple specifies a one-vs-one contrast, where the second condition is +subtracted from the first. + +If a dictionary, must contain the following keys: + +- `name`: a custom name of the contrast +- `conditions`: the conditions to contrast +- `weights`: the weights associated with each condition. + +Pass an empty list to avoid calculation of any contrasts. + +For the contrasts to be computed, the appropriate conditions must have been +epoched, and therefore the conditions should either match or be subsets of +`conditions` above. + +???+ example "Example" + Contrast the "left" and the "right" conditions by calculating + `left - right` at every time point of the evoked responses: + ```python + contrasts = [('left', 'right')] # Note we pass a tuple inside the list! + ``` + + Contrast the "left" and the "right" conditions within the "auditory" and + the "visual" modality, and "auditory" vs "visual" regardless of side: + ```python + contrasts = [('auditory/left', 'auditory/right'), + ('visual/left', 'visual/right'), + ('auditory', 'visual')] + ``` + + Contrast the "left" and the "right" regardless of side, and compute an + arbitrary contrast with a gradient of weights: + ```python + contrasts = [ + ('auditory/left', 'auditory/right'), + { + 'name': 'gradedContrast', + 'conditions': [ + 'auditory/left', + 'auditory/right', + 'visual/left', + 'visual/right' + ], + 'weights': [-1.5, -.5, .5, 1.5] + } + ] + ``` +""" + +# ## Decoding / MVPA + +decode: bool = True +""" +Whether to perform decoding (MVPA) on the specified +[`contrasts`][mne_bids_pipeline._config.contrasts]. Classifiers will be trained +on entire epochs ("full-epochs decoding"), and separately on each time point +("time-by-time decoding"), trying to learn how to distinguish the contrasting +conditions. +""" + +decoding_which_epochs: Literal["uncleaned", "after_ica", "after_ssp", "cleaned"] = ( + "cleaned" +) +""" +This setting controls which epochs will be fed into the decoding algorithms. + +!!! info + Decoding is a very powerful tool that often can deal with noisy data surprisingly + well. Depending on the specific type of data, artifacts, and analysis performed, + decoding performance may even improve with less pre-processed data, as + processing steps such as ICA or SSP often remove parts of the signal, too, in + addition to noise. By default, MNE-BIDS-Pipeline uses cleaned epochs for decoding, + but you may choose to use entirely uncleaned epochs, or epochs before the final + PTP-based rejection or Autoreject step. + +!!! info + No other sensor- and source-level processing steps will be affected by this setting + and use cleaned epochs only. + +If `"uncleaned"`, use the "raw" epochs before any ICA / SSP, PTP-based, or Autoreject +cleaning (epochs with the filename `*_epo.fif`, without a `proc-` part). + +If `"after_ica"` or `"after_ssp"`, use the epochs that were cleaned via ICA or SSP, but +before a followup cleaning through PTP-based rejection or Autorejct (epochs with the +filename `*proc-ica_epo.fif` or `*proc-ssp_epo.fif`). + +If `"cleaned"`, use the epochs after ICA / SSP and the following cleaning through +PTP-based rejection or Autoreject (epochs with the filename `*proc-clean_epo.fif`). +""" + +decoding_epochs_tmin: float | None = 0.0 +""" +The first time sample to use for full epochs decoding. By default it starts +at 0. If `None`,, it starts at the beginning of the epoch. Does not affect time-by-time +decoding. +""" + +decoding_epochs_tmax: float | None = None """ The last time sample to use for full epochs decoding. By default it is set to None so it ends at the end of the epoch. @@ -1418,9 +1626,9 @@ you don't need to be worried about **exactly** balancing class sizes. """ -decoding_n_splits: int = 5 +decoding_n_splits: Annotated[int, Ge(2)] = 5 """ -The number of folds (also called "splits") to use in the cross-validation +The number of folds (also called "splits") to use in the K-fold cross-validation scheme. """ @@ -1454,13 +1662,85 @@ resolution in the resulting matrix. """ +decoding_csp: bool = False +""" +Whether to run decoding via Common Spatial Patterns (CSP) analysis on the +data. CSP takes as input data covariances that are estimated on different +time and frequency ranges. This allows to obtain decoding scores defined over +time and frequency. +""" + +decoding_csp_times: FloatArrayLike | None = None +""" +The edges of the time bins to use for CSP decoding. +Must contain at least two elements. By default, 5 equally-spaced bins are +created across the non-negative time range of the epochs. +All specified time points must be contained in the epochs interval. +If an empty list, do not perform **time-frequency** analysis, and only run CSP on +**frequency** data. + +???+ example "Example" + Create 3 equidistant time bins (0–0.2, 0.2–0.4, 0.4–0.6 sec): + ```python + decoding_csp_times = np.linspace(start=0, stop=0.6, num=4) + ``` + Create 2 time bins of different durations (0–0.4, 0.4–0.6 sec): + ```python + decoding_csp_times = [0, 0.4, 0.6] + ``` +""" + +decoding_csp_freqs: dict[str, FloatArrayLike] | None = None +""" +The edges of the frequency bins to use for CSP decoding. + +This parameter must be a dictionary with: +- keys specifying the unique identifier or "name" to use for the frequency + range to be treated jointly during statistical testing (such as "alpha" or + "beta"), and +- values must be list-like objects containing at least two scalar values, + specifying the edges of the respective frequency bin(s), e.g., `[8, 12]`. + +Defaults to two frequency bins, one from +[`time_frequency_freq_min`][mne_bids_pipeline._config.time_frequency_freq_min] +to the midpoint between this value and +[`time_frequency_freq_max`][mne_bids_pipeline._config.time_frequency_freq_max]; +and the other from that midpoint to `time_frequency_freq_max`. +???+ example "Example" + Create two frequency bins, one for 4–8 Hz, and another for 8–14 Hz, which + will be clustered together during statistical testing (in the + time-frequency plane): + ```python + decoding_csp_freqs = { + 'custom_range': [4, 8, 14] + } + ``` + Create the same two frequency bins, but treat them separately during + statistical testing (i.e., temporal clustering only): + ```python + decoding_csp_freqs = { + 'theta': [4, 8], + 'alpha': [8, 14] + } + ``` + Create 5 equidistant frequency bins from 4 to 14 Hz: + ```python + decoding_csp_freqs = { + 'custom_range': np.linspace( + start=4, + stop=14, + num=5+1 # We need one more to account for the endpoint! + ) + } +""" + n_boot: int = 5000 """ The number of bootstrap resamples when estimating the standard error and confidence interval of the mean decoding scores. """ -cluster_forming_t_threshold: Optional[float] = None +cluster_forming_t_threshold: float | None = None """ The t-value threshold to use for forming clusters in the cluster-based permutation test run on the the time-by-time decoding scores. @@ -1479,7 +1759,7 @@ test to determine the significance of the decoding scores across participants. """ -cluster_permutation_p_threshold: float = 0.05 +cluster_permutation_p_threshold: Annotated[float, Interval(gt=0, lt=1)] = 0.05 """ The alpha level (p-value, p threshold) to use for rejecting the null hypothesis that the clusters show no significant difference between conditions. This is @@ -1490,28 +1770,9 @@ [`cluster_forming_t_threshold`][mne_bids_pipeline._config.cluster_forming_t_threshold]. """ -############################################################################### -# GROUP AVERAGE SENSORS -# --------------------- +# ## Time-frequency analysis -interpolate_bads_grand_average: bool = True -""" -Interpolate bad sensors in each dataset before calculating the grand -average. This parameter is passed to the `mne.grand_average` function via -the keyword argument `interpolate_bads`. It requires to have channel -locations set. - -???+ example "Example" - ```python - interpolate_bads_grand_average = True - ``` -""" - -############################################################################### -# TIME-FREQUENCY -# -------------- - -time_frequency_conditions: Iterable[str] = [] +time_frequency_conditions: Sequence[str] = [] """ The conditions to compute time-frequency decomposition on. @@ -1521,7 +1782,7 @@ ``` """ -time_frequency_freq_min: Optional[float] = 8 +time_frequency_freq_min: float | None = 8 """ Minimum frequency for the time frequency analysis, in Hz. ???+ example "Example" @@ -1530,7 +1791,7 @@ ``` """ -time_frequency_freq_max: Optional[float] = 40 +time_frequency_freq_max: float | None = 40 """ Maximum frequency for the time frequency analysis, in Hz. ???+ example "Example" @@ -1539,7 +1800,7 @@ ``` """ -time_frequency_cycles: Optional[Union[float, ArrayLike]] = None +time_frequency_cycles: float | FloatArrayLike | None = None """ The number of cycles to use in the Morlet wavelet. This can be a single number or one per frequency, where frequencies are calculated via @@ -1550,7 +1811,7 @@ time_frequency_subtract_evoked: bool = False """ -Whether to subtract the evoked signal (averaged across all epochs) from the +Whether to subtract the evoked response (averaged across all epochs) from the epochs before passing them to time-frequency analysis. Set this to `True` to highlight induced activity. @@ -1558,91 +1819,7 @@ This also applies to CSP analysis. """ -############################################################################### -# TIME-FREQUENCY CSP -# ------------------ - -decoding_csp: bool = False -""" -Whether to run decoding via Common Spatial Patterns (CSP) analysis on the -data. CSP takes as input data covariances that are estimated on different -time and frequency ranges. This allows to obtain decoding scores defined over -time and frequency. -""" - -decoding_csp_times: Optional[ArrayLike] = np.linspace( - max(0, epochs_tmin), epochs_tmax, num=6 -) -""" -The edges of the time bins to use for CSP decoding. -Must contain at least two elements. By default, 5 equally-spaced bins are -created across the non-negative time range of the epochs. -All specified time points must be contained in the epochs interval. -If `None`, do not perform **time-frequency** analysis, and only run CSP on -**frequency** data. - -???+ example "Example" - Create 3 equidistant time bins (0–0.2, 0.2–0.4, 0.4–0.6 sec): - ```python - decoding_csp_times = np.linspace(start=0, stop=0.6, num=4) - ``` - Create 2 time bins of different durations (0–0.4, 0.4–0.6 sec): - ```python - decoding_csp_times = [0, 0.4, 0.6] - ``` -""" - -decoding_csp_freqs: Dict[str, ArrayLike] = { - "custom": [ - time_frequency_freq_min, - (time_frequency_freq_max + time_frequency_freq_min) / 2, # noqa: E501 - time_frequency_freq_max, - ] -} -""" -The edges of the frequency bins to use for CSP decoding. - -This parameter must be a dictionary with: -- keys specifying the unique identifier or "name" to use for the frequency - range to be treated jointly during statistical testing (such as "alpha" or - "beta"), and -- values must be list-like objects containing at least two scalar values, - specifying the edges of the respective frequency bin(s), e.g., `[8, 12]`. - -Defaults to two frequency bins, one from -[`time_frequency_freq_min`][mne_bids_pipeline._config.time_frequency_freq_min] -to the midpoint between this value and -[`time_frequency_freq_max`][mne_bids_pipeline._config.time_frequency_freq_max]; -and the other from that midpoint to `time_frequency_freq_max`. -???+ example "Example" - Create two frequency bins, one for 4–8 Hz, and another for 8–14 Hz, which - will be clustered together during statistical testing (in the - time-frequency plane): - ```python - decoding_csp_freqs = { - 'custom_range': [4, 8, 14] - } - ``` - Create the same two frequency bins, but treat them separately during - statistical testing (i.e., temporal clustering only): - ```python - decoding_csp_freqs = { - 'theta': [4, 8], - 'alpha': [8, 14] - } - ``` - Create 5 equidistant frequency bins from 4 to 14 Hz: - ```python - decoding_csp_freqs = { - 'custom_range': np.linspace( - start=4, - stop=14, - num=5+1 # We need one more to account for the endpoint! - ) - } -""" - -time_frequency_baseline: Optional[Tuple[float, float]] = None +time_frequency_baseline: tuple[float, float] | None = None """ Baseline period to use for the time-frequency analysis. If `None`, no baseline. ???+ example "Example" @@ -1661,7 +1838,7 @@ ``` """ -time_frequency_crop: Optional[dict] = None +time_frequency_crop: dict | None = None """ Period and frequency range to crop the time-frequency analysis to. If `None`, no cropping. @@ -1672,17 +1849,34 @@ ``` """ -############################################################################### -# SOURCE ESTIMATION PARAMETERS -# ---------------------------- -# +# ## Group-level analysis + +interpolate_bads_grand_average: bool = True +""" +Interpolate bad sensors in each dataset before calculating the grand +average. This parameter is passed to the `mne.grand_average` function via +the keyword argument `interpolate_bads`. It requires to have channel +locations set. + +???+ example "Example" + ```python + interpolate_bads_grand_average = True + ``` +""" + +# %% +# # Source-level analysis + +# ## General source analysis settings run_source_estimation: bool = True """ Whether to run source estimation processing steps if not explicitly requested. """ -use_template_mri: Optional[str] = None +# ## BEM surface + +use_template_mri: str | None = None """ Whether to use a template MRI subject such as FreeSurfer's `fsaverage` subject. This may come in handy if you don't have individual MR scans of your @@ -1754,7 +1948,9 @@ Whether to print the complete output of FreeSurfer commands. Note that if `False`, no FreeSurfer output might be displayed at all!""" -mri_t1_path_generator: Optional[Callable[[BIDSPath], BIDSPath]] = None +# ## Source space & forward solution + +mri_t1_path_generator: Callable[[BIDSPath], BIDSPath] | None = None """ To perform source-level analyses, the Pipeline needs to generate a transformation matrix that translates coordinates from MEG and EEG sensor @@ -1814,7 +2010,7 @@ def get_t1_from_meeg(bids_path): ``` """ -mri_landmarks_kind: Optional[Callable[[BIDSPath], str]] = None +mri_landmarks_kind: Callable[[BIDSPath], str] | None = None """ This config option allows to look for specific landmarks in the json sidecar file of the T1 MRI file. This can be useful when we have different @@ -1831,7 +2027,7 @@ def mri_landmarks_kind(bids_path): ``` """ -spacing: Union[Literal["oct5", "oct6", "ico4", "ico5", "all"], int] = "oct6" +spacing: Literal["oct5", "oct6", "ico4", "ico5", "all"] | int = "oct6" """ The spacing to use. Can be `'ico#'` for a recursively subdivided icosahedron, `'oct#'` for a recursively subdivided octahedron, @@ -1846,24 +2042,35 @@ def mri_landmarks_kind(bids_path): Exclude points closer than this distance (mm) to the bounding surface. """ -loose: Union[float, Literal["auto"]] = 0.2 +# ## Inverse solution + +loose: Annotated[float, Interval(ge=0, le=1)] | Literal["auto"] = 0.2 """ -Value that weights the source variances of the dipole components -that are parallel (tangential) to the cortical surface. If `0`, then the -inverse solution is computed with **fixed orientation.** -If `1`, it corresponds to **free orientation.** -The default value, `'auto'`, is set to `0.2` for surface-oriented source -spaces, and to `1.0` for volumetric, discrete, or mixed source spaces, -unless `fixed is True` in which case the value 0. is used. +A value between 0 and 1 that weights the source variances of the dipole components +that are parallel (tangential) to the cortical surface. + +If `0`, then the inverse solution is computed with **fixed orientation**, i.e., +only dipole components perpendicular to the cortical surface are considered. + +If `1`, it corresponds to **free orientation**, i.e., dipole components with any +orientation are considered. + +The default value, `0.2`, is suitable for surface-oriented source spaces. + +For volume or mixed source spaces, choose `1.0`. + +!!! info + Support for modeling volume and mixed source spaces will be added in a future + version of MNE-BIDS-Pipeline. """ -depth: Optional[Union[float, dict]] = 0.8 +depth: Annotated[float, Interval(ge=0, le=1)] | dict = 0.8 """ -If float (default 0.8), it acts as the depth weighting exponent (`exp`) -to use (must be between 0 and 1). None is equivalent to 0, meaning no -depth weighting is performed. Can also be a `dict` containing additional -keyword arguments to pass to :func:`mne.forward.compute_depth_prior` -(see docstring for details and defaults). +If a number, it acts as the depth weighting exponent to use +(must be between `0` and`1`), with`0` meaning no depth weighting is performed. + +Can also be a dictionary containing additional keyword arguments to pass to +`mne.forward.compute_depth_prior` (see docstring for details and defaults). """ inverse_method: Literal["MNE", "dSPM", "sLORETA", "eLORETA"] = "dSPM" @@ -1872,11 +2079,11 @@ def mri_landmarks_kind(bids_path): solution. """ -noise_cov: Union[ - Tuple[Optional[float], Optional[float]], - Literal["emptyroom", "rest", "ad-hoc"], - Callable[[BIDSPath], mne.Covariance], -] = (None, 0) +noise_cov: ( + tuple[float | None, float | None] + | Literal["emptyroom", "rest", "ad-hoc"] + | Callable[[BIDSPath], Covariance] +) = (None, 0) """ Specify how to estimate the noise covariance matrix, which is used in inverse modeling. @@ -1940,15 +2147,33 @@ def noise_cov(bids_path): ``` """ -source_info_path_update: Optional[Dict[str, str]] = dict(suffix="ave") +noise_cov_method: Literal[ + "shrunk", + "empirical", + "diagonal_fixed", + "oas", + "ledoit_wolf", + "factor_analysis", + "shrinkage", + "pca", + "auto", +] = "shrunk" +""" +The noise covariance estimation method to use. See the MNE-Python documentation +of `mne.compute_covariance` for details. """ -When computing the forward and inverse solutions, by default the pipeline -retrieves the `mne.Info` object from the cleaned evoked data. However, in -certain situations you may wish to use a different `Info`. +source_info_path_update: dict[str, str] | None = None +""" +When computing the forward and inverse solutions, it is important to +provide the `mne.Info` object from the data on which the noise covariance was +computed, to avoid problems resulting from mismatching ranks. This parameter allows you to explicitly specify from which file to retrieve the `mne.Info` object. Use this parameter to supply a dictionary to `BIDSPath.update()` during the forward and inverse processing steps. +If set to `None` (default), the info will be retrieved either from the raw +file specified in `noise_cov`, or the cleaned evoked +(if `noise_cov` is None or `ad-hoc`). ???+ example "Example" Use the `Info` object stored in the cleaned epochs: @@ -1956,9 +2181,20 @@ def noise_cov(bids_path): source_info_path_update = {'processing': 'clean', 'suffix': 'epo'} ``` + + Use the `Info` object stored in a raw file (e.g. resting state): + ```python + source_info_path_update = {'processing': 'clean', + 'suffix': 'raw', + 'task': 'rest'} + ``` + If you set `noise_cov = 'rest'` and `source_path_info = None`, + then the behavior is identical to that above + (it will automatically use the resting state data). + """ -inverse_targets: List[Literal["evoked"]] = ["evoked"] +inverse_targets: list[Literal["evoked"]] = ["evoked"] """ On which data to apply the inverse operator. Currently, the only supported @@ -1977,11 +2213,12 @@ def noise_cov(bids_path): ``` """ -############################################################################### -# Report generation -# ----------------- +# %% +# # Reports + +# ## Report generation -report_evoked_n_time_points: Optional[int] = None +report_evoked_n_time_points: int | None = None """ Specifies the number of time points to display for each evoked in the report. If `None`, it defaults to the current default in MNE-Python. @@ -1993,7 +2230,7 @@ def noise_cov(bids_path): ``` """ -report_stc_n_time_points: Optional[int] = None +report_stc_n_time_points: int | None = None """ Specifies the number of time points to display for each source estimates in the report. If `None`, it defaults to the current default in MNE-Python. @@ -2005,9 +2242,65 @@ def noise_cov(bids_path): ``` """ -############################################################################### -# Execution -# --------- +report_add_epochs_image_kwargs: dict | None = None +""" +Specifies the limits for the color scales of the epochs_image in the report. +If `None`, it defaults to the current default in MNE-Python. + +???+ example "Example" + Set vmin and vmax to the epochs rejection thresholds (with unit conversion): + + ```python + report_add_epochs_image_kwargs = { + "grad": {"vmin": 0, "vmax": 1e13 * reject["grad"]}, # fT/cm + "mag": {"vmin": 0, "vmax": 1e15 * reject["mag"]}, # fT + } + ``` +""" + +# %% +# # Caching +# +# Per default, the pipeline output is cached (temporarily stored), +# to avoid unnecessary reruns of previously computed steps. +# Yet, for consistency, changes in configuration parameters trigger +# automatic reruns of previous steps. +# !!! info +# To force rerunning a given step, run the pipeline with the option: `--no-cache`. + +memory_location: PathLike | bool | None = True +""" +If not None (or False), caching will be enabled and the cache files will be +stored in the given directory. The default (True) will use a +`"_cache"` subdirectory (name configurable via the +[`memory_subdir`][mne_bids_pipeline._config.memory_subdir] +variable) in the BIDS derivative root of the dataset. +""" + +memory_subdir: str = "_cache" +""" +The caching directory name to use if `memory_location` is `True`. +""" + +memory_file_method: Literal["mtime", "hash"] = "mtime" +""" +The method to use for cache invalidation (i.e., detecting changes). Using the +"modified time" reported by the filesystem (`'mtime'`, default) is very fast +but requires that the filesystem supports proper mtime reporting. Using file +hashes (`'hash'`) is slower and requires reading all input files but should +work on any filesystem. +""" + +memory_verbose: int = 0 +""" +The verbosity to use when using memory. The default (0) does not print, while +1 will print the function calls that will be cached. See the documentation for +the joblib.Memory class for more information.""" + +# %% +# # Parallelization +# +# These options control parallel processing (e.g., multiple subjects at once), n_jobs: int = 1 """ @@ -2029,7 +2322,7 @@ def noise_cov(bids_path): Ignored if `parallel_backend` is not `'dask'`. """ -dask_temp_dir: Optional[PathLike] = None +dask_temp_dir: PathLike | None = None """ The temporary directory to use by Dask. Dask places lock-files in this directory, and also uses it to "spill" RAM contents to disk if the amount of @@ -2047,19 +2340,10 @@ def noise_cov(bids_path): The maximum amount of RAM per Dask worker. """ -random_state: Optional[int] = 42 -""" -You can specify the seed of the random number generator (RNG). -This setting is passed to the ICA algorithm and to the decoding function, -ensuring reproducible results. Set to `None` to avoid setting the RNG -to a defined state. -""" - -shortest_event: int = 1 -""" -Minimum number of samples an event must last. If the -duration is less than this, an exception will be raised. -""" +# %% +# # Logging +# +# These options control how much logging output is produced. log_level: Literal["info", "error"] = "info" """ @@ -2071,6 +2355,13 @@ def noise_cov(bids_path): Set the MNE-Python logging verbosity. """ + +# %% +# # Error handling +# +# These options control how errors while processing the data or the configuration file +# are handled. + on_error: Literal["continue", "abort", "debug"] = "abort" """ Whether to abort processing as soon as an error occurs, continue with all other @@ -2081,28 +2372,6 @@ def noise_cov(bids_path): Enabling debug mode deactivates parallel processing. """ -memory_location: Optional[Union[PathLike, bool]] = True -""" -If not None (or False), caching will be enabled and the cache files will be -stored in the given directory. The default (True) will use a -`'joblib'` subdirectory in the BIDS derivative root of the dataset. -""" - -memory_file_method: Literal["mtime", "hash"] = "mtime" -""" -The method to use for cache invalidation (i.e., detecting changes). Using the -"modified time" reported by the filesystem (`'mtime'`, default) is very fast -but requires that the filesystem supports proper mtime reporting. Using file -hashes (`'hash'`) is slower and requires reading all input files but should -work on any filesystem. -""" - -memory_verbose: int = 0 -""" -The verbosity to use when using memory. The default (0) does not print, while -1 will print the function calls that will be cached. See the documentation for -the joblib.Memory class for more information.""" - config_validation: Literal["raise", "warn", "ignore"] = "raise" """ How strictly to validate the configuration. Errors are always raised for diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 8755c32c6..e947e58f0 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -4,29 +4,41 @@ import importlib import os import pathlib +from dataclasses import field +from functools import partial from types import SimpleNamespace -from typing import Optional, List import matplotlib -import numpy as np import mne -from mne.utils import _check_option, _validate_type +import numpy as np +from pydantic import BaseModel, ConfigDict, ValidationError -from ._logging import logger, gen_log_kwargs +from ._logging import gen_log_kwargs, logger from .typing import PathLike def _import_config( *, - config_path: Optional[PathLike], - overrides: Optional[SimpleNamespace] = None, + config_path: PathLike | None, + overrides: SimpleNamespace | None = None, check: bool = True, log: bool = True, ) -> SimpleNamespace: """Import the default config and the user's config.""" # Get the default config = _get_default_config() + # Public names users generally will have in their config valid_names = [d for d in dir(config) if not d.startswith("_")] + # Names that we will reduce the SimpleConfig to before returning + # (see _update_with_user_config) + keep_names = [d for d in dir(config) if not d.startswith("__")] + [ + "config_path", + "PIPELINE_NAME", + "VERSION", + "CODE_URL", + "_raw_split_size", + "_epochs_split_size", + ] # Update with user config user_names = _update_with_user_config( @@ -36,16 +48,31 @@ def _import_config( log=log, ) + extra_exec_params_keys = () + extra_config = os.getenv("_MNE_BIDS_STUDY_TESTING_EXTRA_CONFIG", "") + if extra_config: + msg = f"With testing config: {extra_config}" + logger.info(**gen_log_kwargs(message=msg, emoji="override")) + _update_config_from_path( + config=config, + config_path=extra_config, + ) + extra_exec_params_keys = ("_n_jobs",) + keep_names.extend(extra_exec_params_keys) + # Check it if check: - _check_config(config) + _check_config(config, config_path) _check_misspellings_removals( - config, valid_names=valid_names, user_names=user_names, log=log, + config_validation=config.config_validation, ) + # Finally, reduce to our actual supported params (all keep_names should be present) + config = SimpleNamespace(**{k: getattr(config, k) for k in keep_names}) + # Take some standard actions mne.set_log_level(verbose=config.mne_log_level.upper()) @@ -64,12 +91,13 @@ def _import_config( "interactive", # Caching "memory_location", + "memory_subdir", "memory_verbose", "memory_file_method", # Misc "deriv_root", "config_path", - ) + ) + extra_exec_params_keys in_both = {"deriv_root"} exec_params = SimpleNamespace(**{k: getattr(config, k) for k in keys}) for k in keys: @@ -89,7 +117,7 @@ def _get_default_config(): ignore_keys = { name.asname or name.name for element in tree.body - if isinstance(element, (ast.Import, ast.ImportFrom)) + if isinstance(element, ast.Import | ast.ImportFrom) for name in element.names } config = SimpleNamespace( @@ -102,13 +130,39 @@ def _get_default_config(): return config +def _update_config_from_path( + *, + config: SimpleNamespace, + config_path: PathLike, +): + user_names = list() + config_path = pathlib.Path(config_path).expanduser().resolve(strict=True) + # Import configuration from an arbitrary path without having to fiddle + # with `sys.path`. + spec = importlib.util.spec_from_file_location( + name="custom_config", location=config_path + ) + custom_cfg = importlib.util.module_from_spec(spec) + spec.loader.exec_module(custom_cfg) + for key in dir(custom_cfg): + if not key.startswith("__"): + # don't validate private vars, but do add to config + # (e.g., so that our hidden _raw_split_size is included) + if not key.startswith("_"): + user_names.append(key) + val = getattr(custom_cfg, key) + logger.debug(f"Overwriting: {key} -> {val}") + setattr(config, key, val) + return user_names + + def _update_with_user_config( *, config: SimpleNamespace, # modified in-place - config_path: Optional[PathLike], - overrides: Optional[SimpleNamespace], + config_path: PathLike | None, + overrides: SimpleNamespace | None, log: bool = False, -) -> List[str]: +) -> list[str]: # 1. Basics and hidden vars from . import __version__ @@ -121,23 +175,12 @@ def _update_with_user_config( # 2. User config user_names = list() if config_path is not None: - config_path = pathlib.Path(config_path).expanduser().resolve(strict=True) - # Import configuration from an arbitrary path without having to fiddle - # with `sys.path`. - spec = importlib.util.spec_from_file_location( - name="custom_config", location=config_path + user_names.extend( + _update_config_from_path( + config=config, + config_path=config_path, + ) ) - custom_cfg = importlib.util.module_from_spec(spec) - spec.loader.exec_module(custom_cfg) - for key in dir(custom_cfg): - if not key.startswith("__"): - # don't validate private vars, but do add to config - # (e.g., so that our hidden _raw_split_size is included) - if not key.startswith("_"): - user_names.append(key) - val = getattr(custom_cfg, key) - logger.debug("Overwriting: %s -> %s" % (key, val)) - setattr(config, key, val) config.config_path = config_path # 3. Overrides via command-line switches @@ -147,9 +190,7 @@ def _update_with_user_config( val = getattr(overrides, name) if log: msg = f"Overriding config.{name} = {repr(val)}" - logger.info( - **gen_log_kwargs(message=msg, step="", emoji="override", box="╶╴") - ) + logger.info(**gen_log_kwargs(message=msg, emoji="override")) setattr(config, name, val) # 4. Env vars and other triaging @@ -168,7 +209,7 @@ def _update_with_user_config( config.deriv_root = pathlib.Path(config.deriv_root).expanduser().resolve() # 5. Consistency - log_kwargs = dict(emoji="override", box=" ", step="") + log_kwargs = dict(emoji="override") if config.interactive: if log and config.on_error != "debug": msg = 'Setting config.on_error="debug" because of interactive mode' @@ -191,10 +232,11 @@ def _update_with_user_config( return user_names -def _check_config(config: SimpleNamespace) -> None: - # TODO: Use pydantic to do these validations - # https://github.com/mne-tools/mne-bids-pipeline/issues/646 - _check_option("config.parallel_backend", config.parallel_backend, ("dask", "loky")) +def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None: + _pydantic_validate(config=config, config_path=config_path) + + # Eventually all of these could be pydantic-validated, but for now we'll + # just change the ones that are easy config.bids_root.resolve(strict=True) @@ -207,12 +249,6 @@ def _check_config(config: SimpleNamespace) -> None: reject = config.reject ica_reject = config.ica_reject if config.spatial_filter == "ica": - _check_option( - "config.ica_algorithm", - config.ica_algorithm, - ("picard", "fastica", "extended_infomax"), - ) - if config.ica_l_freq < 1: raise ValueError( "You requested to high-pass filter the data before ICA with " @@ -235,7 +271,7 @@ def _check_config(config: SimpleNamespace) -> None: if ( ica_reject is not None and reject is not None - and reject != "autoreject_global" + and reject not in ["autoreject_global", "autoreject_local"] ): for ch_type in reject: if ch_type in ica_reject and reject[ch_type] > ica_reject[ch_type]: @@ -246,30 +282,6 @@ def _check_config(config: SimpleNamespace) -> None: f'ica_reject["{ch_type}"] ({ica_reject[ch_type]})' ) - if not config.ch_types: - raise ValueError("Please specify ch_types in your configuration.") - - _VALID_TYPES = ("meg", "mag", "grad", "eeg") - if any(ch_type not in _VALID_TYPES for ch_type in config.ch_types): - raise ValueError( - "Invalid channel type passed. Please adjust `ch_types` in your " - f"configuration, got {config.ch_types} but supported types are " - f"{_VALID_TYPES}" - ) - - _check_option("config.on_error", config.on_error, ("continue", "abort", "debug")) - _check_option( - "config.memory_file_method", config.memory_file_method, ("mtime", "hash") - ) - - if isinstance(config.noise_cov, str): - _check_option( - "config.noise_cov", - config.noise_cov, - ("emptyroom", "ad-hoc", "rest"), - extra="when a string", - ) - if config.noise_cov == "emptyroom" and "eeg" in config.ch_types: raise ValueError( "You requested to process data that contains EEG channels. In " @@ -286,10 +298,6 @@ def _check_config(config: SimpleNamespace) -> None: "Please set process_empty_room = True" ) - _check_option( - "config.bem_mri_images", config.bem_mri_images, ("FLASH", "T1", "auto") - ) - bl = config.baseline if bl is not None: if (bl[0] is not None and bl[0] < config.epochs_tmin) or ( @@ -306,16 +314,7 @@ def _check_config(config: SimpleNamespace) -> None: f"but you set baseline={bl}" ) - # check decoding parameters - if config.decoding_n_splits < 2: - raise ValueError("decoding_n_splits should be at least 2.") - # check cluster permutation parameters - if not 0 < config.cluster_permutation_p_threshold < 1: - raise ValueError( - "cluster_permutation_p_threshold should be in the (0, 1) interval." - ) - if config.cluster_n_permutations < 10 / config.cluster_permutation_p_threshold: raise ValueError( "cluster_n_permutations is not big enough to calculate " @@ -330,32 +329,7 @@ def _check_config(config: SimpleNamespace) -> None: "This is only allowed for resting-state analysis." ) - _check_option( - "config.on_rename_missing_events", - config.on_rename_missing_events, - ("raise", "warn", "ignore"), - ) - - _validate_type(config.n_jobs, int, "n_jobs") - - _check_option( - "config.config_validation", - config.config_validation, - ("raise", "warn", "ignore"), - ) - - _validate_type( - config.mf_destination, - (str, list, tuple, np.ndarray), - "config.mf_destination", - ) - if isinstance(config.mf_destination, str): - _check_option( - "config.mf_destination", - config.mf_destination, - ("reference_run",), - ) - else: + if not isinstance(config.mf_destination, str): destination = np.array(config.mf_destination, float) if destination.shape != (4, 4): raise ValueError( @@ -363,6 +337,84 @@ def _check_config(config: SimpleNamespace) -> None: f"but got shape {destination.shape}" ) +# From: https://github.com/mne-tools/mne-bids-pipeline/pull/812 + # MNE-ICALabel + if config.ica_use_icalabel: + if config.ica_l_freq != 1.0 or config.h_freq != 100.0: + raise ValueError( + f"When using MNE-ICALabel, you must set ica_l_freq=1 and h_freq=100, " + f"but got: ica_l_freq={config.ica_l_freq} and h_freq={config.h_freq}" + ) + + if config.eeg_reference != "average": + raise ValueError( + f'When using MNE-ICALabel, you must set eeg_reference="average", but ' + f"got: eeg_reference={config.eeg_reference}" + ) + +def _default_factory(key, val): + # convert a default to a default factory if needed, having an explicit + # allowlist of non-empty ones + allowlist = [ + {"n_mag": 1, "n_grad": 1, "n_eeg": 1}, # n_proj_* + {"custom": (8, 24.0, 40)}, # decoding_csp_freqs + ["evoked"], # inverse_targets + [4, 8, 16], # autoreject_n_interpolate + #["brain", "muscle artifact", "eye blink", "heart beat", "line noise", "channel noise", "other"], # icalabel_include + ["brain","other"] + ] + for typ in (dict, list): + if isinstance(val, typ): + try: + idx = allowlist.index(val) + except ValueError: + assert val == typ(), (key, val) + default_factory = typ + else: + if typ is dict: + default_factory = partial(typ, **allowlist[idx]) + else: + assert typ is list + default_factory = partial(typ, allowlist[idx]) + return field(default_factory=default_factory) + return val + + +def _pydantic_validate( + config: SimpleNamespace, + config_path: PathLike | None, +): + """Create dataclass from config type hints and validate with pydantic.""" + # https://docs.pydantic.dev/latest/usage/dataclasses/ + from . import _config as root_config + + # Modify annotations to add nested strict parsing + annotations = dict() + attrs = dict() + for key, annot in root_config.__annotations__.items(): + annotations[key] = annot + attrs[key] = _default_factory(key, root_config.__dict__[key]) + name = "user configuration" + if config_path is not None: + name += f" from {config_path}" + model_config = ConfigDict( + arbitrary_types_allowed=True, # needed in 2.6.0 to allow DigMontage for example + validate_assignment=True, + strict=True, # do not allow float for int for example + extra="forbid", + ) + UserConfig = type( + name, + (BaseModel,), + {"__annotations__": annotations, "model_config": model_config, **attrs}, + ) + # Now use pydantic to automagically validate + user_vals = {key: val for key, val in config.__dict__.items() if key in annotations} + try: + UserConfig.model_validate(user_vals) + except ValidationError as err: + raise ValueError(str(err)) from None + _REMOVED_NAMES = { "debug": dict( @@ -379,15 +431,18 @@ def _check_config(config: SimpleNamespace) -> None: "N_JOBS": dict( new_name="n_jobs", ), + "ica_ctps_ecg_threshold": dict( + new_name="ica_ecg_threshold", + ), } def _check_misspellings_removals( - config: SimpleNamespace, *, - valid_names: List[str], - user_names: List[str], + valid_names: list[str], + user_names: list[str], log: bool, + config_validation: str, ) -> None: # for each name in the user names, check if it's in the valid names but # the correct one is not defined @@ -396,7 +451,7 @@ def _check_misspellings_removals( if user_name not in valid_names: # find the closest match closest_match = difflib.get_close_matches(user_name, valid_names, n=1) - msg = f"Found a variable named {repr(user_name)} in your custom " "config," + msg = f"Found a variable named {repr(user_name)} in your custom config," if closest_match and closest_match[0] not in user_names: this_msg = ( f"{msg} did you mean {repr(closest_match[0])}? " @@ -404,7 +459,7 @@ def _check_misspellings_removals( "the variable to reduce ambiguity and avoid this message, " "or set config.config_validation to 'warn' or 'ignore'." ) - _handle_config_error(this_msg, log, config) + _handle_config_error(this_msg, log, config_validation) if user_name in _REMOVED_NAMES: new = _REMOVED_NAMES[user_name]["new_name"] if new not in user_names: @@ -415,16 +470,16 @@ def _check_misspellings_removals( f"{msg} this variable has been removed as a valid " f"config option, {instead}." ) - _handle_config_error(this_msg, log, config) + _handle_config_error(this_msg, log, config_validation) def _handle_config_error( msg: str, log: bool, - config: SimpleNamespace, + config_validation: str, ) -> None: - if config.config_validation == "raise": + if config_validation == "raise": raise ValueError(msg) - elif config.config_validation == "warn": + elif config_validation == "warn": if log: - logger.warning(**gen_log_kwargs(message=msg, step="", emoji="🛟")) + logger.warning(**gen_log_kwargs(message=msg, emoji="🛟")) diff --git a/mne_bids_pipeline/_config_template.py b/mne_bids_pipeline/_config_template.py index 9954811ad..9c5a0ff29 100644 --- a/mne_bids_pipeline/_config_template.py +++ b/mne_bids_pipeline/_config_template.py @@ -1,8 +1,6 @@ from pathlib import Path -from typing import List - -from ._logging import logger, gen_log_kwargs +from ._logging import gen_log_kwargs, logger CONFIG_SOURCE_PATH = Path(__file__).parent / "_config.py" @@ -17,8 +15,8 @@ def create_template_config( raise FileExistsError(f"The specified path already exists: {target_path}") # Create a template by commenting out most of the lines in _config.py - config: List[str] = [] - with open(CONFIG_SOURCE_PATH, "r", encoding="utf-8") as f: + config: list[str] = [] + with open(CONFIG_SOURCE_PATH, encoding="utf-8") as f: for line in f: line = ( line if line.startswith(("#", "\n", "import", "from")) else f"# {line}" @@ -27,7 +25,7 @@ def create_template_config( target_path.write_text("".join(config), encoding="utf-8") message = f"Successfully created template configuration file at: " f"{target_path}" - logger.info(**gen_log_kwargs(message=message, emoji="✅", step="")) + logger.info(**gen_log_kwargs(message=message, emoji="✅")) message = "Please edit the file before running the pipeline." - logger.info(**gen_log_kwargs(message=message, emoji="💡", step="")) + logger.info(**gen_log_kwargs(message=message, emoji="💡")) diff --git a/mne_bids_pipeline/_config_utils.py b/mne_bids_pipeline/_config_utils.py index a93e57504..0160aeccb 100644 --- a/mne_bids_pipeline/_config_utils.py +++ b/mne_bids_pipeline/_config_utils.py @@ -3,21 +3,22 @@ import copy import functools import pathlib -from typing import List, Optional, Union, Iterable, Tuple, Dict, TypeVar, Literal, Any -from types import SimpleNamespace, ModuleType +from collections.abc import Iterable +from types import ModuleType, SimpleNamespace +from typing import Any, Literal, TypeVar -import numpy as np import mne import mne_bids +import numpy as np from mne_bids import BIDSPath -from ._logging import logger, gen_log_kwargs +from ._logging import gen_log_kwargs, logger from .typing import ArbitraryContrast try: - _keys_arbitrary_contrast = set(ArbitraryContrast.__required_keys__) + _set_keys_arbitrary_contrast = set(ArbitraryContrast.__required_keys__) except Exception: - _keys_arbitrary_contrast = set(ArbitraryContrast.__annotations__.keys()) + _set_keys_arbitrary_contrast = set(ArbitraryContrast.__annotations__.keys()) def get_fs_subjects_dir(config: SimpleNamespace) -> pathlib.Path: @@ -47,8 +48,8 @@ def get_fs_subject(config: SimpleNamespace, subject: str) -> str: return f"sub-{subject}" -@functools.lru_cache(maxsize=None) -def _get_entity_vals_cached(*args, **kwargs) -> List[str]: +@functools.cache +def _get_entity_vals_cached(*args, **kwargs) -> list[str]: return mne_bids.get_entity_vals(*args, **kwargs) @@ -73,18 +74,18 @@ def get_datatype(config: SimpleNamespace) -> Literal["meg", "eeg"]: ) -@functools.lru_cache(maxsize=None) +@functools.cache def _get_datatypes_cached(root): return mne_bids.get_datatypes(root=root) -def _get_ignore_datatypes(config: SimpleNamespace) -> Tuple[str]: - _all_datatypes: List[str] = _get_datatypes_cached(root=config.bids_root) +def _get_ignore_datatypes(config: SimpleNamespace) -> tuple[str]: + _all_datatypes: list[str] = _get_datatypes_cached(root=config.bids_root) _ignore_datatypes = set(_all_datatypes) - set([get_datatype(config)]) return tuple(sorted(_ignore_datatypes)) -def get_subjects(config: SimpleNamespace) -> List[str]: +def get_subjects(config: SimpleNamespace) -> list[str]: _valid_subjects = _get_entity_vals_cached( root=config.bids_root, entity_key="subject", @@ -94,15 +95,24 @@ def get_subjects(config: SimpleNamespace) -> List[str]: s = _valid_subjects else: s = config.subjects + missing_subjects = set(s) - set(_valid_subjects) + if missing_subjects: + raise FileNotFoundError( + "The following requested subjects were not found in the dataset: " + f"{', '.join(missing_subjects)}" + ) - subjects = set(s) - set(config.exclude_subjects) - # Drop empty-room subject. - subjects = subjects - set(["emptyroom"]) + # Preserve order and remove excluded subjects + subjects = [ + subject + for subject in s + if subject not in config.exclude_subjects and subject != "emptyroom" + ] - return sorted(subjects) + return subjects -def get_sessions(config: SimpleNamespace) -> Union[List[None], List[str]]: +def get_sessions(config: SimpleNamespace) -> list[None] | list[str]: sessions = copy.deepcopy(config.sessions) _all_sessions = _get_entity_vals_cached( root=config.bids_root, @@ -120,8 +130,8 @@ def get_sessions(config: SimpleNamespace) -> Union[List[None], List[str]]: def get_runs_all_subjects( config: SimpleNamespace, -) -> Dict[str, Union[List[None], List[str]]]: - """Gives the mapping between subjects and their runs. +) -> dict[str, list[None] | list[str]]: + """Give the mapping between subjects and their runs. Returns ------- @@ -142,10 +152,10 @@ def get_runs_all_subjects( ) -@functools.lru_cache(maxsize=None) +@functools.cache def _get_runs_all_subjects_cached( - **config_dict: Dict[str, Any], -) -> Dict[str, Union[List[None], List[str]]]: + **config_dict: dict[str, Any], +) -> dict[str, list[None] | list[str]]: config = SimpleNamespace(**config_dict) # Sometimes we check list equivalence for ch_types, so convert it back config.ch_types = list(config.ch_types) @@ -172,10 +182,20 @@ def _get_runs_all_subjects_cached( return subj_runs -def get_intersect_run(config: SimpleNamespace) -> List[str]: - """Returns the intersection of all the runs of all subjects.""" +def get_intersect_run(config: SimpleNamespace) -> list[str]: + """Return the intersection of all the runs of all subjects.""" subj_runs = get_runs_all_subjects(config) - return list(set.intersection(*map(set, subj_runs.values()))) + # Do not use something like: + # list(set.intersection(*map(set, subj_runs.values()))) + # as it will not preserve order. Instead just be explicit and preserve order. + # We could use "sorted", but it's probably better to use the order provided by + # the user (if they want to put `runs=["02", "01"]` etc. it's better to use "02") + all_runs = list() + for runs in subj_runs.values(): + for run in runs: + if run not in all_runs: + all_runs.append(run) + return all_runs def get_runs( @@ -183,8 +203,8 @@ def get_runs( config: SimpleNamespace, subject: str, verbose: bool = False, -) -> Union[List[str], List[None]]: - """Returns a list of runs in the BIDS input data. +) -> list[str] | list[None]: + """Return a list of runs in the BIDS input data. Parameters ---------- @@ -239,14 +259,20 @@ def get_runs_tasks( *, config: SimpleNamespace, subject: str, - session: Optional[str], - include_noise: bool = True, -) -> List[Tuple[str]]: + session: str | None, + which: tuple[str] = ("runs", "noise", "rest"), +) -> list[tuple[str]]: """Get (run, task) tuples for all runs plus (maybe) rest.""" from ._import_data import _get_noise_path, _get_rest_path - runs = get_runs(config=config, subject=subject) - tasks = [get_task(config=config)] * len(runs) + assert isinstance(which, tuple) + assert all(isinstance(inc, str) for inc in which) + assert all(inc in ("runs", "noise", "rest") for inc in which) + runs = list() + tasks = list() + if "runs" in which: + runs.extend(get_runs(config=config, subject=subject)) + tasks.extend([get_task(config=config)] * len(runs)) kwargs = dict( cfg=config, subject=subject, @@ -254,10 +280,10 @@ def get_runs_tasks( kind="orig", add_bads=False, ) - if _get_rest_path(**kwargs): + if "rest" in which and _get_rest_path(**kwargs): runs.append(None) tasks.append("rest") - if include_noise: + if "noise" in which: mf_reference_run = get_mf_reference_run(config=config) if _get_noise_path(mf_reference_run=mf_reference_run, **kwargs): runs.append(None) @@ -290,7 +316,7 @@ def get_mf_reference_run(config: SimpleNamespace) -> str: ) -def get_task(config: SimpleNamespace) -> Optional[str]: +def get_task(config: SimpleNamespace) -> str | None: task = config.task if task: return task @@ -305,7 +331,7 @@ def get_task(config: SimpleNamespace) -> Optional[str]: return _valid_tasks[0] -def get_channels_to_analyze(info: mne.Info, config: SimpleNamespace) -> List[str]: +def get_channels_to_analyze(info: mne.Info, config: SimpleNamespace) -> list[str]: # Return names of the channels of the channel types we wish to analyze. # We also include channels marked as "bad" here. # `exclude=[]`: keep "bad" channels, too. @@ -347,15 +373,19 @@ def get_mf_cal_fname( *, config: SimpleNamespace, subject: str, session: str ) -> pathlib.Path: if config.mf_cal_fname is None: - mf_cal_fpath = BIDSPath( + bids_path = BIDSPath( subject=subject, session=session, suffix="meg", datatype="meg", root=config.bids_root, - ).meg_calibration_fpath + ).match()[0] + mf_cal_fpath = bids_path.meg_calibration_fpath if mf_cal_fpath is None: - raise ValueError("Could not find Maxwell Filter Calibration " "file.") + raise ValueError( + "Could not determine Maxwell Filter Calibration file from BIDS " + f"definition for file {bids_path}." + ) else: mf_cal_fpath = pathlib.Path(config.mf_cal_fname).expanduser().absolute() if not mf_cal_fpath.exists(): @@ -379,7 +409,7 @@ def get_mf_ctc_fname( root=config.bids_root, ).meg_crosstalk_fpath if mf_ctc_fpath is None: - raise ValueError("Could not find Maxwell Filter cross-talk " "file.") + raise ValueError("Could not find Maxwell Filter cross-talk file.") else: mf_ctc_fpath = pathlib.Path(config.mf_ctc_fname).expanduser().absolute() if not mf_ctc_fpath.exists(): @@ -392,44 +422,32 @@ def get_mf_ctc_fname( RawEpochsEvokedT = TypeVar( - "RawEpochsEvokedT", bound=Union[mne.io.BaseRaw, mne.BaseEpochs, mne.Evoked] + "RawEpochsEvokedT", bound=mne.io.BaseRaw | mne.BaseEpochs | mne.Evoked ) def _restrict_analyze_channels( inst: RawEpochsEvokedT, cfg: SimpleNamespace ) -> RawEpochsEvokedT: - if cfg.analyze_channels: - analyze_channels = cfg.analyze_channels - if cfg.analyze_channels == "ch_types": - analyze_channels = cfg.ch_types - inst.apply_proj() - # We special-case the average reference here to work around a situation - # where e.g. `analyze_channels` might contain only a single channel: - # `concatenate_epochs` below will then fail when trying to create / - # apply the projection. We can avoid this by removing an existing - # average reference projection here, and applying the average reference - # directly – without going through a projector. - elif "eeg" in cfg.ch_types and cfg.eeg_reference == "average": - inst.set_eeg_reference("average") - else: - inst.apply_proj() - inst.pick(analyze_channels) - return inst - - -def _get_scalp_in_files(cfg: SimpleNamespace) -> Dict[str, pathlib.Path]: - subject_path = pathlib.Path(cfg.subjects_dir) / cfg.fs_subject - seghead = subject_path / "surf" / "lh.seghead" - in_files = dict() - if seghead.is_file(): - in_files["seghead"] = seghead + analyze_channels = cfg.analyze_channels + if cfg.analyze_channels == "ch_types": + analyze_channels = cfg.ch_types + inst.apply_proj() + # We special-case the average reference here to work around a situation + # where e.g. `analyze_channels` might contain only a single channel: + # `concatenate_epochs` below will then fail when trying to create / + # apply the projection. We can avoid this by removing an existing + # average reference projection here, and applying the average reference + # directly – without going through a projector. + elif "eeg" in cfg.ch_types and cfg.eeg_reference == "average": + inst.set_eeg_reference("average") else: - in_files["t1"] = subject_path / "mri" / "T1.mgz" - return in_files + inst.apply_proj() + inst.pick(analyze_channels) + return inst -def _get_bem_conductivity(cfg: SimpleNamespace) -> Tuple[Tuple[float], str]: +def _get_bem_conductivity(cfg: SimpleNamespace) -> tuple[tuple[float], str]: if cfg.fs_subject in ("fsaverage", cfg.use_template_mri): conductivity = None # should never be used tag = "5120-5120-5120" @@ -447,7 +465,7 @@ def _meg_in_ch_types(ch_types: str) -> bool: def get_noise_cov_bids_path( - cfg: SimpleNamespace, subject: str, session: Optional[str] + cfg: SimpleNamespace, subject: str, session: str | None ) -> BIDSPath: """Retrieve the path to the noise covariance file. @@ -471,7 +489,7 @@ def get_noise_cov_bids_path( task=cfg.task, acquisition=cfg.acq, run=None, - processing=cfg.proc, + processing="clean", recording=cfg.rec, space=cfg.space, suffix="cov", @@ -512,7 +530,7 @@ def get_all_contrasts(config: SimpleNamespace) -> Iterable[ArbitraryContrast]: return normalized_contrasts -def get_decoding_contrasts(config: SimpleNamespace) -> Iterable[Tuple[str, str]]: +def get_decoding_contrasts(config: SimpleNamespace) -> Iterable[tuple[str, str]]: _validate_contrasts(config.contrasts) normalized_contrasts = [] for contrast in config.contrasts: @@ -532,9 +550,22 @@ def get_decoding_contrasts(config: SimpleNamespace) -> Iterable[Tuple[str, str]] return normalized_contrasts +# Map _config.decoding_which_epochs to a BIDS proc- entity +_EPOCHS_DESCRIPTION_TO_PROC_MAP = { + "uncleaned": None, + "after_ica": "ica", + "after_ssp": "ssp", + "cleaned": "clean", +} + + +def _get_decoding_proc(config: SimpleNamespace) -> str | None: + return _EPOCHS_DESCRIPTION_TO_PROC_MAP[config.decoding_which_epochs] + + def get_eeg_reference( config: SimpleNamespace, -) -> Union[Literal["average"], Iterable[str]]: +) -> Literal["average"] | Iterable[str]: if config.eeg_reference == "average": return config.eeg_reference elif isinstance(config.eeg_reference, str): @@ -549,7 +580,7 @@ def _validate_contrasts(contrasts: SimpleNamespace) -> None: if len(contrast) != 2: raise ValueError("Contrasts' tuples MUST be two conditions") elif isinstance(contrast, dict): - if not _keys_arbitrary_contrast.issubset(set(contrast.keys())): + if not _set_keys_arbitrary_contrast.issubset(set(contrast.keys())): raise ValueError(f"Missing key(s) in contrast {contrast}") if len(contrast["conditions"]) != len(contrast["weights"]): raise ValueError( @@ -560,12 +591,8 @@ def _validate_contrasts(contrasts: SimpleNamespace) -> None: raise ValueError("Contrasts must be tuples or well-formed dicts") -def _get_step_modules() -> Dict[str, Tuple[ModuleType]]: - from .steps import init - from .steps import preprocessing - from .steps import sensor - from .steps import source - from .steps import freesurfer +def _get_step_modules() -> dict[str, tuple[ModuleType]]: + from .steps import freesurfer, init, preprocessing, sensor, source INIT_STEPS = init._STEPS PREPROCESSING_STEPS = preprocessing._STEPS @@ -609,3 +636,31 @@ def _bids_kwargs(*, config: SimpleNamespace) -> dict: def _do_mf_autobad(*, cfg: SimpleNamespace) -> bool: return cfg.find_noisy_channels_meg or cfg.find_flat_channels_meg + + +# Adapted from MNE-Python +def _pl(x, *, non_pl="", pl="s"): + """Determine if plural should be used.""" + len_x = x if isinstance(x, int | np.generic) else len(x) + return non_pl if len_x == 1 else pl + + +def _proj_path( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> BIDSPath: + return BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + datatype=cfg.datatype, + root=cfg.deriv_root, + extension=".fif", + suffix="proj", + check=False, + ) diff --git a/mne_bids_pipeline/_decoding.py b/mne_bids_pipeline/_decoding.py index 5b8ebab43..9c34ded27 100644 --- a/mne_bids_pipeline/_decoding.py +++ b/mne_bids_pipeline/_decoding.py @@ -1,8 +1,12 @@ +import mne import numpy as np -from sklearn.linear_model import LogisticRegression from joblib import parallel_backend - from mne.utils import _validate_type +from sklearn.base import BaseEstimator +from sklearn.decomposition import PCA +from sklearn.linear_model import LogisticRegression + +from ._logging import gen_log_kwargs, logger class LogReg(LogisticRegression): @@ -13,19 +17,49 @@ def fit(self, *args, **kwargs): return super().fit(*args, **kwargs) -def _handle_csp_args(decoding_csp_times, decoding_csp_freqs, decoding_metric): - _validate_type(decoding_csp_times, (list, tuple, np.ndarray), "decoding_csp_times") - if len(decoding_csp_times) < 2: - raise ValueError("decoding_csp_times should contain at least 2 values.") +def _handle_csp_args( + decoding_csp_times, + decoding_csp_freqs, + decoding_metric, + *, + epochs_tmin, + epochs_tmax, + time_frequency_freq_min, + time_frequency_freq_max, +): + _validate_type( + decoding_csp_times, (None, list, tuple, np.ndarray), "decoding_csp_times" + ) + if decoding_csp_times is None: + decoding_csp_times = np.linspace(max(0, epochs_tmin), epochs_tmax, num=6) + else: + decoding_csp_times = np.array(decoding_csp_times, float) + if decoding_csp_times.ndim != 1 or len(decoding_csp_times) == 1: + raise ValueError( + "decoding_csp_times should be 1 dimensional and contain at least 2 values " + "to define time intervals, or be empty to disable time-frequency mode, got " + f"shape {decoding_csp_times.shape}" + ) if not np.array_equal(decoding_csp_times, np.sort(decoding_csp_times)): ValueError("decoding_csp_times should be sorted.") + time_bins = np.c_[decoding_csp_times[:-1], decoding_csp_times[1:]] + assert time_bins.ndim == 2 and time_bins.shape[1] == 2, time_bins.shape + if decoding_metric != "roc_auc": raise ValueError( f'CSP decoding currently only supports the "roc_auc" ' f"decoding metric, but received " f'decoding_metric="{decoding_metric}"' ) - _validate_type(decoding_csp_freqs, dict, "config.decoding_csp_freqs") + _validate_type(decoding_csp_freqs, (None, dict), "config.decoding_csp_freqs") + if decoding_csp_freqs is None: + decoding_csp_freqs = { + "custom": ( + time_frequency_freq_min, + (time_frequency_freq_max + time_frequency_freq_min) / 2, # noqa: E501 + time_frequency_freq_max, + ), + } freq_name_to_bins_map = dict() for freq_range_name, edges in decoding_csp_freqs.items(): _validate_type(freq_range_name, str, "config.decoding_csp_freqs key") @@ -49,4 +83,26 @@ def _handle_csp_args(decoding_csp_times, decoding_csp_freqs, decoding_metric): freq_bins = list(zip(edges[:-1], edges[1:])) freq_name_to_bins_map[freq_range_name] = freq_bins - return freq_name_to_bins_map + return freq_name_to_bins_map, time_bins + + +def _decoding_preproc_steps( + subject: str, + session: str | None, + epochs: mne.Epochs, + pca: bool = True, +) -> list[BaseEstimator]: + scaler = mne.decoding.Scaler(epochs.info) + steps = [scaler] + if pca: + ranks = mne.compute_rank(inst=epochs, rank="info") + rank = sum(ranks.values()) + msg = f"Reducing data dimension via PCA; new rank: {rank} (from {ranks})." + logger.info(**gen_log_kwargs(message=msg)) + steps.append( + mne.decoding.UnsupervisedSpatialFilter( + PCA(rank, whiten=True), + average=False, + ) + ) + return steps diff --git a/mne_bids_pipeline/_docs.py b/mne_bids_pipeline/_docs.py new file mode 100644 index 000000000..440b34690 --- /dev/null +++ b/mne_bids_pipeline/_docs.py @@ -0,0 +1,260 @@ +import ast +import inspect +import re +from collections import defaultdict +from pathlib import Path + +from tqdm import tqdm + +from . import _config_utils, _import_data + +_CONFIG_RE = re.compile(r"config\.([a-zA-Z_]+)") + +_NO_CONFIG = { + "freesurfer/_01_recon_all", +} +_IGNORE_OPTIONS = { + "PIPELINE_NAME", + "VERSION", + "CODE_URL", +} +# We don't need to parse the config itself, just the steps +_MANUAL_KWS = { + "source/_04_make_forward:get_config:t1_bids_path": ("mri_t1_path_generator",), + "source/_04_make_forward:get_config:landmarks_kind": ("mri_landmarks_kind",), + "preprocessing/_01_data_quality:get_config:extra_kwargs": ( + "mf_cal_fname", + "mf_ctc_fname", + "mf_head_origin", + "find_flat_channels_meg", + "find_noisy_channels_meg", + ), +} +# Some don't show up so force them to be empty +_EXECUTION_OPTIONS = ( + # Eventually we could deduplicate these with the execution.md list + "n_jobs", + "parallel_backend", + "dask_open_dashboard", + "dask_temp_dir", + "dask_worker_memory_limit", + "log_level", + "mne_log_level", + "on_error", + "memory_location", + "memory_file_method", + "memory_subdir", + "memory_verbose", + "config_validation", + "interactive", +) +_FORCE_EMPTY = _EXECUTION_OPTIONS + ( + # Plus some BIDS one we don't detect because _bids_kwargs etc. above, + # which we could cross-check against the general.md list. A notable + # exception is random_state, since this does have more localized effects. + # These are used a lot at the very beginning, so adding them will lead + # to long lists. Instead, let's just mention at the top of General that + # messing with basic BIDS params will affect almost every step. + "bids_root", + "deriv_root", + "subjects_dir", + "sessions", + "acq", + "proc", + "rec", + "space", + "task", + "runs", + "exclude_runs", + "subjects", + "crop_runs", + "process_empty_room", + "process_rest", + "eeg_bipolar_channels", + "eeg_reference", + "eeg_template_montage", + "drop_channels", + "reader_extra_params", + "read_raw_bids_verbose", + "plot_psd_for_runs", + "shortest_event", + "find_breaks", + "min_break_duration", + "t_break_annot_start_after_previous_event", + "t_break_annot_stop_before_next_event", + "rename_events", + "on_rename_missing_events", + "mf_reference_run", # TODO: Make clearer that this changes a lot + "fix_stim_artifact", + "stim_artifact_tmin", + "stim_artifact_tmax", + # And some that we force to be empty because they affect too many things + # and what they affect is an incomplete list anyway + "exclude_subjects", + "ch_types", + "task_is_rest", + "data_type", +) +# Eventually we could parse AST to get these, but this is simple enough +_EXTRA_FUNCS = { + "_bids_kwargs": ("get_task",), + "_import_data_kwargs": ("get_mf_reference_run",), + "get_runs": ("get_runs_all_subjects",), +} + + +class _ParseConfigSteps: + def __init__(self, force_empty=None): + self._force_empty = _FORCE_EMPTY if force_empty is None else force_empty + self.steps = defaultdict(list) + # Add a few helper functions + for func in ( + _config_utils.get_eeg_reference, + _config_utils.get_all_contrasts, + _config_utils.get_decoding_contrasts, + _config_utils.get_fs_subject, + _config_utils.get_fs_subjects_dir, + _config_utils.get_mf_cal_fname, + _config_utils.get_mf_ctc_fname, + ): + this_list = [] + for attr in ast.walk(ast.parse(inspect.getsource(func))): + if not isinstance(attr, ast.Attribute): + continue + if not (isinstance(attr.value, ast.Name) and attr.value.id == "config"): + continue + if attr.attr not in this_list: + this_list.append(attr.attr) + _MANUAL_KWS[func.__name__] = tuple(this_list) + + for module in tqdm( + sum(_config_utils._get_step_modules().values(), tuple()), + desc="Generating option->step mapping", + ): + step = "/".join(module.__name__.split(".")[-2:]) + found = False # found at least one? + # Walk the module file for "get_config*" functions (can be multiple!) + for func in ast.walk(ast.parse(Path(module.__file__).read_text("utf-8"))): + if not isinstance(func, ast.FunctionDef): + continue + where = f"{step}:{func.name}" + # Also look at config.* args in main(), e.g. config.recreate_bem + # and config.recreate_scalp_surface + if func.name == "main": + for call in ast.walk(func): + if not isinstance(call, ast.Call): + continue + for keyword in call.keywords: + if not isinstance(keyword.value, ast.Attribute): + continue + if keyword.value.value.id != "config": + continue + if keyword.value.attr in ("exec_params",): + continue + self._add_step_option(step, keyword.value.attr) + # Also look for root-level conditionals like use_maxwell_filter + # or spatial_filter + for cond in ast.iter_child_nodes(func): + # is a conditional + if not isinstance(cond, ast.If): + continue + # has a return statement + if not any(isinstance(c, ast.Return) for c in ast.walk(cond)): + continue + # look at all attributes in the conditional + for attr in ast.walk(cond.test): + if not isinstance(attr, ast.Attribute): + continue + if attr.value.id != "config": + continue + self._add_step_option(step, attr.attr) + # Now look at get_config* functions + if not func.name.startswith("get_config"): + continue + found = True + for call in ast.walk(func): + if not isinstance(call, ast.Call): + continue + if call.func.id != "SimpleNamespace": + continue + break + else: + raise RuntimeError(f"Could not find SimpleNamespace in {func}") + assert call.args == [] + for keyword in call.keywords: + if isinstance(keyword.value, ast.Call): + key = keyword.value.func.id + if key in _MANUAL_KWS: + for option in _MANUAL_KWS[key]: + self._add_step_option(step, option) + continue + if keyword.value.func.id == "_sanitize_callable": + assert len(keyword.value.args) == 1 + assert isinstance(keyword.value.args[0], ast.Attribute) + assert keyword.value.args[0].value.id == "config" + self._add_step_option(step, keyword.value.args[0].attr) + continue + if key not in ( + "_bids_kwargs", + "_import_data_kwargs", + "get_runs", + "get_subjects", + "get_sessions", + ): + raise RuntimeError( + f"{where} cannot handle call {keyword.value.func.id=} " + f"for {key}" + ) + # Get the source and regex for config values + if key == "_import_data_kwargs": + funcs = [getattr(_import_data, key)] + else: + funcs = [getattr(_config_utils, key)] + for func_name in _EXTRA_FUNCS.get(key, ()): + funcs.append(getattr(_config_utils, func_name)) + for fi, func in enumerate(funcs): + source = inspect.getsource(func) + assert "config: SimpleNamespace" in source, key + if fi == 0: + for func_name in _EXTRA_FUNCS.get(key, ()): + assert f"{func_name}(" in source, (key, func_name) + attrs = _CONFIG_RE.findall(source) + assert len(attrs), f"No config.* found in source of {key}" + for attr in attrs: + self._add_step_option(step, attr) + continue + if isinstance(keyword.value, ast.Name): + key = f"{where}:{keyword.value.id}" + if key in _MANUAL_KWS: + for option in _MANUAL_KWS[f"{where}:{keyword.value.id}"]: + self._add_step_option(step, option) + continue + raise RuntimeError(f"{where} cannot handle Name {key=}") + if isinstance(keyword.value, ast.IfExp): # conditional + if keyword.arg == "processing": # inline conditional for proc + continue + if not isinstance(keyword.value, ast.Attribute): + raise RuntimeError( + f"{where} cannot handle type {keyword.value=}" + ) + option = keyword.value.attr + if option in _IGNORE_OPTIONS: + continue + assert keyword.value.value.id == "config", f"{where} {keyword.value.value.id}" # noqa: E501 # fmt: skip + self._add_step_option(step, option) + if step in _NO_CONFIG: + assert not found, f"Found unexpected get_config* in {step}" + else: + assert found, f"Could not find get_config* in {step}" + for key in self._force_empty: + self.steps[key] = list() + for key, val in self.steps.items(): + assert len(val) == len(set(val)), f"{key} {val}" + self.steps = {k: tuple(v) for k, v in self.steps.items()} # no defaultdict + + def _add_step_option(self, step, option): + if step not in self.steps[option]: + self.steps[option].append(step) + + def __call__(self, option: str) -> list[str]: + return self.steps[option] diff --git a/mne_bids_pipeline/_download.py b/mne_bids_pipeline/_download.py index bec670897..2a5308868 100644 --- a/mne_bids_pipeline/_download.py +++ b/mne_bids_pipeline/_download.py @@ -1,4 +1,5 @@ """Download test data.""" + import argparse from pathlib import Path @@ -9,46 +10,30 @@ DEFAULT_DATA_DIR = Path("~/mne_data").expanduser() -def _download_via_datalad(*, ds_name: str, ds_path: Path): - import datalad.api as dl - - print('datalad installing "{}"'.format(ds_name)) - git_url = DATASET_OPTIONS[ds_name]["git"] - dataset = dl.install(path=ds_path, source=git_url) - - # XXX: git-annex bug: - # https://github.com/datalad/datalad/issues/3583 - # if datalad fails, use "get" twice, or set `n_jobs=1` - if ds_name == "ds003104": - n_jobs = 16 - else: - n_jobs = 1 - - for to_get in DATASET_OPTIONS[ds_name]["include"]: - print('datalad get data "{}" for "{}"'.format(to_get, ds_name)) - dataset.get(to_get, jobs=n_jobs) - - def _download_via_openneuro(*, ds_name: str, ds_path: Path): import openneuro + options = DATASET_OPTIONS[ds_name] + assert "hash" not in options + openneuro.download( - dataset=DATASET_OPTIONS[ds_name]["openneuro"], + dataset=options["openneuro"], target_dir=ds_path, - include=DATASET_OPTIONS[ds_name]["include"], - exclude=DATASET_OPTIONS[ds_name]["exclude"], + include=options.get("include", []), + exclude=options.get("exclude", []), verify_size=False, ) def _download_from_web(*, ds_name: str, ds_path: Path): """Retrieve Zip archives from a web URL.""" - import cgi - import zipfile - import httpx - from tqdm import tqdm + import pooch - url = DATASET_OPTIONS[ds_name]["web"] + options = DATASET_OPTIONS[ds_name] + url = options["web"] + known_hash = options["hash"] + assert "exclude" not in options + assert "include" not in options if ds_path.exists(): print( "Dataset directory already exists; remove it if you wish to " @@ -57,76 +42,41 @@ def _download_from_web(*, ds_name: str, ds_path: Path): return ds_path.mkdir(parents=True, exist_ok=True) + path = ds_path.parent.resolve(strict=True) + fname = f"{ds_name}.zip" + pooch.retrieve( + url=url, + path=path, + fname=fname, + processor=pooch.Unzip(extract_dir="."), # relative to path + progressbar=True, + known_hash=known_hash, + ) + (path / f"{ds_name}.zip").unlink() + - with httpx.Client(follow_redirects=True) as client: - with client.stream("GET", url=url) as response: - if not response.is_error: - pass # All good! - else: - raise RuntimeError( - f"Error {response.status_code} when trying " f"to download {url}" - ) - - header = response.headers["content-disposition"] - _, params = cgi.parse_header(header) - # where to store the archive - outfile = ds_path / params["filename"] - remote_file_size = int(response.headers["content-length"]) - - with open(outfile, mode="wb") as f: - with tqdm( - desc=params["filename"], - initial=0, - total=remote_file_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - leave=False, - ) as progress: - num_bytes_downloaded = response.num_bytes_downloaded - - for chunk in response.iter_bytes(): - f.write(chunk) - progress.update( - response.num_bytes_downloaded - num_bytes_downloaded - ) - num_bytes_downloaded = response.num_bytes_downloaded - - assert outfile.suffix == ".zip" - - with zipfile.ZipFile(outfile) as zip: - for zip_info in zip.infolist(): - path_in_zip = Path(zip_info.filename) - # omit top-level directory from Zip archive - target_path = str(Path(*path_in_zip.parts[1:])) - if str(target_path) in (".", ".."): - continue - if zip_info.filename.endswith("/"): - (ds_path / target_path).mkdir(parents=True, exist_ok=True) - continue - zip_info.filename = target_path - print(f"Extracting: {target_path}") - zip.extract(zip_info, ds_path) - - outfile.unlink() +def _download_via_mne(*, ds_name: str, ds_path: Path): + assert ds_path.stem == ds_name, ds_path + getattr(mne.datasets, DATASET_OPTIONS[ds_name]["mne"]).data_path( + ds_path.parent, + verbose=True, + ) def _download(*, ds_name: str, ds_path: Path): - openneuro_name = DATASET_OPTIONS[ds_name]["openneuro"] - git_url = DATASET_OPTIONS[ds_name]["git"] - osf_node = DATASET_OPTIONS[ds_name]["osf"] - web_url = DATASET_OPTIONS[ds_name]["web"] + options = DATASET_OPTIONS[ds_name] + openneuro_name = options.get("openneuro", "") + web_url = options.get("web", "") + mne_mod = options.get("mne", "") + assert sum(bool(x) for x in (openneuro_name, web_url, mne_mod)) == 1 if openneuro_name: download_func = _download_via_openneuro - elif git_url: - download_func = _download_via_datalad - elif osf_node: - raise RuntimeError("OSF downloads are currently not supported.") - elif web_url: - download_func = _download_from_web + elif mne_mod: + download_func = _download_via_mne else: - raise ValueError("No download location was specified.") + assert web_url + download_func = _download_from_web download_func(ds_name=ds_name, ds_path=ds_path) diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index ce1964604..494101055 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -1,21 +1,22 @@ +from collections.abc import Iterable from types import SimpleNamespace -from typing import Dict, Optional, Iterable, Union, List, Literal +from typing import Literal import mne -from mne.utils import _pl -from mne_bids import BIDSPath, read_raw_bids, get_bids_path_from_fname import numpy as np import pandas as pd +from mne_bids import BIDSPath, get_bids_path_from_fname, read_raw_bids from ._config_utils import ( + _bids_kwargs, + _do_mf_autobad, + _pl, + get_datatype, get_mf_reference_run, get_runs, - get_datatype, get_task, - _bids_kwargs, - _do_mf_autobad, ) -from ._io import _read_json, _empty_room_match_path +from ._io import _read_json from ._logging import gen_log_kwargs, logger from ._run import _update_for_splits from .typing import PathLike @@ -25,17 +26,17 @@ def make_epochs( *, task: str, subject: str, - session: Optional[str], + session: str | None, raw: mne.io.BaseRaw, - event_id: Optional[Union[Dict[str, int], Literal["auto"]]], - conditions: Union[Iterable[str], Dict[str, str]], + event_id: dict[str, int] | Literal["auto"] | None, + conditions: Iterable[str] | dict[str, str], tmin: float, tmax: float, - metadata_tmin: Optional[float], - metadata_tmax: Optional[float], - metadata_keep_first: Optional[Iterable[str]], - metadata_keep_last: Optional[Iterable[str]], - metadata_query: Optional[str], + metadata_tmin: float | None, + metadata_tmax: float | None, + metadata_keep_first: Iterable[str] | None, + metadata_keep_last: Iterable[str] | None, + metadata_query: str | None, event_repeated: Literal["error", "drop", "merge"], epochs_decim: int, task_is_rest: bool, @@ -147,12 +148,12 @@ def make_epochs( return epochs -def annotations_to_events(*, raw_paths: List[PathLike]) -> Dict[str, int]: +def annotations_to_events(*, raw_paths: list[PathLike]) -> dict[str, int]: """Generate a unique event name -> event code mapping. The mapping can that can be used across all passed raws. """ - event_names: List[str] = [] + event_names: list[str] = [] for raw_fname in raw_paths: raw = mne.io.read_raw_fif(raw_fname) _, event_id = mne.events_from_annotations(raw=raw) @@ -172,8 +173,8 @@ def _rename_events_func( cfg: SimpleNamespace, raw: mne.io.BaseRaw, subject: str, - session: Optional[str], - run: Optional[str], + session: str | None, + run: str | None, ) -> None: """Rename events (actually, annotations descriptions) in ``raw``. @@ -225,7 +226,7 @@ def _load_data(cfg: SimpleNamespace, bids_path: BIDSPath) -> mne.io.BaseRaw: subject = bids_path.subject raw = read_raw_bids( bids_path=bids_path, - extra_params=cfg.reader_extra_params, + extra_params=cfg.reader_extra_params or {}, verbose=cfg.read_raw_bids_verbose, ) @@ -255,7 +256,7 @@ def _drop_channels_func( cfg: SimpleNamespace, raw: mne.io.BaseRaw, subject: str, - session: Optional[str], + session: str | None, ) -> None: """Drop channels from the data. @@ -271,8 +272,8 @@ def _create_bipolar_channels( cfg: SimpleNamespace, raw: mne.io.BaseRaw, subject: str, - session: Optional[str], - run: Optional[str], + session: str | None, + run: str | None, ) -> None: """Create a channel from a bipolar referencing scheme.. @@ -317,22 +318,18 @@ def _set_eeg_montage( cfg: SimpleNamespace, raw: mne.io.BaseRaw, subject: str, - session: Optional[str], - run: Optional[str], + session: str | None, + run: str | None, ) -> None: """Set an EEG template montage if requested. Modifies ``raw`` in-place. """ montage = cfg.eeg_template_montage - is_mne_montage = isinstance(montage, mne.channels.montage.DigMontage) - montage_name = "custom_montage" if is_mne_montage else montage if cfg.datatype == "eeg" and montage: msg = f"Setting EEG channel locations to template montage: " f"{montage}." logger.info(**gen_log_kwargs(message=msg)) - if not is_mne_montage: - montage = mne.channels.make_standard_montage(montage_name) - raw.set_montage(montage, match_case=False, on_missing="warn") + raw.set_montage(montage, match_case=False, match_alias=True) def _fix_stim_artifact_func(cfg: SimpleNamespace, raw: mne.io.BaseRaw) -> None: @@ -355,8 +352,8 @@ def import_experimental_data( *, cfg: SimpleNamespace, bids_path_in: BIDSPath, - bids_path_bads_in: Optional[BIDSPath], - data_is_rest: Optional[bool], + bids_path_bads_in: BIDSPath | None, + data_is_rest: bool | None, ) -> mne.io.BaseRaw: """Run the data import. @@ -402,6 +399,7 @@ def import_experimental_data( _fix_stim_artifact_func(cfg=cfg, raw=raw) if bids_path_bads_in is not None: + run = "rest" if data_is_rest else run # improve logging bads = _read_bads_tsv(cfg=cfg, bids_path_bads=bids_path_bads_in) msg = f"Marking {len(bads)} channel{_pl(bads)} as bad." logger.info(**gen_log_kwargs(message=msg)) @@ -415,9 +413,9 @@ def import_er_data( *, cfg: SimpleNamespace, bids_path_er_in: BIDSPath, - bids_path_ref_in: Optional[BIDSPath], - bids_path_er_bads_in: Optional[BIDSPath], - bids_path_ref_bads_in: Optional[BIDSPath], + bids_path_ref_in: BIDSPath | None, + bids_path_er_bads_in: BIDSPath | None, + bids_path_ref_bads_in: BIDSPath | None, prepare_maxwell_filter: bool, ) -> mne.io.BaseRaw: """Import empty-room data. @@ -434,6 +432,8 @@ def import_er_data( The BIDS path to the empty room bad channels file. bids_path_ref_bads_in The BIDS path to the reference data bad channels file. + prepare_maxwell_filter + Whether to prepare the empty-room data for Maxwell filtering. Returns ------- @@ -449,7 +449,6 @@ def import_er_data( cfg=cfg, bids_path_bads=bids_path_er_bads_in, ) - raw_er.pick_types(meg=True, exclude=[]) # Don't deal with ref for now (initial data quality / auto bad step) if bids_path_ref_in is None: @@ -458,7 +457,7 @@ def import_er_data( # Load reference run plus its auto-bads raw_ref = read_raw_bids( bids_path_ref_in, - extra_params=cfg.reader_extra_params, + extra_params=cfg.reader_extra_params or {}, verbose=cfg.read_raw_bids_verbose, ) if bids_path_ref_bads_in is not None: @@ -492,8 +491,8 @@ def _find_breaks_func( cfg, raw: mne.io.BaseRaw, subject: str, - session: Optional[str], - run: Optional[str], + session: str | None, + run: str | None, ) -> None: if not cfg.find_breaks: return @@ -524,10 +523,10 @@ def _get_bids_path_in( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - kind: Literal["orig", "sss"] = "orig", + session: str | None, + run: str | None, + task: str | None, + kind: Literal["orig", "sss", "filt"] = "orig", ) -> BIDSPath: # b/c can be used before this is updated path_kwargs = dict( @@ -541,13 +540,13 @@ def _get_bids_path_in( datatype=get_datatype(config=cfg), check=False, ) - if kind == "sss": + if kind != "orig": + assert kind in ("sss", "filt"), kind path_kwargs["root"] = cfg.deriv_root path_kwargs["suffix"] = "raw" path_kwargs["extension"] = ".fif" - path_kwargs["processing"] = "sss" + path_kwargs["processing"] = kind else: - assert kind == "orig", kind path_kwargs["root"] = cfg.bids_root path_kwargs["suffix"] = None path_kwargs["extension"] = None @@ -560,13 +559,13 @@ def _get_run_path( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - kind: Literal["orig", "sss"], - add_bads: Optional[bool] = None, + session: str | None, + run: str | None, + task: str | None, + kind: Literal["orig", "sss", "filt"], + add_bads: bool | None = None, allow_missing: bool = False, - key: Optional[str] = None, + key: str | None = None, ) -> dict: bids_path_in = _get_bids_path_in( cfg=cfg, @@ -583,6 +582,8 @@ def _get_run_path( add_bads=add_bads, kind=kind, allow_missing=allow_missing, + subject=subject, + session=session, ) @@ -590,9 +591,9 @@ def _get_rest_path( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - kind: Literal["orig", "sss"], - add_bads: Optional[bool] = None, + session: str | None, + kind: Literal["orig", "sss", "filt"], + add_bads: bool | None = None, ) -> dict: if not (cfg.process_rest and not cfg.task_is_rest): return dict() @@ -612,14 +613,15 @@ def _get_noise_path( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - kind: Literal["orig", "sss"], - mf_reference_run: Optional[str], - add_bads: Optional[bool] = None, + session: str | None, + kind: Literal["orig", "sss", "filt"], + mf_reference_run: str | None, + add_bads: bool | None = None, ) -> dict: if not (cfg.process_empty_room and get_datatype(config=cfg) == "meg"): return dict() - if kind == "sss": + if kind != "orig": + assert kind in ("sss", "filt") raw_fname = _get_bids_path_in( cfg=cfg, subject=subject, @@ -648,6 +650,8 @@ def _get_noise_path( add_bads=add_bads, kind=kind, allow_missing=True, + subject=subject, + session=session, ) @@ -655,12 +659,12 @@ def _get_run_rest_noise_path( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - kind: Literal["orig", "sss"], - mf_reference_run: Optional[str], - add_bads: Optional[bool] = None, + session: str | None, + run: str | None, + task: str | None, + kind: Literal["orig", "sss", "filt"], + mf_reference_run: str | None, + add_bads: bool | None = None, ) -> dict: kwargs = dict( cfg=cfg, @@ -680,10 +684,11 @@ def _get_run_rest_noise_path( def _get_mf_reference_run_path( + *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - add_bads: bool, + session: str | None, + add_bads: bool | None = None, ) -> dict: return _get_run_path( cfg=cfg, @@ -697,14 +702,22 @@ def _get_mf_reference_run_path( ) +def _empty_room_match_path(run_path: BIDSPath, cfg: SimpleNamespace) -> BIDSPath: + return run_path.copy().update( + extension=".json", suffix="emptyroommatch", root=cfg.deriv_root + ) + + def _path_dict( *, cfg: SimpleNamespace, bids_path_in: BIDSPath, - add_bads: Optional[bool] = None, - kind: Literal["orig", "sss"], + add_bads: bool | None = None, + kind: Literal["orig", "sss", "filt"], allow_missing: bool, - key: Optional[str] = None, + key: str | None = None, + subject: str, + session: str | None, ) -> dict: if add_bads is None: add_bads = kind == "orig" and _do_mf_autobad(cfg=cfg) @@ -715,35 +728,30 @@ def _path_dict( if allow_missing and not in_files[key].fpath.exists(): return dict() if add_bads: - bads_tsv_fname = _bads_path(cfg=cfg, bids_path_in=bids_path_in) + bads_tsv_fname = _bads_path( + cfg=cfg, + bids_path_in=bids_path_in, + subject=subject, + session=session, + ) if bads_tsv_fname.fpath.is_file() or not allow_missing: in_files[f"{key}-bads"] = bads_tsv_fname return in_files -def _auto_scores_path( - *, - cfg: SimpleNamespace, - bids_path_in: BIDSPath, -) -> BIDSPath: - return bids_path_in.copy().update( - suffix="scores", - extension=".json", - root=cfg.deriv_root, - split=None, - check=False, - ) - - def _bads_path( *, cfg: SimpleNamespace, bids_path_in: BIDSPath, + subject: str, + session: str | None, ) -> BIDSPath: return bids_path_in.copy().update( suffix="bads", extension=".tsv", root=cfg.deriv_root, + subject=subject, + session=session, split=None, check=False, ) @@ -753,7 +761,7 @@ def _read_bads_tsv( *, cfg: SimpleNamespace, bids_path_bads: BIDSPath, -) -> List[str]: +) -> list[str]: bads_tsv = pd.read_csv(bids_path_bads.fpath, sep="\t", header=0) return bads_tsv[bads_tsv.columns[0]].tolist() @@ -802,3 +810,15 @@ def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict: runs=get_runs(config=config, subject=subject), # XXX needs to accept session! **_bids_kwargs(config=config), ) + + +def _read_raw_msg( + bids_path_in: BIDSPath, + run: str | None, + task: str | None, +) -> tuple[str]: + if run is None and task in ("noise", "rest"): + run_type = dict(rest="resting-state", noise="empty-room")[task] + else: + run_type = "experimental" + return f"Reading {run_type} recording: {bids_path_in.basename}", run_type diff --git a/mne_bids_pipeline/_io.py b/mne_bids_pipeline/_io.py index 0b7485f76..dc894cb6b 100644 --- a/mne_bids_pipeline/_io.py +++ b/mne_bids_pipeline/_io.py @@ -1,9 +1,6 @@ """I/O helpers.""" -from types import SimpleNamespace - import json_tricks -from mne_bids import BIDSPath from .typing import PathLike @@ -14,11 +11,5 @@ def _write_json(fname: PathLike, data: dict) -> None: def _read_json(fname: PathLike) -> dict: - with open(fname, "r", encoding="utf-8") as f: + with open(fname, encoding="utf-8") as f: return json_tricks.load(f) - - -def _empty_room_match_path(run_path: BIDSPath, cfg: SimpleNamespace) -> BIDSPath: - return run_path.copy().update( - extension=".json", suffix="emptyroommatch", root=cfg.deriv_root - ) diff --git a/mne_bids_pipeline/_logging.py b/mne_bids_pipeline/_logging.py index fc6085d6a..4d561a488 100644 --- a/mne_bids_pipeline/_logging.py +++ b/mne_bids_pipeline/_logging.py @@ -1,9 +1,9 @@ """Logging.""" + import datetime import inspect import logging import os -from typing import Optional, Union import rich.console import rich.theme @@ -27,25 +27,39 @@ def _console(self): force_terminal = os.getenv("MNE_BIDS_PIPELINE_FORCE_TERMINAL", None) if force_terminal is not None: force_terminal = force_terminal.lower() in ("true", "1") - kwargs = dict(soft_wrap=True, force_terminal=force_terminal) + legacy_windows = os.getenv("MNE_BIDS_PIPELINE_LEGACY_WINDOWS", None) + if legacy_windows is not None: + legacy_windows = legacy_windows.lower() in ("true", "1") + kwargs = dict( + soft_wrap=True, + force_terminal=force_terminal, + legacy_windows=legacy_windows, + ) kwargs["theme"] = rich.theme.Theme( dict( default="white", + # Rule + title="bold green", # Prefixes asctime="green", - step="bold cyan", + prefix="bold cyan", # Messages debug="dim", - info="bold", - warning="bold magenta", - error="bold red", + info="", + warning="magenta", + error="red", ) ) self.__console = rich.console.Console(**kwargs) return self.__console - def rule(self, title="", *, align="center"): - self.__console.rule(title=title, characters="─", style="rule.line", align=align) + def title(self, title): + # Align left with ASCTIME offset + title = f"[title]┌────────┬ {title}[/]" + self._console.rule(title=title, characters="─", style="title", align="left") + + def end(self, msg=""): + self._console.print(f"[title]└────────┴ {msg}[/]") @property def level(self): @@ -56,48 +70,37 @@ def level(self, level): level = int(level) self._level = level - def debug(self, msg: str, *, extra: Optional[LogKwargsT] = None) -> None: + def debug(self, msg: str, *, extra: LogKwargsT | None = None) -> None: self._log_message(kind="debug", msg=msg, **(extra or {})) - def info(self, msg: str, *, extra: Optional[LogKwargsT] = None) -> None: + def info(self, msg: str, *, extra: LogKwargsT | None = None) -> None: self._log_message(kind="info", msg=msg, **(extra or {})) - def warning(self, msg: str, *, extra: Optional[LogKwargsT] = None) -> None: + def warning(self, msg: str, *, extra: LogKwargsT | None = None) -> None: self._log_message(kind="warning", msg=msg, **(extra or {})) - def error(self, msg: str, *, extra: Optional[LogKwargsT] = None) -> None: + def error(self, msg: str, *, extra: LogKwargsT | None = None) -> None: self._log_message(kind="error", msg=msg, **(extra or {})) def _log_message( self, kind: str, msg: str, - subject: Optional[Union[str, int]] = None, - session: Optional[Union[str, int]] = None, - run: Optional[Union[str, int]] = None, - step: Optional[str] = None, + subject: str | int | None = None, + session: str | int | None = None, + run: str | int | None = None, emoji: str = "", - box: str = "", ): this_level = getattr(logging, kind.upper()) if this_level < self.level: return - if not subject: - subject = "" - if not session: - session = "" - if not run: - run = "" - if not step: - step = "" - if step and emoji: - step = f"{emoji} {step}" - asctime = datetime.datetime.now().strftime("[%H:%M:%S]") - msg = ( - f"[asctime]{asctime}[/asctime] " - f"[step]{box}{step}{subject}{session}{run}[/step]" - f"[{kind}]{msg}[/{kind}]" - ) + # Construct str + essr = [x for x in [emoji, subject, session, run] if x] + essr = " ".join(essr) + if essr: + essr += " " + asctime = datetime.datetime.now().strftime("│%H:%M:%S│") + msg = f"[asctime]{asctime} [/][prefix]{essr}[/][{kind}]{msg}[/]" self._console.print(msg) @@ -107,17 +110,14 @@ def _log_message( def gen_log_kwargs( message: str, *, - subject: Optional[Union[str, int]] = None, - session: Optional[Union[str, int]] = None, - run: Optional[Union[str, int]] = None, - task: Optional[str] = None, - step: Optional[str] = None, + subject: str | int | None = None, + session: str | int | None = None, + run: str | int | None = None, + task: str | None = None, emoji: str = "⏳️", - box: str = "│ ", ) -> LogKwargsT: - from ._run import _get_step_path, _short_step_path - # Try to figure these out + assert isinstance(message, str), type(message) stack = inspect.stack() up_locals = stack[1].frame.f_locals if subject is None: @@ -130,23 +130,14 @@ def gen_log_kwargs( task = task or up_locals.get("task", None) if task in ("noise", "rest"): run = task - if step is None: - step_path = _get_step_path(stack) - if step_path: - step = _short_step_path(_get_step_path()) - else: - step = "" # Do some nice formatting if subject is not None: - subject = f" sub-{subject}" + subject = f"sub-{subject}" if session is not None: - session = f" ses-{session}" + session = f"ses-{session}" if run is not None: - run = f" run-{run}" - if step != "": - # need an extra space - message = f" {message}" + run = f"run-{run}" # Choose some to be our standards emoji = dict( @@ -154,10 +145,7 @@ def gen_log_kwargs( skip="⏩", override="❌", ).get(emoji, emoji) - extra = { - "step": f"{emoji} {step}", - "box": box, - } + extra = {"emoji": emoji} if subject: extra["subject"] = subject if session: @@ -170,3 +158,11 @@ def gen_log_kwargs( "extra": extra, } return kwargs + + +def _linkfile(uri): + return f"[link=file://{uri}]{uri}[/link]" + + +def _is_testing() -> bool: + return os.getenv("_MNE_BIDS_STUDY_TESTING", "") == "true" diff --git a/mne_bids_pipeline/_main.py b/mne_bids_pipeline/_main.py index 7bfbb392a..634018a13 100755 --- a/mne_bids_pipeline/_main.py +++ b/mne_bids_pipeline/_main.py @@ -1,16 +1,16 @@ import argparse import pathlib -from textwrap import dedent import time -from typing import List +from textwrap import dedent from types import ModuleType, SimpleNamespace import numpy as np -from ._config_utils import _get_step_modules from ._config_import import _import_config from ._config_template import create_template_config -from ._logging import logger, gen_log_kwargs +from ._config_utils import _get_step_modules +from ._logging import gen_log_kwargs, logger +from ._parallel import get_parallel_backend from ._run import _short_step_path @@ -36,7 +36,7 @@ def main(): metavar="FILE", help="Create a template configuration file with the specified name. " "If specified, all other parameters will be ignored.", - ), + ) parser.add_argument( "--steps", dest="steps", @@ -58,6 +58,18 @@ def main(): default=None, help="BIDS root directory of the data to process.", ) + parser.add_argument( + "--deriv_root", + dest="deriv_root", + default=None, + help=dedent( + """\ + The root of the derivatives directory + in which the pipeline will store the processing results. + If unspecified, this will be derivatives/mne-bids-pipeline + inside the BIDS root.""" + ), + ) parser.add_argument( "--subject", dest="subject", default=None, help="The subject to process." ) @@ -82,7 +94,11 @@ def main(): help="Enable interactive mode.", ) parser.add_argument( - "--debug", dest="debug", action="store_true", help="Enable debugging on error." + "--debug", + "--pdb", + dest="debug", + action="store_true", + help="Enable debugging on error.", ) parser.add_argument( "--no-cache", @@ -114,6 +130,7 @@ def main(): ) steps = options.steps root_dir = options.root_dir + deriv_root = options.deriv_root subject, session = options.subject, options.session task, run = options.task, options.run n_jobs = options.n_jobs @@ -128,7 +145,6 @@ def main(): steps = (steps,) on_error = "debug" if debug else None - cache = "1" if cache else "0" processing_stages = [] processing_steps = [] @@ -147,6 +163,10 @@ def main(): overrides = SimpleNamespace() if root_dir: overrides.bids_root = pathlib.Path(root_dir).expanduser().resolve(strict=True) + if deriv_root: + overrides.deriv_root = ( + pathlib.Path(deriv_root).expanduser().resolve(strict=False) + ) if subject: overrides.subjects = [subject] if session: @@ -164,7 +184,7 @@ def main(): if not cache: overrides.memory_location = False - step_modules: List[ModuleType] = [] + step_modules: list[ModuleType] = [] STEP_MODULES = _get_step_modules() for stage, step in zip(processing_stages, processing_steps): if stage not in STEP_MODULES.keys(): @@ -193,22 +213,24 @@ def main(): # them twice. step_modules = [*STEP_MODULES["init"], *step_modules] - msg = "Welcome aboard the MNE BIDS Pipeline!" - logger.info(**gen_log_kwargs(message=msg, emoji="👋", box="╶╴", step="")) + logger.title("Welcome aboard MNE-BIDS-Pipeline! 👋") msg = f"Using configuration: {config}" - logger.info(**gen_log_kwargs(message=msg, emoji="🧾", box="╶╴", step="")) - + __mne_bids_pipeline_step__ = pathlib.Path(__file__) # used for logging + logger.info(**gen_log_kwargs(message=msg, emoji="📝")) config_imported = _import_config( config_path=config_path, overrides=overrides, ) - for si, step_module in enumerate(step_modules): + # Initialize dask now + with get_parallel_backend(config_imported.exec_params): + pass + del __mne_bids_pipeline_step__ + logger.end() + + for step_module in step_modules: start = time.time() step = _short_step_path(pathlib.Path(step_module.__file__)) - if si == 0: - logger.rule() - msg = "Now running 👇" - logger.info(**gen_log_kwargs(message=msg, box="┌╴", emoji="🚀", step=step)) + logger.title(title=f"{step}") step_module.main(config=config_imported) elapsed = time.time() - start hours, remainder = divmod(elapsed, 3600) @@ -221,6 +243,4 @@ def main(): elapsed = f"{minutes}m {elapsed}" if hours: elapsed = f"{hours}h {elapsed}" - msg = f"Done running 👆 [{elapsed}]" - logger.info(**gen_log_kwargs(message=msg, box="└╴", emoji="🎉", step=step)) - logger.rule() + logger.end(f"done ({elapsed})") diff --git a/mne_bids_pipeline/_parallel.py b/mne_bids_pipeline/_parallel.py index c2f9430ae..acee195c0 100644 --- a/mne_bids_pipeline/_parallel.py +++ b/mne_bids_pipeline/_parallel.py @@ -1,19 +1,32 @@ """Parallelization.""" -from typing import Literal, Callable +from collections.abc import Callable from types import SimpleNamespace +from typing import Literal import joblib +from mne.utils import logger as mne_logger +from mne.utils import use_log_level -from ._logging import logger +from ._logging import _is_testing, gen_log_kwargs, logger -def get_n_jobs(*, exec_params: SimpleNamespace) -> int: +def get_n_jobs(*, exec_params: SimpleNamespace, log_override: bool = False) -> int: n_jobs = exec_params.n_jobs if n_jobs < 0: n_cores = joblib.cpu_count() n_jobs = min(n_cores + n_jobs + 1, n_cores) + # Shim to allow overriding n_jobs for specific steps + if _is_testing() and hasattr(exec_params, "_n_jobs"): + from ._run import _get_step_path, _short_step_path + + step_path = _short_step_path(_get_step_path()) + orig_n_jobs = n_jobs + n_jobs = exec_params._n_jobs.get(step_path, n_jobs) + if log_override and n_jobs != orig_n_jobs: + msg = f"Overriding n_jobs: {orig_n_jobs}→{n_jobs}" + logger.info(**gen_log_kwargs(message=msg, emoji="override")) return n_jobs @@ -30,14 +43,16 @@ def setup_dask_client(*, exec_params: SimpleNamespace) -> None: return n_workers = get_n_jobs(exec_params=exec_params) - logger.info(f"👾 Initializing Dask client with {n_workers} workers …") + msg = f"Dask initializing with {n_workers} workers …" + logger.info(**gen_log_kwargs(message=msg, emoji="👾")) if exec_params.dask_temp_dir is None: this_dask_temp_dir = exec_params.deriv_root / ".dask-worker-space" else: this_dask_temp_dir = exec_params.dask_temp_dir - logger.info(f"📂 Temporary directory is: {this_dask_temp_dir}") + msg = f"Dask temporary directory: {this_dask_temp_dir}" + logger.info(**gen_log_kwargs(message=msg, emoji="📂")) dask.config.set( { "temporary-directory": this_dask_temp_dir, @@ -61,10 +76,8 @@ def setup_dask_client(*, exec_params: SimpleNamespace) -> None: client.auto_restart = False # don't restart killed workers dashboard_url = client.dashboard_link - logger.info( - f"⏱ The Dask client is ready. Open {dashboard_url} " - f"to monitor the workers.\n" - ) + msg = "Dask client dashboard: " f"[link={dashboard_url}]{dashboard_url}[/link]" + logger.info(**gen_log_kwargs(message=msg, emoji="🌎")) if exec_params.dask_open_dashboard: import webbrowser @@ -76,29 +89,37 @@ def setup_dask_client(*, exec_params: SimpleNamespace) -> None: def get_parallel_backend_name( - *, exec_params: SimpleNamespace + *, + exec_params: SimpleNamespace, ) -> Literal["dask", "loky"]: if ( exec_params.parallel_backend == "loky" or get_n_jobs(exec_params=exec_params) == 1 ): - return "loky" + backend = "loky" elif exec_params.parallel_backend == "dask": # Disable interactive plotting backend import matplotlib matplotlib.use("Agg") - return "dask" + backend = "dask" else: # TODO: Move to value validation step raise ValueError(f"Unknown parallel backend: {exec_params.parallel_backend}") + return backend + def get_parallel_backend(exec_params: SimpleNamespace) -> joblib.parallel_backend: import joblib backend = get_parallel_backend_name(exec_params=exec_params) - kwargs = {"n_jobs": get_n_jobs(exec_params=exec_params)} + kwargs = { + "n_jobs": get_n_jobs( + exec_params=exec_params, + log_override=True, + ) + } if backend == "loky": kwargs["inner_max_num_threads"] = 1 @@ -109,19 +130,21 @@ def get_parallel_backend(exec_params: SimpleNamespace) -> joblib.parallel_backen def parallel_func(func: Callable, *, exec_params: SimpleNamespace): - if get_parallel_backend_name(exec_params=exec_params) == "loky": - if get_n_jobs(exec_params=exec_params) == 1: - my_func = func - parallel = list - else: - from joblib import Parallel, delayed - - parallel = Parallel() - my_func = delayed(func) - else: # Dask + if ( + get_parallel_backend_name(exec_params=exec_params) == "loky" + and get_n_jobs(exec_params=exec_params) == 1 + ): + my_func = func + parallel = list + else: # Dask or n_jobs > 1 from joblib import Parallel, delayed parallel = Parallel() - my_func = delayed(func) + + def run_verbose(*args, verbose=mne_logger.level, **kwargs): + with use_log_level(verbose=verbose): + return func(*args, **kwargs) + + my_func = delayed(run_verbose) return parallel, my_func diff --git a/mne_bids_pipeline/_reject.py b/mne_bids_pipeline/_reject.py index 5b3729dc2..3837daa97 100644 --- a/mne_bids_pipeline/_reject.py +++ b/mne_bids_pipeline/_reject.py @@ -1,21 +1,22 @@ """Rejection.""" -from typing import Optional, Union, Iterable, Dict, Literal +from collections.abc import Iterable +from typing import Literal import mne -from ._logging import logger, gen_log_kwargs +from ._logging import gen_log_kwargs, logger def _get_reject( *, subject: str, - session: Optional[str], - reject: Union[Dict[str, float], Literal["autoreject_global"]], + session: str | None, + reject: dict[str, float] | Literal["autoreject_global"], ch_types: Iterable[Literal["meg", "mag", "grad", "eeg"]], param: str, - epochs: Optional[mne.BaseEpochs] = None, -) -> Dict[str, float]: + epochs: mne.BaseEpochs | None = None, +) -> dict[str, float]: if reject is None: return dict() @@ -44,11 +45,11 @@ def _get_reject( # Only keep thresholds for channel types of interest reject = reject.copy() - if ch_types == ["eeg"]: - ch_types_to_remove = ("mag", "grad") - else: - ch_types_to_remove = ("eeg",) - + ch_types_to_remove = list() + if "meg" not in ch_types: + ch_types_to_remove.extend(("mag", "grad")) + if "eeg" not in ch_types: + ch_types_to_remove.append("eeg") for ch_type in ch_types_to_remove: try: del reject[ch_type] diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index eefbca167..fbab290d0 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -1,26 +1,26 @@ import contextlib +import re +import traceback from functools import lru_cache from io import StringIO -import os.path as op -from pathlib import Path -from typing import Optional, List, Literal +from textwrap import indent from types import SimpleNamespace +from typing import Literal -from filelock import FileLock import matplotlib.transforms +import mne import numpy as np import pandas as pd -from scipy.io import loadmat - -import mne +from filelock import FileLock from mne.io import BaseRaw from mne.utils import _pl from mne_bids import BIDSPath from mne_bids.stats import count_events +from scipy.io import loadmat -from ._config_utils import sanitize_cond_name, get_subjects, _restrict_analyze_channels +from ._config_utils import get_all_contrasts from ._decoding import _handle_csp_args -from ._logging import logger, gen_log_kwargs +from ._logging import _linkfile, gen_log_kwargs, logger @contextlib.contextmanager @@ -29,30 +29,35 @@ def _open_report( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str] = None, - task: Optional[str] = None, + session: str | None, + run: str | None = None, + task: str | None = None, + fname_report: BIDSPath | None = None, + name: str = "report", ): - fname_report = BIDSPath( - subject=subject, - session=session, - # Report is across all runs, but for logging purposes it's helpful - # to pass the run and task for gen_log_kwargs - run=None, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - extension=".h5", - datatype=cfg.datatype, - root=cfg.deriv_root, - suffix="report", - check=False, - ).fpath + if fname_report is None: + fname_report = BIDSPath( + subject=subject, + session=session, + # Report is across all runs, but for logging purposes it's helpful + # to pass the run and task for gen_log_kwargs + run=None, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + extension=".h5", + datatype=cfg.datatype, + root=cfg.deriv_root, + suffix="report", + check=False, + ) + fname_report = fname_report.fpath + assert fname_report.suffix == ".h5", fname_report.suffix # prevent parallel file access with FileLock(f"{fname_report}.lock"), _agg_backend(): if not fname_report.is_file(): - msg = "Initializing report HDF5 file" + msg = f"Initializing {name} HDF5 file" logger.info(**gen_log_kwargs(message=msg)) report = _gen_empty_report( cfg=cfg, @@ -64,26 +69,26 @@ def _open_report( report = mne.open_report(fname_report) except Exception as exc: raise exc.__class__( - f"Could not open report HDF5 file:\n{fname_report}\n" - f"Got error:\n{exc}\nPerhaps you need to delete it?" + f"Could not open {name} HDF5 file:\n{fname_report}, " + "Perhaps you need to delete it? Got error:\n\n" + f'{indent(traceback.format_exc(), " ")}' ) from None try: yield report finally: try: - msg = "Adding config and sys info to report" - logger.info(**gen_log_kwargs(message=msg)) _finalize( report=report, exec_params=exec_params, subject=subject, session=session, run=run, + task=task, ) except Exception as exc: logger.warning(f"Failed: {exc}") fname_report_html = fname_report.with_suffix(".html") - msg = f"Saving report: {fname_report_html}" + msg = f"Saving {name}: {_linkfile(fname_report_html)}" logger.info(**gen_log_kwargs(message=msg)) report.save(fname_report, overwrite=True) report.save(fname_report_html, overwrite=True, open_browser=False) @@ -125,8 +130,8 @@ def _open_report( def _plot_full_epochs_decoding_scores( - contrast_names: List[str], - scores: List[np.ndarray], + contrast_names: list[str], + scores: list[np.ndarray], metric: str, kind: Literal["single-subject", "grand-average"] = "single-subject", ): @@ -204,7 +209,7 @@ def _plot_mean_cv_score(x, **kwargs): g.set_xlabels("") fig = g.fig - return fig, caption + return fig, caption, data def _plot_time_by_time_decoding_scores( @@ -361,8 +366,8 @@ def plot_time_by_time_decoding_t_values(decoding_data): # We squeeze() to make Matplotlib happy. all_times = decoding_data["cluster_all_times"].squeeze() all_t_values = decoding_data["cluster_all_t_values"].squeeze() - t_threshold = decoding_data["cluster_t_threshold"] - decim = decoding_data["decim"] + t_threshold = decoding_data["cluster_t_threshold"].item() + decim = decoding_data["decim"].item() fig, ax = plt.subplots(constrained_layout=True) ax.plot(all_times, all_t_values, ls="-", color="black", label="observed $t$-values") @@ -448,7 +453,7 @@ def _plot_decoding_time_generalization( def _gen_empty_report( - *, cfg: SimpleNamespace, subject: str, session: Optional[str] + *, cfg: SimpleNamespace, subject: str, session: str | None ) -> mne.Report: title = f"sub-{subject}" if session is not None: @@ -456,49 +461,33 @@ def _gen_empty_report( if cfg.task is not None: title += f", task-{cfg.task}" - report = mne.Report(title=title, raw_psd=True) + report = mne.Report(title=title, raw_psd=True, verbose=False) return report -def _contrasts_to_names(contrasts: List[List[str]]) -> List[str]: +def _contrasts_to_names(contrasts: list[list[str]]) -> list[str]: return [f"{c[0]} vs.\n{c[1]}" for c in contrasts] def add_event_counts( - *, cfg, subject: Optional[str], session: Optional[str], report: mne.Report + *, cfg, subject: str | None, session: str | None, report: mne.Report ) -> None: try: df_events = count_events(BIDSPath(root=cfg.bids_root, session=session)) except ValueError: msg = "Could not read events." logger.warning(**gen_log_kwargs(message=msg)) - df_events = None + return + logger.info(**gen_log_kwargs(message="Adding event counts to report …")) if df_events is not None: - css_classes = ("table", "table-striped", "table-borderless", "table-hover") + df_events.reset_index(drop=False, inplace=True, col_level=1) report.add_html( - f'
\n' - f"{df_events.to_html(classes=css_classes, border=0)}\n" - f"
", + _df_bootstrap_table(df=df_events, data_id="events"), title="Event counts", tags=("events",), replace=True, ) - css = ( - ".event-counts {\n" - " display: -webkit-box;\n" - " display: -ms-flexbox;\n" - " display: -webkit-flex;\n" - " display: flex;\n" - " justify-content: center;\n" - " text-align: center;\n" - "}\n\n" - "th, td {\n" - " text-align: center;\n" - "}\n" - ) - if css not in report.include: - report.add_custom_css(css=css) def _finalize( @@ -506,14 +495,19 @@ def _finalize( report: mne.Report, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], + session: str | None, + run: str | None, + task: str | None, ) -> None: """Add system information and the pipeline configuration to the report.""" # ensure they are always appended titles = ["Configuration file", "System information"] for title in titles: report.remove(title=title, remove_all=True) + # Print this exactly once + if _cached_sys_info.cache_info()[-1] == 0: # never run + msg = "Adding config and sys info to report" + logger.info(**gen_log_kwargs(message=msg)) # No longer need replace=True in these report.add_code( code=exec_params.config_path, @@ -549,330 +543,13 @@ def _all_conditions(*, cfg): conditions = list(cfg.conditions.keys()) else: conditions = cfg.conditions.copy() - conditions.extend([contrast["name"] for contrast in cfg.all_contrasts]) + all_contrasts = get_all_contrasts(cfg) + conditions.extend([contrast["name"] for contrast in all_contrasts]) return conditions -def run_report_average_sensor( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], -) -> None: - msg = "Generating grand average report …" - logger.info(**gen_log_kwargs(message=msg)) - assert matplotlib.get_backend() == "agg", matplotlib.get_backend() - - evoked_fname = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix="ave", - extension=".fif", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - - title = f"sub-{subject}" - if session is not None: - title += f", ses-{session}" - if cfg.task is not None: - title += f", task-{cfg.task}" - - all_evokeds = mne.read_evokeds(evoked_fname) - for evoked in all_evokeds: - _restrict_analyze_channels(evoked, cfg) - conditions = _all_conditions(cfg=cfg) - assert len(conditions) == len(all_evokeds) - all_evokeds = {cond: evoked for cond, evoked in zip(conditions, all_evokeds)} - - with _open_report( - cfg=cfg, exec_params=exec_params, subject=subject, session=session - ) as report: - ####################################################################### - # - # Add event stats. - # - add_event_counts( - cfg=cfg, - report=report, - subject=subject, - session=session, - ) - - ####################################################################### - # - # Visualize evoked responses. - # - if all_evokeds: - msg = ( - f"Adding {len(all_evokeds)} evoked signals and contrasts to " - "the report." - ) - else: - msg = "No evoked conditions or contrasts found." - logger.info(**gen_log_kwargs(message=msg)) - for condition, evoked in all_evokeds.items(): - tags = ("evoked", _sanitize_cond_tag(condition)) - if condition in cfg.conditions: - title = f"Condition: {condition}" - else: # It's a contrast of two conditions. - title = f"Contrast: {condition}" - tags = tags + ("contrast",) - - report.add_evokeds( - evokeds=evoked, - titles=title, - projs=False, - tags=tags, - n_time_points=cfg.report_evoked_n_time_points, - # captions=evoked.comment, # TODO upstream - replace=True, - n_jobs=1, # don't auto parallelize - ) - - ####################################################################### - # - # Visualize decoding results. - # - if cfg.decode and cfg.decoding_contrasts: - msg = "Adding decoding results." - logger.info(**gen_log_kwargs(message=msg)) - add_decoding_grand_average(session=session, cfg=cfg, report=report) - - if cfg.decode and cfg.decoding_csp: - # No need for a separate message here because these are very quick - # and the general message above is sufficient - add_csp_grand_average(session=session, cfg=cfg, report=report) - - -def run_report_average_source( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], -) -> None: - ####################################################################### - # - # Visualize forward solution, inverse operator, and inverse solutions. - # - evoked_fname = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix="ave", - extension=".fif", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - evokeds = mne.read_evokeds(evoked_fname) - method = cfg.inverse_method - inverse_str = method - hemi_str = "hemi" # MNE will auto-append '-lh' and '-rh'. - morph_str = "morph2fsaverage" - conditions = _all_conditions(cfg=cfg) - with _open_report( - cfg=cfg, exec_params=exec_params, subject=subject, session=session - ) as report: - for condition, evoked in zip(conditions, evokeds): - tags = ( - "source-estimate", - _sanitize_cond_tag(condition), - ) - if condition in cfg.conditions: - title = f"Average: {condition}" - else: # It's a contrast of two conditions. - title = f"Average contrast: {condition}" - tags = tags + ("contrast",) - cond_str = sanitize_cond_name(condition) - fname_stc_avg = evoked_fname.copy().update( - suffix=f"{cond_str}+{inverse_str}+{morph_str}+{hemi_str}", - extension=None, - ) - if not Path(f"{fname_stc_avg.fpath}-lh.stc").exists(): - continue - report.add_stc( - stc=fname_stc_avg, - title=title, - subject="fsaverage", - subjects_dir=cfg.fs_subjects_dir, - n_time_points=cfg.report_stc_n_time_points, - tags=tags, - replace=True, - ) - - -def add_decoding_grand_average( - *, - session: Optional[str], - cfg: SimpleNamespace, - report: mne.Report, -): - """Add decoding results to the grand average report.""" - import matplotlib.pyplot as plt # nested import to help joblib - - bids_path = BIDSPath( - subject="average", - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix="ave", - extension=".fif", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - - # Full-epochs decoding - all_decoding_scores = [] - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f"{cond_1}+{cond_2}".replace(op.sep, "") - processing = f"{a_vs_b}+FullEpochs+{cfg.decoding_metric}" - processing = processing.replace("_", "-").replace("-", "") - fname_decoding = bids_path.copy().update( - processing=processing, suffix="decoding", extension=".mat" - ) - decoding_data = loadmat(fname_decoding) - all_decoding_scores.append(np.atleast_1d(decoding_data["scores"].squeeze())) - del fname_decoding, processing, a_vs_b, decoding_data - - fig, caption = _plot_full_epochs_decoding_scores( - contrast_names=_contrasts_to_names(cfg.decoding_contrasts), - scores=all_decoding_scores, - metric=cfg.decoding_metric, - kind="grand-average", - ) - title = f"Full-epochs decoding: {cond_1} vs. {cond_2}" - report.add_figure( - fig=fig, - title=title, - section="Decoding: full-epochs", - caption=caption, - tags=( - "epochs", - "contrast", - "decoding", - *[ - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}" - for cond_1, cond_2 in cfg.decoding_contrasts - ], - ), - replace=True, - ) - # close figure to save memory - plt.close(fig) - del fig, caption, title - - # Time-by-time decoding - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f"{cond_1}+{cond_2}".replace(op.sep, "") - section = "Decoding: time-by-time" - tags = ( - "epochs", - "contrast", - "decoding", - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", - ) - processing = f"{a_vs_b}+TimeByTime+{cfg.decoding_metric}" - processing = processing.replace("_", "-").replace("-", "") - fname_decoding = bids_path.copy().update( - processing=processing, suffix="decoding", extension=".mat" - ) - decoding_data = loadmat(fname_decoding) - del fname_decoding, processing, a_vs_b - - # Plot scores - fig = _plot_time_by_time_decoding_scores_gavg( - cfg=cfg, - decoding_data=decoding_data, - ) - caption = ( - f'Based on N={decoding_data["N"].squeeze()} ' - f"subjects. Standard error and confidence interval " - f"of the mean were bootstrapped with {cfg.n_boot} " - f"resamples. CI must not be used for statistical inference here, " - f"as it is not corrected for multiple testing." - ) - if len(get_subjects(cfg)) > 1: - caption += ( - f" Time periods with decoding performance significantly above " - f"chance, if any, were derived with a one-tailed " - f"cluster-based permutation test " - f'({decoding_data["cluster_n_permutations"].squeeze()} ' - f"permutations) and are highlighted in yellow." - ) - title = f"Decoding over time: {cond_1} vs. {cond_2}" - report.add_figure( - fig=fig, - title=title, - caption=caption, - section=section, - tags=tags, - replace=True, - ) - plt.close(fig) - - # Plot t-values used to form clusters - if len(get_subjects(cfg)) > 1: - fig = plot_time_by_time_decoding_t_values(decoding_data=decoding_data) - t_threshold = np.round(decoding_data["cluster_t_threshold"], 3).item() - caption = ( - f"Observed t-values. Time points with " - f"t-values > {t_threshold} were used to form clusters." - ) - report.add_figure( - fig=fig, - title=f"t-values across time: {cond_1} vs. {cond_2}", - caption=caption, - section=section, - tags=tags, - replace=True, - ) - plt.close(fig) - - if cfg.decoding_time_generalization: - fig = _plot_decoding_time_generalization( - decoding_data=decoding_data, - metric=cfg.decoding_metric, - kind="grand-average", - ) - caption = ( - f"Time generalization (generalization across time, GAT): " - f"each classifier is trained on each time point, and tested " - f"on all other time points. The results were averaged across " - f'N={decoding_data["N"].item()} subjects.' - ) - title = f"Time generalization: {cond_1} vs. {cond_2}" - report.add_figure( - fig=fig, - title=title, - caption=caption, - section=section, - tags=tags, - replace=True, - ) - plt.close(fig) - - def _sanitize_cond_tag(cond): - return cond.lower().replace(" ", "-") + return str(cond).lower().replace(" ", "-") def _imshow_tf( @@ -913,281 +590,257 @@ def _imshow_tf( def add_csp_grand_average( *, - session: str, cfg: SimpleNamespace, + subject: str, + session: str, report: mne.Report, + cond_1: str, + cond_2: str, + fname_csp_freq_results: BIDSPath, + fname_csp_cluster_results: pd.DataFrame | None, ): """Add CSP decoding results to the grand average report.""" import matplotlib.pyplot as plt # nested import to help joblib - bids_path = BIDSPath( - subject="average", - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix="decoding", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - # First, plot decoding scores across frequency bins (entire epochs). - section = "Decoding: CSP" - freq_name_to_bins_map = _handle_csp_args( + section = f"Decoding: CSP, N = {len(cfg.subjects)}" + freq_name_to_bins_map, _ = _handle_csp_args( cfg.decoding_csp_times, cfg.decoding_csp_freqs, cfg.decoding_metric, + epochs_tmin=cfg.epochs_tmin, + epochs_tmax=cfg.epochs_tmax, + time_frequency_freq_min=cfg.time_frequency_freq_min, + time_frequency_freq_max=cfg.time_frequency_freq_max, ) - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f"{cond_1}+{cond_2}".replace(op.sep, "") - processing = f"{a_vs_b}+CSP+{cfg.decoding_metric}" - processing = processing.replace("_", "-").replace("-", "") - fname_csp_freq_results = bids_path.copy().update( - processing=processing, - extension=".xlsx", - ) - csp_freq_results = pd.read_excel( - fname_csp_freq_results, sheet_name="CSP Frequency" - ) - freq_bin_starts = list() - freq_bin_widths = list() - decoding_scores = list() - error_bars = list() - for freq_range_name, freq_bins in freq_name_to_bins_map.items(): - results = csp_freq_results.loc[ - csp_freq_results["freq_range_name"] == freq_range_name, : - ] - results.reset_index(drop=True, inplace=True) - assert len(results) == len(freq_bins) - for bi, freq_bin in enumerate(freq_bins): - freq_bin_starts.append(freq_bin[0]) - freq_bin_widths.append(np.diff(freq_bin)[0]) - decoding_scores.append(results["mean"][bi]) - cis_lower = results["mean_ci_lower"][bi] - cis_upper = results["mean_ci_upper"][bi] - error_bars_lower = decoding_scores[-1] - cis_lower - error_bars_upper = cis_upper - decoding_scores[-1] - error_bars.append(np.stack([error_bars_lower, error_bars_upper])) - assert len(error_bars[-1]) == 2 # lower, upper - del cis_lower, cis_upper, error_bars_lower, error_bars_upper - error_bars = np.array(error_bars, float).T - - if cfg.decoding_metric == "roc_auc": - metric = "ROC AUC" - - fig, ax = plt.subplots(constrained_layout=True) - ax.bar( - x=freq_bin_starts, - width=freq_bin_widths, - height=decoding_scores, - align="edge", - yerr=error_bars, - edgecolor="black", - ) - ax.set_ylim([0, 1.02]) - offset = matplotlib.transforms.offset_copy( - ax.transData, fig, 0, 5, units="points" - ) - for freq_range_name, freq_bins in freq_name_to_bins_map.items(): - start = freq_bins[0][0] - stop = freq_bins[-1][1] - width = stop - start - ax.text( - x=start + width / 2, - y=0.0, - transform=offset, - s=freq_range_name, - ha="center", - va="bottom", - ) - ax.axhline(0.5, color="black", linestyle="--", label="chance") - ax.legend() - ax.set_xlabel("Frequency (Hz)") - ax.set_ylabel(f"Mean decoding score ({metric})") - tags = ( - "epochs", - "contrast", - "decoding", - "csp", - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", - ) - title = f"CSP decoding: {cond_1} vs. {cond_2}" - report.add_figure( - fig=fig, - title=title, - section=section, - caption="Mean decoding scores. Error bars represent " - "bootstrapped 95% confidence intervals.", - tags=tags, - replace=True, + freq_bin_starts = list() + freq_bin_widths = list() + decoding_scores = list() + error_bars = list() + csp_freq_results = pd.read_excel(fname_csp_freq_results, sheet_name="CSP Frequency") + for freq_range_name, freq_bins in freq_name_to_bins_map.items(): + results = csp_freq_results.loc[ + csp_freq_results["freq_range_name"] == freq_range_name, : + ] + results.reset_index(drop=True, inplace=True) + assert len(results) == len(freq_bins) + for bi, freq_bin in enumerate(freq_bins): + freq_bin_starts.append(freq_bin[0]) + freq_bin_widths.append(np.diff(freq_bin)[0]) + decoding_scores.append(results["mean"][bi]) + cis_lower = results["mean_ci_lower"][bi] + cis_upper = results["mean_ci_upper"][bi] + error_bars_lower = decoding_scores[-1] - cis_lower + error_bars_upper = cis_upper - decoding_scores[-1] + error_bars.append(np.stack([error_bars_lower, error_bars_upper])) + assert len(error_bars[-1]) == 2 # lower, upper + del cis_lower, cis_upper, error_bars_lower, error_bars_upper + error_bars = np.array(error_bars, float).T + + if cfg.decoding_metric == "roc_auc": + metric = "ROC AUC" + + fig, ax = plt.subplots(constrained_layout=True) + ax.bar( + x=freq_bin_starts, + width=freq_bin_widths, + height=decoding_scores, + align="edge", + yerr=error_bars, + edgecolor="black", + ) + ax.set_ylim([0, 1.02]) + offset = matplotlib.transforms.offset_copy(ax.transData, fig, 0, 5, units="points") + for freq_range_name, freq_bins in freq_name_to_bins_map.items(): + start = freq_bins[0][0] + stop = freq_bins[-1][1] + width = stop - start + ax.text( + x=start + width / 2, + y=0.0, + transform=offset, + s=freq_range_name, + ha="center", + va="bottom", ) + ax.axhline(0.5, color="black", linestyle="--", label="chance") + ax.legend() + ax.set_xlabel("Frequency (Hz)") + ax.set_ylabel(f"Mean decoding score ({metric})") + tags = ( + "epochs", + "contrast", + "decoding", + "csp", + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", + ) + title = f"CSP decoding: {cond_1} vs. {cond_2}" + report.add_figure( + fig=fig, + title=title, + section=section, + caption="Mean decoding scores. Error bars represent " + "bootstrapped 95% confidence intervals.", + tags=tags, + replace=True, + ) # Now, plot decoding scores across time-frequency bins. - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - a_vs_b = f"{cond_1}+{cond_2}".replace(op.sep, "") - processing = f"{a_vs_b}+CSP+{cfg.decoding_metric}" - processing = processing.replace("_", "-").replace("-", "") - fname_csp_cluster_results = bids_path.copy().update( - processing=processing, - extension=".mat", + if fname_csp_cluster_results is None: + return + csp_cluster_results = loadmat(fname_csp_cluster_results) + fig, ax = plt.subplots( + nrows=1, ncols=2, sharex=True, sharey=True, constrained_layout=True + ) + n_clu = 0 + cbar = None + lims = [np.inf, -np.inf, np.inf, -np.inf] + for freq_range_name, bins in freq_name_to_bins_map.items(): + results = csp_cluster_results[freq_range_name][0][0] + mean_crossval_scores = results["mean_crossval_scores"].ravel() + # t_vals = results['t_vals'] + clusters = results["clusters"] + cluster_p_vals = np.atleast_1d(results["cluster_p_vals"].squeeze()) + tmin = results["time_bin_edges"].ravel() + tmin, tmax = tmin[:-1], tmin[1:] + fmin = results["freq_bin_edges"].ravel() + fmin, fmax = fmin[:-1], fmin[1:] + lims[0] = min(lims[0], tmin.min()) + lims[1] = max(lims[1], tmax.max()) + lims[2] = min(lims[2], fmin.min()) + lims[3] = max(lims[3], fmax.max()) + # replicate, matching time-frequency order during clustering + fmin, fmax = np.tile(fmin, len(tmin)), np.tile(fmax, len(tmax)) + tmin, tmax = np.repeat(tmin, len(bins)), np.repeat(tmax, len(bins)) + assert fmin.shape == fmax.shape == tmin.shape == tmax.shape + assert fmin.shape == mean_crossval_scores.shape + cluster_t_threshold = results["cluster_t_threshold"].ravel().item() + + significant_cluster_idx = np.where( + cluster_p_vals < cfg.cluster_permutation_p_threshold + )[0] + significant_clusters = clusters[significant_cluster_idx] + n_clu += len(significant_cluster_idx) + + # XXX Add support for more metrics + assert cfg.decoding_metric == "roc_auc" + metric = "ROC AUC" + vmax = ( + max( + np.abs(mean_crossval_scores.min() - 0.5), + np.abs(mean_crossval_scores.max() - 0.5), + ) + + 0.5 ) - csp_cluster_results = loadmat(fname_csp_cluster_results) - - fig, ax = plt.subplots( - nrows=1, ncols=2, sharex=True, sharey=True, constrained_layout=True + vmin = 0.5 - (vmax - 0.5) + # For diverging gray colormap, we need to combine two existing + # colormaps, as there is no diverging colormap with gray/black at + # both endpoints. + from matplotlib.cm import gray, gray_r + from matplotlib.colors import ListedColormap + + black_to_white = gray(np.linspace(start=0, stop=1, endpoint=False, num=128)) + white_to_black = gray_r(np.linspace(start=0, stop=1, endpoint=False, num=128)) + black_to_white_to_black = np.vstack((black_to_white, white_to_black)) + diverging_gray_cmap = ListedColormap( + black_to_white_to_black, name="DivergingGray" ) - n_clu = 0 - cbar = None - lims = [np.inf, -np.inf, np.inf, -np.inf] - for freq_range_name, bins in freq_name_to_bins_map.items(): - results = csp_cluster_results[freq_range_name][0][0] - mean_crossval_scores = results["mean_crossval_scores"].ravel() - # t_vals = results['t_vals'] - clusters = results["clusters"] - cluster_p_vals = np.atleast_1d(results["cluster_p_vals"].squeeze()) - tmin = results["time_bin_edges"].ravel() - tmin, tmax = tmin[:-1], tmin[1:] - fmin = results["freq_bin_edges"].ravel() - fmin, fmax = fmin[:-1], fmin[1:] - lims[0] = min(lims[0], tmin.min()) - lims[1] = max(lims[1], tmax.max()) - lims[2] = min(lims[2], fmin.min()) - lims[3] = max(lims[3], fmax.max()) - # replicate, matching time-frequency order during clustering - fmin, fmax = np.tile(fmin, len(tmin)), np.tile(fmax, len(tmax)) - tmin, tmax = np.repeat(tmin, len(bins)), np.repeat(tmax, len(bins)) - assert fmin.shape == fmax.shape == tmin.shape == tmax.shape - assert fmin.shape == mean_crossval_scores.shape - cluster_t_threshold = results["cluster_t_threshold"].ravel().item() - - significant_cluster_idx = np.where( - cluster_p_vals < cfg.cluster_permutation_p_threshold - )[0] - significant_clusters = clusters[significant_cluster_idx] - n_clu += len(significant_cluster_idx) - - # XXX Add support for more metrics - assert cfg.decoding_metric == "roc_auc" - metric = "ROC AUC" - vmax = ( - max( - np.abs(mean_crossval_scores.min() - 0.5), - np.abs(mean_crossval_scores.max() - 0.5), - ) - + 0.5 - ) - vmin = 0.5 - (vmax - 0.5) - # For diverging gray colormap, we need to combine two existing - # colormaps, as there is no diverging colormap with gray/black at - # both endpoints. - from matplotlib.cm import gray, gray_r - from matplotlib.colors import ListedColormap - - black_to_white = gray(np.linspace(start=0, stop=1, endpoint=False, num=128)) - white_to_black = gray_r( - np.linspace(start=0, stop=1, endpoint=False, num=128) - ) - black_to_white_to_black = np.vstack((black_to_white, white_to_black)) - diverging_gray_cmap = ListedColormap( - black_to_white_to_black, name="DivergingGray" - ) - cmap_gray = diverging_gray_cmap - img = _imshow_tf( - mean_crossval_scores, - ax[0], - tmin=tmin, - tmax=tmax, - fmin=fmin, - fmax=fmax, - vmin=vmin, - vmax=vmax, - ) - if cbar is None: - ax[0].set_xlabel("Time (s)") - ax[0].set_ylabel("Frequency (Hz)") - ax[1].set_xlabel("Time (s)") - cbar = fig.colorbar( - ax=ax[1], shrink=0.75, orientation="vertical", mappable=img - ) - cbar.set_label(f"Mean decoding score ({metric})") - offset = matplotlib.transforms.offset_copy( - ax[0].transData, fig, 6, 0, units="points" - ) - ax[0].text( - tmin.min(), - 0.5 * fmin.min() + 0.5 * fmax.max(), - freq_range_name, - transform=offset, - ha="left", - va="center", - rotation=90, + cmap_gray = diverging_gray_cmap + img = _imshow_tf( + mean_crossval_scores, + ax[0], + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + vmin=vmin, + vmax=vmax, + ) + if cbar is None: + ax[0].set_xlabel("Time (s)") + ax[0].set_ylabel("Frequency (Hz)") + ax[1].set_xlabel("Time (s)") + cbar = fig.colorbar( + ax=ax[1], shrink=0.75, orientation="vertical", mappable=img ) + cbar.set_label(f"Mean decoding score ({metric})") + offset = matplotlib.transforms.offset_copy( + ax[0].transData, fig, 6, 0, units="points" + ) + ax[0].text( + tmin.min(), + 0.5 * fmin.min() + 0.5 * fmax.max(), + freq_range_name, + transform=offset, + ha="left", + va="center", + rotation=90, + ) - if len(significant_clusters): - # Create a masked array that only shows the T-values for - # time-frequency bins that belong to significant clusters. - if len(significant_clusters) == 1: - mask = ~significant_clusters[0].astype(bool) - else: - mask = ~np.logical_or(*significant_clusters) - mask = mask.ravel() + if len(significant_clusters): + # Create a masked array that only shows the T-values for + # time-frequency bins that belong to significant clusters. + if len(significant_clusters) == 1: + mask = ~significant_clusters[0].astype(bool) else: - mask = np.ones(mean_crossval_scores.shape, dtype=bool) - _imshow_tf( - mean_crossval_scores, - ax[1], - tmin=tmin, - tmax=tmax, - fmin=fmin, - fmax=fmax, - vmin=vmin, - vmax=vmax, - mask=mask, - cmap_masked=cmap_gray, - ) - - ax[0].set_xlim(lims[:2]) - ax[0].set_ylim(lims[2:]) - ax[0].set_title("Scores") - ax[1].set_title("Masked") - tags = ( - "epochs", - "contrast", - "decoding", - "csp", - f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", - ) - title = f"CSP TF decoding: {cond_1} vs. {cond_2}" - report.add_figure( - fig=fig, - title=title, - section=section, - caption=f"Found {n_clu} " - f"cluster{_pl(n_clu)} with " - f"p < {cfg.cluster_permutation_p_threshold} " - f"(clustering bins with absolute t-values > " - f"{round(cluster_t_threshold, 3)}).", - tags=tags, - replace=True, + mask = ~np.logical_or(*significant_clusters) + mask = mask.ravel() + else: + mask = np.ones(mean_crossval_scores.shape, dtype=bool) + _imshow_tf( + mean_crossval_scores, + ax[1], + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + vmin=vmin, + vmax=vmax, + mask=mask, + cmap_masked=cmap_gray, ) + ax[0].set_xlim(lims[:2]) + ax[0].set_ylim(lims[2:]) + ax[0].set_title("Scores") + ax[1].set_title("Masked") + tags = ( + "epochs", + "contrast", + "decoding", + "csp", + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", + ) + title = f"CSP TF decoding: {cond_1} vs. {cond_2}" + report.add_figure( + fig=fig, + title=title, + section=section, + caption=f"Found {n_clu} " + f"cluster{_pl(n_clu)} with " + f"p < {cfg.cluster_permutation_p_threshold} " + f"(clustering bins with absolute t-values > " + f"{round(cluster_t_threshold, 3)}).", + tags=tags, + replace=True, + ) + @contextlib.contextmanager def _agg_backend(): import matplotlib backend = matplotlib.get_backend() - matplotlib.use("Agg", force=True) + matplotlib.use("agg", force=True) try: yield finally: - matplotlib.use(backend, force=True) + if backend.lower() != "agg": + import matplotlib.pyplot as plt + + plt.close("all") + matplotlib.use(backend, force=True) def _add_raw( @@ -1195,12 +848,13 @@ def _add_raw( cfg: SimpleNamespace, report: mne.report.Report, bids_path_in: BIDSPath, + raw: BaseRaw, title: str, tags: tuple = (), - raw: Optional[BaseRaw] = None, + extra_html: str | None = None, ): if bids_path_in.run is not None: - title += f", run {repr(bids_path_in.run)}" + title += f", run {bids_path_in.run}" elif bids_path_in.task in ("noise", "rest"): title += f", {bids_path_in.task}" plot_raw_psd = ( @@ -1208,13 +862,103 @@ def _add_raw( or bids_path_in.run in cfg.plot_psd_for_runs or bids_path_in.task in cfg.plot_psd_for_runs ) + tags = ("raw", f"run-{bids_path_in.run}") + tags with mne.use_log_level("error"): report.add_raw( - raw=raw or bids_path_in, + raw=raw, title=title, butterfly=5, psd=plot_raw_psd, - tags=("raw", f"run-{bids_path_in.run}") + tags, + tags=tags, # caption=bids_path_in.basename, # TODO upstream replace=True, ) + if extra_html is not None: + report.add_html( + extra_html, + title=title, + tags=tags, + section=title, + replace=True, + ) + + +def _render_bem( + *, + cfg: SimpleNamespace, + report: mne.report.Report, + subject: str, + session: str | None, +): + logger.info(**gen_log_kwargs(message="Rendering MRI slices with BEM contours.")) + report.add_bem( + subject=cfg.fs_subject, + subjects_dir=cfg.fs_subjects_dir, + title="BEM", + width=256, + decim=8, + replace=True, + n_jobs=1, # prevent automatic parallelization + ) + + +# Copied from mne/report/report.py + +try: + from mne.report.report import _df_bootstrap_table +except ImportError: # MNE < 1.7 + + def _df_bootstrap_table(*, df, data_id): + html = df.to_html( + border=0, + index=False, + show_dimensions=True, + justify="unset", + float_format=lambda x: f"{x:.3f}", + classes="table table-hover table-striped table-sm table-responsive small", + na_rep="", + ) + htmls = html.split("\n") + header_pattern = "(.*)" + + for idx, html in enumerate(htmls): + if "' + ) + continue + + col_headers = re.findall(pattern=header_pattern, string=html) + if col_headers: + # Make columns sortable + assert len(col_headers) == 1 + col_header = col_headers[0] + htmls[idx] = html.replace( + "", + f'', + ) + + html = "\n".join(htmls) + return html diff --git a/mne_bids_pipeline/_run.py b/mne_bids_pipeline/_run.py index af4e63b4c..8975037ed 100644 --- a/mne_bids_pipeline/_run.py +++ b/mne_bids_pipeline/_run.py @@ -4,28 +4,30 @@ import functools import hashlib import inspect -import os import pathlib import pdb import sys -import traceback import time -from typing import Callable, Optional, Dict, List +import traceback +from collections.abc import Callable from types import SimpleNamespace +from typing import Literal -from filelock import FileLock -from joblib import Memory import json_tricks import pandas as pd +from filelock import FileLock +from joblib import Memory from mne_bids import BIDSPath from ._config_utils import get_task -from ._logging import logger, gen_log_kwargs +from ._logging import _is_testing, gen_log_kwargs, logger def failsafe_run( - get_input_fnames: Optional[Callable] = None, - get_output_fnames: Optional[Callable] = None, + *, + get_input_fnames: Callable | None = None, + get_output_fnames: Callable | None = None, + require_output: bool = True, ) -> Callable: def failsafe_run_decorator(func): @functools.wraps(func) # Preserve "identity" of original function @@ -37,15 +39,13 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): exec_params=exec_params, get_input_fnames=get_input_fnames, get_output_fnames=get_output_fnames, + require_output=require_output, + func_name=f"{__mne_bids_pipeline_step__}::{func.__name__}", ) - kwargs_copy = copy.deepcopy(kwargs) t0 = time.time() - kwargs_copy["cfg"] = json_tricks.dumps( - kwargs_copy["cfg"], sort_keys=False, indent=4 - ) log_info = pd.concat( [ - pd.Series(kwargs_copy, dtype=object), + pd.Series(kwargs, dtype=object), pd.Series(index=["time", "success", "error_message"], dtype=object), ] ) @@ -58,10 +58,10 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): log_info["error_message"] = "" except Exception as e: # Only keep what gen_log_kwargs() can handle - kwargs_copy = { - k: v - for k, v in kwargs_copy.items() - if k in ("subject", "session", "task", "run") + kwargs_log = { + k: kwargs[k] + for k in ("subject", "session", "task", "run") + if k in kwargs } message = ( f"A critical error occurred. " f"The error message was: {str(e)}" @@ -72,29 +72,29 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): # Find the limit / step where the error occurred step_dir = pathlib.Path(__file__).parent / "steps" tb = traceback.extract_tb(e.__traceback__) - for fi, frame in enumerate(inspect.stack()): + for fi, frame in enumerate(tb): is_step = pathlib.Path(frame.filename).parent.parent == step_dir del frame if is_step: # omit everything before the "step" dir, which will # generally be stuff from this file and joblib - tb = tb[-fi:] + tb = tb[fi:] break tb = "".join(traceback.format_list(tb)) if on_error == "abort": message += f"\n\nAborting pipeline run. The traceback is:\n\n{tb}" - if os.getenv("_MNE_BIDS_STUDY_TESTING", "") == "true": + if _is_testing(): raise logger.error( - **gen_log_kwargs(message=message, **kwargs_copy, emoji="❌") + **gen_log_kwargs(message=message, **kwargs_log, emoji="❌") ) sys.exit(1) elif on_error == "debug": message += "\n\nStarting post-mortem debugger." logger.error( - **gen_log_kwargs(message=message, **kwargs_copy, emoji="🐛") + **gen_log_kwargs(message=message, **kwargs_log, emoji="🐛") ) extype, value, tb = sys.exc_info() print(tb) @@ -103,7 +103,7 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): else: message += "\n\nContinuing pipeline run." logger.error( - **gen_log_kwargs(message=message, **kwargs_copy, emoji="🔂") + **gen_log_kwargs(message=message, **kwargs_log, emoji="🔂") ) log_info["time"] = round(time.time() - t0, ndigits=1) return log_info @@ -121,10 +121,18 @@ def hash_file_path(path: pathlib.Path) -> str: class ConditionalStepMemory: - def __init__(self, *, exec_params, get_input_fnames, get_output_fnames): + def __init__( + self, + *, + exec_params: SimpleNamespace, + get_input_fnames: Callable | None, + get_output_fnames: Callable | None, + require_output: bool, + func_name: str, + ): memory_location = exec_params.memory_location if memory_location is True: - use_location = exec_params.deriv_root / "joblib" + use_location = exec_params.deriv_root / exec_params.memory_subdir elif not memory_location: use_location = None else: @@ -139,6 +147,8 @@ def __init__(self, *, exec_params, get_input_fnames, get_output_fnames): self.get_input_fnames = get_input_fnames self.get_output_fnames = get_output_fnames self.memory_file_method = exec_params.memory_file_method + self.require_output = require_output + self.func_name = func_name def cache(self, func): def wrapper(*args, **kwargs): @@ -163,20 +173,7 @@ def wrapper(*args, **kwargs): # If this is ever true, we'll need to improve the logic below assert not (unknown_inputs and force_run) - def hash_(k, v): - if isinstance(v, BIDSPath): - v = v.fpath - assert isinstance( - v, pathlib.Path - ), f'Bad type {type(v)}: in_files["{k}"] = {v}' - assert v.exists(), f'missing in_files["{k}"] = {v}' - if self.memory_file_method == "mtime": - this_hash = v.lstat().st_mtime - else: - assert self.memory_file_method == "hash" # guaranteed - this_hash = hash_file_path(v) - return (str(v), this_hash) - + hash_ = functools.partial(_path_to_str_hash, method=self.memory_file_method) hashes = [] for k, v in in_files.items(): hashes.append(hash_(k, v)) @@ -211,9 +208,12 @@ def hash_(k, v): memorized_func = self.memory.cache(func, ignore=self.ignore) msg = emoji = None short_circuit = False - subject = kwargs.get("subject", None) - session = kwargs.get("session", None) - run = kwargs.get("run", None) + # Used for logging automatically + subject = kwargs.get("subject", None) # noqa + session = kwargs.get("session", None) # noqa + run = kwargs.get("run", None) # noqa + task = kwargs.get("task", None) # noqa + bad_out_files = False try: done = memorized_func.check_call_in_cache(*args, **kwargs) except Exception: @@ -229,9 +229,31 @@ def hash_(k, v): msg = "Computation forced despite existing cached result …" emoji = "🔂" else: - msg = "Computation unnecessary (cached) …" - emoji = "cache" - # When out_files is not None, we should check if the output files + # Check our output file hashes + # Need to make a copy of kwargs["in_files"] in particular + use_kwargs = copy.deepcopy(kwargs) + out_files_hashes = memorized_func(*args, **use_kwargs) + for key, (fname, this_hash) in out_files_hashes.items(): + fname = pathlib.Path(fname) + if not fname.exists(): + msg = f"Output file missing: {fname}, will recompute …" + emoji = "🧩" + bad_out_files = True + break + got_hash = hash_(key, fname, kind="out")[1] + if this_hash != got_hash: + msg = ( + f"Output file {self.memory_file_method} mismatch for " + f"{fname} ({this_hash} != {got_hash}), will " + "recompute …" + ) + emoji = "🚫" + bad_out_files = True + break + else: + msg = "Computation unnecessary (cached) …" + emoji = "cache" + # When out_files_expected is not None, we should check if the output files # exist and stop if they do (e.g., in bem surface or coreg surface # creation) elif out_files is not None: @@ -246,40 +268,32 @@ def hash_(k, v): msg = "Computation unnecessary (output files exist) …" emoji = "🔍" short_circuit = True + del out_files + if msg is not None: - step = _short_step_path(pathlib.Path(inspect.getfile(func))) - logger.info( - **gen_log_kwargs( - message=msg, - subject=subject, - session=session, - run=run, - emoji=emoji, - step=step, - ) - ) + logger.info(**gen_log_kwargs(message=msg, emoji=emoji)) if short_circuit: return # https://joblib.readthedocs.io/en/latest/memory.html#joblib.memory.MemorizedFunc.call # noqa: E501 - if force_run or unknown_inputs: - out_files, _ = memorized_func.call(*args, **kwargs) + if force_run or unknown_inputs or bad_out_files: + # Joblib 1.4.0 only returns the output, but 1.3.2 returns both. + # Fortunately we can use tuple-ness to tell the difference (we always + # return None or a dict) + out_files = memorized_func.call(*args, **kwargs) + if isinstance(out_files, tuple): + out_files = out_files[0] else: out_files = memorized_func(*args, **kwargs) - assert isinstance(out_files, dict), type(out_files) - out_files_missing_msg = "\n".join( - f"- {key}={fname}" - for key, fname in out_files.items() - if not pathlib.Path(fname).exists() - ) - if out_files_missing_msg: - raise ValueError( - "Missing at least one output file: \n" - + out_files_missing_msg - + "\n" - + "This should not happen unless some files " - "have been manually moved or deleted. You " - "need to flush your cache to fix this." + if self.require_output: + assert isinstance(out_files, dict) and len(out_files), ( + f"Internal error: step must return non-empty out_files dict, got " + f"{type(out_files).__name__} for:\n{self.func_name}" + ) + else: + assert out_files is None, ( + f"Internal error: step must return None, got {type(out_files)} " + f"for:\n{self.func_name}" ) return wrapper @@ -288,7 +302,7 @@ def clear(self) -> None: self.memory.clear() -def save_logs(*, config: SimpleNamespace, logs) -> None: # TODO add type +def save_logs(*, config: SimpleNamespace, logs: list[pd.Series]) -> None: fname = config.deriv_root / f"task-{get_task(config)}_log.xlsx" # Get the script from which the function is called for logging @@ -296,15 +310,7 @@ def save_logs(*, config: SimpleNamespace, logs) -> None: # TODO add type sheet_name = sheet_name[-30:] # shorten due to limit of excel format df = pd.DataFrame(logs) - - columns = df.columns - if "cfg" in columns: - columns = list(columns) - idx = columns.index("cfg") - del columns[idx] - columns.insert(-3, "cfg") # put it before time, success & err cols - - df = df[columns] + del logs with FileLock(fname.with_suffix(fname.suffix + ".lock")): append = fname.exists() @@ -314,13 +320,33 @@ def save_logs(*, config: SimpleNamespace, logs) -> None: # TODO add type mode="a" if append else "w", if_sheet_exists="replace" if append else None, ) + assert isinstance(config, SimpleNamespace), type(config) + cf_df = dict() + for key, val in config.__dict__.items(): + # We need to be careful about functions, json_tricks does not work with them + if inspect.isfunction(val): + new_val = "" + if func_file := inspect.getfile(val): + new_val += f"{func_file}:" + if getattr(val, "__qualname__", None): + new_val += val.__qualname__ + val = "custom callable" if not new_val else new_val + val = json_tricks.dumps(val, indent=4, sort_keys=False) + # 32767 char limit per cell (could split over lines but if something is + # this long, you'll probably get the gist from the first 32k chars) + if len(val) > 32767: + val = val[:32765] + " …" + cf_df[key] = val + cf_df = pd.DataFrame([cf_df], dtype=object) with writer: + # Config first then the data + cf_df.to_excel(writer, sheet_name="config", index=False) df.to_excel(writer, sheet_name=sheet_name, index=False) def _update_for_splits( - files_dict: Dict[str, BIDSPath], - key: str, + files_dict: dict[str, BIDSPath] | BIDSPath, + key: str | None, *, single: bool = False, allow_missing: bool = False, @@ -362,7 +388,7 @@ def _sanitize_callable(val): def _get_step_path( - stack: Optional[List[inspect.FrameInfo]] = None, + stack: list[inspect.FrameInfo] | None = None, ) -> pathlib.Path: if stack is None: stack = inspect.stack() @@ -372,8 +398,10 @@ def _get_step_path( if "steps" in fname.parts: return fname else: # pragma: no cover - if frame.function == "__mne_bids_pipeline_failsafe_wrapper__": + try: return frame.frame.f_locals["__mne_bids_pipeline_step__"] + except KeyError: + pass else: # pragma: no cover paths = "\n".join(paths) raise RuntimeError(f"Could not find step path in call stack:\n{paths}") @@ -381,3 +409,54 @@ def _get_step_path( def _short_step_path(step_path: pathlib.Path) -> str: return f"{step_path.parent.name}/{step_path.stem}" + + +def _prep_out_files( + *, + exec_params: SimpleNamespace, + out_files: dict[str, BIDSPath], + check_relative: pathlib.Path | None = None, + bids_only: bool = True, +): + if check_relative is None: + check_relative = exec_params.deriv_root + for key, fname in out_files.items(): + # Sanity check that we only ever write to the derivatives directory + if bids_only: + assert isinstance(fname, BIDSPath), (type(fname), fname) + # raw and epochs can split on write, and .save should check for us now, so + # we only need to check *other* types (these should never split) + if isinstance(fname, BIDSPath) and fname.suffix not in ("raw", "epo"): + assert fname.split is None, fname + fname = pathlib.Path(fname) + if not fname.is_relative_to(check_relative): + raise RuntimeError( + f"Output BIDSPath not relative to expected root {check_relative}:" + f"\n{fname}" + ) + out_files[key] = _path_to_str_hash( + key, + fname, + method=exec_params.memory_file_method, + kind="out", + ) + return out_files + + +def _path_to_str_hash( + k: str, + v: BIDSPath | pathlib.Path, + *, + method: Literal["mtime", "hash"], + kind: str = "in", +): + if isinstance(v, BIDSPath): + v = v.fpath + assert isinstance(v, pathlib.Path), f'Bad type {type(v)}: {kind}_files["{k}"] = {v}' + assert v.exists(), f'missing {kind}_files["{k}"] = {v}' + if method == "mtime": + this_hash = v.stat().st_mtime + else: + assert method == "hash" # guaranteed + this_hash = hash_file_path(v) + return (str(v), this_hash) diff --git a/mne_bids_pipeline/_viz.py b/mne_bids_pipeline/_viz.py index 8e49af509..4055ab7c4 100644 --- a/mne_bids_pipeline/_viz.py +++ b/mne_bids_pipeline/_viz.py @@ -1,10 +1,9 @@ -from typing import List import numpy as np import pandas as pd from matplotlib.figure import Figure -def plot_auto_scores(auto_scores, *, ch_types) -> List[Figure]: +def plot_auto_scores(auto_scores, *, ch_types) -> list[Figure]: # Plot scores of automated bad channel detection. import matplotlib.pyplot as plt import seaborn as sns @@ -15,7 +14,7 @@ def plot_auto_scores(auto_scores, *, ch_types) -> List[Figure]: ch_types_[idx] = "grad" ch_types_.insert(idx + 1, "mag") - figs: List[Figure] = [] + figs: list[Figure] = [] for ch_type in ch_types_: # Only select the data for mag or grad channels. ch_subset = auto_scores["ch_types"] == ch_type diff --git a/mne_bids_pipeline/steps/freesurfer/_01_recon_all.py b/mne_bids_pipeline/steps/freesurfer/_01_recon_all.py index ee803c800..1889867cf 100755 --- a/mne_bids_pipeline/steps/freesurfer/_01_recon_all.py +++ b/mne_bids_pipeline/steps/freesurfer/_01_recon_all.py @@ -3,6 +3,7 @@ This will run FreeSurfer's ``recon-all --all`` if necessary. """ + import os import shutil import sys @@ -11,8 +12,8 @@ from mne.utils import run_subprocess from ..._config_utils import get_fs_subjects_dir, get_subjects -from ..._logging import logger, gen_log_kwargs -from ..._parallel import parallel_func, get_parallel_backend +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func fs_bids_app = Path(__file__).parent / "contrib" / "run.py" diff --git a/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py b/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py index 2be2d786e..a76c037ef 100644 --- a/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py +++ b/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py @@ -4,24 +4,35 @@ Use FreeSurfer's ``mkheadsurf`` and related utilities to make head surfaces suitable for coregistration. """ + from pathlib import Path from types import SimpleNamespace import mne.bem from ..._config_utils import ( - get_fs_subjects_dir, get_fs_subject, + get_fs_subjects_dir, get_subjects, - _get_scalp_in_files, ) -from ..._logging import logger, gen_log_kwargs -from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._run import _prep_out_files, failsafe_run fs_bids_app = Path(__file__).parent / "contrib" / "run.py" +def _get_scalp_in_files(cfg: SimpleNamespace) -> dict[str, Path]: + subject_path = Path(cfg.fs_subjects_dir) / cfg.fs_subject + seghead = subject_path / "surf" / "lh.seghead" + in_files = dict() + if seghead.is_file(): + in_files["seghead"] = seghead + else: + in_files["t1"] = subject_path / "mri" / "T1.mgz" + return in_files + + def get_input_fnames_coreg_surfaces( *, cfg: SimpleNamespace, @@ -32,7 +43,7 @@ def get_input_fnames_coreg_surfaces( def get_output_fnames_coreg_surfaces(*, cfg: SimpleNamespace, subject: str) -> dict: out_files = dict() - subject_path = Path(cfg.subjects_dir) / cfg.fs_subject + subject_path = Path(cfg.fs_subjects_dir) / cfg.fs_subject out_files["seghead"] = subject_path / "surf" / "lh.seghead" for key in ("dense", "medium", "sparse"): out_files[f"head-{key}"] = ( @@ -57,19 +68,23 @@ def make_coreg_surfaces( in_files.pop("t1" if "t1" in in_files else "seghead") mne.bem.make_scalp_surfaces( subject=cfg.fs_subject, - subjects_dir=cfg.subjects_dir, + subjects_dir=cfg.fs_subjects_dir, force=True, overwrite=True, ) out_files = get_output_fnames_coreg_surfaces(cfg=cfg, subject=subject) - return out_files + return _prep_out_files( + exec_params=exec_params, + out_files=out_files, + check_relative=cfg.fs_subjects_dir, + bids_only=False, + ) def get_config(*, config, subject) -> SimpleNamespace: cfg = SimpleNamespace( - subject=subject, fs_subject=get_fs_subject(config, subject), - subjects_dir=get_fs_subjects_dir(config), + fs_subjects_dir=get_fs_subjects_dir(config), ) return cfg diff --git a/mne_bids_pipeline/steps/freesurfer/__init__.py b/mne_bids_pipeline/steps/freesurfer/__init__.py index 84e37008a..7f4d9d088 100644 --- a/mne_bids_pipeline/steps/freesurfer/__init__.py +++ b/mne_bids_pipeline/steps/freesurfer/__init__.py @@ -3,7 +3,6 @@ Surface reconstruction via FreeSurfer. These steps are not run by default. """ -from . import _01_recon_all -from . import _02_coreg_surfaces +from . import _01_recon_all, _02_coreg_surfaces _STEPS = (_01_recon_all, _02_coreg_surfaces) diff --git a/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py b/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py index 40a879374..e779b1382 100644 --- a/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py +++ b/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py @@ -3,24 +3,23 @@ Initialize the derivatives directory. """ -from typing import Optional +from pathlib import Path from types import SimpleNamespace from mne_bids.config import BIDS_VERSION from mne_bids.utils import _write_json -from ..._config_utils import get_subjects, get_sessions, _bids_kwargs +from ..._config_utils import _bids_kwargs, get_sessions, get_subjects from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run +from ..._run import _prep_out_files, failsafe_run -def init_dataset(cfg) -> None: +@failsafe_run() +def init_dataset(cfg: SimpleNamespace, exec_params: SimpleNamespace) -> dict[str, Path]: """Prepare the pipeline directory in /derivatives.""" - fname_json = cfg.deriv_root / "dataset_description.json" - if fname_json.is_file(): - return # already exists - msg = "Initializing output directories." - logger.info(**gen_log_kwargs(message=msg)) + out_files = dict() + out_files["json"] = cfg.deriv_root / "dataset_description.json" + logger.info(**gen_log_kwargs(message="Initializing output directories.")) cfg.deriv_root.mkdir(exist_ok=True, parents=True) @@ -37,16 +36,18 @@ def init_dataset(cfg) -> None: "URL": "n/a", } - _write_json(fname_json, ds_json, overwrite=True) + _write_json(out_files["json"], ds_json, overwrite=True) + return _prep_out_files( + exec_params=exec_params, out_files=out_files, bids_only=False + ) -@failsafe_run() def init_subject_dirs( *, cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> None: """Create processing data output directories for individual participants.""" out_dir = cfg.deriv_root / f"sub-{subject}" @@ -72,9 +73,9 @@ def get_config( def main(*, config): """Initialize the output directories.""" - init_dataset(cfg=get_config(config=config)) + init_dataset(cfg=get_config(config=config), exec_params=config.exec_params) # Don't bother with parallelization here as I/O operations are generally - # not well paralellized (and this should be very fast anyway) + # not well parallelized (and this should be very fast anyway) for subject in get_subjects(config): for session in get_sessions(config): init_subject_dirs( diff --git a/mne_bids_pipeline/steps/init/_02_find_empty_room.py b/mne_bids_pipeline/steps/init/_02_find_empty_room.py index 15428bcea..02515021a 100644 --- a/mne_bids_pipeline/steps/init/_02_find_empty_room.py +++ b/mne_bids_pipeline/steps/init/_02_find_empty_room.py @@ -1,26 +1,26 @@ """Find empty-room data matches.""" from types import SimpleNamespace -from typing import Dict, Optional -from mne.utils import _pl from mne_bids import BIDSPath from ..._config_utils import ( + _bids_kwargs, + _pl, get_datatype, + get_mf_reference_run, get_sessions, get_subjects, - get_mf_reference_run, - _bids_kwargs, ) -from ..._io import _empty_room_match_path, _write_json +from ..._import_data import _empty_room_match_path +from ..._io import _write_json from ..._logging import gen_log_kwargs, logger -from ..._run import _update_for_splits, failsafe_run, save_logs +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs def get_input_fnames_find_empty_room( - *, subject: str, session: Optional[str], run: Optional[str], cfg: SimpleNamespace -) -> Dict[str, BIDSPath]: + *, subject: str, session: str | None, run: str | None, cfg: SimpleNamespace +) -> dict[str, BIDSPath]: """Get paths of files required by find_empty_room function.""" bids_path_in = BIDSPath( subject=subject, @@ -35,7 +35,7 @@ def get_input_fnames_find_empty_room( root=cfg.bids_root, check=False, ) - in_files: Dict[str, BIDSPath] = dict() + in_files: dict[str, BIDSPath] = dict() in_files[f"raw_run-{run}"] = bids_path_in _update_for_splits(in_files, f"raw_run-{run}", single=True) if hasattr(bids_path_in, "find_matching_sidecar"): @@ -62,10 +62,10 @@ def find_empty_room( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - in_files: Dict[str, BIDSPath], -) -> Dict[str, BIDSPath]: + session: str | None, + run: str | None, + in_files: dict[str, BIDSPath], +) -> dict[str, BIDSPath]: raw_path = in_files.pop(f"raw_run-{run}") in_files.pop("sidecar", None) try: @@ -96,7 +96,7 @@ def find_empty_room( out_files = dict() out_files["empty_room_match"] = _empty_room_match_path(raw_path, cfg) _write_json(out_files["empty_room_match"], dict(fname=fname)) - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( diff --git a/mne_bids_pipeline/steps/init/__init__.py b/mne_bids_pipeline/steps/init/__init__.py index 72a80cf13..6435ffdfe 100644 --- a/mne_bids_pipeline/steps/init/__init__.py +++ b/mne_bids_pipeline/steps/init/__init__.py @@ -1,7 +1,6 @@ """Filesystem initialization and dataset inspection.""" -from . import _01_init_derivatives_dir -from . import _02_find_empty_room +from . import _01_init_derivatives_dir, _02_find_empty_room _STEPS = ( _01_init_derivatives_dir, diff --git a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py index 3f66bebe1..86a2c6b6e 100644 --- a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py +++ b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py @@ -1,36 +1,33 @@ """Assess data quality and find bad (and flat) channels.""" from types import SimpleNamespace -from typing import Optional - -import pandas as pd import mne -from mne.utils import _pl -from mne_bids import BIDSPath +import pandas as pd from ..._config_utils import ( + _do_mf_autobad, + _pl, get_mf_cal_fname, get_mf_ctc_fname, - get_subjects, - get_sessions, get_runs_tasks, - _do_mf_autobad, + get_sessions, + get_subjects, ) from ..._import_data import ( - _get_run_rest_noise_path, - _get_mf_reference_run_path, - import_experimental_data, - import_er_data, _bads_path, - _auto_scores_path, + _get_mf_reference_run_path, + _get_run_rest_noise_path, _import_data_kwargs, + _read_raw_msg, + import_er_data, + import_experimental_data, ) from ..._io import _write_json from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, _add_raw -from ..._run import failsafe_run, save_logs +from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _add_raw, _open_report +from ..._run import _prep_out_files, failsafe_run, save_logs from ..._viz import plot_auto_scores @@ -38,9 +35,9 @@ def get_input_fnames_data_quality( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], + session: str | None, + run: str | None, + task: str | None, ) -> dict: """Get paths of files required by assess_data_quality function.""" kwargs = dict( @@ -70,9 +67,9 @@ def assess_data_quality( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], + session: str | None, + run: str | None, + task: str | None, in_files: dict, ) -> dict: """Assess data quality and find and mark bad channels.""" @@ -81,27 +78,111 @@ def assess_data_quality( out_files = dict() key = f"raw_task-{task}_run-{run}" bids_path_in = in_files.pop(key) + if key == "raw_task-noise_run-None": + bids_path_ref_in = in_files.pop("raw_ref_run", None) + else: + bids_path_ref_in = None + msg, _ = _read_raw_msg(bids_path_in=bids_path_in, run=run, task=task) + logger.info(**gen_log_kwargs(message=msg)) + if run is None and task == "noise": + raw = import_er_data( + cfg=cfg, + bids_path_er_in=bids_path_in, + bids_path_er_bads_in=None, + bids_path_ref_in=bids_path_ref_in, + bids_path_ref_bads_in=None, + prepare_maxwell_filter=True, + ) + else: + data_is_rest = run is None and task == "rest" + raw = import_experimental_data( + bids_path_in=bids_path_in, + bids_path_bads_in=None, + cfg=cfg, + data_is_rest=data_is_rest, + ) + preexisting_bads = sorted(raw.info["bads"]) + if _do_mf_autobad(cfg=cfg): - if key == "raw_task-noise_run-None": - bids_path_ref_in = in_files.pop("raw_ref_run") - else: - bids_path_ref_in = None - auto_scores = _find_bads_maxwell( + ( + auto_noisy_chs, + auto_flat_chs, + auto_scores, + ) = _find_bads_maxwell( cfg=cfg, exec_params=exec_params, - bids_path_in=bids_path_in, - bids_path_ref_in=bids_path_ref_in, + raw=raw, subject=subject, session=session, run=run, task=task, - out_files=out_files, ) + bads = sorted(set(raw.info["bads"] + auto_noisy_chs + auto_flat_chs)) + msg = f"Found {len(bads)} bad channel{_pl(bads)}." + raw.info["bads"] = bads + del bads + logger.info(**gen_log_kwargs(message=msg)) else: - auto_scores = None + auto_scores = auto_noisy_chs = auto_flat_chs = None del key + # Always output the scores and bads TSV + out_files["auto_scores"] = bids_path_in.copy().update( + suffix="scores", + extension=".json", + root=cfg.deriv_root, + split=None, + check=False, + session=session, + subject=subject, + ) + _write_json(out_files["auto_scores"], auto_scores) + + # Write the bad channels to disk. + out_files["bads_tsv"] = _bads_path( + cfg=cfg, + bids_path_in=bids_path_in, + subject=subject, + session=session, + ) + bads_for_tsv = [] + reasons = [] + + if auto_flat_chs: + for ch in auto_flat_chs: + reason = ( + "pre-existing (before MNE-BIDS-pipeline was run) & auto-flat" + if ch in preexisting_bads + else "auto-flat" + ) + bads_for_tsv.append(ch) + reasons.append(reason) + + if auto_noisy_chs: + for ch in auto_noisy_chs: + reason = ( + "pre-existing (before MNE-BIDS-pipeline was run) & auto-noisy" + if ch in preexisting_bads + else "auto-noisy" + ) + bads_for_tsv.append(ch) + reasons.append(reason) + + if preexisting_bads: + for ch in preexisting_bads: + if ch in bads_for_tsv: + continue + bads_for_tsv.append(ch) + reasons.append("pre-existing (before MNE-BIDS-pipeline was run)") + + tsv_data = pd.DataFrame(dict(name=bads_for_tsv, reason=reasons)) + tsv_data = tsv_data.sort_values(by="name") + tsv_data.to_csv(out_files["bads_tsv"], sep="\t", index=False) + # Report + # Restore bads to their original state so they will show up in the report + raw.info["bads"] = preexisting_bads + with _open_report( cfg=cfg, exec_params=exec_params, @@ -118,9 +199,11 @@ def assess_data_quality( cfg=cfg, report=report, bids_path_in=bids_path_in, + raw=raw, title=f"Raw ({kind})", tags=("data-quality",), ) + title = f"Bad channel detection: {run}" if cfg.find_noisy_channels_meg: assert auto_scores is not None msg = "Adding noisy channel detection to report" @@ -132,60 +215,42 @@ def assess_data_quality( fig=figs, caption=captions, section="Data quality", - title=f"Bad channel detection: {run}", + title=title, tags=tags, replace=True, ) for fig in figs: plt.close(fig) + else: + report.remove(title=title) assert len(in_files) == 0, in_files.keys() - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def _find_bads_maxwell( *, cfg: SimpleNamespace, exec_params: SimpleNamespace, - bids_path_in: BIDSPath, - bids_path_ref_in: Optional[BIDSPath], + raw: mne.io.BaseRaw, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - out_files: dict, + session: str | None, + run: str | None, + task: str | None, ): - if cfg.find_flat_channels_meg and not cfg.find_noisy_channels_meg: - msg = "Finding flat channels." - elif cfg.find_noisy_channels_meg and not cfg.find_flat_channels_meg: - msg = "Finding noisy channels using Maxwell filtering." + if cfg.find_flat_channels_meg: + if cfg.find_noisy_channels_meg: + msg = "Finding flat channels and noisy channels using Maxwell filtering." + else: + msg = "Finding flat channels." else: - msg = "Finding flat channels and noisy channels using " "Maxwell filtering." + assert cfg.find_noisy_channels_meg + msg = "Finding noisy channels using Maxwell filtering." logger.info(**gen_log_kwargs(message=msg)) - if run is None and task == "noise": - raw = import_er_data( - cfg=cfg, - bids_path_er_in=bids_path_in, - bids_path_er_bads_in=None, - bids_path_ref_in=bids_path_ref_in, - bids_path_ref_bads_in=None, - prepare_maxwell_filter=True, - ) - else: - data_is_rest = run is None and task == "rest" - raw = import_experimental_data( - bids_path_in=bids_path_in, - bids_path_bads_in=None, - cfg=cfg, - data_is_rest=data_is_rest, - ) - # Filter the data manually before passing it to find_bad_channels_maxwell() # This reduces memory usage, as we can control the number of jobs used # during filtering. - preexisting_bads = raw.info["bads"].copy() - bads = preexisting_bads.copy() raw_filt = raw.copy().filter(l_freq=None, h_freq=40, n_jobs=1) ( auto_noisy_chs, @@ -211,7 +276,8 @@ def _find_bads_maxwell( else: msg = "Found no flat channels." logger.info(**gen_log_kwargs(message=msg)) - bads.extend(auto_flat_chs) + else: + auto_flat_chs = [] if cfg.find_noisy_channels_meg: if auto_noisy_chs: @@ -224,51 +290,8 @@ def _find_bads_maxwell( msg = "Found no noisy channels." logger.info(**gen_log_kwargs(message=msg)) - bads.extend(auto_noisy_chs) - - bads = sorted(set(bads)) - msg = f"Found {len(bads)} channel{_pl(bads)} as bad." - raw.info["bads"] = bads - del bads - logger.info(**gen_log_kwargs(message=msg)) - - if cfg.find_noisy_channels_meg: - out_files["auto_scores"] = _auto_scores_path( - cfg=cfg, - bids_path_in=bids_path_in, - ) - if not out_files["auto_scores"].fpath.parent.exists(): - out_files["auto_scores"].fpath.parent.mkdir(parents=True) - _write_json(out_files["auto_scores"], auto_scores) - - # Write the bad channels to disk. - out_files["bads_tsv"] = _bads_path( - cfg=cfg, - bids_path_in=bids_path_in, - ) - bads_for_tsv = [] - reasons = [] - - if cfg.find_flat_channels_meg: - bads_for_tsv.extend(auto_flat_chs) - reasons.extend(["auto-flat"] * len(auto_flat_chs)) - preexisting_bads = set(preexisting_bads) - set(auto_flat_chs) - - if cfg.find_noisy_channels_meg: - bads_for_tsv.extend(auto_noisy_chs) - reasons.extend(["auto-noisy"] * len(auto_noisy_chs)) - preexisting_bads = set(preexisting_bads) - set(auto_noisy_chs) - - preexisting_bads = list(preexisting_bads) - if preexisting_bads: - bads_for_tsv.extend(preexisting_bads) - reasons.extend( - ["pre-existing (before MNE-BIDS-pipeline was run)"] * len(preexisting_bads) - ) - - tsv_data = pd.DataFrame(dict(name=bads_for_tsv, reason=reasons)) - tsv_data = tsv_data.sort_values(by="name") - tsv_data.to_csv(out_files["bads_tsv"], sep="\t", index=False) + else: + auto_noisy_chs = [] # Interaction if exec_params.interactive and cfg.find_noisy_channels_meg: @@ -277,17 +300,18 @@ def _find_bads_maxwell( plot_auto_scores(auto_scores, ch_types=cfg.ch_types) plt.show() - return auto_scores + return auto_noisy_chs, auto_flat_chs, auto_scores def get_config( *, config: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> SimpleNamespace: extra_kwargs = dict() if config.find_noisy_channels_meg or config.find_flat_channels_meg: + # If these change, need to update hooks.py in doc build extra_kwargs["mf_cal_fname"] = get_mf_cal_fname( config=config, subject=subject, @@ -311,7 +335,7 @@ def get_config( def main(*, config: SimpleNamespace) -> None: - """Run maxwell_filter.""" + """Run assess_data_quality.""" with get_parallel_backend(config.exec_params): parallel, run_func = parallel_func( assess_data_quality, exec_params=config.exec_params diff --git a/mne_bids_pipeline/steps/preprocessing/_02_head_pos.py b/mne_bids_pipeline/steps/preprocessing/_02_head_pos.py index ba3e9fbac..de8996338 100644 --- a/mne_bids_pipeline/steps/preprocessing/_02_head_pos.py +++ b/mne_bids_pipeline/steps/preprocessing/_02_head_pos.py @@ -1,33 +1,32 @@ """Estimate head positions.""" -from typing import Optional from types import SimpleNamespace import mne from ..._config_utils import ( - get_subjects, - get_sessions, get_runs_tasks, + get_sessions, + get_subjects, ) from ..._import_data import ( - import_experimental_data, _get_run_rest_noise_path, _import_data_kwargs, + import_experimental_data, ) from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend +from ..._parallel import get_parallel_backend, parallel_func from ..._report import _open_report -from ..._run import failsafe_run, save_logs +from ..._run import _prep_out_files, failsafe_run, save_logs def get_input_fnames_head_pos( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], + session: str | None, + run: str | None, + task: str | None, ) -> dict: """Get paths of files required by run_head_pos function.""" return _get_run_rest_noise_path( @@ -49,9 +48,9 @@ def run_head_pos( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], + session: str | None, + run: str | None, + task: str | None, in_files: dict, ) -> dict: import matplotlib.pyplot as plt @@ -62,7 +61,8 @@ def run_head_pos( out_files = dict() key = f"raw_run-{run}-pos" out_files[key] = bids_path_in.copy().update( - extension=".pos", + suffix="headpos", + extension=".txt", root=cfg.deriv_root, check=False, ) @@ -140,14 +140,14 @@ def run_head_pos( plt.close(fig) del bids_path_in assert len(in_files) == 0, in_files.keys() - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( *, config: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> SimpleNamespace: cfg = SimpleNamespace( mf_mc_t_step_min=config.mf_mc_t_step_min, @@ -183,7 +183,7 @@ def main(*, config: SimpleNamespace) -> None: config=config, subject=subject, session=session, - include_noise=False, + which=("runs", "rest"), ) ) diff --git a/mne_bids_pipeline/steps/preprocessing/_03_maxfilter.py b/mne_bids_pipeline/steps/preprocessing/_03_maxfilter.py index 255f6bc60..e1e178395 100644 --- a/mne_bids_pipeline/steps/preprocessing/_03_maxfilter.py +++ b/mne_bids_pipeline/steps/preprocessing/_03_maxfilter.py @@ -15,41 +15,175 @@ """ import gc -from typing import Optional +from copy import deepcopy from types import SimpleNamespace -import numpy as np import mne +import numpy as np from mne_bids import read_raw_bids from ..._config_utils import ( + _pl, get_mf_cal_fname, get_mf_ctc_fname, - get_subjects, - get_sessions, get_runs_tasks, + get_sessions, + get_subjects, ) from ..._import_data import ( - import_experimental_data, - import_er_data, + _get_mf_reference_run_path, _get_run_path, _get_run_rest_noise_path, - _get_mf_reference_run_path, _import_data_kwargs, + import_er_data, + import_experimental_data, ) from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, _add_raw -from ..._run import failsafe_run, save_logs, _update_for_splits +from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _add_raw, _open_report +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs + + +# %% eSSS +def get_input_fnames_esss( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> dict: + kwargs = dict( + cfg=cfg, + subject=subject, + session=session, + ) + in_files = _get_run_rest_noise_path( + run=None, + task="noise", + kind="orig", + mf_reference_run=cfg.mf_reference_run, + **kwargs, + ) + in_files.update(_get_mf_reference_run_path(**kwargs)) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_esss, +) +def compute_esss_proj( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: dict, +) -> dict: + import matplotlib.pyplot as plt + + run, task = None, "noise" + in_key = f"raw_task-{task}_run-{run}" + bids_path_in = in_files.pop(in_key) + bids_path_bads_in = in_files.pop(f"{in_key}-bads", None) # noqa + bids_path_ref_in = in_files.pop("raw_ref_run") + bids_path_ref_bads_in = in_files.pop("raw_ref_run-bads", None) + raw_noise = import_er_data( + cfg=cfg, + bids_path_er_in=bids_path_in, + bids_path_ref_in=bids_path_ref_in, + # TODO: This must match below, so we don't pass it + # bids_path_er_bads_in=bids_path_bads_in, + bids_path_er_bads_in=None, + bids_path_ref_bads_in=bids_path_ref_bads_in, + prepare_maxwell_filter=True, + ) + logger.info( + **gen_log_kwargs( + f"Computing eSSS basis with {cfg.mf_esss} component{_pl(cfg.mf_esss)}" + ) + ) + projs = mne.compute_proj_raw( + raw_noise, + n_grad=cfg.mf_esss, + n_mag=cfg.mf_esss, + reject=cfg.mf_esss_reject, + meg="combined", + ) + out_files = dict() + out_files["esss_basis"] = bids_path_in.copy().update( + subject=subject, # need these in the case of an empty room match + session=session, + run=run, + task=task, + suffix="esssproj", + split=None, + extension=".fif", + root=cfg.deriv_root, + check=False, + ) + mne.write_proj(out_files["esss_basis"], projs, overwrite=True) + + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) as report: + msg = "Adding eSSS projectors to report." + logger.info(**gen_log_kwargs(message=msg)) + kinds_picks = list() + for kind in ("mag", "grad"): + picks = mne.pick_types(raw_noise.info, meg=kind, exclude="bads") + if not len(picks): + continue + kinds_picks.append([kind, picks]) + n_row, n_col = len(kinds_picks), cfg.mf_esss + fig, axes = plt.subplots( + n_row, + n_col, + figsize=(n_col + 0.5, n_row + 0.5), + constrained_layout=True, + squeeze=False, + ) + # TODO: plot_projs_topomap doesn't handle meg="combined" well: + # https://github.com/mne-tools/mne-python/pull/11792 + for ax_row, (kind, picks) in zip(axes, kinds_picks): + info = mne.pick_info(raw_noise.info, picks) + ch_names = info["ch_names"] + these_projs = deepcopy(projs) + for proj in these_projs: + sub_idx = [proj["data"]["col_names"].index(name) for name in ch_names] + proj["data"]["data"] = proj["data"]["data"][:, sub_idx] + proj["data"]["col_names"] = ch_names + mne.viz.plot_projs_topomap( + these_projs, + info=info, + axes=ax_row, + ) + for ai, ax in enumerate(ax_row): + ax.set_title(f"{kind} {ai + 1}") + report.add_figure( + fig, + title="eSSS projectors", + tags=("sss", "raw"), + replace=True, + ) + plt.close(fig) + + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +# %% maxwell_filter def get_input_fnames_maxwell_filter( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], + session: str | None, + run: str | None, + task: str | None, ) -> dict: """Get paths of files required by maxwell_filter function.""" kwargs = dict( @@ -64,6 +198,8 @@ def get_input_fnames_maxwell_filter( mf_reference_run=cfg.mf_reference_run, **kwargs, ) + in_key = f"raw_task-{task}_run-{run}" + assert in_key in in_files # head positions if cfg.mf_mc: if run is None and task == "noise": @@ -77,16 +213,51 @@ def get_input_fnames_maxwell_filter( kind="orig", **kwargs, )[f"raw_task-{pos_task}_run-{pos_run}"] - in_files[f"raw_task-{task}_run-{run}-pos"] = path.update( - extension=".pos", + in_files[f"{in_key}-pos"] = path.update( + suffix="headpos", + extension=".txt", root=cfg.deriv_root, check=False, task=pos_task, run=pos_run, ) + if cfg.mf_esss: + in_files["esss_basis"] = ( + in_files[in_key] + .copy() + .update( + subject=subject, + session=session, + run=None, + task="noise", + suffix="esssproj", + split=None, + extension=".fif", + root=cfg.deriv_root, + check=False, + ) + ) + # reference run (used for `destination` and also bad channels for noise) - in_files.update(_get_mf_reference_run_path(add_bads=True, **kwargs)) + # use add_bads=None here to mean "add if autobad is turned on" + in_files.update(_get_mf_reference_run_path(**kwargs)) + + is_rest_noise = run is None and task in ("noise", "rest") + if is_rest_noise: + key = "raw_ref_run_sss" + in_files[key] = ( + in_files["raw_ref_run"] + .copy() + .update( + processing="sss", + suffix="raw", + extension=".fif", + root=cfg.deriv_root, + check=False, + ) + ) + _update_for_splits(in_files, key, single=True) # standard files in_files["mf_cal_fname"] = cfg.mf_cal_fname @@ -102,9 +273,9 @@ def run_maxwell_filter( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], + session: str | None, + run: str | None, + task: str | None, in_files: dict, ) -> dict: if cfg.proc and "sss" in cfg.proc and cfg.use_maxwell_filter: @@ -139,14 +310,10 @@ def run_maxwell_filter( extension=".fif", root=cfg.deriv_root, check=False, + split=None, ) bids_path_out = bids_path_in.copy().update(**bids_path_out_kwargs) - # Now take everything from the bids_path_in and overwrite the parameters - subject = bids_path_in.subject # noqa: F841 - session = bids_path_in.session # noqa: F841 - run = bids_path_in.run - out_files = dict() # Load dev_head_t and digitization points from MaxFilter reference run. msg = f"Loading reference run: {cfg.mf_reference_run}." @@ -167,15 +334,23 @@ def run_maxwell_filter( # Maxwell-filter experimental data. apply_msg = "Applying " + extra = list() if cfg.mf_st_duration: apply_msg += f"tSSS ({cfg.mf_st_duration} sec, corr={cfg.mf_st_correlation})" else: apply_msg += "SSS" if cfg.mf_mc: - apply_msg += " with MC" + extra.append("MC") head_pos = mne.chpi.read_head_pos(in_files.pop(f"{in_key}-pos")) else: head_pos = None + if cfg.mf_esss: + extra.append("eSSS") + extended_proj = mne.read_proj(in_files.pop("esss_basis")) + else: + extended_proj = () + if extra: + apply_msg += " with " + "/".join(extra) apply_msg += " to" mf_kws = dict( @@ -187,6 +362,7 @@ def run_maxwell_filter( coord_frame="head", destination=destination, head_pos=head_pos, + extended_proj=extended_proj, ) logger.info(**gen_log_kwargs(message=f"{apply_msg} {recording_type} data")) @@ -249,7 +425,7 @@ def run_maxwell_filter( # copy the bad channel selection from the reference run over to # the resting-state recording. - bids_path_ref_sss = bids_path_ref_in.copy().update(**bids_path_out_kwargs) + bids_path_ref_sss = in_files.pop("raw_ref_run_sss") raw_exp = mne.io.read_raw_fif(bids_path_ref_sss) rank_exp = mne.compute_rank(raw_exp, rank="info")["meg"] rank_noise = mne.compute_rank(raw_sss, rank="info")["meg"] @@ -282,6 +458,42 @@ def run_maxwell_filter( t_window=cfg.mf_mc_t_window, ) + if cfg.mf_mc and ( + cfg.mf_mc_rotation_velocity_limit is not None + or cfg.mf_mc_translation_velocity_limit is not None + ): + movement_annot, _ = mne.preprocessing.annotate_movement( + raw_sss, + pos=head_pos, + rotation_velocity_limit=cfg.mf_mc_rotation_velocity_limit, + translation_velocity_limit=cfg.mf_mc_translation_velocity_limit, + ) + perc_time = 100 / raw_sss.times[-1] + extra_html = list() + for kind, unit in (("translation", "m"), ("rotation", "°")): + limit = getattr(cfg, f"mf_mc_{kind}_velocity_limit") + if limit is None: + continue + desc = (f"BAD_mov_{kind[:5]}_vel",) + tot_time = np.sum( + movement_annot.duration[movement_annot.description == desc] + ) + perc = perc_time * tot_time + logger_meth = logger.warning if perc > 20 else logger.info + msg = ( + f"{kind.capitalize()} velocity exceeded {limit} {unit}/s " + f"limit for {tot_time:0.1f} s ({perc:0.1f}%)" + ) + logger_meth(**gen_log_kwargs(message=msg)) + extra_html.append(f"
  • {msg}
  • ") + extra_html = ( + "

    The raw data were annotated with the following movement-related bad " + f"segment annotations:

      {''.join(extra_html)}

    " + ) + raw_sss.set_annotations(raw_sss.annotations + movement_annot) + else: + movement_annot = extra_html = None + out_files["sss_raw"] = bids_path_out msg = f"Writing {out_files['sss_raw'].fpath.relative_to(cfg.deriv_root)}" logger.info(**gen_log_kwargs(message=msg)) @@ -307,6 +519,7 @@ def run_maxwell_filter( ) as report: msg = "Adding Maxwell filtered raw data to report." logger.info(**gen_log_kwargs(message=msg)) + _add_raw( cfg=cfg, report=report, @@ -314,17 +527,32 @@ def run_maxwell_filter( title="Raw (maxwell filtered)", tags=("sss",), raw=raw_sss, + extra_html=extra_html, ) assert len(in_files) == 0, in_files.keys() - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) -def get_config( +def get_config_esss( *, config: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, +) -> SimpleNamespace: + cfg = SimpleNamespace( + mf_esss=config.mf_esss, + mf_esss_reject=config.mf_esss_reject, + **_import_data_kwargs(config=config, subject=subject), + ) + return cfg + + +def get_config_maxwell_filter( + *, + config: SimpleNamespace, + subject: str, + session: str | None, ) -> SimpleNamespace: cfg = SimpleNamespace( mf_cal_fname=get_mf_cal_fname( @@ -345,6 +573,9 @@ def get_config( mf_destination=config.mf_destination, mf_int_order=config.mf_int_order, mf_mc_t_window=config.mf_mc_t_window, + mf_mc_rotation_velocity_limit=config.mf_mc_rotation_velocity_limit, + mf_mc_translation_velocity_limit=config.mf_mc_translation_velocity_limit, + mf_esss=config.mf_esss, **_import_data_kwargs(config=config, subject=subject), ) return cfg @@ -358,25 +589,55 @@ def main(*, config: SimpleNamespace) -> None: return with get_parallel_backend(config.exec_params): + logs = list() + # First step: compute eSSS projectors + if config.mf_esss: + parallel, run_func = parallel_func( + compute_esss_proj, exec_params=config.exec_params + ) + logs += parallel( + run_func( + cfg=get_config_esss( + config=config, + subject=subject, + session=session, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + ) + + # Second: maxwell_filter parallel, run_func = parallel_func( run_maxwell_filter, exec_params=config.exec_params ) - logs = parallel( - run_func( - cfg=get_config(config=config, subject=subject, session=session), - exec_params=config.exec_params, - subject=subject, - session=session, - run=run, - task=task, - ) - for subject in get_subjects(config) - for session in get_sessions(config) - for run, task in get_runs_tasks( - config=config, - subject=subject, - session=session, + # We need to guarantee that the reference_run completes before the + # noise/rest runs are processed, so we split the loops. + for which in [("runs",), ("noise", "rest")]: + logs += parallel( + run_func( + cfg=get_config_maxwell_filter( + config=config, + subject=subject, + session=session, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + for run, task in get_runs_tasks( + config=config, + subject=subject, + session=session, + which=which, + ) ) - ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py b/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py index 50a3f3da0..aec4b609a 100644 --- a/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py +++ b/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py @@ -14,36 +14,40 @@ If config.interactive = True plots raw data and power spectral density. """ # noqa: E501 -import numpy as np +from collections.abc import Iterable from types import SimpleNamespace -from typing import Optional, Union, Literal, Iterable +from typing import Literal import mne +import numpy as np +from mne.io.pick import _picks_to_idx +from mne.preprocessing import EOGRegression from ..._config_utils import ( - get_sessions, get_runs_tasks, + get_sessions, get_subjects, ) from ..._import_data import ( - import_experimental_data, - import_er_data, _get_run_rest_noise_path, _import_data_kwargs, + _read_raw_msg, + import_er_data, + import_experimental_data, ) from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, _add_raw -from ..._run import failsafe_run, save_logs, _update_for_splits +from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _add_raw, _open_report +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs def get_input_fnames_frequency_filter( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, run: str, - task: Optional[str], + task: str | None, ) -> dict: """Get paths of files required by filter_data function.""" kind = "sss" if cfg.use_maxwell_filter else "orig" @@ -61,13 +65,14 @@ def get_input_fnames_frequency_filter( def notch_filter( raw: mne.io.BaseRaw, subject: str, - session: Optional[str], + session: str | None, run: str, - task: Optional[str], - freqs: Optional[Union[float, Iterable[float]]], - trans_bandwidth: Union[float, Literal["auto"]], - notch_widths: Optional[Union[float, Iterable[float]]], + task: str | None, + freqs: float | Iterable[float] | None, + trans_bandwidth: float | Literal["auto"], + notch_widths: float | Iterable[float] | None, run_type: Literal["experimental", "empty-room", "resting-state"], + picks: np.ndarray | None, ) -> None: """Filter data channels (MEG and EEG).""" if freqs is None: @@ -85,20 +90,22 @@ def notch_filter( trans_bandwidth=trans_bandwidth, notch_widths=notch_widths, n_jobs=1, + picks=picks, ) def bandpass_filter( raw: mne.io.BaseRaw, subject: str, - session: Optional[str], + session: str | None, run: str, - task: Optional[str], - l_freq: Optional[float], - h_freq: Optional[float], - l_trans_bandwidth: Union[float, Literal["auto"]], - h_trans_bandwidth: Union[float, Literal["auto"]], + task: str | None, + l_freq: float | None, + h_freq: float | None, + l_trans_bandwidth: float | Literal["auto"], + h_trans_bandwidth: float | Literal["auto"], run_type: Literal["experimental", "empty-room", "resting-state"], + picks: np.ndarray | None, ) -> None: """Filter data channels (MEG and EEG).""" if l_freq is not None and h_freq is None: @@ -121,15 +128,16 @@ def bandpass_filter( l_trans_bandwidth=l_trans_bandwidth, h_trans_bandwidth=h_trans_bandwidth, n_jobs=1, + picks=picks, ) def resample( raw: mne.io.BaseRaw, subject: str, - session: Optional[str], + session: str | None, run: str, - task: Optional[str], + task: str | None, sfreq: float, run_type: Literal["experimental", "empty-room", "resting-state"], ) -> None: @@ -149,9 +157,9 @@ def filter_data( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, run: str, - task: Optional[str], + task: str | None, in_files: dict, ) -> dict: """Filter data from a single subject.""" @@ -159,21 +167,15 @@ def filter_data( in_key = f"raw_task-{task}_run-{run}" bids_path_in = in_files.pop(in_key) bids_path_bads_in = in_files.pop(f"{in_key}-bads", None) - - if run is None and task in ("noise", "rest"): - run_type = dict(rest="resting-state", noise="empty-room")[task] - else: - run_type = "experimental" - + msg, run_type = _read_raw_msg(bids_path_in=bids_path_in, run=run, task=task) + logger.info(**gen_log_kwargs(message=msg)) if cfg.use_maxwell_filter: - msg = f"Reading {run_type} recording: " f"{bids_path_in.basename}" - logger.info(**gen_log_kwargs(message=msg)) raw = mne.io.read_raw_fif(bids_path_in) elif run is None and task == "noise": raw = import_er_data( cfg=cfg, bids_path_er_in=bids_path_in, - bids_path_ref_in=in_files.pop("raw_ref_run"), + bids_path_ref_in=in_files.pop("raw_ref_run", None), bids_path_er_bads_in=bids_path_bads_in, # take bads from this run (0) bids_path_ref_bads_in=in_files.pop("raw_ref_run-bads", None), @@ -190,14 +192,29 @@ def filter_data( out_files[in_key] = bids_path_in.copy().update( root=cfg.deriv_root, + subject=subject, # save under subject's directory so all files are there + session=session, processing="filt", extension=".fif", suffix="raw", split=None, task=task, run=run, + check=False, ) + if cfg.regress_artifact is None: + picks = None + else: + # Need to figure out the correct picks to use + model = EOGRegression(**cfg.regress_artifact) + picks_regress = _picks_to_idx( + raw.info, model.picks, none="data", exclude=model.exclude + ) + picks_artifact = _picks_to_idx(raw.info, model.picks_artifact) + picks_data = _picks_to_idx(raw.info, "data", exclude=()) # raw.filter default + picks = np.unique(np.r_[picks_regress, picks_artifact, picks_data]) + raw.load_data() notch_filter( raw=raw, @@ -209,6 +226,7 @@ def filter_data( trans_bandwidth=cfg.notch_trans_bandwidth, notch_widths=cfg.notch_widths, run_type=run_type, + picks=picks, ) bandpass_filter( raw=raw, @@ -221,6 +239,7 @@ def filter_data( h_trans_bandwidth=cfg.h_trans_bandwidth, l_trans_bandwidth=cfg.l_trans_bandwidth, run_type=run_type, + picks=picks, ) resample( raw=raw, @@ -232,6 +251,9 @@ def filter_data( run_type=run_type, ) + # For example, might need to create + # derivatives/mne-bids-pipeline/sub-emptyroom/ses-20230412/meg + out_files[in_key].fpath.parent.mkdir(exist_ok=True, parents=True) raw.save( out_files[in_key], overwrite=True, @@ -265,7 +287,7 @@ def filter_data( ) assert len(in_files) == 0, in_files.keys() - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( @@ -282,6 +304,7 @@ def get_config( notch_trans_bandwidth=config.notch_trans_bandwidth, notch_widths=config.notch_widths, raw_resample_sfreq=config.raw_resample_sfreq, + regress_artifact=config.regress_artifact, **_import_data_kwargs(config=config, subject=subject), ) return cfg diff --git a/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py b/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py new file mode 100644 index 000000000..d52f78ed1 --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py @@ -0,0 +1,166 @@ +"""Temporal regression for artifact removal.""" + +from types import SimpleNamespace + +import mne +from mne.io.pick import _picks_to_idx +from mne.preprocessing import EOGRegression + +from ..._config_utils import ( + get_runs_tasks, + get_sessions, + get_subjects, +) +from ..._import_data import _get_run_rest_noise_path, _import_data_kwargs, _read_raw_msg +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _add_raw, _open_report +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs + + +def get_input_fnames_regress_artifact( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, +) -> dict: + """Get paths of files required by regress_artifact function.""" + out = _get_run_rest_noise_path( + cfg=cfg, + subject=subject, + session=session, + run=run, + task=task, + kind="filt", + mf_reference_run=cfg.mf_reference_run, + ) + assert len(out) + return out + + +@failsafe_run( + get_input_fnames=get_input_fnames_regress_artifact, +) +def run_regress_artifact( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, + in_files: dict, +) -> dict: + model = EOGRegression(proj=False, **cfg.regress_artifact) + out_files = dict() + in_key = f"raw_task-{task}_run-{run}" + bids_path_in = in_files.pop(in_key) + out_files[in_key] = bids_path_in.copy().update(processing="regress") + msg, _ = _read_raw_msg(bids_path_in=bids_path_in, run=run, task=task) + logger.info(**gen_log_kwargs(message=msg)) + raw = mne.io.read_raw_fif(bids_path_in).load_data() + projs = raw.info["projs"] + raw.del_proj() + model.fit(raw) + all_types = raw.get_channel_types() + picks = _picks_to_idx(raw.info, model.picks, none="data", exclude=model.exclude) + ch_types = set(all_types[pick] for pick in picks) + del picks + out_files["regress"] = bids_path_in.copy().update( + processing=None, + split=None, + suffix="regress", + extension=".h5", + ) + model.apply(raw, copy=False) + if projs: + raw.add_proj(projs) + raw.save(out_files[in_key], overwrite=True, split_size=cfg._raw_split_size) + _update_for_splits(out_files, in_key) + model.save(out_files["regress"], overwrite=True) + assert len(in_files) == 0, in_files.keys() + + # Report + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) as report: + msg = "Adding regressed raw data to report" + logger.info(**gen_log_kwargs(message=msg)) + figs, captions = list(), list() + for kind in ("mag", "grad", "eeg"): + if kind not in ch_types: + continue + figs.append(model.plot(ch_type=kind)) + captions.append(f"Run {run}: {kind}") + if figs: + report.add_figure( + fig=figs, + caption=captions, + title="Regression weights", + tags=("raw", f"run-{run}", "regression"), + replace=True, + ) + _add_raw( + cfg=cfg, + report=report, + bids_path_in=out_files[in_key], + title="Raw (regression)", + tags=("regression",), + raw=raw, + ) + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, +) -> SimpleNamespace: + cfg = SimpleNamespace( + regress_artifact=config.regress_artifact, + **_import_data_kwargs(config=config, subject=subject), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Run artifact regression.""" + if config.regress_artifact is None: + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + parallel, run_func = parallel_func( + run_regress_artifact, exec_params=config.exec_params + ) + + logs = parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + for run, task in get_runs_tasks( + config=config, + subject=subject, + session=session, + ) + ) + + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py new file mode 100644 index 000000000..39b4e59a2 --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py @@ -0,0 +1,399 @@ +"""Fit ICA. + +This fits Independent Component Analysis (ICA) on high-pass filtered raw data, +temporarily creating task-related epochs. The epochs created here are used for +the purpose of fitting ICA only, and will not enter any other processing steps. + +Before performing ICA, we reject epochs based on peak-to-peak amplitude above +the 'ica_reject' limits to remove high-amplitude non-biological artifacts +(e.g., voltage or flux spikes). +""" + +from types import SimpleNamespace + +import autoreject +import mne +import numpy as np +from mne.preprocessing import ICA +from mne_bids import BIDSPath + +from ..._config_utils import ( + _bids_kwargs, + get_eeg_reference, + get_runs, + get_sessions, + get_subjects, +) +from ..._import_data import annotations_to_events, make_epochs +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._reject import _get_reject +from ..._report import _open_report +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs + + +def get_input_fnames_run_ica( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> dict: + bids_basename = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False, + extension=".fif", + ) + in_files = dict() + for run in cfg.runs: + key = f"raw_run-{run}" + in_files[key] = bids_basename.copy().update( + run=run, processing=cfg.processing, suffix="raw" + ) + _update_for_splits(in_files, key, single=True) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_run_ica, +) +def run_ica( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: dict, +) -> dict: + """Run ICA.""" + import matplotlib.pyplot as plt + + if cfg.ica_use_icalabel: + # The ICALabel network was trained on extended-Infomax ICA decompositions fit + # on data flltered between 1 and 100 Hz. + assert cfg.ica_algorithm in ["picard-extended_infomax", "extended_infomax"] + assert cfg.ica_l_freq == 1.0 + assert cfg.h_freq == 100.0 + assert cfg.eeg_reference == "average" + + raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] + out_files = dict() + bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) + out_files["ica"] = bids_basename.copy().update(processing="icafit", suffix="ica") + out_files["epochs"] = ( + out_files["ica"].copy().update(suffix="epo", processing="icafit") + ) + del bids_basename + + # Generate a list of raw data paths (i.e., paths of individual runs) + # we want to create epochs from. + + # Generate a unique event name -> event code mapping that can be used + # across all runs. + event_name_to_code_map = annotations_to_events(raw_paths=raw_fnames) + + epochs = None + for idx, (run, raw_fname) in enumerate(zip(cfg.runs, raw_fnames)): + msg = f"Processing raw data from {raw_fname.basename}" + logger.info(**gen_log_kwargs(message=msg)) + raw = mne.io.read_raw_fif(raw_fname, preload=True) + + # Produce high-pass filtered version of the data for ICA. + # Sanity check – make sure we're using the correct data! + if cfg.raw_resample_sfreq is not None: + assert np.allclose(raw.info["sfreq"], cfg.raw_resample_sfreq) + if cfg.l_freq is not None: + assert np.allclose(raw.info["highpass"], cfg.l_freq) + + if idx == 0: + if cfg.ica_l_freq is None: + msg = ( + f"Not applying high-pass filter (data is already filtered, " + f'cutoff: {raw.info["highpass"]} Hz).' + ) + logger.info(**gen_log_kwargs(message=msg)) + else: + msg = f"Applying high-pass filter with {cfg.ica_l_freq} Hz cutoff …" + logger.info(**gen_log_kwargs(message=msg)) + raw.filter(l_freq=cfg.ica_l_freq, h_freq=None, n_jobs=1) + + # Only keep the subset of the mapping that applies to the current run + event_id = event_name_to_code_map.copy() + for event_name in event_id.copy().keys(): + if event_name not in raw.annotations.description: + del event_id[event_name] + + if idx == 0: + msg = "Creating task-related epochs …" + logger.info(**gen_log_kwargs(message=msg)) + these_epochs = make_epochs( + subject=subject, + session=session, + task=cfg.task, + conditions=cfg.conditions, + raw=raw, + event_id=event_id, + tmin=cfg.epochs_tmin, + tmax=cfg.epochs_tmax, + metadata_tmin=cfg.epochs_metadata_tmin, + metadata_tmax=cfg.epochs_metadata_tmax, + metadata_keep_first=cfg.epochs_metadata_keep_first, + metadata_keep_last=cfg.epochs_metadata_keep_last, + metadata_query=cfg.epochs_metadata_query, + event_repeated=cfg.event_repeated, + epochs_decim=cfg.epochs_decim, + task_is_rest=cfg.task_is_rest, + rest_epochs_duration=cfg.rest_epochs_duration, + rest_epochs_overlap=cfg.rest_epochs_overlap, + ) + + these_epochs.load_data() # Remove reference to raw + del raw # free memory + + if epochs is None: + epochs = these_epochs + else: + epochs = mne.concatenate_epochs([epochs, these_epochs], on_mismatch="warn") + + del these_epochs + del run + + # Set an EEG reference + if "eeg" in cfg.ch_types: + if cfg.ica_use_icalabel: + assert cfg.eeg_reference == "average" + projection = False # Avg. ref. needs to be applied for MNE-ICALabel + elif cfg.eeg_reference == "average": + projection = True + else: + projection = False + + if not projection: + msg = "Applying average reference to EEG epochs used for ICA fitting." + logger.info(**gen_log_kwargs(message=msg)) + + epochs.set_eeg_reference(cfg.eeg_reference, projection=projection) + + ar_reject_log = ar_n_interpolate_ = None + if cfg.ica_reject == "autoreject_local": + msg = ( + "Using autoreject to find bad epochs for ICA " + "(no interpolation will be performend)" + ) + logger.info(**gen_log_kwargs(message=msg)) + ar = autoreject.AutoReject( + n_interpolate=cfg.autoreject_n_interpolate, + random_state=cfg.random_state, + n_jobs=exec_params.n_jobs, + verbose=False, + ) + ar.fit(epochs) + ar_reject_log = ar.get_reject_log(epochs) + epochs = epochs[~ar_reject_log.bad_epochs] + + n_epochs_before_reject = len(epochs) + n_epochs_rejected = ar_reject_log.bad_epochs.sum() + n_epochs_after_reject = n_epochs_before_reject - n_epochs_rejected + + ar_n_interpolate_ = ar.n_interpolate_ + msg = ( + f"autoreject marked {n_epochs_rejected} epochs as bad " + f"(cross-validated n_interpolate limit: {ar_n_interpolate_})" + ) + logger.info(**gen_log_kwargs(message=msg)) + del ar + else: + # Reject epochs based on peak-to-peak rejection thresholds + ica_reject = _get_reject( + subject=subject, + session=session, + reject=cfg.ica_reject, + ch_types=cfg.ch_types, + param="ica_reject", + ) + n_epochs_before_reject = len(epochs) + epochs.drop_bad(reject=ica_reject) + n_epochs_after_reject = len(epochs) + n_epochs_rejected = n_epochs_before_reject - n_epochs_after_reject + + msg = ( + f"Removed {n_epochs_rejected} of {n_epochs_before_reject} epochs via PTP " + f"rejection thresholds: {ica_reject}" + ) + logger.info(**gen_log_kwargs(message=msg)) + ar = None + + if 0 < n_epochs_after_reject < 0.5 * n_epochs_before_reject: + msg = ( + "More than 50% of all epochs rejected. Please check the " + "rejection thresholds." + ) + logger.warning(**gen_log_kwargs(message=msg)) + elif n_epochs_after_reject == 0: + rejection_type = ( + cfg.ica_reject if cfg.ica_reject == "autoreject_local" else "PTP-based" + ) + raise RuntimeError( + f"No epochs remaining after {rejection_type} rejection. Cannot continue." + ) + + msg = f"Saving {n_epochs_after_reject} ICA epochs to disk." + logger.info(**gen_log_kwargs(message=msg)) + epochs.save( + out_files["epochs"], + overwrite=True, + split_naming="bids", + split_size=cfg._epochs_split_size, + ) + _update_for_splits(out_files, "epochs") + + msg = f"Calculating ICA solution using method: {cfg.ica_algorithm}." + logger.info(**gen_log_kwargs(message=msg)) + + algorithm = cfg.ica_algorithm + fit_params = None + + if algorithm == "picard": + fit_params = dict(fastica_it=5) + elif algorithm == "picard-extended_infomax": + algorithm = "picard" + fit_params = dict(ortho=False, extended=True) + elif algorithm == "extended_infomax": + algorithm = "infomax" + fit_params = dict(extended=True) + + ica = ICA( + method=algorithm, + random_state=cfg.random_state, + n_components=cfg.ica_n_components, + fit_params=fit_params, + max_iter=cfg.ica_max_iterations, + ) + ica.fit(epochs, decim=cfg.ica_decim) + explained_var = ( + ica.pca_explained_variance_[: ica.n_components_].sum() + / ica.pca_explained_variance_.sum() + ) + msg = ( + f"Fit {ica.n_components_} components (explaining " + f"{round(explained_var * 100, 1)}% of the variance) in " + f"{ica.n_iter_} iterations." + ) + logger.info(**gen_log_kwargs(message=msg)) + msg = "Saving ICA solution to disk." + logger.info(**gen_log_kwargs(message=msg)) + ica.save(out_files["ica"], overwrite=True) + + # Add to report + tags = ("ica", "epochs") + title = "ICA: epochs for fitting" + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + task=cfg.task, + ) as report: + report.add_epochs( + epochs=epochs, + title=title, + drop_log_ignore=(), + replace=True, + tags=tags, + ) + if cfg.ica_reject == "autoreject_local": + caption = ( + f"Autoreject was run to produce cleaner epochs before fitting ICA. " + f"{ar_reject_log.bad_epochs.sum()} epochs were rejected because more " + f"than {ar_n_interpolate_} channels were bad (cross-validated " + f"n_interpolate limit; excluding globally bad and non-data channels, " + f"shown in white). Note that none of the blue segments were actually " + f"interpolated before submitting the data to ICA. This is following " + f"the recommended approach for ICA described in the the Autoreject " + f"documentation." + ) + fig = ar_reject_log.plot( + orientation="horizontal", aspect="auto", show=False + ) + report.add_figure( + fig=fig, + title="Autoreject cleaning", + section=title, + caption=caption, + tags=tags + ("autoreject",), + replace=True, + ) + plt.close(fig) + del caption + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, + session: str | None = None, +) -> SimpleNamespace: + cfg = SimpleNamespace( + conditions=config.conditions, + runs=get_runs(config=config, subject=subject), + task_is_rest=config.task_is_rest, + ica_l_freq=config.ica_l_freq, + ica_algorithm=config.ica_algorithm, + ica_n_components=config.ica_n_components, + ica_max_iterations=config.ica_max_iterations, + ica_decim=config.ica_decim, + ica_reject=config.ica_reject, + ica_use_icalabel=config.ica_use_icalabel, + autoreject_n_interpolate=config.autoreject_n_interpolate, + random_state=config.random_state, + ch_types=config.ch_types, + l_freq=config.l_freq, + h_freq=config.h_freq, + epochs_decim=config.epochs_decim, + raw_resample_sfreq=config.raw_resample_sfreq, + event_repeated=config.event_repeated, + epochs_tmin=config.epochs_tmin, + epochs_tmax=config.epochs_tmax, + epochs_metadata_tmin=config.epochs_metadata_tmin, + epochs_metadata_tmax=config.epochs_metadata_tmax, + epochs_metadata_keep_first=config.epochs_metadata_keep_first, + epochs_metadata_keep_last=config.epochs_metadata_keep_last, + epochs_metadata_query=config.epochs_metadata_query, + eeg_reference=get_eeg_reference(config), + eog_channels=config.eog_channels, + rest_epochs_duration=config.rest_epochs_duration, + rest_epochs_overlap=config.rest_epochs_overlap, + processing="filt" if config.regress_artifact is None else "regress", + _epochs_split_size=config._epochs_split_size, + **_bids_kwargs(config=config), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Run ICA.""" + if config.spatial_filter != "ica": + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + parallel, run_func = parallel_func(run_ica, exec_params=config.exec_params) + logs = parallel( + run_func( + cfg=get_config(config=config, subject=subject), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + ) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py new file mode 100644 index 000000000..d2ac04e22 --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py @@ -0,0 +1,465 @@ +"""Find ICA artifacts. + +This step automatically finds ECG- and EOG-related ICs in your data, and sets them +as bad components. + +To actually remove designated ICA components from your data, you will have to +run the apply_ica step. +""" + +from types import SimpleNamespace +from typing import Literal + +import mne +import numpy as np +import pandas as pd +from mne.preprocessing import create_ecg_epochs, create_eog_epochs +from mne.viz import plot_ica_components +from mne_bids import BIDSPath +from mne_icalabel import label_components +import mne_icalabel +import matplotlib.pyplot as plt + +from ..._config_utils import ( + _bids_kwargs, + get_eeg_reference, + get_runs, + get_sessions, + get_subjects, +) +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _open_report +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs + + +def detect_bad_components( + *, + cfg, + which: Literal["eog", "ecg"], + epochs: mne.BaseEpochs | None, + ica: mne.preprocessing.ICA, + ch_names: list[str] | None, + subject: str, + session: str | None, +) -> tuple[list[int], np.ndarray]: + artifact = which.upper() + if epochs is None: + msg = ( + f"No {artifact} events could be found. " + f"Not running {artifact} artifact detection." + ) + logger.info(**gen_log_kwargs(message=msg)) + return [], [] + msg = f"Performing automated {artifact} artifact detection …" + logger.info(**gen_log_kwargs(message=msg)) + + if which == "eog": + inds, scores = ica.find_bads_eog( + epochs, + threshold=cfg.ica_eog_threshold, + ch_name=ch_names, + ) + else: + inds, scores = ica.find_bads_ecg( + epochs, + method="ctps", + threshold=cfg.ica_ecg_threshold, + ch_name=ch_names, + ) + + if not inds: + adjust_setting = f"ica_{which}_threshold" + warn = ( + f"No {artifact}-related ICs detected, this is highly " + f"suspicious. A manual check is suggested. You may wish to " + f'lower "{adjust_setting}".' + ) + logger.warning(**gen_log_kwargs(message=warn)) + else: + msg = ( + f"Detected {len(inds)} {artifact}-related ICs in " + f"{len(epochs)} {artifact} epochs: {', '.join([str(i) for i in inds])}" + ) + logger.info(**gen_log_kwargs(message=msg)) + + return inds, scores + + +def get_input_fnames_find_ica_artifacts( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> dict: + bids_basename = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False, + extension=".fif", + ) + in_files = dict() + in_files["epochs"] = bids_basename.copy().update(processing="icafit", suffix="epo") + _update_for_splits(in_files, "epochs", single=True) + for run in cfg.runs: + key = f"raw_run-{run}" + in_files[key] = bids_basename.copy().update( + run=run, processing=cfg.processing, suffix="raw" + ) + _update_for_splits(in_files, key, single=True) + in_files["ica"] = bids_basename.copy().update(processing="icafit", suffix="ica") + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_find_ica_artifacts, +) +def find_ica_artifacts( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: dict, +) -> dict: + """Run ICA.""" + raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] + bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) + out_files = dict() + out_files["ica"] = bids_basename.copy().update(processing="ica", suffix="ica") + out_files["ecg"] = bids_basename.copy().update(processing="ica+ecg", suffix="ave") + out_files["eog"] = bids_basename.copy().update(processing="ica+eog", suffix="ave") + + # DO NOT add this to out_files["ica"] because we expect it to be modified by users. + # If the modify it and it's in out_files, caching will detect the hash change and + # consider *this step* a cache miss, and it will run again, overwriting the user's + # changes. Instead, we want the ica.apply step to rerun (which it will if the + # file changes). + out_files_components = bids_basename.copy().update( + processing="ica", suffix="components", extension=".tsv" + ) + del bids_basename + msg = "Loading ICA solution" + logger.info(**gen_log_kwargs(message=msg)) + ica = mne.preprocessing.read_ica(in_files.pop("ica")) + + # Epochs used for ICA fitting + epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) + + # ECG component detection + epochs_ecg = None + ecg_ics, ecg_scores = [], [] + if cfg.ica_use_ecg_detection: + for ri, raw_fname in enumerate(raw_fnames): + # Have the channels needed to make ECG epochs + raw = mne.io.read_raw(raw_fname, preload=False) + # ECG epochs + if not ( + "ecg" in raw.get_channel_types() + or "meg" in cfg.ch_types + or "mag" in cfg.ch_types + ): + msg = ( + "No ECG or magnetometer channels are present, cannot " + "automate artifact detection for ECG." + ) + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating ECG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + + # We want to extract a total of 5 min of data for ECG epochs generation + # (across all runs) + total_ecg_dur = 5 * 60 + ecg_dur_per_run = total_ecg_dur / len(raw_fnames) + t_mid = (raw.times[-1] + raw.times[0]) / 2 + raw = raw.crop( + tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), + tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), + ).load_data() + + these_ecg_epochs = create_ecg_epochs( + raw, + baseline=(None, -0.2), + tmin=-0.5, + tmax=0.5, + ) + del raw # Free memory + if len(these_ecg_epochs): + if epochs.reject is not None: + these_ecg_epochs.drop_bad(reject=epochs.reject) + if len(these_ecg_epochs): + if epochs_ecg is None: + epochs_ecg = these_ecg_epochs + else: + epochs_ecg = mne.concatenate_epochs( + [epochs_ecg, these_ecg_epochs], on_mismatch="warn" + ) + del these_ecg_epochs + else: # did not break so had usable channels + ecg_ics, ecg_scores = detect_bad_components( + cfg=cfg, + which="ecg", + epochs=epochs_ecg, + ica=ica, + ch_names=None, # we currently don't allow for custom channels + subject=subject, + session=session, + ) + + # EOG component detection + epochs_eog = None + eog_ics = eog_scores = [] + if cfg.ica_use_eog_detection: + for ri, raw_fname in enumerate(raw_fnames): + raw = mne.io.read_raw_fif(raw_fname, preload=True) + if cfg.eog_channels: + ch_names = cfg.eog_channels + assert all([ch_name in raw.ch_names for ch_name in ch_names]) + else: + eog_picks = mne.pick_types(raw.info, meg=False, eog=True) + ch_names = [raw.ch_names[pick] for pick in eog_picks] + if not ch_names: + msg = "No EOG channel is present, cannot automate IC detection for EOG." + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating EOG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + these_eog_epochs = create_eog_epochs( + raw, + ch_name=ch_names, + baseline=(None, -0.2), + ) + if len(these_eog_epochs): + if epochs.reject is not None: + these_eog_epochs.drop_bad(reject=epochs.reject) + if len(these_eog_epochs): + if epochs_eog is None: + epochs_eog = these_eog_epochs + else: + epochs_eog = mne.concatenate_epochs( + [epochs_eog, these_eog_epochs], on_mismatch="warn" + ) + else: # did not break + eog_ics, eog_scores = detect_bad_components( + cfg=cfg, + which="eog", + epochs=epochs_eog, + ica=ica, + ch_names=cfg.eog_channels, + subject=subject, + session=session, + ) + + + # Run MNE-ICALabel if requested. + if cfg.ica_use_icalabel: + icalabel_ics = [] + icalabel_labels = [] + icalabel_prob = [] + msg = "Performing automated artifact detection (MNE-ICALabel) …" + logger.info(**gen_log_kwargs(message=msg)) + + label_results = mne_icalabel.label_components(inst=epochs, ica=ica, method="iclabel") + for idx, (label,prob) in enumerate(zip(label_results["labels"],label_results["y_pred_proba"])): + #icalabel_include = ["brain", "other"] + print(label) + print(prob) + + if label not in cfg.icalabel_include: + icalabel_ics.append(idx) + icalabel_labels.append(label) + icalabel_prob.append(prob) + + msg = ( + f"Detected {len(icalabel_ics)} artifact-related independent component(s) " + f"in {len(epochs)} epochs." + ) + logger.info(**gen_log_kwargs(message=msg)) + else: + icalabel_ics = [] + + ica.exclude = sorted(set(ecg_ics + eog_ics + icalabel_ics)) + + # Save updated ICA to disk. + # We also store the automatically identified ECG- and EOG-related ICs. + msg = "Saving ICA solution and detected artifacts to disk." + logger.info(**gen_log_kwargs(message=msg)) + ica.save(out_files["ica"], overwrite=True) + + # Create TSV. + tsv_data = pd.DataFrame( + dict( + component=list(range(ica.n_components_)), + type=["ica"] * ica.n_components_, + description=["Independent Component"] * ica.n_components_, + status=["good"] * ica.n_components_, + status_description=["n/a"] * ica.n_components_, + ) + ) + + if cfg.ica_use_icalabel: + assert len(icalabel_ics) == len(icalabel_labels) + for component, label in zip(icalabel_ics, icalabel_labels): + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = f"Auto-detected {label} (MNE-ICALabel)" + if cfg.ica_use_ecg_detection: + for component in ecg_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = "Auto-detected ECG artifact (MNE)" + if cfg.ica_use_eog_detection: + for component in eog_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[ + row_idx, "status_description" + ] = "Auto-detected EOG artifact (MNE)" + + tsv_data.to_csv(out_files_components, sep="\t", index=False) + + # Lastly, add info about the epochs used for the ICA fit, and plot all ICs + # for manual inspection. + + ecg_evoked = None if epochs_ecg is None else epochs_ecg.average() + eog_evoked = None if epochs_eog is None else epochs_eog.average() + + + ecg_scores = None if len(ecg_scores) == 0 else ecg_scores + eog_scores = None if len(eog_scores) == 0 else eog_scores + + # Save ECG and EOG evokeds to disk. + for artifact_name, artifact_evoked in zip(("ecg", "eog"), (ecg_evoked, eog_evoked)): + if artifact_evoked: + msg = f"Saving {artifact_name.upper()} artifact: {out_files[artifact_name]}" + logger.info(**gen_log_kwargs(message=msg)) + artifact_evoked.save(out_files[artifact_name], overwrite=True) + else: + # Don't track the non-existent output file + del out_files[artifact_name] + + del artifact_name, artifact_evoked + + title = "ICA: components" + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + task=cfg.task, + ) as report: + logger.info(**gen_log_kwargs(message=f'Adding "{title}" to report.')) + report.add_ica( + ica=ica, + title=title, + inst=epochs, + ecg_evoked=ecg_evoked, + eog_evoked=eog_evoked, + ecg_scores=ecg_scores, + eog_scores=eog_scores, + replace=True, + n_jobs=1, # avoid automatic parallelization + tags=("ica",), # the default but be explicit + ) + + # Add a plot for each excluded IC together with the given label and the probability + # TODO: Improve this plot e.g. combine all figures in one plot + for ic, label, prob in zip(icalabel_ics, icalabel_labels, icalabel_prob): + excluded_IC_figure = plot_ica_components( + ica=ica, + picks=ic, + ) + excluded_IC_figure.axes[0].text(0, -0.15, f"Label: {label} \n Probability: {prob:.3f}", ha="center", fontsize=8, bbox={"facecolor":"orange", "alpha":0.5, "pad":5}) + + report.add_figure( + fig=excluded_IC_figure, + title = f'ICA{ic:03}', + replace=True, + ) + plt.close(excluded_IC_figure) + + msg = 'Carefully review the extracted ICs and mark components "bad" in:' + logger.info(**gen_log_kwargs(message=msg, emoji="🛑")) + logger.info(**gen_log_kwargs(message=str(out_files_components), emoji="🛑")) + + assert len(in_files) == 0, in_files.keys() + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, + session: str | None = None, +) -> SimpleNamespace: + cfg = SimpleNamespace( + conditions=config.conditions, + runs=get_runs(config=config, subject=subject), + task_is_rest=config.task_is_rest, + ica_l_freq=config.ica_l_freq, + ica_reject=config.ica_reject, + ica_use_eog_detection = config.ica_use_eog_detection, + ica_eog_threshold=config.ica_eog_threshold, + ica_use_ecg_detection = config.ica_use_ecg_detection, + ica_ecg_threshold=config.ica_ecg_threshold, + ica_use_icalabel=config.ica_use_icalabel, + icalabel_include = config.icalabel_include, + autoreject_n_interpolate=config.autoreject_n_interpolate, + random_state=config.random_state, + ch_types=config.ch_types, + l_freq=config.l_freq, + epochs_decim=config.epochs_decim, + raw_resample_sfreq=config.raw_resample_sfreq, + event_repeated=config.event_repeated, + epochs_tmin=config.epochs_tmin, + epochs_tmax=config.epochs_tmax, + epochs_metadata_tmin=config.epochs_metadata_tmin, + epochs_metadata_tmax=config.epochs_metadata_tmax, + epochs_metadata_keep_first=config.epochs_metadata_keep_first, + epochs_metadata_keep_last=config.epochs_metadata_keep_last, + epochs_metadata_query=config.epochs_metadata_query, + eeg_reference=get_eeg_reference(config), + eog_channels=config.eog_channels, + rest_epochs_duration=config.rest_epochs_duration, + rest_epochs_overlap=config.rest_epochs_overlap, + processing="filt" if config.regress_artifact is None else "regress", + **_bids_kwargs(config=config), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Run ICA.""" + if config.spatial_filter != "ica": + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + parallel, run_func = parallel_func( + find_ica_artifacts, exec_params=config.exec_params + ) + logs = parallel( + run_func( + cfg=get_config(config=config, subject=subject), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + ) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py deleted file mode 100644 index c49eaa825..000000000 --- a/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py +++ /dev/null @@ -1,592 +0,0 @@ -"""Run Independent Component Analysis (ICA) for artifact correction. - -This fits ICA on epoched data filtered with 1 Hz highpass, -for this purpose only using fastICA. Separate ICAs are fitted and stored for -MEG and EEG data. - -Before performing ICA, we reject epochs based on peak-to-peak amplitude above -the 'ica_reject' to filter massive non-biological artifacts. - -To actually remove designated ICA components from your data, you will have to -run 05a-apply_ica.py. -""" - -from typing import List, Optional, Iterable, Tuple, Literal -from types import SimpleNamespace - -import pandas as pd -import numpy as np - -import mne -from mne.report import Report -from mne.preprocessing import ICA, create_ecg_epochs, create_eog_epochs -from mne_bids import BIDSPath - -from ..._config_utils import ( - get_runs, - get_sessions, - get_subjects, - get_eeg_reference, - _bids_kwargs, -) -from ..._import_data import make_epochs, annotations_to_events -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._reject import _get_reject -from ..._report import _agg_backend -from ..._run import failsafe_run, _update_for_splits, save_logs - - -def filter_for_ica( - *, - cfg, - raw: mne.io.BaseRaw, - subject: str, - session: Optional[str], - run: Optional[str] = None, -) -> None: - """Apply a high-pass filter if needed.""" - if cfg.ica_l_freq is None: - msg = ( - f"Not applying high-pass filter (data is already filtered, " - f'cutoff: {raw.info["highpass"]} Hz).' - ) - logger.info(**gen_log_kwargs(message=msg)) - else: - msg = f"Applying high-pass filter with {cfg.ica_l_freq} Hz cutoff …" - logger.info(**gen_log_kwargs(message=msg)) - raw.filter(l_freq=cfg.ica_l_freq, h_freq=None, n_jobs=1) - - -def fit_ica( - *, - cfg, - epochs: mne.BaseEpochs, - subject: str, - session: Optional[str], -) -> mne.preprocessing.ICA: - algorithm = cfg.ica_algorithm - fit_params = None - - if algorithm == "picard": - fit_params = dict(fastica_it=5) - elif algorithm == "extended_infomax": - algorithm = "infomax" - fit_params = dict(extended=True) - - ica = ICA( - method=algorithm, - random_state=cfg.random_state, - n_components=cfg.ica_n_components, - fit_params=fit_params, - max_iter=cfg.ica_max_iterations, - ) - - ica.fit(epochs, decim=cfg.ica_decim) - - explained_var = ( - ica.pca_explained_variance_[: ica.n_components_].sum() - / ica.pca_explained_variance_.sum() - ) - msg = ( - f"Fit {ica.n_components_} components (explaining " - f"{round(explained_var * 100, 1)}% of the variance) in " - f"{ica.n_iter_} iterations." - ) - logger.info(**gen_log_kwargs(message=msg)) - return ica - - -def make_ecg_epochs( - *, - cfg, - raw_path: BIDSPath, - subject: str, - session: Optional[str], - run: Optional[str] = None, - n_runs: int, -) -> Optional[mne.BaseEpochs]: - # ECG either needs an ecg channel, or avg of the mags (i.e. MEG data) - raw = mne.io.read_raw(raw_path, preload=False) - - if ( - "ecg" in raw.get_channel_types() - or "meg" in cfg.ch_types - or "mag" in cfg.ch_types - ): - msg = "Creating ECG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - - # We want to extract a total of 5 min of data for ECG epochs generation - # (across all runs) - total_ecg_dur = 5 * 60 - ecg_dur_per_run = total_ecg_dur / n_runs - t_mid = (raw.times[-1] + raw.times[0]) / 2 - raw = raw.crop( - tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), - tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), - ).load_data() - - ecg_epochs = create_ecg_epochs(raw, baseline=(None, -0.2), tmin=-0.5, tmax=0.5) - del raw # Free memory - - if len(ecg_epochs) == 0: - msg = "No ECG events could be found. Not running ECG artifact " "detection." - logger.info(**gen_log_kwargs(message=msg)) - ecg_epochs = None - else: - msg = ( - "No ECG or magnetometer channels are present. Cannot " - "automate artifact detection for ECG" - ) - logger.info(**gen_log_kwargs(message=msg)) - ecg_epochs = None - - return ecg_epochs - - -def make_eog_epochs( - *, - raw: mne.io.BaseRaw, - eog_channels: Optional[Iterable[str]], - subject: str, - session: Optional[str], - run: Optional[str] = None, -) -> Optional[mne.Epochs]: - """Create EOG epochs. No rejection thresholds will be applied.""" - if eog_channels: - ch_names = eog_channels - assert all([ch_name in raw.ch_names for ch_name in ch_names]) - else: - ch_idx = mne.pick_types(raw.info, meg=False, eog=True) - ch_names = [raw.ch_names[i] for i in ch_idx] - del ch_idx - - if ch_names: - msg = "Creating EOG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - - eog_epochs = create_eog_epochs(raw, ch_name=ch_names, baseline=(None, -0.2)) - - if len(eog_epochs) == 0: - msg = "No EOG events could be found. Not running EOG artifact " "detection." - logger.warning(**gen_log_kwargs(message=msg)) - eog_epochs = None - else: - msg = "No EOG channel is present. Cannot automate IC detection " "for EOG" - logger.info(**gen_log_kwargs(message=msg)) - eog_epochs = None - - return eog_epochs - - -def detect_bad_components( - *, - cfg, - which: Literal["eog", "ecg"], - epochs: mne.BaseEpochs, - ica: mne.preprocessing.ICA, - ch_names: Optional[List[str]], - subject: str, - session: str, -) -> Tuple[List[int], np.ndarray]: - artifact = which.upper() - msg = f"Performing automated {artifact} artifact detection …" - logger.info(**gen_log_kwargs(message=msg)) - - if which == "eog": - inds, scores = ica.find_bads_eog( - epochs, - threshold=cfg.ica_eog_threshold, - ch_name=ch_names, - ) - else: - inds, scores = ica.find_bads_ecg( - epochs, - method="ctps", - threshold=cfg.ica_ctps_ecg_threshold, - ch_name=ch_names, - ) - - if not inds: - adjust_setting = ( - "ica_eog_threshold" if which == "eog" else "ica_ctps_ecg_threshold" - ) - warn = ( - f"No {artifact}-related ICs detected, this is highly " - f"suspicious. A manual check is suggested. You may wish to " - f'lower "{adjust_setting}".' - ) - logger.warning(**gen_log_kwargs(message=warn)) - else: - msg = ( - f"Detected {len(inds)} {artifact}-related ICs in " - f"{len(epochs)} {artifact} epochs." - ) - logger.info(**gen_log_kwargs(message=msg)) - - return inds, scores - - -def get_input_fnames_run_ica( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], -) -> dict: - bids_basename = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - in_files = dict() - for run in cfg.runs: - key = f"raw_run-{run}" - in_files[key] = bids_basename.copy().update( - run=run, processing="filt", suffix="raw" - ) - _update_for_splits(in_files, key, single=True) - return in_files - - -@failsafe_run( - get_input_fnames=get_input_fnames_run_ica, -) -def run_ica( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], - in_files: dict, -) -> dict: - """Run ICA.""" - raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] - bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) - out_files = dict() - out_files["ica"] = bids_basename.copy().update(suffix="ica", extension=".fif") - out_files["components"] = bids_basename.copy().update( - processing="ica", suffix="components", extension=".tsv" - ) - out_files["report"] = bids_basename.copy().update( - processing="ica+components", suffix="report", extension=".html" - ) - del bids_basename - - # Generate a list of raw data paths (i.e., paths of individual runs) - # we want to create epochs from. - - # Generate a unique event name -> event code mapping that can be used - # across all runs. - event_name_to_code_map = annotations_to_events(raw_paths=raw_fnames) - - # Now, generate epochs from each individual run - eog_epochs_all_runs = None - ecg_epochs_all_runs = None - - for idx, (run, raw_fname) in enumerate(zip(cfg.runs, raw_fnames)): - msg = f"Loading filtered raw data from {raw_fname.basename}" - logger.info(**gen_log_kwargs(message=msg)) - - # ECG epochs - ecg_epochs = make_ecg_epochs( - cfg=cfg, - raw_path=raw_fname, - subject=subject, - session=session, - run=run, - n_runs=len(cfg.runs), - ) - if ecg_epochs is not None: - if idx == 0: - ecg_epochs_all_runs = ecg_epochs - else: - ecg_epochs_all_runs = mne.concatenate_epochs( - [ecg_epochs_all_runs, ecg_epochs], on_mismatch="warn" - ) - - del ecg_epochs - - # EOG epochs - raw = mne.io.read_raw_fif(raw_fname, preload=True) - eog_epochs = make_eog_epochs( - raw=raw, - eog_channels=cfg.eog_channels, - subject=subject, - session=session, - run=run, - ) - if eog_epochs is not None: - if idx == 0: - eog_epochs_all_runs = eog_epochs - else: - eog_epochs_all_runs = mne.concatenate_epochs( - [eog_epochs_all_runs, eog_epochs], on_mismatch="warn" - ) - - del eog_epochs - - # Produce high-pass filtered version of the data for ICA. - # Sanity check – make sure we're using the correct data! - if cfg.raw_resample_sfreq is not None: - assert np.allclose(raw.info["sfreq"], cfg.raw_resample_sfreq) - if cfg.l_freq is not None: - assert np.allclose(raw.info["highpass"], cfg.l_freq) - - filter_for_ica(cfg=cfg, raw=raw, subject=subject, session=session, run=run) - - # Only keep the subset of the mapping that applies to the current run - event_id = event_name_to_code_map.copy() - for event_name in event_id.copy().keys(): - if event_name not in raw.annotations.description: - del event_id[event_name] - - msg = "Creating task-related epochs …" - logger.info(**gen_log_kwargs(message=msg)) - epochs = make_epochs( - subject=subject, - session=session, - task=cfg.task, - conditions=cfg.conditions, - raw=raw, - event_id=event_id, - tmin=cfg.epochs_tmin, - tmax=cfg.epochs_tmax, - metadata_tmin=cfg.epochs_metadata_tmin, - metadata_tmax=cfg.epochs_metadata_tmax, - metadata_keep_first=cfg.epochs_metadata_keep_first, - metadata_keep_last=cfg.epochs_metadata_keep_last, - metadata_query=cfg.epochs_metadata_query, - event_repeated=cfg.event_repeated, - epochs_decim=cfg.epochs_decim, - task_is_rest=cfg.task_is_rest, - rest_epochs_duration=cfg.rest_epochs_duration, - rest_epochs_overlap=cfg.rest_epochs_overlap, - ) - - epochs.load_data() # Remove reference to raw - del raw # free memory - - if idx == 0: - epochs_all_runs = epochs - else: - epochs_all_runs = mne.concatenate_epochs( - [epochs_all_runs, epochs], on_mismatch="warn" - ) - - del epochs - - # Clean up namespace - epochs = epochs_all_runs - epochs_ecg = ecg_epochs_all_runs - epochs_eog = eog_epochs_all_runs - - del epochs_all_runs, eog_epochs_all_runs, ecg_epochs_all_runs - - # Set an EEG reference - if "eeg" in cfg.ch_types: - projection = True if cfg.eeg_reference == "average" else False - epochs.set_eeg_reference(cfg.eeg_reference, projection=projection) - - # Reject epochs based on peak-to-peak rejection thresholds - ica_reject = _get_reject( - subject=subject, - session=session, - reject=cfg.ica_reject, - ch_types=cfg.ch_types, - param="ica_reject", - ) - - msg = f"Using PTP rejection thresholds: {ica_reject}" - logger.info(**gen_log_kwargs(message=msg)) - - epochs.drop_bad(reject=ica_reject) - if epochs_eog is not None: - epochs_eog.drop_bad(reject=ica_reject) - if epochs_ecg is not None: - epochs_ecg.drop_bad(reject=ica_reject) - - # Now actually perform ICA. - msg = "Calculating ICA solution." - logger.info(**gen_log_kwargs(message=msg)) - ica = fit_ica(cfg=cfg, epochs=epochs, subject=subject, session=session) - - # Start a report - title = f"ICA – sub-{subject}" - if session is not None: - title += f", ses-{session}" - if cfg.task is not None: - title += f", task-{cfg.task}" - - # ECG and EOG component detection - if epochs_ecg: - ecg_ics, ecg_scores = detect_bad_components( - cfg=cfg, - which="ecg", - epochs=epochs_ecg, - ica=ica, - ch_names=None, # we currently don't allow for custom channels - subject=subject, - session=session, - ) - else: - ecg_ics = ecg_scores = [] - - if epochs_eog: - eog_ics, eog_scores = detect_bad_components( - cfg=cfg, - which="eog", - epochs=epochs_eog, - ica=ica, - ch_names=cfg.eog_channels, - subject=subject, - session=session, - ) - else: - eog_ics = eog_scores = [] - - # Save ICA to disk. - # We also store the automatically identified ECG- and EOG-related ICs. - msg = "Saving ICA solution and detected artifacts to disk." - logger.info(**gen_log_kwargs(message=msg)) - ica.exclude = sorted(set(ecg_ics + eog_ics)) - ica.save(out_files["ica"], overwrite=True) - _update_for_splits(out_files, "ica") - - # Create TSV. - tsv_data = pd.DataFrame( - dict( - component=list(range(ica.n_components_)), - type=["ica"] * ica.n_components_, - description=["Independent Component"] * ica.n_components_, - status=["good"] * ica.n_components_, - status_description=["n/a"] * ica.n_components_, - ) - ) - - for component in ecg_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected ECG artifact" - - for component in eog_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected EOG artifact" - - tsv_data.to_csv(out_files["components"], sep="\t", index=False) - - # Lastly, add info about the epochs used for the ICA fit, and plot all ICs - # for manual inspection. - msg = "Adding diagnostic plots for all ICA components to the HTML report …" - logger.info(**gen_log_kwargs(message=msg)) - - report = Report(info_fname=epochs, title=title, verbose=False) - ecg_evoked = None if epochs_ecg is None else epochs_ecg.average() - eog_evoked = None if epochs_eog is None else epochs_eog.average() - ecg_scores = None if len(ecg_scores) == 0 else ecg_scores - eog_scores = None if len(eog_scores) == 0 else eog_scores - - with _agg_backend(): - report.add_epochs( - epochs=epochs, - title="Epochs used for ICA fitting", - drop_log_ignore=(), - replace=True, - ) - report.add_ica( - ica=ica, - title="ICA cleaning", - inst=epochs, - ecg_evoked=ecg_evoked, - eog_evoked=eog_evoked, - ecg_scores=ecg_scores, - eog_scores=eog_scores, - replace=True, - n_jobs=1, # avoid automatic parallelization - ) - - msg = ( - f"ICA completed. Please carefully review the extracted ICs in the " - f"report {out_files['report'].basename}, and mark all components " - f"you wish to reject as 'bad' in " - f"{out_files['components'].basename}" - ) - logger.info(**gen_log_kwargs(message=msg)) - - report.save( - out_files["report"], - overwrite=True, - open_browser=exec_params.interactive, - ) - - assert len(in_files) == 0, in_files.keys() - return out_files - - -def get_config( - *, - config: SimpleNamespace, - subject: Optional[str] = None, - session: Optional[str] = None, -) -> SimpleNamespace: - cfg = SimpleNamespace( - conditions=config.conditions, - runs=get_runs(config=config, subject=subject), - task_is_rest=config.task_is_rest, - ica_l_freq=config.ica_l_freq, - ica_algorithm=config.ica_algorithm, - ica_n_components=config.ica_n_components, - ica_max_iterations=config.ica_max_iterations, - ica_decim=config.ica_decim, - ica_reject=config.ica_reject, - ica_eog_threshold=config.ica_eog_threshold, - ica_ctps_ecg_threshold=config.ica_ctps_ecg_threshold, - random_state=config.random_state, - ch_types=config.ch_types, - l_freq=config.l_freq, - epochs_decim=config.epochs_decim, - raw_resample_sfreq=config.raw_resample_sfreq, - event_repeated=config.event_repeated, - epochs_tmin=config.epochs_tmin, - epochs_tmax=config.epochs_tmax, - epochs_metadata_tmin=config.epochs_metadata_tmin, - epochs_metadata_tmax=config.epochs_metadata_tmax, - epochs_metadata_keep_first=config.epochs_metadata_keep_first, - epochs_metadata_keep_last=config.epochs_metadata_keep_last, - epochs_metadata_query=config.epochs_metadata_query, - eeg_reference=get_eeg_reference(config), - eog_channels=config.eog_channels, - rest_epochs_duration=config.rest_epochs_duration, - rest_epochs_overlap=config.rest_epochs_overlap, - **_bids_kwargs(config=config), - ) - return cfg - - -def main(*, config: SimpleNamespace) -> None: - """Run ICA.""" - if config.spatial_filter != "ica": - msg = "Skipping …" - logger.info(**gen_log_kwargs(message=msg, emoji="skip")) - return - - with get_parallel_backend(config.exec_params): - parallel, run_func = parallel_func(run_ica, exec_params=config.exec_params) - logs = parallel( - run_func( - cfg=get_config(config=config, subject=subject), - exec_params=config.exec_params, - subject=subject, - session=session, - ) - for subject in get_subjects(config) - for session in get_sessions(config) - ) - save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py b/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py index 2d73a6ce3..b17816a7e 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py +++ b/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py @@ -1,35 +1,42 @@ -"""Run Signal Subspace Projections (SSP) for artifact correction. +"""Compute SSP. +Signal subspace projections (SSP) vectors are computed from EOG and ECG signals. These are often also referred to as PCA vectors. """ -from typing import Optional from types import SimpleNamespace import mne -from mne.preprocessing import create_eog_epochs, create_ecg_epochs -from mne import compute_proj_evoked, compute_proj_epochs +import numpy as np +from mne import compute_proj_epochs, compute_proj_evoked +from mne.preprocessing import find_ecg_events, find_eog_events from mne_bids import BIDSPath -from mne.utils import _pl from ..._config_utils import ( + _bids_kwargs, + _pl, + _proj_path, get_runs, get_sessions, get_subjects, - _bids_kwargs, ) from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend +from ..._parallel import get_parallel_backend, parallel_func from ..._reject import _get_reject from ..._report import _open_report -from ..._run import failsafe_run, _update_for_splits, save_logs +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs + + +def _find_ecg_events(raw: mne.io.Raw, ch_name: str | None) -> np.ndarray: + """Wrap find_ecg_events to use the same defaults as create_ecg_events.""" + return find_ecg_events(raw, ch_name=ch_name, l_freq=8, h_freq=16)[0] def get_input_fnames_run_ssp( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> dict: bids_basename = BIDSPath( subject=subject, @@ -47,7 +54,7 @@ def get_input_fnames_run_ssp( for run in cfg.runs: key = f"raw_run-{run}" in_files[key] = bids_basename.copy().update( - run=run, processing="filt", suffix="raw" + run=run, processing=cfg.processing, suffix="raw" ) _update_for_splits(in_files, key, single=True) return in_files @@ -61,22 +68,15 @@ def run_ssp( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, in_files: dict, ) -> dict: import matplotlib.pyplot as plt - # compute SSP on first run of raw + # compute SSP on all runs of raw raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] - # when saving proj, use run=None - out_files = dict() - out_files["proj"] = ( - raw_fnames[0] - .copy() - .update(run=None, suffix="proj", split=None, processing=None, check=False) - ) - + out_files = dict(proj=_proj_path(cfg=cfg, subject=subject, session=session)) msg = ( f"Input{_pl(raw_fnames)} ({len(raw_fnames)}): " f'{raw_fnames[0].basename}{_pl(raw_fnames, pl=" ...")}' @@ -93,7 +93,7 @@ def run_ssp( projs = dict() proj_kinds = ("ecg", "eog") rate_names = dict(ecg="heart", eog="blink") - epochs_fun = dict(ecg=create_ecg_epochs, eog=create_eog_epochs) + events_fun = dict(ecg=_find_ecg_events, eog=find_eog_events) minimums = dict(ecg=cfg.min_ecg_epochs, eog=cfg.min_eog_epochs) rejects = dict(ecg=cfg.ssp_reject_ecg, eog=cfg.ssp_reject_eog) avg = dict(ecg=cfg.ecg_proj_from_average, eog=cfg.eog_proj_from_average) @@ -111,17 +111,38 @@ def run_ssp( projs[kind] = [] if not any(n_projs[kind].values()): continue - proj_epochs = epochs_fun[kind]( - raw, - ch_name=ch_name[kind], - decim=cfg.epochs_decim, - ) - n_orig = len(proj_epochs.selection) + events = events_fun[kind](raw=raw, ch_name=ch_name[kind]) + n_orig = len(events) rate = n_orig / raw.times[-1] * 60 bpm_msg = f"{rate:5.1f} bpm" msg = f"Detected {rate_names[kind]} rate: {bpm_msg}" logger.info(**gen_log_kwargs(message=msg)) - # Enough to start + # Enough to create epochs + if len(events) < minimums[kind]: + msg = ( + f"No {kind.upper()} projectors computed: got " + f"{len(events)} original events < {minimums[kind]} {bpm_msg}" + ) + logger.warning(**gen_log_kwargs(message=msg)) + continue + out_files[f"events_{kind}"] = ( + out_files["proj"] + .copy() + .update(suffix=f"{kind}-eve", split=None, check=False, extension=".txt") + ) + mne.write_events(out_files[f"events_{kind}"], events, overwrite=True) + proj_epochs = mne.Epochs( + raw, + events=events, + event_id=events[0, 2], + tmin=-0.5, + tmax=0.5, + proj=False, + baseline=(None, None), + reject_by_annotation=True, + preload=True, + decim=cfg.epochs_decim, + ) if len(proj_epochs) >= minimums[kind]: reject_ = _get_reject( subject=subject, @@ -134,7 +155,6 @@ def run_ssp( proj_epochs.drop_bad(reject=reject_) # Still enough after rejection if len(proj_epochs) >= minimums[kind]: - proj_epochs.apply_baseline((None, None)) use = proj_epochs.average() if avg[kind] else proj_epochs fun = compute_proj_evoked if avg[kind] else compute_proj_epochs desc_prefix = ( @@ -205,7 +225,7 @@ def run_ssp( replace=True, ) plt.close(fig) - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( @@ -229,6 +249,7 @@ def get_config( epochs_decim=config.epochs_decim, use_maxwell_filter=config.use_maxwell_filter, runs=get_runs(config=config, subject=subject), + processing="filt" if config.regress_artifact is None else "regress", **_bids_kwargs(config=config), ) return cfg diff --git a/mne_bids_pipeline/steps/preprocessing/_05_make_epochs.py b/mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py similarity index 92% rename from mne_bids_pipeline/steps/preprocessing/_05_make_epochs.py rename to mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py index 516061726..47f717959 100644 --- a/mne_bids_pipeline/steps/preprocessing/_05_make_epochs.py +++ b/mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py @@ -7,36 +7,37 @@ To save space, the epoch data can be decimated. """ +import inspect from types import SimpleNamespace -from typing import Optional import mne from mne_bids import BIDSPath from ..._config_utils import ( - get_runs, - get_subjects, + _bids_kwargs, get_eeg_reference, + get_runs, get_sessions, - _bids_kwargs, + get_subjects, ) -from ..._import_data import make_epochs, annotations_to_events +from ..._import_data import annotations_to_events, make_epochs from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func from ..._report import _open_report from ..._run import ( + _prep_out_files, + _sanitize_callable, + _update_for_splits, failsafe_run, save_logs, - _update_for_splits, - _sanitize_callable, ) -from ..._parallel import parallel_func, get_parallel_backend def get_input_fnames_epochs( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> dict: """Get paths of files required by filter_data function.""" # Construct the basenames of the files we wish to load, and of the empty- @@ -53,7 +54,7 @@ def get_input_fnames_epochs( extension=".fif", datatype=cfg.datatype, root=cfg.deriv_root, - processing="filt", + processing=cfg.processing, ).update(suffix="raw", check=False) # Generate a list of raw data paths (i.e., paths of individual runs) @@ -77,7 +78,7 @@ def run_epochs( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, in_files: dict, ) -> dict: """Extract epochs for one subject.""" @@ -213,7 +214,10 @@ def run_epochs( logger.info(**gen_log_kwargs(message=msg)) out_files = dict() out_files["epochs"] = bids_path_in.copy().update( - suffix="epo", processing=None, check=False + suffix="epo", + processing=None, + check=False, + split=None, ) epochs.save( out_files["epochs"], @@ -255,6 +259,7 @@ def run_epochs( psd=psd, drop_log_ignore=(), replace=True, + **_add_epochs_image_kwargs(cfg), ) # Interactive @@ -262,7 +267,15 @@ def run_epochs( epochs.plot() epochs.plot_image(combine="gfp", sigma=2.0, cmap="YlGnBu_r") assert len(in_files) == 0, in_files.keys() - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def _add_epochs_image_kwargs(cfg: SimpleNamespace) -> dict: + arg_spec = inspect.getfullargspec(mne.Report.add_epochs) + kwargs = dict() + if cfg.report_add_epochs_image_kwargs and "image_kwargs" in arg_spec.kwonlyargs: + kwargs["image_kwargs"] = cfg.report_add_epochs_image_kwargs + return kwargs # TODO: ideally we wouldn't need this anymore and could refactor the code above @@ -275,7 +288,7 @@ def _get_events(cfg, subject, session): acquisition=cfg.acq, recording=cfg.rec, space=cfg.space, - processing="filt", + processing=cfg.processing, suffix="raw", extension=".fif", datatype=cfg.datatype, @@ -314,6 +327,7 @@ def get_config( epochs_metadata_query=config.epochs_metadata_query, event_repeated=config.event_repeated, epochs_decim=config.epochs_decim, + report_add_epochs_image_kwargs=config.report_add_epochs_image_kwargs, ch_types=config.ch_types, noise_cov=_sanitize_callable(config.noise_cov), eeg_reference=get_eeg_reference(config), @@ -321,6 +335,7 @@ def get_config( rest_epochs_overlap=config.rest_epochs_overlap, _epochs_split_size=config._epochs_split_size, runs=get_runs(config=config, subject=subject), + processing="filt" if config.regress_artifact is None else "regress", **_bids_kwargs(config=config), ) return cfg diff --git a/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py b/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py deleted file mode 100644 index 2f0f84bdd..000000000 --- a/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Apply ICA and obtain the cleaned epochs. - -Blinks and ECG artifacts are automatically detected and the corresponding ICA -components are removed from the data. -This relies on the ICAs computed in 04-run_ica.py - -!! If you manually add components to remove (config.rejcomps_man), -make sure you did not re-run the ICA in the meantime. Otherwise (especially if -the random state was not set, or you used a different machine, the component -order might differ). - -""" - -from types import SimpleNamespace -from typing import Optional - -import pandas as pd -import mne -from mne.preprocessing import read_ica -from mne.report import Report - -from mne_bids import BIDSPath - -from ..._config_utils import ( - get_subjects, - get_sessions, - _bids_kwargs, -) -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._reject import _get_reject -from ..._report import _open_report, _agg_backend -from ..._run import failsafe_run, _update_for_splits, save_logs - - -def get_input_fnames_apply_ica( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], -) -> dict: - bids_basename = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - in_files = dict() - in_files["ica"] = bids_basename.copy().update(suffix="ica", extension=".fif") - in_files["components"] = bids_basename.copy().update( - processing="ica", suffix="components", extension=".tsv" - ) - in_files["epochs"] = bids_basename.copy().update(suffix="epo", extension=".fif") - return in_files - - -@failsafe_run( - get_input_fnames=get_input_fnames_apply_ica, -) -def apply_ica( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], - in_files: dict, -) -> dict: - bids_basename = in_files["ica"].copy().update(processing=None) - out_files = dict() - out_files["epochs"] = in_files["epochs"].copy().update(processing="ica") - out_files["report"] = bids_basename.copy().update( - processing="ica", suffix="report", extension=".html" - ) - - title = f"ICA artifact removal – sub-{subject}" - if session is not None: - title += f", ses-{session}" - if cfg.task is not None: - title += f", task-{cfg.task}" - - # Load ICA. - msg = f"Reading ICA: {in_files['ica']}" - logger.debug(**gen_log_kwargs(message=msg)) - ica = read_ica(fname=in_files.pop("ica")) - - # Select ICs to remove. - tsv_data = pd.read_csv(in_files.pop("components"), sep="\t") - ica.exclude = tsv_data.loc[tsv_data["status"] == "bad", "component"].to_list() - - # Load epochs to reject ICA components. - msg = f'Input: {in_files["epochs"].basename}' - logger.info(**gen_log_kwargs(message=msg)) - msg = f'Output: {out_files["epochs"].basename}' - logger.info(**gen_log_kwargs(message=msg)) - - epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) - ica_reject = _get_reject( - subject=subject, - session=session, - reject=cfg.ica_reject, - ch_types=cfg.ch_types, - param="ica_reject", - ) - epochs.drop_bad(ica_reject) - - # Now actually reject the components. - msg = f'Rejecting ICs: {", ".join([str(ic) for ic in ica.exclude])}' - logger.info(**gen_log_kwargs(message=msg)) - epochs_cleaned = ica.apply(epochs.copy()) # Copy b/c works in-place! - - msg = "Saving reconstructed epochs after ICA." - logger.info(**gen_log_kwargs(message=msg)) - epochs_cleaned.save( - out_files["epochs"], - overwrite=True, - split_naming="bids", - split_size=cfg._epochs_split_size, - ) - _update_for_splits(out_files, "epochs") - - # Compare ERP/ERF before and after ICA artifact rejection. The evoked - # response is calculated across ALL epochs, just like ICA was run on - # all epochs, regardless of their respective experimental condition. - # - # We apply baseline correction here to (hopefully!) make the effects of - # ICA easier to see. Otherwise, individual channels might just have - # arbitrary DC shifts, and we wouldn't be able to easily decipher what's - # going on! - report = Report(out_files["report"], title=title, verbose=False) - picks = ica.exclude if ica.exclude else None - with _agg_backend(): - report.add_ica( - ica=ica, - title="Effects of ICA cleaning", - inst=epochs.copy().apply_baseline(cfg.baseline), - picks=picks, - replace=True, - n_jobs=1, # avoid automatic parallelization - ) - report.save( - out_files["report"], - overwrite=True, - open_browser=exec_params.interactive, - ) - - assert len(in_files) == 0, in_files.keys() - - # Report - if ica.exclude: - msg = "Adding ICA to report." - else: - msg = "Skipping ICA addition to report, no components marked as bad." - logger.info(**gen_log_kwargs(message=msg)) - if ica.exclude: - with _open_report( - cfg=cfg, exec_params=exec_params, subject=subject, session=session - ) as report: - report.add_ica( - ica=ica, - title="ICA", - inst=epochs, - picks=ica.exclude, - # TODO upstream - # captions=f'Evoked response (across all epochs) ' - # f'before and after ICA ' - # f'({len(ica.exclude)} ICs removed)' - replace=True, - ) - - return out_files - - -def get_config( - *, - config: SimpleNamespace, -) -> SimpleNamespace: - cfg = SimpleNamespace( - baseline=config.baseline, - ica_reject=config.ica_reject, - ch_types=config.ch_types, - _epochs_split_size=config._epochs_split_size, - **_bids_kwargs(config=config), - ) - return cfg - - -def main(*, config: SimpleNamespace) -> None: - """Apply ICA.""" - if not config.spatial_filter == "ica": - msg = "Skipping …" - logger.info(**gen_log_kwargs(message=msg, emoji="skip")) - return - - with get_parallel_backend(config.exec_params): - parallel, run_func = parallel_func(apply_ica, exec_params=config.exec_params) - logs = parallel( - run_func( - cfg=get_config( - config=config, - ), - exec_params=config.exec_params, - subject=subject, - session=session, - ) - for subject in get_subjects(config) - for session in get_sessions(config) - ) - save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py b/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py deleted file mode 100644 index 84ee81a3d..000000000 --- a/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Apply SSP projections and obtain the cleaned epochs. - -Blinks and ECG artifacts are automatically detected and the corresponding SSP -projections components are removed from the data. - -""" - -from types import SimpleNamespace -from typing import Optional - -import mne -from mne_bids import BIDSPath - -from ..._config_utils import ( - get_sessions, - get_subjects, - _bids_kwargs, -) -from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, _update_for_splits, save_logs -from ..._parallel import parallel_func, get_parallel_backend - - -def get_input_fnames_apply_ssp( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], -) -> dict: - bids_basename = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - datatype=cfg.datatype, - root=cfg.deriv_root, - extension=".fif", - check=False, - ) - in_files = dict() - in_files["epochs"] = bids_basename.copy().update(suffix="epo", check=False) - in_files["proj"] = bids_basename.copy().update(suffix="proj", check=False) - return in_files - - -@failsafe_run( - get_input_fnames=get_input_fnames_apply_ssp, -) -def apply_ssp( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], - in_files: dict, -) -> dict: - # load epochs to reject ICA components - # compute SSP on first run of raw - out_files = dict() - out_files["epochs"] = ( - in_files["epochs"].copy().update(processing="ssp", check=False) - ) - msg = f"Input epochs: {in_files['epochs'].basename}" - logger.info(**gen_log_kwargs(message=msg)) - msg = f'Input SSP: {in_files["proj"].basename}' - logger.info(**gen_log_kwargs(message=msg)) - msg = f"Output: {out_files['epochs'].basename}" - logger.info(**gen_log_kwargs(message=msg)) - epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) - projs = mne.read_proj(in_files.pop("proj")) - epochs_cleaned = epochs.copy().add_proj(projs).apply_proj() - epochs_cleaned.save( - out_files["epochs"], - overwrite=True, - split_naming="bids", - split_size=cfg._epochs_split_size, - ) - _update_for_splits(out_files, "epochs") - assert len(in_files) == 0, in_files.keys() - return out_files - - -def get_config( - *, - config: SimpleNamespace, -) -> SimpleNamespace: - cfg = SimpleNamespace( - _epochs_split_size=config._epochs_split_size, - **_bids_kwargs(config=config), - ) - return cfg - - -def main(*, config: SimpleNamespace) -> None: - """Apply ssp.""" - if not config.spatial_filter == "ssp": - msg = "Skipping …" - logger.info(**gen_log_kwargs(message=msg, emoji="skip")) - return - - with get_parallel_backend(config.exec_params): - parallel, run_func = parallel_func(apply_ssp, exec_params=config.exec_params) - logs = parallel( - run_func( - cfg=get_config( - config=config, - ), - exec_params=config.exec_params, - subject=subject, - session=session, - ) - for subject in get_subjects(config) - for session in get_sessions(config) - ) - save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py b/mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py deleted file mode 100644 index 9f1607055..000000000 --- a/mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Remove epochs based on peak-to-peak (PTP) amplitudes. - -Epochs containing peak-to-peak above the thresholds defined -in the 'reject' parameter are removed from the data. - -This step will drop epochs containing non-biological artifacts -but also epochs containing biological artifacts not sufficiently -corrected by the ICA or the SSP processing. -""" - -from types import SimpleNamespace -from typing import Optional - -import mne -from mne_bids import BIDSPath - -from ..._config_utils import ( - get_sessions, - get_subjects, - _bids_kwargs, -) -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._reject import _get_reject -from ..._report import _open_report -from ..._run import failsafe_run, _update_for_splits, save_logs - - -def get_input_fnames_drop_ptp( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], -) -> dict: - bids_path = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - suffix="epo", - extension=".fif", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - in_files = dict() - in_files["epochs"] = bids_path.copy().update(processing=cfg.spatial_filter) - return in_files - - -@failsafe_run( - get_input_fnames=get_input_fnames_drop_ptp, -) -def drop_ptp( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], - in_files: dict, -) -> dict: - out_files = dict() - out_files["epochs"] = in_files["epochs"].copy().update(processing="clean") - msg = f'Input: {in_files["epochs"].basename}' - logger.info(**gen_log_kwargs(message=msg)) - msg = f'Output: {out_files["epochs"].basename}' - logger.info(**gen_log_kwargs(message=msg)) - - # Get rejection parameters and drop bad epochs - epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) - reject = _get_reject( - subject=subject, - session=session, - reject=cfg.reject, - ch_types=cfg.ch_types, - param="reject", - epochs=epochs, - ) - if cfg.spatial_filter == "ica": - ica_reject = _get_reject( - subject=subject, - session=session, - reject=cfg.ica_reject, - ch_types=cfg.ch_types, - param="ica_reject", - ) - else: - ica_reject = None - - if ica_reject is not None: - for ch_type, threshold in ica_reject.items(): - if ch_type in reject and threshold < reject[ch_type]: - # This can only ever happen in case of - # reject = 'autoreject_global' - msg = ( - f"Adjusting PTP rejection threshold proposed by " - f"autoreject, as it is greater than ica_reject: " - f"{ch_type}: {reject[ch_type]} -> {threshold}" - ) - logger.info(**gen_log_kwargs(message=msg)) - reject[ch_type] = threshold - - msg = f"Using PTP rejection thresholds: {reject}" - logger.info(**gen_log_kwargs(message=msg)) - - n_epochs_before_reject = len(epochs) - epochs.reject_tmin = cfg.reject_tmin - epochs.reject_tmax = cfg.reject_tmax - epochs.drop_bad(reject=reject) - n_epochs_after_reject = len(epochs) - - if 0 < n_epochs_after_reject < 0.5 * n_epochs_before_reject: - msg = ( - "More than 50% of all epochs rejected. Please check the " - "rejection thresholds." - ) - logger.warning(**gen_log_kwargs(message=msg)) - elif n_epochs_after_reject == 0: - raise RuntimeError( - "No epochs remaining after peak-to-peak-based " - "rejection. Cannot continue." - ) - - msg = "Saving cleaned, baseline-corrected epochs …" - - epochs.apply_baseline(cfg.baseline) - epochs.save( - out_files["epochs"], - overwrite=True, - split_naming="bids", - split_size=cfg._epochs_split_size, - ) - _update_for_splits(out_files, "epochs") - assert len(in_files) == 0, in_files.keys() - - # Report - msg = "Adding cleaned epochs to report." - logger.info(**gen_log_kwargs(message=msg)) - # Add PSD plots for 30s of data or all epochs if we have less available - if len(epochs) * (epochs.tmax - epochs.tmin) < 30: - psd = True - else: - psd = 30 - with _open_report( - cfg=cfg, exec_params=exec_params, subject=subject, session=session - ) as report: - report.add_epochs( - epochs=epochs, - title="Epochs: after cleaning", - psd=psd, - drop_log_ignore=(), - replace=True, - ) - return out_files - - -def get_config( - *, - config: SimpleNamespace, -) -> SimpleNamespace: - cfg = SimpleNamespace( - baseline=config.baseline, - reject_tmin=config.reject_tmin, - reject_tmax=config.reject_tmax, - spatial_filter=config.spatial_filter, - ica_reject=config.ica_reject, - reject=config.reject, - ch_types=config.ch_types, - _epochs_split_size=config._epochs_split_size, - **_bids_kwargs(config=config), - ) - return cfg - - -def main(*, config: SimpleNamespace) -> None: - """Run epochs.""" - parallel, run_func = parallel_func(drop_ptp, exec_params=config.exec_params) - - with get_parallel_backend(config.exec_params): - logs = parallel( - run_func( - cfg=get_config( - config=config, - ), - exec_params=config.exec_params, - subject=subject, - session=session, - ) - for subject in get_subjects(config) - for session in get_sessions(config) - ) - save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py b/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py new file mode 100644 index 000000000..430c4cdd3 --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py @@ -0,0 +1,302 @@ +"""Apply ICA. + +!! If you manually add components to remove, make sure you did not re-run the ICA in +the meantime. Otherwise (especially if the random state was not set, or you used a +different machine) the component order might differ. +""" + +from types import SimpleNamespace + +import mne +import pandas as pd +from mne.preprocessing import read_ica +from mne_bids import BIDSPath + +from ..._config_utils import ( + get_runs_tasks, + get_sessions, + get_subjects, +) +from ..._import_data import _get_run_rest_noise_path, _import_data_kwargs +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _add_raw, _open_report +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs + + +def _ica_paths( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +): + bids_basename = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False, + ) + in_files = dict() + in_files["ica"] = bids_basename.copy().update( + processing="ica", + suffix="ica", + extension=".fif", + ) + in_files["components"] = bids_basename.copy().update( + processing="ica", suffix="components", extension=".tsv" + ) + return in_files + + +def _read_ica_and_exclude( + in_files: dict, +) -> None: + ica = read_ica(fname=in_files.pop("ica")) + tsv_data = pd.read_csv(in_files.pop("components"), sep="\t") + ica.exclude = tsv_data.loc[tsv_data["status"] == "bad", "component"].to_list() + return ica + + +def get_input_fnames_apply_ica_epochs( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> dict: + in_files = _ica_paths(cfg=cfg, subject=subject, session=session) + in_files["epochs"] = ( + in_files["ica"] + .copy() + .update( + suffix="epo", + extension=".fif", + processing=None, + ) + ) + _update_for_splits(in_files, "epochs", single=True) + return in_files + + +def get_input_fnames_apply_ica_raw( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, +) -> dict: + in_files = _get_run_rest_noise_path( + cfg=cfg, + subject=subject, + session=session, + run=run, + task=task, + kind="filt", + mf_reference_run=cfg.mf_reference_run, + ) + assert len(in_files) + in_files.update(_ica_paths(cfg=cfg, subject=subject, session=session)) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_apply_ica_epochs, +) +def apply_ica_epochs( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: dict, +) -> dict: + out_files = dict() + out_files["epochs"] = in_files["epochs"].copy().update(processing="ica", split=None) + + title = f"ICA artifact removal – sub-{subject}" + if session is not None: + title += f", ses-{session}" + if cfg.task is not None: + title += f", task-{cfg.task}" + + # Load ICA. + msg = f"Reading ICA: {in_files['ica']}" + logger.debug(**gen_log_kwargs(message=msg)) + ica = _read_ica_and_exclude(in_files) + + # Load epochs. + msg = f'Input: {in_files["epochs"].basename}' + logger.info(**gen_log_kwargs(message=msg)) + msg = f'Output: {out_files["epochs"].basename}' + logger.info(**gen_log_kwargs(message=msg)) + + epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) + + # Now actually reject the components. + msg = ( + f'Rejecting ICs with the following indices: ' + f'{", ".join([str(i) for i in ica.exclude])}' + ) + logger.info(**gen_log_kwargs(message=msg)) + epochs_cleaned = ica.apply(epochs.copy()) # Copy b/c works in-place! + + msg = f"Saving {len(epochs)} reconstructed epochs after ICA." + logger.info(**gen_log_kwargs(message=msg)) + epochs_cleaned.save( + out_files["epochs"], + overwrite=True, + split_naming="bids", + split_size=cfg._epochs_split_size, + ) + _update_for_splits(out_files, "epochs") + assert len(in_files) == 0, in_files.keys() + + # Report + kwargs = dict() + if ica.exclude: + msg = "Adding ICA to report." + else: + msg = "Skipping ICA addition to report, no components marked as bad." + kwargs["emoji"] = "skip" + logger.info(**gen_log_kwargs(message=msg, **kwargs)) + if ica.exclude: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + ) as report: + report.add_ica( + ica=ica, + title="ICA: removals", + inst=epochs, + picks=ica.exclude, + # TODO upstream + # captions=f'Evoked response (across all epochs) ' + # f'before and after ICA ' + # f'({len(ica.exclude)} ICs removed)' + replace=True, + n_jobs=1, # avoid automatic parallelization + ) + + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +@failsafe_run( + get_input_fnames=get_input_fnames_apply_ica_raw, +) +def apply_ica_raw( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, + in_files: dict, +) -> dict: + ica = _read_ica_and_exclude(in_files) + in_key = list(in_files)[0] + assert in_key.startswith("raw"), in_key + raw_fname = in_files.pop(in_key) + assert len(in_files) == 0, in_files + out_files = dict() + out_files[in_key] = raw_fname.copy().update(processing="clean", split=None) + msg = f"Writing {out_files[in_key].basename} …" + logger.info(**gen_log_kwargs(message=msg)) + raw = mne.io.read_raw_fif(raw_fname, preload=True) + ica.apply(raw) + raw.save(out_files[in_key], overwrite=True, split_size=cfg._raw_split_size) + _update_for_splits(out_files, in_key) + # Report + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) as report: + msg = "Adding cleaned raw data to report" + logger.info(**gen_log_kwargs(message=msg)) + _add_raw( + cfg=cfg, + report=report, + bids_path_in=out_files[in_key], + title="Raw (clean)", + tags=("clean",), + raw=raw, + ) + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, +) -> SimpleNamespace: + cfg = SimpleNamespace( + baseline=config.baseline, + ica_reject=config.ica_reject, + processing="filt" if config.regress_artifact is None else "regress", + _epochs_split_size=config._epochs_split_size, + **_import_data_kwargs(config=config, subject=subject), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Apply ICA.""" + if not config.spatial_filter == "ica": + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + # Epochs + parallel, run_func = parallel_func( + apply_ica_epochs, exec_params=config.exec_params + ) + logs = parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + ) + # Raw + parallel, run_func = parallel_func( + apply_ica_raw, exec_params=config.exec_params + ) + logs += parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + for run, task in get_runs_tasks( + config=config, + subject=subject, + session=session, + ) + ) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py b/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py new file mode 100644 index 000000000..6ab00dc12 --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py @@ -0,0 +1,208 @@ +"""Apply SSP. + +Blinks and ECG artifacts are automatically detected and the corresponding SSP +projections components are removed from the data. +""" + +from types import SimpleNamespace + +import mne + +from ..._config_utils import ( + _proj_path, + get_runs_tasks, + get_sessions, + get_subjects, +) +from ..._import_data import _get_run_rest_noise_path, _import_data_kwargs +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _add_raw, _open_report +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs + + +def get_input_fnames_apply_ssp_epochs( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> dict: + in_files = dict() + in_files["proj"] = _proj_path(cfg=cfg, subject=subject, session=session) + in_files["epochs"] = in_files["proj"].copy().update(suffix="epo", check=False) + _update_for_splits(in_files, "epochs", single=True) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_apply_ssp_epochs, +) +def apply_ssp_epochs( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: dict, +) -> dict: + out_files = dict() + out_files["epochs"] = ( + in_files["epochs"].copy().update(processing="ssp", split=None, check=False) + ) + msg = f"Input epochs: {in_files['epochs'].basename}" + logger.info(**gen_log_kwargs(message=msg)) + msg = f'Input SSP: {in_files["proj"].basename}' + logger.info(**gen_log_kwargs(message=msg)) + msg = f"Output: {out_files['epochs'].basename}" + logger.info(**gen_log_kwargs(message=msg)) + epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) + projs = mne.read_proj(in_files.pop("proj")) + epochs_cleaned = epochs.copy().add_proj(projs).apply_proj() + + msg = f"Saving {len(epochs_cleaned)} reconstructed epochs after SSP." + logger.info(**gen_log_kwargs(message=msg)) + + epochs_cleaned.save( + out_files["epochs"], + overwrite=True, + split_naming="bids", + split_size=cfg._epochs_split_size, + ) + _update_for_splits(out_files, "epochs") + assert len(in_files) == 0, in_files.keys() + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_input_fnames_apply_ssp_raw( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, +) -> dict: + in_files = _get_run_rest_noise_path( + cfg=cfg, + subject=subject, + session=session, + run=run, + task=task, + kind="filt", + mf_reference_run=cfg.mf_reference_run, + ) + assert len(in_files) + in_files["proj"] = _proj_path(cfg=cfg, subject=subject, session=session) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_apply_ssp_raw, +) +def apply_ssp_raw( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, + in_files: dict, +) -> dict: + projs = mne.read_proj(in_files.pop("proj")) + in_key = list(in_files.keys())[0] + assert in_key.startswith("raw"), in_key + raw_fname = in_files.pop(in_key) + assert len(in_files) == 0, in_files.keys() + raw = mne.io.read_raw_fif(raw_fname) + raw.add_proj(projs) + out_files = dict() + out_files[in_key] = raw_fname.copy().update(processing="clean", split=None) + msg = f"Writing {out_files[in_key].basename} …" + logger.info(**gen_log_kwargs(message=msg)) + raw.save(out_files[in_key], overwrite=True, split_size=cfg._raw_split_size) + _update_for_splits(out_files, in_key) + # Report + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) as report: + msg = "Adding cleaned raw data to report" + logger.info(**gen_log_kwargs(message=msg)) + _add_raw( + cfg=cfg, + report=report, + bids_path_in=out_files[in_key], + title="Raw (clean)", + tags=("clean",), + raw=raw, + ) + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, +) -> SimpleNamespace: + cfg = SimpleNamespace( + processing="filt" if config.regress_artifact is None else "regress", + _epochs_split_size=config._epochs_split_size, + **_import_data_kwargs(config=config, subject=subject), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Apply ssp.""" + if not config.spatial_filter == "ssp": + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + # Epochs + parallel, run_func = parallel_func( + apply_ssp_epochs, exec_params=config.exec_params + ) + logs = parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + ) + # Raw + parallel, run_func = parallel_func( + apply_ssp_raw, exec_params=config.exec_params + ) + logs += parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + for run, task in get_runs_tasks( + config=config, + subject=subject, + session=session, + ) + ) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_09_ptp_reject.py b/mne_bids_pipeline/steps/preprocessing/_09_ptp_reject.py new file mode 100644 index 000000000..bf4bcfe40 --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_09_ptp_reject.py @@ -0,0 +1,279 @@ +"""Remove epochs based on PTP amplitudes. + +Epochs containing peak-to-peak (PTP) above the thresholds defined +in the 'reject' parameter are removed from the data. + +This step will drop epochs containing non-biological artifacts +but also epochs containing biological artifacts not sufficiently +corrected by the ICA or the SSP processing. +""" + +from types import SimpleNamespace + +import autoreject +import mne +import numpy as np +from mne_bids import BIDSPath + +from ..._config_utils import ( + _bids_kwargs, + get_sessions, + get_subjects, +) +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._reject import _get_reject +from ..._report import _open_report +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs +from ._07_make_epochs import _add_epochs_image_kwargs + + +def get_input_fnames_drop_ptp( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> dict: + bids_path = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + run=None, + recording=cfg.rec, + space=cfg.space, + suffix="epo", + extension=".fif", + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False, + ) + in_files = dict() + in_files["epochs"] = bids_path.copy().update(processing=cfg.spatial_filter) + _update_for_splits(in_files, "epochs", single=True) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_drop_ptp, +) +def drop_ptp( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: dict, +) -> dict: + import matplotlib.pyplot as plt + + out_files = dict() + out_files["epochs"] = ( + in_files["epochs"] + .copy() + .update( + processing="clean", + split=None, + ) + ) + msg = f'Input: {in_files["epochs"].basename}' + logger.info(**gen_log_kwargs(message=msg)) + msg = f'Output: {out_files["epochs"].basename}' + logger.info(**gen_log_kwargs(message=msg)) + + # Get rejection parameters and drop bad epochs + epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) + + if cfg.reject == "autoreject_local": + msg = ( + "Using autoreject to find and repair bad epochs (interpolating bad " + "segments)" + ) + logger.info(**gen_log_kwargs(message=msg)) + + ar = autoreject.AutoReject( + n_interpolate=np.array(cfg.autoreject_n_interpolate), + random_state=cfg.random_state, + n_jobs=exec_params.n_jobs, + verbose=False, + ) + n_epochs_before_reject = len(epochs) + epochs, reject_log = ar.fit_transform(epochs, return_log=True) + n_epochs_after_reject = len(epochs) + assert ( + n_epochs_before_reject - n_epochs_after_reject + == reject_log.bad_epochs.sum() + ) + + msg = ( + f"autoreject marked {reject_log.bad_epochs.sum()} epochs as bad " + f"(cross-validated n_interpolate limit: {ar.n_interpolate_})" + ) + logger.info(**gen_log_kwargs(message=msg)) + else: + reject = _get_reject( + subject=subject, + session=session, + reject=cfg.reject, + ch_types=cfg.ch_types, + param="reject", + epochs=epochs, + ) + + if cfg.spatial_filter == "ica" and cfg.ica_reject != "autoreject_local": + ica_reject = _get_reject( + subject=subject, + session=session, + reject=cfg.ica_reject, + ch_types=cfg.ch_types, + param="ica_reject", + ) + else: + ica_reject = None + + if ica_reject is not None: + for ch_type, threshold in ica_reject.items(): + if ch_type in reject and threshold < reject[ch_type]: + # This can only ever happen in case of + # reject = 'autoreject_global' + msg = ( + f"Adjusting PTP rejection threshold proposed by " + f"autoreject, as it is greater than ica_reject: " + f"{ch_type}: {reject[ch_type]} -> {threshold}" + ) + logger.info(**gen_log_kwargs(message=msg)) + reject[ch_type] = threshold + + n_epochs_before_reject = len(epochs) + epochs.reject_tmin = cfg.reject_tmin + epochs.reject_tmax = cfg.reject_tmax + epochs.drop_bad(reject=reject) + n_epochs_after_reject = len(epochs) + n_epochs_rejected = n_epochs_before_reject - n_epochs_after_reject + + msg = ( + f"Removed {n_epochs_rejected} of {n_epochs_before_reject} epochs via PTP " + f"rejection thresholds: {reject}" + ) + logger.info(**gen_log_kwargs(message=msg)) + + if 0 < n_epochs_after_reject < 0.5 * n_epochs_before_reject: + msg = ( + "More than 50% of all epochs rejected. Please check the " + "rejection thresholds." + ) + logger.warning(**gen_log_kwargs(message=msg)) + elif n_epochs_after_reject == 0: + rejection_type = ( + cfg.reject + if cfg.reject in ["autoreject_global", "autoreject_local"] + else "PTP-based" + ) + raise RuntimeError( + f"No epochs remaining after {rejection_type} rejection. Cannot continue." + ) + + msg = f"Saving {n_epochs_after_reject} cleaned, baseline-corrected epochs …" + + epochs.apply_baseline(cfg.baseline) + epochs.save( + out_files["epochs"], + overwrite=True, + split_naming="bids", + split_size=cfg._epochs_split_size, + ) + _update_for_splits(out_files, "epochs") + assert len(in_files) == 0, in_files.keys() + + # Report + msg = "Adding cleaned epochs to report." + logger.info(**gen_log_kwargs(message=msg)) + # Add PSD plots for 30s of data or all epochs if we have less available + if len(epochs) * (epochs.tmax - epochs.tmin) < 30: + psd = True + else: + psd = 30 + tags = ("epochs", "clean") + kind = cfg.reject if isinstance(cfg.reject, str) else "Rejection" + title = "Epochs: after cleaning" + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + if cfg.reject == "autoreject_local": + caption = ( + f"Autoreject was run to produce cleaner epochs. " + f"{reject_log.bad_epochs.sum()} epochs were rejected because more than " + f"{ar.n_interpolate_} channels were bad (cross-validated n_interpolate " + f"limit; excluding globally bad and non-data channels, shown in white)." + ) + fig = reject_log.plot(orientation="horizontal", aspect="auto", show=False) + report.add_figure( + fig=fig, + title=f"{kind} cleaning", + caption=caption, + section=title, + tags=tags, + replace=True, + ) + plt.close(fig) + del caption + else: + report.add_html( + html=f"{reject}", + title=f"{kind} thresholds", + section=title, + replace=True, + tags=tags, + ) + + report.add_epochs( + epochs=epochs, + title=title, + psd=psd, + drop_log_ignore=(), + tags=tags, + replace=True, + **_add_epochs_image_kwargs(cfg=cfg), + ) + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, +) -> SimpleNamespace: + cfg = SimpleNamespace( + baseline=config.baseline, + reject_tmin=config.reject_tmin, + reject_tmax=config.reject_tmax, + spatial_filter=config.spatial_filter, + ica_reject=config.ica_reject, + reject=config.reject, + autoreject_n_interpolate=config.autoreject_n_interpolate, + random_state=config.random_state, + ch_types=config.ch_types, + _epochs_split_size=config._epochs_split_size, + report_add_epochs_image_kwargs=config.report_add_epochs_image_kwargs, + **_bids_kwargs(config=config), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Run epochs.""" + parallel, run_func = parallel_func(drop_ptp, exec_params=config.exec_params) + + with get_parallel_backend(config.exec_params): + logs = parallel( + run_func( + cfg=get_config( + config=config, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject in get_subjects(config) + for session in get_sessions(config) + ) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/__init__.py b/mne_bids_pipeline/steps/preprocessing/__init__.py index 95637ecab..f9072617c 100644 --- a/mne_bids_pipeline/steps/preprocessing/__init__.py +++ b/mne_bids_pipeline/steps/preprocessing/__init__.py @@ -1,25 +1,31 @@ """Preprocessing.""" -from . import _01_data_quality -from . import _02_head_pos -from . import _03_maxfilter -from . import _04_frequency_filter -from . import _05_make_epochs -from . import _06a_run_ica -from . import _06b_run_ssp -from . import _07a_apply_ica -from . import _07b_apply_ssp -from . import _08_ptp_reject +from . import ( + _01_data_quality, + _02_head_pos, + _03_maxfilter, + _04_frequency_filter, + _05_regress_artifact, + _06a1_fit_ica, + _06a2_find_ica_artifacts, + _06b_run_ssp, + _07_make_epochs, + _08a_apply_ica, + _08b_apply_ssp, + _09_ptp_reject, +) _STEPS = ( _01_data_quality, _02_head_pos, _03_maxfilter, _04_frequency_filter, - _05_make_epochs, - _06a_run_ica, + _05_regress_artifact, + _06a1_fit_ica, + _06a2_find_ica_artifacts, _06b_run_ssp, - _07a_apply_ica, - _07b_apply_ssp, - _08_ptp_reject, + _07_make_epochs, + _08a_apply_ica, + _08b_apply_ssp, + _09_ptp_reject, ) diff --git a/mne_bids_pipeline/steps/sensor/_01_make_evoked.py b/mne_bids_pipeline/steps/sensor/_01_make_evoked.py index 0c0ed1ecf..879baf298 100644 --- a/mne_bids_pipeline/steps/sensor/_01_make_evoked.py +++ b/mne_bids_pipeline/steps/sensor/_01_make_evoked.py @@ -1,29 +1,36 @@ """Extract evoked data for each condition.""" from types import SimpleNamespace -from typing import Optional import mne from mne_bids import BIDSPath from ..._config_utils import ( - get_sessions, - get_subjects, - get_all_contrasts, _bids_kwargs, + _pl, _restrict_analyze_channels, + get_all_contrasts, + get_eeg_reference, + get_sessions, + get_subjects, ) from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, _sanitize_cond_tag -from ..._run import failsafe_run, save_logs, _sanitize_callable +from ..._parallel import get_parallel_backend, parallel_func +from ..._report import _all_conditions, _open_report, _sanitize_cond_tag +from ..._run import ( + _prep_out_files, + _sanitize_callable, + _update_for_splits, + failsafe_run, + save_logs, +) def get_input_fnames_evoked( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> dict: fname_epochs = BIDSPath( subject=subject, @@ -42,6 +49,7 @@ def get_input_fnames_evoked( ) in_files = dict() in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -53,12 +61,19 @@ def run_evoked( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, in_files: dict, ) -> dict: out_files = dict() out_files["evoked"] = ( - in_files["epochs"].copy().update(suffix="ave", processing=None, check=False) + in_files["epochs"] + .copy() + .update( + suffix="ave", + processing=None, + check=False, + split=None, + ) ) msg = f'Input: {in_files["epochs"].basename}' @@ -98,10 +113,17 @@ def run_evoked( # Report if evokeds: - msg = f"Adding {len(evokeds)} evoked signals and contrasts to the " f"report." + n_contrasts = len(cfg.contrasts) + n_signals = len(evokeds) - n_contrasts + msg = ( + f"Adding {n_signals} evoked response{_pl(n_signals)} and " + f"{n_contrasts} contrast{_pl(n_contrasts)} to the report." + ) else: msg = "No evoked conditions or contrasts found." logger.info(**gen_log_kwargs(message=msg)) + all_conditions = _all_conditions(cfg=cfg) + assert list(all_conditions) == list(all_evoked) # otherwise we have a bug with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session ) as report: @@ -138,7 +160,7 @@ def run_evoked( # topomap_args=topomap_args) assert len(in_files) == 0, in_files.keys() - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( @@ -150,6 +172,7 @@ def get_config( contrasts=get_all_contrasts(config), noise_cov=_sanitize_callable(config.noise_cov), analyze_channels=config.analyze_channels, + eeg_reference=get_eeg_reference(config), ch_types=config.ch_types, report_evoked_n_time_points=config.report_evoked_n_time_points, **_bids_kwargs(config=config), diff --git a/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py b/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py index ddc32bf3f..597ee409f 100644 --- a/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py +++ b/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py @@ -10,48 +10,46 @@ import os.path as op from types import SimpleNamespace -from typing import Optional +import mne import numpy as np import pandas as pd -from scipy.io import savemat, loadmat - -from sklearn.model_selection import cross_val_score -from sklearn.pipeline import make_pipeline -from sklearn.model_selection import StratifiedKFold - -import mne -from mne.decoding import Scaler, Vectorizer +from mne.decoding import Vectorizer from mne_bids import BIDSPath +from scipy.io import loadmat, savemat +from sklearn.model_selection import StratifiedKFold, cross_val_score +from sklearn.pipeline import make_pipeline from ..._config_utils import ( - get_sessions, - get_subjects, - get_eeg_reference, - get_decoding_contrasts, _bids_kwargs, + _get_decoding_proc, _restrict_analyze_channels, + get_decoding_contrasts, + get_eeg_reference, + get_sessions, + get_subjects, ) +from ..._decoding import LogReg, _decoding_preproc_steps from ..._logging import gen_log_kwargs, logger -from ..._decoding import LogReg -from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run, save_logs +from ..._parallel import get_parallel_backend, parallel_func from ..._report import ( - _open_report, _contrasts_to_names, + _open_report, _plot_full_epochs_decoding_scores, _sanitize_cond_tag, ) +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs def get_input_fnames_epochs_decoding( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, condition1: str, condition2: str, ) -> dict: + proc = _get_decoding_proc(config=cfg) fname_epochs = BIDSPath( subject=subject, session=session, @@ -60,6 +58,7 @@ def get_input_fnames_epochs_decoding( run=None, recording=cfg.rec, space=cfg.space, + processing=proc, suffix="epo", extension=".fif", datatype=cfg.datatype, @@ -68,6 +67,7 @@ def get_input_fnames_epochs_decoding( ) in_files = dict() in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -79,7 +79,7 @@ def run_epochs_decoding( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, condition1: str, condition2: str, in_files: dict, @@ -89,11 +89,10 @@ def run_epochs_decoding( msg = f"Contrasting conditions: {condition1} – {condition2}" logger.info(**gen_log_kwargs(message=msg)) out_files = dict() - bids_path = in_files["epochs"].copy() + bids_path = in_files["epochs"].copy().update(split=None) epochs = mne.read_epochs(in_files.pop("epochs")) _restrict_analyze_channels(epochs, cfg) - epochs.crop(cfg.decoding_epochs_tmin, cfg.decoding_epochs_tmax) # We define the epochs and the labels if isinstance(cfg.conditions, dict): @@ -110,15 +109,26 @@ def run_epochs_decoding( [epochs[epochs_conds[0]], epochs[epochs_conds[1]]], verbose="error" ) + # Crop to the desired analysis interval. Do it only after the concatenation to work + # around https://github.com/mne-tools/mne-python/issues/12153 + epochs.crop(cfg.decoding_epochs_tmin, cfg.decoding_epochs_tmax) + # omit bad channels and reference MEG sensors + epochs.pick_types(meg=True, eeg=True, ref_meg=False, exclude="bads") + pre_steps = _decoding_preproc_steps( + subject=subject, + session=session, + epochs=epochs, + ) + n_cond1 = len(epochs[epochs_conds[0]]) n_cond2 = len(epochs[epochs_conds[1]]) X = epochs.get_data() y = np.r_[np.ones(n_cond1), np.zeros(n_cond2)] - classification_pipeline = make_pipeline( - Scaler(scalings="mean"), - Vectorizer(), # So we can pass the data to scikit-learn + clf = make_pipeline( + *pre_steps, + Vectorizer(), LogReg( solver="liblinear", # much faster than the default random_state=cfg.random_state, @@ -134,7 +144,13 @@ def run_epochs_decoding( n_splits=cfg.decoding_n_splits, ) scores = cross_val_score( - estimator=classification_pipeline, X=X, y=y, cv=cv, scoring="roc_auc", n_jobs=1 + estimator=clf, + X=X, + y=y, + cv=cv, + scoring="roc_auc", + n_jobs=1, + error_score="raise", ) # Save the scores @@ -184,7 +200,7 @@ def run_epochs_decoding( all_contrasts.append(contrast) del fname_decoding, processing, a_vs_b, decoding_data - fig, caption = _plot_full_epochs_decoding_scores( + fig, caption, _ = _plot_full_epochs_decoding_scores( contrast_names=_contrasts_to_names(all_contrasts), scores=all_decoding_scores, metric=cfg.decoding_metric, @@ -209,7 +225,7 @@ def run_epochs_decoding( plt.close(fig) assert len(in_files) == 0, in_files.keys() - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( @@ -220,6 +236,7 @@ def get_config( conditions=config.conditions, contrasts=get_decoding_contrasts(config), decode=config.decode, + decoding_which_epochs=config.decoding_which_epochs, decoding_metric=config.decoding_metric, decoding_epochs_tmin=config.decoding_epochs_tmin, decoding_epochs_tmax=config.decoding_epochs_tmax, @@ -237,12 +254,12 @@ def main(*, config: SimpleNamespace) -> None: """Run time-by-time decoding.""" if not config.contrasts: msg = "No contrasts specified; not performing decoding." - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) return if not config.decode: msg = "No decoding requested by user." - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) return with get_parallel_backend(config.exec_params): diff --git a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py index fe4d64fb1..8e5402d96 100644 --- a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py +++ b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py @@ -13,50 +13,51 @@ import os.path as op from types import SimpleNamespace -from typing import Optional +import mne import numpy as np import pandas as pd -from scipy.io import savemat, loadmat - -import mne -from mne.decoding import GeneralizingEstimator, SlidingEstimator, cross_val_multiscore - +from mne.decoding import ( + GeneralizingEstimator, + SlidingEstimator, + Vectorizer, + cross_val_multiscore, +) from mne_bids import BIDSPath - -from sklearn.preprocessing import StandardScaler -from sklearn.pipeline import make_pipeline +from scipy.io import loadmat, savemat from sklearn.model_selection import StratifiedKFold +from sklearn.pipeline import make_pipeline from ..._config_utils import ( - get_sessions, - get_subjects, - get_eeg_reference, - get_decoding_contrasts, _bids_kwargs, + _get_decoding_proc, _restrict_analyze_channels, + get_decoding_contrasts, + get_eeg_reference, + get_sessions, + get_subjects, ) -from ..._decoding import LogReg +from ..._decoding import LogReg, _decoding_preproc_steps from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, save_logs from ..._parallel import get_parallel_backend, get_parallel_backend_name from ..._report import ( _open_report, _plot_decoding_time_generalization, - _sanitize_cond_tag, _plot_time_by_time_decoding_scores, + _sanitize_cond_tag, ) +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs def get_input_fnames_time_decoding( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, condition1: str, condition2: str, ) -> dict: - # TODO: Shouldn't this at least use the PTP-rejected epochs if available? + proc = _get_decoding_proc(config=cfg) fname_epochs = BIDSPath( subject=subject, session=session, @@ -65,6 +66,7 @@ def get_input_fnames_time_decoding( run=None, recording=cfg.rec, space=cfg.space, + processing=proc, suffix="epo", extension=".fif", datatype=cfg.datatype, @@ -73,6 +75,7 @@ def get_input_fnames_time_decoding( ) in_files = dict() in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -84,7 +87,7 @@ def run_time_decoding( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, condition1: str, condition2: str, in_files: dict, @@ -98,7 +101,7 @@ def run_time_decoding( msg = f"Contrasting conditions ({kind}): {condition1} – {condition2}" logger.info(**gen_log_kwargs(message=msg)) out_files = dict() - bids_path = in_files["epochs"].copy() + bids_path = in_files["epochs"].copy().update(split=None) epochs = mne.read_epochs(in_files.pop("epochs")) _restrict_analyze_channels(epochs, cfg) @@ -122,6 +125,22 @@ def run_time_decoding( epochs = mne.concatenate_epochs([epochs[epochs_conds[0]], epochs[epochs_conds[1]]]) n_cond1 = len(epochs[epochs_conds[0]]) n_cond2 = len(epochs[epochs_conds[1]]) + epochs.pick_types(meg=True, eeg=True, ref_meg=False, exclude="bads") + # We can't use the full rank here because the number of samples can just be the + # number of epochs (which can be fewer than the number of channels) + pre_steps = _decoding_preproc_steps( + subject=subject, + session=session, + epochs=epochs, + pca=False, + ) + # At some point we might want to enable this, but it's really slow and arguably + # unnecessary so let's omit it for now: + # pre_steps.append( + # mne.decoding.UnsupervisedSpatialFilter( + # PCA(n_components=0.999, whiten=True), + # ) + # ) decim = cfg.decoding_time_generalization_decim if cfg.decoding_time_generalization and decim > 1: @@ -133,7 +152,8 @@ def run_time_decoding( verbose = get_parallel_backend_name(exec_params=exec_params) != "dask" with get_parallel_backend(exec_params): clf = make_pipeline( - StandardScaler(), + *pre_steps, + Vectorizer(), LogReg( solver="liblinear", # much faster than the default random_state=cfg.random_state, @@ -286,7 +306,7 @@ def run_time_decoding( del decoding_data, cond_1, cond_2, caption assert len(in_files) == 0, in_files.keys() - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( @@ -297,6 +317,7 @@ def get_config( conditions=config.conditions, contrasts=get_decoding_contrasts(config), decode=config.decode, + decoding_which_epochs=config.decoding_which_epochs, decoding_metric=config.decoding_metric, decoding_n_splits=config.decoding_n_splits, decoding_time_generalization=config.decoding_time_generalization, @@ -314,12 +335,12 @@ def main(*, config: SimpleNamespace) -> None: """Run time-by-time decoding.""" if not config.contrasts: msg = "No contrasts specified; not performing decoding." - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) return if not config.decode: msg = "No decoding requested by user." - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) return # Here we go parallel inside the :class:`mne.decoding.SlidingEstimator` diff --git a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py index 1d88c2813..be04ca547 100644 --- a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py +++ b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py @@ -5,37 +5,31 @@ """ from types import SimpleNamespace -from typing import Optional - -import numpy as np import mne - +import numpy as np from mne_bids import BIDSPath from ..._config_utils import ( + _bids_kwargs, + _restrict_analyze_channels, + get_eeg_reference, get_sessions, get_subjects, - get_eeg_reference, sanitize_cond_name, - _bids_kwargs, - _restrict_analyze_channels, ) from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, save_logs from ..._parallel import get_parallel_backend, parallel_func from ..._report import _open_report, _sanitize_cond_tag +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs def get_input_fnames_time_frequency( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> dict: - processing = None - if cfg.spatial_filter is not None: - processing = "clean" fname_epochs = BIDSPath( subject=subject, session=session, @@ -46,13 +40,14 @@ def get_input_fnames_time_frequency( space=cfg.space, datatype=cfg.datatype, root=cfg.deriv_root, - processing=processing, + processing="clean", suffix="epo", extension=".fif", check=False, ) in_files = dict() in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -64,16 +59,17 @@ def run_time_frequency( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, in_files: dict, ) -> dict: import matplotlib.pyplot as plt - msg = f'Input: {in_files["epochs"].basename}' + epochs_path = in_files.pop("epochs") + msg = f"Reading {epochs_path.basename}" logger.info(**gen_log_kwargs(message=msg)) - bids_path = in_files["epochs"].copy().update(processing=None) - - epochs = mne.read_epochs(in_files.pop("epochs")) + epochs = mne.read_epochs(epochs_path) + bids_path = epochs_path.copy().update(processing=None, split=None) + del epochs_path _restrict_analyze_channels(epochs, cfg) if cfg.time_frequency_subtract_evoked: @@ -87,6 +83,7 @@ def run_time_frequency( out_files = dict() for condition in cfg.time_frequency_conditions: + logger.info(**gen_log_kwargs(message=f"Computing TFR for {condition}")) this_epochs = epochs[condition] power, itc = mne.time_frequency.tfr_morlet( this_epochs, freqs=freqs, return_itc=True, n_cycles=time_frequency_cycles @@ -106,8 +103,8 @@ def run_time_frequency( # conform to MNE filename checks. This is because BIDS has not # finalized how derivatives should be named. Once this is done, we # should update our names and/or MNE's checks. - power.save(out_files[power_key], overwrite=True, verbose="error") - itc.save(out_files[itc_key], overwrite=True, verbose="error") + power.save(out_files[power_key].fpath, overwrite=True, verbose="error") + itc.save(out_files[itc_key].fpath, overwrite=True, verbose="error") # Report with _open_report( @@ -117,8 +114,8 @@ def run_time_frequency( logger.info(**gen_log_kwargs(message=msg)) for condition in cfg.time_frequency_conditions: cond = sanitize_cond_name(condition) - fname_tfr_pow_cond = out_files[f"power-{cond}"] - fname_tfr_itc_cond = out_files[f"itc-{cond}"] + fname_tfr_pow_cond = out_files[f"power-{cond}"].fpath + fname_tfr_itc_cond = out_files[f"itc-{cond}"].fpath with mne.use_log_level("error"): # filename convention power = mne.time_frequency.read_tfrs(fname_tfr_pow_cond, condition=0) power.apply_baseline( @@ -153,7 +150,7 @@ def run_time_frequency( del itc assert len(in_files) == 0, in_files.keys() - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( @@ -182,7 +179,7 @@ def main(*, config: SimpleNamespace) -> None: """Run Time-frequency decomposition.""" if not config.time_frequency_conditions: msg = "Skipping …" - logger.info(**gen_log_kwargs(message=msg)) + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) return parallel, run_func = parallel_func( diff --git a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py index a51371c3d..8c2368d8c 100644 --- a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py +++ b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py @@ -1,50 +1,47 @@ -""" -Decoding based on common spatial patterns (CSP). -""" +"""Decoding based on common spatial patterns (CSP).""" import os.path as op from types import SimpleNamespace -from typing import Dict, Optional, Tuple +import matplotlib.transforms import mne import numpy as np import pandas as pd -import matplotlib.transforms -from mne.decoding import CSP, UnsupervisedSpatialFilter +from mne.decoding import CSP from mne_bids import BIDSPath -from sklearn.decomposition import PCA from sklearn.model_selection import StratifiedKFold, cross_val_score from sklearn.pipeline import make_pipeline from ..._config_utils import ( - get_sessions, - get_subjects, - get_eeg_reference, - get_decoding_contrasts, _bids_kwargs, + _get_decoding_proc, _restrict_analyze_channels, + get_decoding_contrasts, + get_eeg_reference, + get_sessions, + get_subjects, ) -from ..._decoding import LogReg, _handle_csp_args -from ..._logging import logger, gen_log_kwargs -from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run, save_logs +from ..._decoding import LogReg, _decoding_preproc_steps, _handle_csp_args +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func from ..._report import ( + _imshow_tf, _open_report, - _sanitize_cond_tag, _plot_full_epochs_decoding_scores, - _imshow_tf, + _sanitize_cond_tag, ) +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs -def _prepare_labels(*, epochs: mne.BaseEpochs, contrast: Tuple[str, str]) -> np.ndarray: +def _prepare_labels(*, epochs: mne.BaseEpochs, contrast: tuple[str, str]) -> np.ndarray: """Return the projection of the events_id on a boolean vector. This projection is useful in the case of hierarchical events: we project the different events contained in one condition into just one label. - Returns: - -------- + Returns + ------- A boolean numpy array containing the labels. """ epochs_cond_0 = epochs[contrast[0]] @@ -78,27 +75,18 @@ def _prepare_labels(*, epochs: mne.BaseEpochs, contrast: Tuple[str, str]) -> np. def prepare_epochs_and_y( - *, epochs: mne.BaseEpochs, contrast: Tuple[str, str], cfg, fmin: float, fmax: float -) -> Tuple[mne.BaseEpochs, np.ndarray]: + *, epochs: mne.BaseEpochs, contrast: tuple[str, str], cfg, fmin: float, fmax: float +) -> tuple[mne.BaseEpochs, np.ndarray]: """Band-pass between, sub-select the desired epochs, and prepare y.""" - epochs_filt = epochs.copy().pick_types( - meg=True, - eeg=True, - ) - - # We only take mag to speed up computation - # because the information is redundant between grad and mag - if cfg.datatype == "meg" and cfg.use_maxwell_filter: - epochs_filt.pick_types(meg="mag") - # filtering out the conditions we are not interested in, to ensure here we # have a valid partition between the condition of the contrast. - # + # XXX Hack for handling epochs selection via metadata + # This also makes a copy if contrast[0].startswith("event_name.isin"): - epochs_filt = epochs_filt[f"{contrast[0]} or {contrast[1]}"] + epochs_filt = epochs[f"{contrast[0]} or {contrast[1]}"] else: - epochs_filt = epochs_filt[contrast] + epochs_filt = epochs[contrast] # Filtering is costly, so do it last, after the selection of the channels # and epochs. We know that often the filter will be longer than the signal, @@ -113,9 +101,10 @@ def get_input_fnames_csp( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - contrast: Tuple[str], + session: str | None, + contrast: tuple[str], ) -> dict: + proc = _get_decoding_proc(config=cfg) fname_epochs = BIDSPath( subject=subject, session=session, @@ -124,6 +113,7 @@ def get_input_fnames_csp( run=None, recording=cfg.rec, space=cfg.space, + processing=proc, suffix="epo", extension=".fif", datatype=cfg.datatype, @@ -132,6 +122,7 @@ def get_input_fnames_csp( ) in_files = dict() in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files @@ -142,8 +133,8 @@ def one_subject_decoding( exec_params: SimpleNamespace, subject: str, session: str, - contrast: Tuple[str, str], - in_files: Dict[str, BIDSPath], + contrast: tuple[str, str], + in_files: dict[str, BIDSPath], ) -> dict: """Run one subject. @@ -157,33 +148,27 @@ def one_subject_decoding( msg = f"Contrasting conditions: {condition1} – {condition2}" logger.info(**gen_log_kwargs(msg)) - bids_path = in_files["epochs"].copy().update(processing=None) + bids_path = in_files["epochs"].copy().update(processing=None, split=None) epochs = mne.read_epochs(in_files.pop("epochs")) _restrict_analyze_channels(epochs, cfg) + epochs.pick_types(meg=True, eeg=True, ref_meg=False, exclude="bads") if cfg.time_frequency_subtract_evoked: epochs.subtract_evoked() - # Perform rank reduction via PCA. - # - # Select the channel type with the smallest rank. - # Limit it to a maximum of 100. - ranks = mne.compute_rank(inst=epochs, rank="info") - ch_type_smallest_rank = min(ranks, key=ranks.get) - rank = min(ranks[ch_type_smallest_rank], 100) - del ch_type_smallest_rank, ranks - - msg = f"Reducing data dimension via PCA; new rank: {rank}." - logger.info(**gen_log_kwargs(msg)) - pca = UnsupervisedSpatialFilter(PCA(rank), average=False) + preproc_steps = _decoding_preproc_steps( + subject=subject, + session=session, + epochs=epochs, + ) # Classifier csp = CSP( n_components=4, # XXX revisit reg=0.1, # XXX revisit - rank="info", ) clf = make_pipeline( + *preproc_steps, csp, LogReg( solver="liblinear", # much faster than the default @@ -198,8 +183,14 @@ def one_subject_decoding( ) # Loop over frequencies (all time points lumped together) - freq_name_to_bins_map = _handle_csp_args( - cfg.decoding_csp_times, cfg.decoding_csp_freqs, cfg.decoding_metric + freq_name_to_bins_map, time_bins = _handle_csp_args( + cfg.decoding_csp_times, + cfg.decoding_csp_freqs, + cfg.decoding_metric, + epochs_tmin=cfg.epochs_tmin, + epochs_tmax=cfg.epochs_tmax, + time_frequency_freq_min=cfg.time_frequency_freq_min, + time_frequency_freq_max=cfg.time_frequency_freq_max, ) freq_decoding_table_rows = [] for freq_range_name, freq_bins in freq_name_to_bins_map.items(): @@ -251,19 +242,14 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non # Get the data for all time points X = epochs_filt.get_data() - # We apply PCA before running CSP: - # - much faster CSP processing - # - reduced risk of numerical instabilities. - X_pca = pca.fit_transform(X) - del X - cv_scores = cross_val_score( estimator=clf, - X=X_pca, + X=X, y=y, scoring=cfg.decoding_metric, cv=cv, n_jobs=1, + error_score="raise", ) freq_decoding_table.loc[idx, "mean_crossval_score"] = cv_scores.mean() freq_decoding_table.at[idx, "scores"] = cv_scores @@ -272,11 +258,6 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non # # Note: We don't support varying time ranges for different frequency # ranges to avoid leaking of information. - time_bins = np.array(cfg.decoding_csp_times) - if time_bins.ndim == 1: - time_bins = np.array(list(zip(time_bins[:-1], time_bins[1:]))) - assert time_bins.ndim == 2 - tf_decoding_table_rows = [] for freq_range_name, freq_bins in freq_name_to_bins_map.items(): @@ -300,13 +281,18 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non } tf_decoding_table_rows.append(row) - tf_decoding_table = pd.concat( - [pd.DataFrame.from_dict(row) for row in tf_decoding_table_rows], - ignore_index=True, - ) + if len(tf_decoding_table_rows): + tf_decoding_table = pd.concat( + [pd.DataFrame.from_dict(row) for row in tf_decoding_table_rows], + ignore_index=True, + ) + else: + tf_decoding_table = pd.DataFrame() del tf_decoding_table_rows for idx, row in tf_decoding_table.iterrows(): + if len(row) == 0: + break # no data tmin = row["t_min"] tmax = row["t_max"] fmin = row["f_min"] @@ -321,18 +307,16 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non # Crop data to the time window of interest if tmax is not None: # avoid warnings about outside the interval tmax = min(tmax, epochs_filt.times[-1]) - epochs_filt.crop(tmin, tmax) - X = epochs_filt.get_data() - X_pca = pca.transform(X) - del X - + X = epochs_filt.crop(tmin, tmax).get_data() + del epochs_filt cv_scores = cross_val_score( estimator=clf, - X=X_pca, + X=X, y=y, scoring=cfg.decoding_metric, cv=cv, n_jobs=1, + error_score="raise", ) score = cv_scores.mean() tf_decoding_table.loc[idx, "mean_crossval_score"] = score @@ -351,8 +335,10 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non ) with pd.ExcelWriter(fname_results) as w: freq_decoding_table.to_excel(w, sheet_name="CSP Frequency", index=False) - tf_decoding_table.to_excel(w, sheet_name="CSP Time-Frequency", index=False) + if not tf_decoding_table.empty: + tf_decoding_table.to_excel(w, sheet_name="CSP Time-Frequency", index=False) out_files = {"csp-excel": fname_results} + del freq_decoding_table # Report with _open_report( @@ -361,11 +347,6 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non msg = "Adding CSP decoding results to the report." logger.info(**gen_log_kwargs(message=msg)) section = "Decoding: CSP" - freq_name_to_bins_map = _handle_csp_args( - cfg.decoding_csp_times, - cfg.decoding_csp_freqs, - cfg.decoding_metric, - ) all_csp_tf_results = dict() for contrast in cfg.decoding_contrasts: cond_1, cond_2 = contrast @@ -388,14 +369,15 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non csp_freq_results["scores"] = csp_freq_results["scores"].apply( lambda x: np.array(x[1:-1].split(), float) ) - csp_tf_results = pd.read_excel( - fname_decoding, sheet_name="CSP Time-Frequency" - ) - csp_tf_results["scores"] = csp_tf_results["scores"].apply( - lambda x: np.array(x[1:-1].split(), float) - ) - all_csp_tf_results[contrast] = csp_tf_results - del csp_tf_results + if not tf_decoding_table.empty: + csp_tf_results = pd.read_excel( + fname_decoding, sheet_name="CSP Time-Frequency" + ) + csp_tf_results["scores"] = csp_tf_results["scores"].apply( + lambda x: np.array(x[1:-1].split(), float) + ) + all_csp_tf_results[contrast] = csp_tf_results + del csp_tf_results all_decoding_scores = list() contrast_names = list() @@ -412,7 +394,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non contrast_names.append( f"{freq_range_name}\n" f"({f_min:0.1f}-{f_max:0.1f} Hz)" ) - fig, caption = _plot_full_epochs_decoding_scores( + fig, caption, _ = _plot_full_epochs_decoding_scores( contrast_names=contrast_names, scores=all_decoding_scores, metric=cfg.decoding_metric, @@ -445,11 +427,13 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non results = all_csp_tf_results[contrast] mean_crossval_scores = list() tmin, tmax, fmin, fmax = list(), list(), list(), list() - mean_crossval_scores.extend(results["mean_crossval_score"].ravel()) - tmin.extend(results["t_min"].ravel()) - tmax.extend(results["t_max"].ravel()) - fmin.extend(results["f_min"].ravel()) - fmax.extend(results["f_max"].ravel()) + mean_crossval_scores.extend( + results["mean_crossval_score"].to_numpy().ravel() + ) + tmin.extend(results["t_min"].to_numpy().ravel()) + tmax.extend(results["t_max"].to_numpy().ravel()) + fmin.extend(results["f_min"].to_numpy().ravel()) + fmax.extend(results["f_max"].to_numpy().ravel()) mean_crossval_scores = np.array(mean_crossval_scores, float) fig, ax = plt.subplots(constrained_layout=True) # XXX Add support for more metrics @@ -502,13 +486,15 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non tags=tags, replace=True, ) + plt.close(fig) + del fig, title assert len(in_files) == 0, in_files.keys() - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( - *, config: SimpleNamespace, subject: str, session: Optional[str] + *, config: SimpleNamespace, subject: str, session: str | None ) -> SimpleNamespace: cfg = SimpleNamespace( # Data parameters @@ -517,7 +503,12 @@ def get_config( ch_types=config.ch_types, eeg_reference=get_eeg_reference(config), # Processing parameters + epochs_tmin=config.epochs_tmin, + epochs_tmax=config.epochs_tmax, + time_frequency_freq_min=config.time_frequency_freq_min, + time_frequency_freq_max=config.time_frequency_freq_max, time_frequency_subtract_evoked=config.time_frequency_subtract_evoked, + decoding_which_epochs=config.decoding_which_epochs, decoding_metric=config.decoding_metric, decoding_csp_freqs=config.decoding_csp_freqs, decoding_csp_times=config.decoding_csp_times, diff --git a/mne_bids_pipeline/steps/sensor/_06_make_cov.py b/mne_bids_pipeline/steps/sensor/_06_make_cov.py index 01f80625f..0b907b679 100644 --- a/mne_bids_pipeline/steps/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/steps/sensor/_06_make_cov.py @@ -3,35 +3,40 @@ Covariance matrices are computed and saved. """ -from typing import Optional from types import SimpleNamespace import mne from mne_bids import BIDSPath +from ..._config_import import _import_config from ..._config_utils import ( + _bids_kwargs, + _restrict_analyze_channels, + get_eeg_reference, + get_noise_cov_bids_path, get_sessions, get_subjects, - get_noise_cov_bids_path, - _bids_kwargs, ) -from ..._config_import import _import_config -from ..._config_utils import _restrict_analyze_channels, get_all_contrasts from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func -from ..._report import _open_report, _sanitize_cond_tag, _all_conditions -from ..._run import failsafe_run, save_logs, _sanitize_callable +from ..._report import _all_conditions, _open_report, _sanitize_cond_tag +from ..._run import ( + _prep_out_files, + _sanitize_callable, + _update_for_splits, + failsafe_run, + save_logs, +) def get_input_fnames_cov( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> dict: cov_type = _get_cov_type(cfg) in_files = dict() - processing = "clean" if cfg.spatial_filter is not None else None fname_epochs = BIDSPath( subject=subject, session=session, @@ -42,14 +47,13 @@ def get_input_fnames_cov( space=cfg.space, extension=".fif", suffix="epo", - processing=processing, + processing="clean", datatype=cfg.datatype, root=cfg.deriv_root, check=False, ) - in_files["report_info"] = fname_epochs.copy().update( - processing="clean", suffix="epo" - ) + in_files["report_info"] = fname_epochs.copy().update(processing="clean") + _update_for_splits(in_files, "report_info", single=True) fname_evoked = fname_epochs.copy().update( suffix="ave", processing=None, check=False ) @@ -67,15 +71,14 @@ def get_input_fnames_cov( run=None, recording=cfg.rec, space=cfg.space, - processing="filt", + processing="clean", suffix="raw", extension=".fif", datatype=cfg.datatype, root=cfg.deriv_root, check=False, ) - run_type = "resting-state" if cfg.noise_cov == "rest" else "empty-room" - if run_type == "resting-state": + if cfg.noise_cov == "rest": bids_path_raw_noise.task = "rest" else: bids_path_raw_noise.task = "noise" @@ -83,17 +86,18 @@ def get_input_fnames_cov( else: assert cov_type == "epochs", cov_type in_files["epochs"] = fname_epochs + _update_for_splits(in_files, "epochs", single=True) return in_files def compute_cov_from_epochs( *, - tmin: Optional[float], - tmax: Optional[float], + tmin: float | None, + tmax: float | None, cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, in_files: dict, out_files: dict, ) -> mne.Covariance: @@ -111,7 +115,7 @@ def compute_cov_from_epochs( epochs, tmin=tmin, tmax=tmax, - method="shrunk", + method=cfg.noise_cov_method, rank="info", verbose="error", # TODO: not baseline corrected, maybe problematic? ) @@ -123,13 +127,13 @@ def compute_cov_from_raw( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, in_files: dict, out_files: dict, ) -> mne.Covariance: fname_raw = in_files.pop("raw") - run_type = "resting-state" if fname_raw.task == "rest" else "empty-room" - msg = f"Computing regularized covariance based on {run_type} recording." + run_msg = "resting-state" if fname_raw.task == "rest" else "empty-room" + msg = f"Computing regularized covariance based on {run_msg} recording." logger.info(**gen_log_kwargs(message=msg)) msg = f"Input: {fname_raw.basename}" logger.info(**gen_log_kwargs(message=msg)) @@ -146,7 +150,7 @@ def retrieve_custom_cov( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, in_files: dict, out_files: dict, ) -> mne.Covariance: @@ -168,7 +172,7 @@ def retrieve_custom_cov( task=cfg.task, acquisition=cfg.acq, run=None, - processing=cfg.proc, + processing="clean", recording=cfg.rec, space=cfg.space, suffix="ave", @@ -178,7 +182,7 @@ def retrieve_custom_cov( check=False, ) - msg = "Retrieving noise covariance matrix from custom user-supplied " "function" + msg = "Retrieving noise covariance matrix from custom user-supplied function" logger.info(**gen_log_kwargs(message=msg)) msg = f'Output: {out_files["cov"].basename}' logger.info(**gen_log_kwargs(message=msg)) @@ -207,7 +211,7 @@ def run_covariance( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str] = None, + session: str | None = None, in_files: dict, ) -> dict: import matplotlib.pyplot as plt @@ -258,10 +262,8 @@ def run_covariance( for evoked, condition in zip(all_evoked, conditions): _restrict_analyze_channels(evoked, cfg) tags = ("evoked", "covariance", _sanitize_cond_tag(condition)) - if condition in cfg.conditions: - title = f"Whitening: {condition}" - else: # It's a contrast of two conditions. - title = f"Whitening: {condition}" + title = f"Whitening: {condition}" + if condition not in cfg.conditions: tags = tags + ("contrast",) fig = evoked.plot_white(cov, verbose="error") report.add_figure( @@ -274,7 +276,7 @@ def run_covariance( plt.close(fig) assert len(in_files) == 0, in_files - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( @@ -287,8 +289,10 @@ def get_config( run_source_estimation=config.run_source_estimation, noise_cov=_sanitize_callable(config.noise_cov), conditions=config.conditions, - all_contrasts=get_all_contrasts(config), + contrasts=config.contrasts, analyze_channels=config.analyze_channels, + eeg_reference=get_eeg_reference(config), + noise_cov_method=config.noise_cov_method, **_bids_kwargs(config=config), ) return cfg diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index b126eecec..ffc645cdf 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -5,43 +5,52 @@ import os import os.path as op -from collections import defaultdict -from typing import Optional, TypedDict, List, Tuple +from functools import partial from types import SimpleNamespace +import mne import numpy as np import pandas as pd -from scipy.io import loadmat, savemat - -import mne from mne_bids import BIDSPath +from scipy.io import loadmat, savemat from ..._config_utils import ( + _bids_kwargs, + _pl, + _restrict_analyze_channels, + get_decoding_contrasts, + get_eeg_reference, get_sessions, get_subjects, - get_eeg_reference, - get_decoding_contrasts, - get_all_contrasts, - _bids_kwargs, ) from ..._decoding import _handle_csp_args from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func -from ..._run import failsafe_run, save_logs -from ..._report import run_report_average_sensor +from ..._report import ( + _all_conditions, + _contrasts_to_names, + _open_report, + _plot_decoding_time_generalization, + _plot_full_epochs_decoding_scores, + _plot_time_by_time_decoding_scores_gavg, + _sanitize_cond_tag, + add_csp_grand_average, + add_event_counts, + plot_time_by_time_decoding_t_values, +) +from ..._run import _prep_out_files, _update_for_splits, failsafe_run, save_logs +from ...typing import TypedDict -def average_evokeds( +def get_input_fnames_average_evokeds( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> List[mne.Evoked]: - # Container for all conditions: - all_evokeds = defaultdict(list) - + session: dict | None, +) -> dict: + in_files = dict() for this_subject in cfg.subjects: - fname_in = BIDSPath( + in_files[f"evoked-{this_subject}"] = BIDSPath( subject=this_subject, session=session, task=cfg.task, @@ -55,28 +64,49 @@ def average_evokeds( root=cfg.deriv_root, check=False, ) + return in_files - msg = f"Input: {fname_in.basename}" - logger.info(**gen_log_kwargs(message=msg)) - evokeds = mne.read_evokeds(fname_in) - for idx, evoked in enumerate(evokeds): - all_evokeds[idx].append(evoked) # Insert into the container +@failsafe_run( + get_input_fnames=get_input_fnames_average_evokeds, +) +def average_evokeds( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: dict, +) -> dict: + logger.info(**gen_log_kwargs(message="Creating grand averages")) + # Container for all conditions: + conditions = _all_conditions(cfg=cfg) + evokeds = [list() for _ in range(len(conditions))] - for idx, evokeds in all_evokeds.items(): - all_evokeds[idx] = mne.grand_average( - evokeds, interpolate_bads=cfg.interpolate_bads_grand_average + keys = list(in_files) + for key in keys: + if not key.startswith("evoked-"): + continue + fname_in = in_files.pop(key) + these_evokeds = mne.read_evokeds(fname_in) + for idx, evoked in enumerate(these_evokeds): + evokeds[idx].append(evoked) # Insert into the container + + for idx, these_evokeds in enumerate(evokeds): + evokeds[idx] = mne.grand_average( + these_evokeds, interpolate_bads=cfg.interpolate_bads_grand_average ) # Combine subjects # Keep condition in comment - all_evokeds[idx].comment = "Grand average: " + evokeds[0].comment + evokeds[idx].comment = "Grand average: " + these_evokeds[0].comment - fname_out = BIDSPath( + out_files = dict() + fname_out = out_files["evokeds"] = BIDSPath( subject=subject, session=session, task=cfg.task, acquisition=cfg.acq, run=None, - processing=cfg.proc, + processing="clean", recording=cfg.rec, space=cfg.space, suffix="ave", @@ -91,8 +121,59 @@ def average_evokeds( msg = f"Saving grand-averaged evoked sensor data: {fname_out.basename}" logger.info(**gen_log_kwargs(message=msg)) - mne.write_evokeds(fname_out, list(all_evokeds.values()), overwrite=True) - return list(all_evokeds.values()) + mne.write_evokeds(fname_out, evokeds, overwrite=True) + if exec_params.interactive: + for evoked in evokeds: + evoked.plot() + + # Reporting + evokeds = [_restrict_analyze_channels(evoked, cfg) for evoked in evokeds] + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + # Add event stats. + add_event_counts( + cfg=cfg, + report=report, + subject=subject, + session=session, + ) + + # Evoked responses + if evokeds: + n_contrasts = len(cfg.contrasts) + n_signals = len(evokeds) - n_contrasts + msg = ( + f"Adding {n_signals} evoked response{_pl(n_signals)} and " + f"{n_contrasts} contrast{_pl(n_contrasts)} to the report." + ) + else: + msg = "No evoked conditions or contrasts found." + logger.info(**gen_log_kwargs(message=msg)) + for condition, evoked in zip(conditions, evokeds): + tags = ("evoked", _sanitize_cond_tag(condition)) + if condition in cfg.conditions: + title = f"Average (sensor): {condition}, N = {len(cfg.subjects)}" + else: # It's a contrast of two conditions. + title = ( + f"Average (sensor) contrast: {condition}, " + f"N = {len(cfg.subjects)}" + ) + tags = tags + ("contrast",) + + report.add_evokeds( + evokeds=evoked, + titles=title, + projs=False, + tags=tags, + n_time_points=cfg.report_evoked_n_time_points, + # captions=evoked.comment, # TODO upstream + replace=True, + n_jobs=1, # don't auto parallelize + ) + + assert len(in_files) == 0, list(in_files) + return _prep_out_files(exec_params=exec_params, out_files=out_files) class ClusterAcrossTime(TypedDict): @@ -103,10 +184,10 @@ class ClusterAcrossTime(TypedDict): def _decoding_cluster_permutation_test( scores: np.ndarray, times: np.ndarray, - cluster_forming_t_threshold: Optional[float], + cluster_forming_t_threshold: float | None, n_permutations: int, random_seed: int, -) -> Tuple[np.ndarray, List[ClusterAcrossTime], int]: +) -> tuple[np.ndarray, list[ClusterAcrossTime], int]: """Perform a cluster permutation test on decoding scores. The clusters are formed across time points. @@ -119,7 +200,7 @@ def _decoding_cluster_permutation_test( out_type="mask", tail=1, # one-sided: significantly above chance level seed=random_seed, - verbose=True, + verbose="error", # ignore No clusters found ) n_permutations = H0.size - 1 @@ -134,10 +215,14 @@ def _decoding_cluster_permutation_test( return t_vals, clusters, n_permutations -def average_time_by_time_decoding(cfg: SimpleNamespace, session: str): - # Get the time points from the very first subject. They are identical - # across all subjects and conditions, so this should suffice. - fname_epo = BIDSPath( +def _get_epochs_in_files( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> dict: + in_files = dict() + in_files["epochs"] = BIDSPath( subject=cfg.subjects[0], session=session, task=cfg.task, @@ -151,220 +236,201 @@ def average_time_by_time_decoding(cfg: SimpleNamespace, session: str): root=cfg.deriv_root, check=False, ) - epochs = mne.read_epochs(fname_epo) - dtg_decim = cfg.decoding_time_generalization_decim - if cfg.decoding_time_generalization and dtg_decim > 1: - epochs.decimate(dtg_decim, verbose="error") - times = epochs.times - subjects = cfg.subjects - del epochs, fname_epo + _update_for_splits(in_files, "epochs", single=True) + return in_files - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - if cfg.decoding_time_generalization: - time_points_shape = (len(times), len(times)) - else: - time_points_shape = (len(times),) - - contrast_score_stats = { - "cond_1": cond_1, - "cond_2": cond_2, - "times": times, - "N": len(subjects), - "decim": dtg_decim, - "mean": np.empty(time_points_shape), - "mean_min": np.empty(time_points_shape), - "mean_max": np.empty(time_points_shape), - "mean_se": np.empty(time_points_shape), - "mean_ci_lower": np.empty(time_points_shape), - "mean_ci_upper": np.empty(time_points_shape), - "cluster_all_times": np.array([]), - "cluster_all_t_values": np.array([]), - "cluster_t_threshold": np.nan, - "cluster_n_permutations": np.nan, - "clusters": list(), - } - - processing = ( - f"{cond_1}+{cond_2}+TimeByTime+{cfg.decoding_metric}".replace(op.sep, "") - .replace("_", "-") - .replace("-", "") - ) - # Extract mean CV scores from all subjects. - mean_scores = np.empty((len(subjects), *time_points_shape)) +def _decoding_out_fname( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + cond_1: str | None, + cond_2: str | None, + kind: str, + extension: str = ".mat", +): + if cond_1 is None: + assert cond_2 is None + processing = "" + else: + assert cond_2 is not None + processing = f"{cond_1}+{cond_2}+" + processing = ( + f"{processing}{kind}+{cfg.decoding_metric}".replace(op.sep, "") + .replace("_", "-") + .replace("-", "") + ) + return BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + run=None, + recording=cfg.rec, + space=cfg.space, + processing=processing, + suffix="decoding", + extension=extension, + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False, + ) - for sub_idx, subject in enumerate(subjects): - fname_mat = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - processing=processing, - suffix="decoding", - extension=".mat", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - decoding_data = loadmat(fname_mat) - mean_scores[sub_idx, :] = decoding_data["scores"].mean(axis=0) - - # Cluster permutation test. - # We can only permute for two or more subjects - # - # If we've performed time generalization, we will only use the diagonal - # CV scores here (classifiers trained and tested at the same time - # points). - - if len(subjects) > 1: - # Constrain cluster permutation test to time points of the - # time-locked event or later. - # We subtract the chance level from the scores as we'll be - # performing a 1-sample test (i.e., test against 0)! - idx = np.where(times >= 0)[0] - - if cfg.decoding_time_generalization: - cluster_permutation_scores = mean_scores[:, idx, idx] - 0.5 - else: - cluster_permutation_scores = mean_scores[:, idx] - 0.5 +def _get_input_fnames_decoding( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + cond_1: str, + cond_2: str, + kind: str, + extension: str = ".mat", +) -> dict: + in_files = _get_epochs_in_files(cfg=cfg, subject=subject, session=session) + for this_subject in cfg.subjects: + in_files[f"scores-{this_subject}"] = _decoding_out_fname( + cfg=cfg, + subject=this_subject, + session=session, + cond_1=cond_1, + cond_2=cond_2, + kind=kind, + extension=extension, + ) + return in_files - cluster_permutation_times = times[idx] - if cfg.cluster_forming_t_threshold is None: - import scipy.stats - cluster_forming_t_threshold = scipy.stats.t.ppf( - 1 - 0.05, len(cluster_permutation_scores) - 1 - ) - else: - cluster_forming_t_threshold = cfg.cluster_forming_t_threshold - - t_vals, clusters, n_perm = _decoding_cluster_permutation_test( - scores=cluster_permutation_scores, - times=cluster_permutation_times, - cluster_forming_t_threshold=cluster_forming_t_threshold, - n_permutations=cfg.cluster_n_permutations, - random_seed=cfg.random_state, - ) +@failsafe_run( + get_input_fnames=partial( + _get_input_fnames_decoding, + kind="TimeByTime", + ), +) +def average_time_by_time_decoding( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + cond_1: str, + cond_2: str, + in_files: dict, +) -> dict: + logger.info(**gen_log_kwargs(message="Averaging time-by-time decoding results")) + # Get the time points from the very first subject. They are identical + # across all subjects and conditions, so this should suffice. + epochs = mne.read_epochs(in_files.pop("epochs"), preload=False) + dtg_decim = cfg.decoding_time_generalization_decim + if cfg.decoding_time_generalization and dtg_decim > 1: + epochs.decimate(dtg_decim, verbose="error") + times = epochs.times + del epochs - contrast_score_stats.update( - { - "cluster_all_times": cluster_permutation_times, - "cluster_all_t_values": t_vals, - "cluster_t_threshold": cluster_forming_t_threshold, - "clusters": clusters, - "cluster_n_permutations": n_perm, - } - ) + if cfg.decoding_time_generalization: + time_points_shape = (len(times), len(times)) + else: + time_points_shape = (len(times),) - del cluster_permutation_scores, cluster_permutation_times, n_perm + n_subjects = len(cfg.subjects) + contrast_score_stats = { + "cond_1": cond_1, + "cond_2": cond_2, + "times": times, + "N": n_subjects, + "decim": dtg_decim, + "mean": np.empty(time_points_shape), + "mean_min": np.empty(time_points_shape), + "mean_max": np.empty(time_points_shape), + "mean_se": np.empty(time_points_shape), + "mean_ci_lower": np.empty(time_points_shape), + "mean_ci_upper": np.empty(time_points_shape), + "cluster_all_times": np.array([]), + "cluster_all_t_values": np.array([]), + "cluster_t_threshold": np.nan, + "cluster_n_permutations": np.nan, + "clusters": list(), + } - # Now we can calculate some descriptive statistics on the mean scores. - # We use the [:] here as a safeguard to ensure we don't mess up the - # dimensions. - # - # For time generalization, all values (each time point vs each other) - # are considered. - contrast_score_stats["mean"][:] = mean_scores.mean(axis=0) - contrast_score_stats["mean_min"][:] = mean_scores.min(axis=0) - contrast_score_stats["mean_max"][:] = mean_scores.max(axis=0) + # Extract mean CV scores from all subjects. + mean_scores = np.empty((n_subjects, *time_points_shape)) + + # Remaining in_files are all decoding data + assert len(in_files) == n_subjects, list(in_files.keys()) + for sub_idx, key in enumerate(list(in_files)): + decoding_data = loadmat(in_files.pop(key)) + mean_scores[sub_idx, :] = decoding_data["scores"].mean(axis=0) + + # Cluster permutation test. + # We can only permute for two or more subjects + # + # If we've performed time generalization, we will only use the diagonal + # CV scores here (classifiers trained and tested at the same time + # points). + + if n_subjects > 1: + # Constrain cluster permutation test to time points of the + # time-locked event or later. + # We subtract the chance level from the scores as we'll be + # performing a 1-sample test (i.e., test against 0)! + idx = np.where(times >= 0)[0] - # Finally, for each time point, bootstrap the mean, and calculate the - # SD of the bootstrapped distribution: this is the standard error of - # the mean. We also derive 95% confidence intervals. - rng = np.random.default_rng(seed=cfg.random_state) - for time_idx in range(len(times)): - if cfg.decoding_time_generalization: - data = mean_scores[:, time_idx, time_idx] - else: - data = mean_scores[:, time_idx] - scores_resampled = rng.choice( - data, size=(cfg.n_boot, len(subjects)), replace=True - ) - bootstrapped_means = scores_resampled.mean(axis=1) + if cfg.decoding_time_generalization: + cluster_permutation_scores = mean_scores[:, idx, idx] - 0.5 + else: + cluster_permutation_scores = mean_scores[:, idx] - 0.5 - # SD of the bootstrapped distribution == SE of the metric. - se = bootstrapped_means.std(ddof=1) - ci_lower = np.quantile(bootstrapped_means, q=0.025) - ci_upper = np.quantile(bootstrapped_means, q=0.975) - - contrast_score_stats["mean_se"][time_idx] = se - contrast_score_stats["mean_ci_lower"][time_idx] = ci_lower - contrast_score_stats["mean_ci_upper"][time_idx] = ci_upper - - del bootstrapped_means, se, ci_lower, ci_upper - - fname_out = fname_mat.copy().update(subject="average") - savemat(fname_out, contrast_score_stats) - del contrast_score_stats, fname_out - - -def average_full_epochs_decoding(cfg: SimpleNamespace, session: str): - for contrast in cfg.decoding_contrasts: - cond_1, cond_2 = contrast - n_subjects = len(cfg.subjects) - - contrast_score_stats = { - "cond_1": cond_1, - "cond_2": cond_2, - "N": n_subjects, - "subjects": cfg.subjects, - "scores": np.nan, - "mean": np.nan, - "mean_min": np.nan, - "mean_max": np.nan, - "mean_se": np.nan, - "mean_ci_lower": np.nan, - "mean_ci_upper": np.nan, - } - - processing = ( - f"{cond_1}+{cond_2}+FullEpochs+{cfg.decoding_metric}".replace(op.sep, "") - .replace("_", "-") - .replace("-", "") - ) + cluster_permutation_times = times[idx] + if cfg.cluster_forming_t_threshold is None: + import scipy.stats - # Extract mean CV scores from all subjects. - mean_scores = np.empty(n_subjects) - for sub_idx, subject in enumerate(cfg.subjects): - fname_mat = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - processing=processing, - suffix="decoding", - extension=".mat", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, + cluster_forming_t_threshold = scipy.stats.t.ppf( + 1 - 0.05, len(cluster_permutation_scores) - 1 ) + else: + cluster_forming_t_threshold = cfg.cluster_forming_t_threshold - decoding_data = loadmat(fname_mat) - mean_scores[sub_idx] = decoding_data["scores"].mean() - - # Now we can calculate some descriptive statistics on the mean scores. - # We use the [:] here as a safeguard to ensure we don't mess up the - # dimensions. - contrast_score_stats["scores"] = mean_scores - contrast_score_stats["mean"] = mean_scores.mean() - contrast_score_stats["mean_min"] = mean_scores.min() - contrast_score_stats["mean_max"] = mean_scores.max() + t_vals, clusters, n_perm = _decoding_cluster_permutation_test( + scores=cluster_permutation_scores, + times=cluster_permutation_times, + cluster_forming_t_threshold=cluster_forming_t_threshold, + n_permutations=cfg.cluster_n_permutations, + random_seed=cfg.random_state, + ) - # Finally, bootstrap the mean, and calculate the - # SD of the bootstrapped distribution: this is the standard error of - # the mean. We also derive 95% confidence intervals. - rng = np.random.default_rng(seed=cfg.random_state) - scores_resampled = rng.choice( - mean_scores, size=(cfg.n_boot, n_subjects), replace=True + contrast_score_stats.update( + { + "cluster_all_times": cluster_permutation_times, + "cluster_all_t_values": t_vals, + "cluster_t_threshold": cluster_forming_t_threshold, + "clusters": clusters, + "cluster_n_permutations": n_perm, + } ) + + del cluster_permutation_scores, cluster_permutation_times, n_perm + + # Now we can calculate some descriptive statistics on the mean scores. + # We use the [:] here as a safeguard to ensure we don't mess up the + # dimensions. + # + # For time generalization, all values (each time point vs each other) + # are considered. + contrast_score_stats["mean"][:] = mean_scores.mean(axis=0) + contrast_score_stats["mean_min"][:] = mean_scores.min(axis=0) + contrast_score_stats["mean_max"][:] = mean_scores.max(axis=0) + + # Finally, for each time point, bootstrap the mean, and calculate the + # SD of the bootstrapped distribution: this is the standard error of + # the mean. We also derive 95% confidence intervals. + rng = np.random.default_rng(seed=cfg.random_state) + for time_idx in range(len(times)): + if cfg.decoding_time_generalization: + data = mean_scores[:, time_idx, time_idx] + else: + data = mean_scores[:, time_idx] + scores_resampled = rng.choice(data, size=(cfg.n_boot, n_subjects), replace=True) bootstrapped_means = scores_resampled.mean(axis=1) # SD of the bootstrapped distribution == SE of the metric. @@ -372,174 +438,481 @@ def average_full_epochs_decoding(cfg: SimpleNamespace, session: str): ci_lower = np.quantile(bootstrapped_means, q=0.025) ci_upper = np.quantile(bootstrapped_means, q=0.975) - contrast_score_stats["mean_se"] = se - contrast_score_stats["mean_ci_lower"] = ci_lower - contrast_score_stats["mean_ci_upper"] = ci_upper + contrast_score_stats["mean_se"][time_idx] = se + contrast_score_stats["mean_ci_lower"][time_idx] = ci_lower + contrast_score_stats["mean_ci_upper"][time_idx] = ci_upper del bootstrapped_means, se, ci_lower, ci_upper - fname_out = fname_mat.copy().update(subject="average") - if not fname_out.fpath.parent.exists(): - os.makedirs(fname_out.fpath.parent) - savemat(fname_out, contrast_score_stats) - del contrast_score_stats, fname_out + out_files = dict() + out_files["mat"] = _decoding_out_fname( + cfg=cfg, + subject=subject, + session=session, + cond_1=cond_1, + cond_2=cond_2, + kind="TimeByTime", + ) + savemat(out_files["mat"], contrast_score_stats) + + section = f"Decoding: time-by-time, N = {len(cfg.subjects)}" + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + logger.info(**gen_log_kwargs(message="Adding time-by-time decoding results")) + import matplotlib.pyplot as plt + + tags = ( + "epochs", + "contrast", + "decoding", + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", + ) + decoding_data = loadmat(out_files["mat"]) + + # Plot scores + fig = _plot_time_by_time_decoding_scores_gavg( + cfg=cfg, + decoding_data=decoding_data, + ) + caption = ( + f'Based on N={decoding_data["N"].squeeze()} ' + f"subjects. Standard error and confidence interval " + f"of the mean were bootstrapped with {cfg.n_boot} " + f"resamples. CI must not be used for statistical inference here, " + f"as it is not corrected for multiple testing." + ) + if len(get_subjects(cfg)) > 1: + caption += ( + f" Time periods with decoding performance significantly above " + f"chance, if any, were derived with a one-tailed " + f"cluster-based permutation test " + f'({decoding_data["cluster_n_permutations"].squeeze()} ' + f"permutations) and are highlighted in yellow." + ) + title = f"Decoding over time: {cond_1} vs. {cond_2}" + report.add_figure( + fig=fig, + title=title, + caption=caption, + section=section, + tags=tags, + replace=True, + ) + plt.close(fig) + + # Plot t-values used to form clusters + if len(get_subjects(cfg)) > 1: + fig = plot_time_by_time_decoding_t_values(decoding_data=decoding_data) + t_threshold = np.round(decoding_data["cluster_t_threshold"], 3).item() + caption = ( + f"Observed t-values. Time points with " + f"t-values > {t_threshold} were used to form clusters." + ) + report.add_figure( + fig=fig, + title=f"t-values across time: {cond_1} vs. {cond_2}", + caption=caption, + section=section, + tags=tags, + replace=True, + ) + plt.close(fig) + if cfg.decoding_time_generalization: + fig = _plot_decoding_time_generalization( + decoding_data=decoding_data, + metric=cfg.decoding_metric, + kind="grand-average", + ) + caption = ( + f"Time generalization (generalization across time, GAT): " + f"each classifier is trained on each time point, and tested " + f"on all other time points. The results were averaged across " + f'N={decoding_data["N"].item()} subjects.' + ) + title = f"Time generalization: {cond_1} vs. {cond_2}" + report.add_figure( + fig=fig, + title=title, + caption=caption, + section=section, + tags=tags, + replace=True, + ) + plt.close(fig) -def average_csp_decoding( + return _prep_out_files(out_files=out_files, exec_params=exec_params) + + +@failsafe_run( + get_input_fnames=partial( + _get_input_fnames_decoding, + kind="FullEpochs", + ), +) +def average_full_epochs_decoding( + *, cfg: SimpleNamespace, - session: str, + exec_params: SimpleNamespace, subject: str, - condition_1: str, - condition_2: str, -): - msg = f"Summarizing CSP results: {condition_1} - {condition_2}." - logger.info(**gen_log_kwargs(message=msg)) + session: str | None, + cond_1: str, + cond_2: str, + in_files: dict, +) -> dict: + n_subjects = len(cfg.subjects) + in_files.pop("epochs") # not used but okay to include + + contrast_score_stats = { + "cond_1": cond_1, + "cond_2": cond_2, + "N": n_subjects, + "subjects": cfg.subjects, + "scores": np.nan, + "mean": np.nan, + "mean_min": np.nan, + "mean_max": np.nan, + "mean_se": np.nan, + "mean_ci_lower": np.nan, + "mean_ci_upper": np.nan, + } # Extract mean CV scores from all subjects. - a_vs_b = f"{condition_1}+{condition_2}".replace(op.sep, "") - processing = f"{a_vs_b}+CSP+{cfg.decoding_metric}" - processing = processing.replace("_", "-").replace("-", "") + mean_scores = np.empty(n_subjects) + for sub_idx, key in enumerate(list(in_files)): + decoding_data = loadmat(in_files.pop(key)) + mean_scores[sub_idx] = decoding_data["scores"].mean() + + # Now we can calculate some descriptive statistics on the mean scores. + # We use the [:] here as a safeguard to ensure we don't mess up the + # dimensions. + contrast_score_stats["scores"] = mean_scores + contrast_score_stats["mean"] = mean_scores.mean() + contrast_score_stats["mean_min"] = mean_scores.min() + contrast_score_stats["mean_max"] = mean_scores.max() + + # Finally, bootstrap the mean, and calculate the + # SD of the bootstrapped distribution: this is the standard error of + # the mean. We also derive 95% confidence intervals. + rng = np.random.default_rng(seed=cfg.random_state) + scores_resampled = rng.choice( + mean_scores, size=(cfg.n_boot, n_subjects), replace=True + ) + bootstrapped_means = scores_resampled.mean(axis=1) - all_decoding_data_freq = [] - all_decoding_data_time_freq = [] + # SD of the bootstrapped distribution == SE of the metric. + se = bootstrapped_means.std(ddof=1) + ci_lower = np.quantile(bootstrapped_means, q=0.025) + ci_upper = np.quantile(bootstrapped_means, q=0.975) + + contrast_score_stats["mean_se"] = se + contrast_score_stats["mean_ci_lower"] = ci_lower + contrast_score_stats["mean_ci_upper"] = ci_upper - # First load the data. - fname_out = BIDSPath( - subject="average", + del bootstrapped_means, se, ci_lower, ci_upper + + out_files = dict() + fname_out = out_files["mat"] = _decoding_out_fname( + cfg=cfg, + subject=subject, session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - recording=cfg.rec, - space=cfg.space, - processing=processing, - suffix="decoding", + cond_1=cond_1, + cond_2=cond_2, + kind="FullEpochs", + ) + if not fname_out.fpath.parent.exists(): + os.makedirs(fname_out.fpath.parent) + savemat(fname_out, contrast_score_stats) + return _prep_out_files(out_files=out_files, exec_params=exec_params) + + +def get_input_files_average_full_epochs_report( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + decoding_contrasts: list[list[str]], +) -> dict: + in_files = dict() + for contrast in decoding_contrasts: + in_files[f"decoding-full-epochs-{contrast}"] = _decoding_out_fname( + cfg=cfg, + subject=subject, + session=session, + cond_1=contrast[0], + cond_2=contrast[1], + kind="FullEpochs", + ) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_files_average_full_epochs_report, +) +def average_full_epochs_report( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + decoding_contrasts: list[list[str]], + in_files: dict, +) -> dict: + """Add decoding results to the grand average report.""" + out_files = dict() + out_files["cluster"] = _decoding_out_fname( + cfg=cfg, + subject=subject, + session=session, + cond_1=None, + cond_2=None, + kind="FullEpochs", extension=".xlsx", - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, ) - for subject in cfg.subjects: - fname_xlsx = fname_out.copy().update(subject=subject) - decoding_data_freq = pd.read_excel( - fname_xlsx, - sheet_name="CSP Frequency", - dtype={"subject": str}, # don't drop trailing zeros + + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + import matplotlib.pyplot as plt # nested import to help joblib + + logger.info( + **gen_log_kwargs(message="Adding full-epochs decoding results to report") + ) + + # Full-epochs decoding + all_decoding_scores = [] + for key in list(in_files): + if not key.startswith("decoding-full-epochs-"): + continue + decoding_data = loadmat(in_files.pop(key)) + all_decoding_scores.append(np.atleast_1d(decoding_data["scores"].squeeze())) + del decoding_data + + fig, caption, data = _plot_full_epochs_decoding_scores( + contrast_names=_contrasts_to_names(decoding_contrasts), + scores=all_decoding_scores, + metric=cfg.decoding_metric, + kind="grand-average", ) - decoding_data_time_freq = pd.read_excel( - fname_xlsx, - sheet_name="CSP Time-Frequency", - dtype={"subject": str}, # don't drop trailing zeros + with pd.ExcelWriter(out_files["cluster"]) as w: + data.to_excel(w, sheet_name="FullEpochs", index=False) + report.add_figure( + fig=fig, + title="Full-epochs decoding", + section=f"Decoding: full-epochs, N = {len(cfg.subjects)}", + caption=caption, + tags=( + "epochs", + "contrast", + "decoding", + *[ + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}" + for cond_1, cond_2 in cfg.decoding_contrasts + ], + ), + replace=True, ) - all_decoding_data_freq.append(decoding_data_freq) - all_decoding_data_time_freq.append(decoding_data_time_freq) + # close figure to save memory + plt.close(fig) + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +@failsafe_run( + get_input_fnames=partial( + _get_input_fnames_decoding, + kind="CSP", + extension=".xlsx", + ), +) +def average_csp_decoding( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + cond_1: str, + cond_2: str, + in_files: dict, +): + msg = f"Summarizing CSP results: {cond_1} - {cond_2}." + logger.info(**gen_log_kwargs(message=msg)) + in_files.pop("epochs") + + all_decoding_data_freq = [] + all_decoding_data_time_freq = [] + for key in list(in_files): + fname_xlsx = in_files.pop(key) + with pd.ExcelFile(fname_xlsx) as xf: + decoding_data_freq = pd.read_excel( + xf, + sheet_name="CSP Frequency", + dtype={"subject": str}, # don't drop trailing zeros + ) + all_decoding_data_freq.append(decoding_data_freq) + if "CSP Time-Frequency" in xf.sheet_names: + decoding_data_time_freq = pd.read_excel( + xf, + sheet_name="CSP Time-Frequency", + dtype={"subject": str}, # don't drop trailing zeros + ) + all_decoding_data_time_freq.append(decoding_data_time_freq) del fname_xlsx # Now calculate descriptes and bootstrap CIs. grand_average_freq = _average_csp_time_freq( cfg=cfg, + subject=subject, + session=session, data=all_decoding_data_freq, ) - grand_average_time_freq = _average_csp_time_freq( + if len(all_decoding_data_time_freq): + grand_average_time_freq = _average_csp_time_freq( + cfg=cfg, + subject=subject, + session=session, + data=all_decoding_data_time_freq, + ) + else: + grand_average_time_freq = None + + out_files = dict() + out_files["freq"] = _decoding_out_fname( cfg=cfg, - data=all_decoding_data_time_freq, + subject=subject, + session=session, + cond_1=cond_1, + cond_2=cond_2, + kind="CSP", + extension=".xlsx", ) - - with pd.ExcelWriter(fname_out) as w: + with pd.ExcelWriter(out_files["freq"]) as w: grand_average_freq.to_excel(w, sheet_name="CSP Frequency", index=False) - grand_average_time_freq.to_excel( - w, sheet_name="CSP Time-Frequency", index=False - ) + if grand_average_time_freq is not None: + grand_average_time_freq.to_excel( + w, sheet_name="CSP Time-Frequency", index=False + ) + del grand_average_time_freq # Perform a cluster-based permutation test. subjects = cfg.subjects - time_bins = np.array(cfg.decoding_csp_times) - if time_bins.ndim == 1: - time_bins = np.array(list(zip(time_bins[:-1], time_bins[1:]))) - time_bins = pd.DataFrame(time_bins, columns=["t_min", "t_max"]) - freq_name_to_bins_map = _handle_csp_args( - cfg.decoding_csp_times, cfg.decoding_csp_freqs, cfg.decoding_metric + freq_name_to_bins_map, time_bins = _handle_csp_args( + cfg.decoding_csp_times, + cfg.decoding_csp_freqs, + cfg.decoding_metric, + epochs_tmin=cfg.epochs_tmin, + epochs_tmax=cfg.epochs_tmax, + time_frequency_freq_min=cfg.time_frequency_freq_min, + time_frequency_freq_max=cfg.time_frequency_freq_max, ) - data_for_clustering = {} - for freq_range_name in freq_name_to_bins_map: - a = np.empty( - shape=( - len(subjects), - len(time_bins), - len(freq_name_to_bins_map[freq_range_name]), + if not len(time_bins): + fname_csp_cluster_results = None + else: + time_bins = pd.DataFrame(time_bins, columns=["t_min", "t_max"]) + data_for_clustering = {} + for freq_range_name in freq_name_to_bins_map: + a = np.empty( + shape=( + len(subjects), + len(time_bins), + len(freq_name_to_bins_map[freq_range_name]), + ) ) + a.fill(np.nan) + data_for_clustering[freq_range_name] = a + + g = pd.concat(all_decoding_data_time_freq).groupby( + ["subject", "freq_range_name", "t_min", "t_max"] ) - a.fill(np.nan) - data_for_clustering[freq_range_name] = a - g = pd.concat(all_decoding_data_time_freq).groupby( - ["subject", "freq_range_name", "t_min", "t_max"] - ) + for (subject_, freq_range_name, t_min, t_max), df in g: + scores = df["mean_crossval_score"] + sub_idx = subjects.index(subject_) + time_bin_idx = time_bins.loc[ + (np.isclose(time_bins["t_min"], t_min)) + & (np.isclose(time_bins["t_max"], t_max)), + :, + ].index + assert len(time_bin_idx) == 1 + time_bin_idx = time_bin_idx[0] + data_for_clustering[freq_range_name][sub_idx][time_bin_idx] = scores + + if cfg.cluster_forming_t_threshold is None: + import scipy.stats + + cluster_forming_t_threshold = scipy.stats.t.ppf( + 1 - 0.05, + len(cfg.subjects) - 1, # one-sided test + ) + else: + cluster_forming_t_threshold = cfg.cluster_forming_t_threshold - for (subject, freq_range_name, t_min, t_max), df in g: - scores = df["mean_crossval_score"] - sub_idx = subjects.index(subject) - time_bin_idx = time_bins.loc[ - (np.isclose(time_bins["t_min"], t_min)) - & (np.isclose(time_bins["t_max"], t_max)), - :, - ].index - assert len(time_bin_idx) == 1 - time_bin_idx = time_bin_idx[0] - data_for_clustering[freq_range_name][sub_idx][time_bin_idx] = scores - - if cfg.cluster_forming_t_threshold is None: - import scipy.stats - - cluster_forming_t_threshold = scipy.stats.t.ppf( - 1 - 0.05, len(cfg.subjects) - 1 # one-sided test - ) - else: - cluster_forming_t_threshold = cfg.cluster_forming_t_threshold - - cluster_permutation_results = {} - for freq_range_name, X in data_for_clustering.items(): - ( - t_vals, - all_clusters, - cluster_p_vals, - H0, - ) = mne.stats.permutation_cluster_1samp_test( # noqa: E501 - X=X - 0.5, # One-sample test against zero. - threshold=cluster_forming_t_threshold, - n_permutations=cfg.cluster_n_permutations, - adjacency=None, # each time & freq bin connected to its neighbors - out_type="mask", - tail=1, # one-sided: significantly above chance level - seed=cfg.random_state, + cluster_permutation_results = {} + for freq_range_name, X in data_for_clustering.items(): + if len(X) < 2: + t_vals = np.full(X.shape[1:], np.nan) + H0 = all_clusters = cluster_p_vals = np.array([]) + else: + ( + t_vals, + all_clusters, + cluster_p_vals, + H0, + ) = mne.stats.permutation_cluster_1samp_test( # noqa: E501 + X=X - 0.5, # One-sample test against zero. + threshold=cluster_forming_t_threshold, + n_permutations=cfg.cluster_n_permutations, + adjacency=None, # each time & freq bin connected to its neighbors + out_type="mask", + tail=1, # one-sided: significantly above chance level + seed=cfg.random_state, + ) + n_permutations = H0.size - 1 + all_clusters = np.array(all_clusters) # preserve "empty" 0th dimension + cluster_permutation_results[freq_range_name] = { + "mean_crossval_scores": X.mean(axis=0), + "t_vals": t_vals, + "clusters": all_clusters, + "cluster_p_vals": cluster_p_vals, + "cluster_t_threshold": cluster_forming_t_threshold, + "n_permutations": n_permutations, + "time_bin_edges": cfg.decoding_csp_times, + "freq_bin_edges": cfg.decoding_csp_freqs[freq_range_name], + } + + out_files["cluster"] = out_files["freq"].copy().update(extension=".mat") + savemat(file_name=out_files["cluster"], mdict=cluster_permutation_results) + fname_csp_cluster_results = out_files["cluster"] + + assert subject == "average" + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + add_csp_grand_average( + cfg=cfg, + subject=subject, + session=session, + report=report, + cond_1=cond_1, + cond_2=cond_2, + fname_csp_freq_results=out_files["freq"], + fname_csp_cluster_results=fname_csp_cluster_results, ) - n_permutations = H0.size - 1 - all_clusters = np.array(all_clusters) # preserve "empty" 0th dimension - cluster_permutation_results[freq_range_name] = { - "mean_crossval_scores": X.mean(axis=0), - "t_vals": t_vals, - "clusters": all_clusters, - "cluster_p_vals": cluster_p_vals, - "cluster_t_threshold": cluster_forming_t_threshold, - "n_permutations": n_permutations, - "time_bin_edges": cfg.decoding_csp_times, - "freq_bin_edges": cfg.decoding_csp_freqs[freq_range_name], - } - - fname_out.update(extension=".mat") - savemat(file_name=fname_out, mdict=cluster_permutation_results) + return _prep_out_files(out_files=out_files, exec_params=exec_params) def _average_csp_time_freq( *, cfg: SimpleNamespace, + subject: str, + session: str | None, data: pd.DataFrame, ) -> pd.DataFrame: # Prepare a dataframe for storing the results. grand_average = data[0].copy() del grand_average["mean_crossval_score"] - grand_average["subject"] = "average" + grand_average["subject"] = subject grand_average["mean"] = np.nan grand_average["mean_se"] = np.nan grand_average["mean_ci_lower"] = np.nan @@ -567,7 +940,8 @@ def _average_csp_time_freq( bootstrapped_means = scores_resampled.mean(axis=1) # SD of the bootstrapped distribution == SE of the metric. - se = bootstrapped_means.std(ddof=1) + with np.errstate(over="raise"): + se = bootstrapped_means.std(ddof=1) ci_lower = np.quantile(bootstrapped_means, q=0.025) ci_upper = np.quantile(bootstrapped_means, q=0.975) @@ -593,17 +967,20 @@ def get_config( *, config, ) -> SimpleNamespace: - dtg_decim = config.decoding_time_generalization_decim cfg = SimpleNamespace( subjects=get_subjects(config), task_is_rest=config.task_is_rest, conditions=config.conditions, - contrasts=get_all_contrasts(config), + contrasts=config.contrasts, + epochs_tmin=config.epochs_tmin, + epochs_tmax=config.epochs_tmax, + time_frequency_freq_min=config.time_frequency_freq_min, + time_frequency_freq_max=config.time_frequency_freq_max, decode=config.decode, decoding_metric=config.decoding_metric, decoding_n_splits=config.decoding_n_splits, decoding_time_generalization=config.decoding_time_generalization, - decoding_time_generalization_decim=dtg_decim, + decoding_time_generalization_decim=config.decoding_time_generalization_decim, decoding_csp=config.decoding_csp, decoding_csp_freqs=config.decoding_csp_freqs, decoding_csp_times=config.decoding_csp_times, @@ -618,7 +995,6 @@ def get_config( eeg_reference=get_eeg_reference(config), sessions=get_sessions(config), exclude_subjects=config.exclude_subjects, - all_contrasts=get_all_contrasts(config), report_evoked_n_time_points=config.report_evoked_n_time_points, cluster_permutation_p_threshold=config.cluster_permutation_p_threshold, # TODO: needed because get_datatype gets called again... @@ -628,67 +1004,93 @@ def get_config( return cfg -@failsafe_run() -def run_group_average_sensor( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, -) -> None: - if cfg.task_is_rest: +def main(*, config: SimpleNamespace) -> None: + if config.task_is_rest: msg = ' … skipping: for "rest" task.' logger.info(**gen_log_kwargs(message=msg)) return - - sessions = get_sessions(cfg) - if not sessions: - sessions = [None] - + cfg = get_config( + config=config, + ) + exec_params = config.exec_params + subject = "average" + sessions = get_sessions(config=config) + if cfg.decode or cfg.decoding_csp: + decoding_contrasts = get_decoding_contrasts(config=cfg) + else: + decoding_contrasts = [] + logs = list() with get_parallel_backend(exec_params): - for session in sessions: - evokeds = average_evokeds( + # 1. Evoked data + logs += [ + average_evokeds( cfg=cfg, + exec_params=exec_params, subject=subject, session=session, ) - if exec_params.interactive: - for evoked in evokeds: - evoked.plot() - - if cfg.decode: - average_full_epochs_decoding(cfg, session) - average_time_by_time_decoding(cfg, session) - if cfg.decoding_csp: + for session in sessions + ] + + # 2. Time decoding + if cfg.decode and decoding_contrasts: + # Full epochs (single report function plots across all contrasts + # so it's a separate cached step) + logs += [ + average_full_epochs_decoding( + cfg=cfg, + subject=subject, + session=session, + cond_1=contrast[0], + cond_2=contrast[1], + exec_params=exec_params, + ) + for session in sessions + for contrast in decoding_contrasts + ] + logs += [ + average_full_epochs_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + decoding_contrasts=decoding_contrasts, + ) + for session in sessions + ] + # Time-by-time parallel, run_func = parallel_func( - average_csp_decoding, exec_params=exec_params + average_time_by_time_decoding, exec_params=exec_params ) - parallel( + logs += parallel( run_func( cfg=cfg, - session=session, + exec_params=exec_params, subject=subject, - condition_1=contrast[0], - condition_2=contrast[1], + session=session, + cond_1=contrast[0], + cond_2=contrast[1], ) - for session in get_sessions(config=cfg) - for contrast in get_decoding_contrasts(config=cfg) + for session in sessions + for contrast in decoding_contrasts ) - for session in sessions: - run_report_average_sensor( - cfg=cfg, - exec_params=exec_params, - subject=subject, - session=session, + # 3. CSP + if cfg.decoding_csp and decoding_contrasts: + parallel, run_func = parallel_func( + average_csp_decoding, exec_params=exec_params + ) + logs += parallel( + run_func( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + cond_1=contrast[0], + cond_2=contrast[1], + ) + for contrast in get_decoding_contrasts(config=cfg) + for session in sessions ) - -def main(*, config: SimpleNamespace) -> None: - log = run_group_average_sensor( - cfg=get_config( - config=config, - ), - exec_params=config.exec_params, - subject="average", - ) - save_logs(config=config, logs=[log]) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/sensor/__init__.py b/mne_bids_pipeline/steps/sensor/__init__.py index fc76bf551..848efadf8 100644 --- a/mne_bids_pipeline/steps/sensor/__init__.py +++ b/mne_bids_pipeline/steps/sensor/__init__.py @@ -1,12 +1,14 @@ """Sensor-space analysis.""" -from . import _01_make_evoked -from . import _02_decoding_full_epochs -from . import _03_decoding_time_by_time -from . import _04_time_frequency -from . import _05_decoding_csp -from . import _06_make_cov -from . import _99_group_average +from . import ( + _01_make_evoked, + _02_decoding_full_epochs, + _03_decoding_time_by_time, + _04_time_frequency, + _05_decoding_csp, + _06_make_cov, + _99_group_average, +) _STEPS = ( _01_make_evoked, diff --git a/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py b/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py index 110a9c103..95c451327 100644 --- a/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py +++ b/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py @@ -10,14 +10,17 @@ import mne from ..._config_utils import ( - get_fs_subject, - get_subjects, + _bids_kwargs, _get_bem_conductivity, + get_fs_subject, get_fs_subjects_dir, + get_sessions, + get_subjects, ) -from ..._logging import logger, gen_log_kwargs +from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func -from ..._run import failsafe_run, save_logs +from ..._report import _open_report, _render_bem +from ..._run import _prep_out_files, failsafe_run, save_logs def _get_bem_params(cfg: SimpleNamespace): @@ -38,6 +41,7 @@ def get_input_fnames_make_bem_surfaces( *, cfg: SimpleNamespace, subject: str, + session: str | None, ) -> dict: in_files = dict() mri_images, mri_dir, flash_dir = _get_bem_params(cfg) @@ -54,6 +58,7 @@ def get_output_fnames_make_bem_surfaces( *, cfg: SimpleNamespace, subject: str, + session: str | None, ) -> dict: out_files = dict() conductivity, _ = _get_bem_conductivity(cfg) @@ -73,6 +78,7 @@ def make_bem_surfaces( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, + session: str | None, in_files: dict, ) -> dict: mri_images, _, _ = _get_bem_params(cfg) @@ -96,8 +102,21 @@ def make_bem_surfaces( show=show, verbose=cfg.freesurfer_verbose, ) - out_files = get_output_fnames_make_bem_surfaces(cfg=cfg, subject=subject) - return out_files + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + _render_bem(report=report, cfg=cfg, subject=subject, session=session) + out_files = get_output_fnames_make_bem_surfaces( + cfg=cfg, + subject=subject, + session=session, + ) + return _prep_out_files( + exec_params=exec_params, + out_files=out_files, + check_relative=cfg.fs_subjects_dir, + bids_only=False, + ) def get_config( @@ -112,6 +131,7 @@ def get_config( freesurfer_verbose=config.freesurfer_verbose, use_template_mri=config.use_template_mri, ch_types=config.ch_types, + **_bids_kwargs(config=config), ) return cfg @@ -143,6 +163,7 @@ def main(*, config: SimpleNamespace) -> None: ), exec_params=config.exec_params, subject=subject, + session=get_sessions(config)[0], force_run=config.recreate_bem, ) for subject in get_subjects(config) diff --git a/mne_bids_pipeline/steps/source/_02_make_bem_solution.py b/mne_bids_pipeline/steps/source/_02_make_bem_solution.py index c8b3ddc51..1f2947d01 100644 --- a/mne_bids_pipeline/steps/source/_02_make_bem_solution.py +++ b/mne_bids_pipeline/steps/source/_02_make_bem_solution.py @@ -10,13 +10,13 @@ from ..._config_utils import ( _get_bem_conductivity, - get_fs_subjects_dir, get_fs_subject, + get_fs_subjects_dir, get_subjects, ) -from ..._logging import logger, gen_log_kwargs -from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run, save_logs +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._run import _prep_out_files, failsafe_run, save_logs def get_input_fnames_make_bem_solution( @@ -69,7 +69,12 @@ def make_bem_solution( out_files = get_output_fnames_make_bem_solution(cfg=cfg, subject=subject) mne.write_bem_surfaces(out_files["model"], bem_model, overwrite=True) mne.write_bem_solution(out_files["sol"], bem_sol, overwrite=True) - return out_files + return _prep_out_files( + exec_params=exec_params, + out_files=out_files, + check_relative=cfg.fs_subjects_dir, + bids_only=False, + ) def get_config( @@ -94,7 +99,7 @@ def main(*, config) -> None: return if config.use_template_mri is not None: - msg = "Skipping, BEM solution computation not needed for " "MRI template …" + msg = "Skipping, BEM solution computation not needed for MRI template …" logger.info(**gen_log_kwargs(message=msg, emoji="skip")) if config.use_template_mri == "fsaverage": # Ensure we have the BEM diff --git a/mne_bids_pipeline/steps/source/_03_setup_source_space.py b/mne_bids_pipeline/steps/source/_03_setup_source_space.py index 52c2538dd..bcd4bef34 100644 --- a/mne_bids_pipeline/steps/source/_03_setup_source_space.py +++ b/mne_bids_pipeline/steps/source/_03_setup_source_space.py @@ -8,9 +8,9 @@ import mne from ..._config_utils import get_fs_subject, get_fs_subjects_dir, get_subjects -from ..._logging import logger, gen_log_kwargs -from ..._run import failsafe_run, save_logs -from ..._parallel import parallel_func, get_parallel_backend +from ..._logging import gen_log_kwargs, logger +from ..._parallel import get_parallel_backend, parallel_func +from ..._run import _prep_out_files, failsafe_run, save_logs def get_input_fnames_setup_source_space(*, cfg, subject): @@ -55,7 +55,12 @@ def run_setup_source_space( in_files.clear() # all used by setup_source_space out_files = get_output_fnames_setup_source_space(cfg=cfg, subject=subject) mne.write_source_spaces(out_files["src"], src, overwrite=True) - return out_files + return _prep_out_files( + exec_params=exec_params, + out_files=out_files, + check_relative=cfg.fs_subjects_dir, + bids_only=False, + ) def get_config( diff --git a/mne_bids_pipeline/steps/source/_04_make_forward.py b/mne_bids_pipeline/steps/source/_04_make_forward.py index c8b7ae7e1..1596cadff 100644 --- a/mne_bids_pipeline/steps/source/_04_make_forward.py +++ b/mne_bids_pipeline/steps/source/_04_make_forward.py @@ -4,32 +4,33 @@ """ from types import SimpleNamespace -from typing import Optional import mne +import numpy as np from mne.coreg import Coregistration from mne_bids import BIDSPath, get_head_mri_trans from ..._config_utils import ( - get_fs_subject, - get_subjects, + _bids_kwargs, _get_bem_conductivity, + _meg_in_ch_types, + get_fs_subject, get_fs_subjects_dir, get_runs, - _meg_in_ch_types, get_sessions, - _bids_kwargs, + get_subjects, ) -from ..._config_import import _import_config -from ..._logging import logger, gen_log_kwargs +from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func -from ..._report import _open_report -from ..._run import failsafe_run, save_logs +from ..._report import _open_report, _render_bem +from ..._run import _prep_out_files, _sanitize_callable, failsafe_run, save_logs def _prepare_trans_template( *, cfg: SimpleNamespace, + subject: str, + session: str | None, info: mne.Info, ) -> mne.transforms.Transform: assert isinstance(cfg.use_template_mri, str) @@ -47,58 +48,40 @@ def _prepare_trans_template( ) else: fiducials = "estimated" # get fiducials from fsaverage + logger.info(**gen_log_kwargs("Matching template MRI using fiducials")) coreg = Coregistration( info, cfg.fs_subject, cfg.fs_subjects_dir, fiducials=fiducials ) - coreg.fit_fiducials(verbose=True) + # Adapted from MNE-Python + coreg.fit_fiducials(verbose=False) + dist = np.median(coreg.compute_dig_mri_distances() * 1000) + logger.info(**gen_log_kwargs(f"Median dig ↔ MRI distance: {dist:6.2f} mm")) trans = coreg.trans return trans -def _prepare_trans( +def _prepare_trans_subject( *, cfg: SimpleNamespace, exec_params: SimpleNamespace, + subject: str, + session: str | None, bids_path: BIDSPath, ) -> mne.transforms.Transform: # Generate a head ↔ MRI transformation matrix from the # electrophysiological and MRI sidecar files, and save it to an MNE # "trans" file in the derivatives folder. - subject, session = bids_path.subject, bids_path.session - - # TODO: This breaks our encapsulation - config = _import_config( - config_path=exec_params.config_path, - check=False, - log=False, - ) - if config.mri_t1_path_generator is None: - t1_bids_path = None - else: - t1_bids_path = BIDSPath(subject=subject, session=session, root=cfg.bids_root) - t1_bids_path = config.mri_t1_path_generator(t1_bids_path.copy()) - if t1_bids_path.suffix is None: - t1_bids_path.update(suffix="T1w") - if t1_bids_path.datatype is None: - t1_bids_path.update(datatype="anat") - - if config.mri_landmarks_kind is None: - landmarks_kind = None - else: - landmarks_kind = config.mri_landmarks_kind( - BIDSPath(subject=subject, session=session) - ) - msg = "Estimating head ↔ MRI transform" + msg = "Computing head ↔ MRI transform from matched fiducials" logger.info(**gen_log_kwargs(message=msg)) trans = get_head_mri_trans( bids_path.copy().update(run=cfg.runs[0], root=cfg.bids_root, extension=None), - t1_bids_path=t1_bids_path, + t1_bids_path=cfg.t1_bids_path, fs_subject=cfg.fs_subject, fs_subjects_dir=cfg.fs_subjects_dir, - kind=landmarks_kind, + kind=cfg.landmarks_kind, ) return trans @@ -119,7 +102,18 @@ def get_input_fnames_forward(*, cfg, subject, session): check=False, ) in_files = dict() - in_files["info"] = bids_path.copy().update(**cfg.source_info_path_update) + # for consistency with 05_make_inverse, read the info from the + # data used for the noise_cov + if cfg.source_info_path_update is None: + if cfg.noise_cov in ("rest", "noise"): + source_info_path_update = dict( + processing="clean", suffix="raw", task=cfg.noise_cov + ) + else: + source_info_path_update = dict(suffix="ave") + else: + source_info_path_update = cfg.source_info_path_update + in_files["info"] = bids_path.copy().update(**source_info_path_update) bem_path = cfg.fs_subjects_dir / cfg.fs_subject / "bem" _, tag = _get_bem_conductivity(cfg) in_files["bem"] = bem_path / f"{cfg.fs_subject}-{tag}-bem-sol.fif" @@ -135,7 +129,7 @@ def run_forward( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, in_files: dict, ) -> dict: bids_path = BIDSPath( @@ -174,11 +168,15 @@ def run_forward( if cfg.use_template_mri is not None: trans = _prepare_trans_template( cfg=cfg, + subject=subject, + session=session, info=info, ) else: - trans = _prepare_trans( + trans = _prepare_trans_subject( cfg=cfg, + subject=subject, + session=session, exec_params=exec_params, bids_path=bids_path, ) @@ -200,17 +198,7 @@ def run_forward( ) as report: msg = "Adding forward information to report" logger.info(**gen_log_kwargs(message=msg)) - msg = "Rendering MRI slices with BEM contours." - logger.info(**gen_log_kwargs(message=msg)) - report.add_bem( - subject=cfg.fs_subject, - subjects_dir=cfg.fs_subjects_dir, - title="BEM", - width=256, - decim=8, - replace=True, - n_jobs=1, # prevent automatic parallelization - ) + _render_bem(report=report, cfg=cfg, subject=subject, session=session) msg = "Rendering sensor alignment (coregistration)" logger.info(**gen_log_kwargs(message=msg)) report.add_trans( @@ -233,14 +221,31 @@ def run_forward( ) assert len(in_files) == 0, in_files - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( *, config: SimpleNamespace, subject: str, + session: str | None, ) -> SimpleNamespace: + if config.mri_t1_path_generator is None: + t1_bids_path = None + else: + t1_bids_path = BIDSPath(subject=subject, session=session, root=config.bids_root) + t1_bids_path = config.mri_t1_path_generator(t1_bids_path.copy()) + if t1_bids_path.suffix is None: + t1_bids_path.update(suffix="T1w") + if t1_bids_path.datatype is None: + t1_bids_path.update(datatype="anat") + if config.mri_landmarks_kind is None: + landmarks_kind = None + else: + landmarks_kind = config.mri_landmarks_kind( + BIDSPath(subject=subject, session=session) + ) + cfg = SimpleNamespace( runs=get_runs(config=config, subject=subject), mindist=config.mindist, @@ -248,9 +253,12 @@ def get_config( use_template_mri=config.use_template_mri, adjust_coreg=config.adjust_coreg, source_info_path_update=config.source_info_path_update, + noise_cov=_sanitize_callable(config.noise_cov), ch_types=config.ch_types, fs_subject=get_fs_subject(config=config, subject=subject), fs_subjects_dir=get_fs_subjects_dir(config), + t1_bids_path=t1_bids_path, + landmarks_kind=landmarks_kind, **_bids_kwargs(config=config), ) return cfg @@ -267,7 +275,7 @@ def main(*, config: SimpleNamespace) -> None: parallel, run_func = parallel_func(run_forward, exec_params=config.exec_params) logs = parallel( run_func( - cfg=get_config(config=config, subject=subject), + cfg=get_config(config=config, subject=subject, session=session), exec_params=config.exec_params, subject=subject, session=session, diff --git a/mne_bids_pipeline/steps/source/_05_make_inverse.py b/mne_bids_pipeline/steps/source/_05_make_inverse.py index c60fe44e0..9cc01b74f 100644 --- a/mne_bids_pipeline/steps/source/_05_make_inverse.py +++ b/mne_bids_pipeline/steps/source/_05_make_inverse.py @@ -3,38 +3,36 @@ Compute and apply an inverse solution for each evoked data set. """ -import pathlib from types import SimpleNamespace -from typing import Optional import mne from mne.minimum_norm import ( - make_inverse_operator, apply_inverse, + make_inverse_operator, write_inverse_operator, ) from mne_bids import BIDSPath from ..._config_utils import ( + _bids_kwargs, + get_fs_subject, + get_fs_subjects_dir, get_noise_cov_bids_path, + get_sessions, get_subjects, sanitize_cond_name, - get_sessions, - get_fs_subjects_dir, - get_fs_subject, - _bids_kwargs, ) -from ..._logging import logger, gen_log_kwargs +from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func -from ..._report import _open_report, _sanitize_cond_tag -from ..._run import failsafe_run, save_logs, _sanitize_callable +from ..._report import _all_conditions, _open_report, _sanitize_cond_tag +from ..._run import _prep_out_files, _sanitize_callable, failsafe_run, save_logs def get_input_fnames_inverse( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ): bids_path = BIDSPath( subject=subject, @@ -50,7 +48,19 @@ def get_input_fnames_inverse( check=False, ) in_files = dict() - in_files["info"] = bids_path.copy().update(**cfg.source_info_path_update) + # make sure the info matches the data from which the noise cov + # is computed to avoid rank-mismatch + if cfg.source_info_path_update is None: + if cfg.noise_cov in ("rest", "noise"): + source_info_path_update = dict( + processing="clean", suffix="raw", task=cfg.noise_cov + ) + else: + source_info_path_update = dict(suffix="ave") + # XXX is this the right solution also for noise_cov = 'ad-hoc'? + else: + source_info_path_update = cfg.source_info_path_update + in_files["info"] = bids_path.copy().update(**source_info_path_update) in_files["forward"] = bids_path.copy().update(suffix="fwd") if cfg.noise_cov != "ad-hoc": in_files["cov"] = get_noise_cov_bids_path( @@ -69,7 +79,7 @@ def run_inverse( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, in_files: dict, ) -> dict: # TODO: Eventually we should maybe loop over ch_types, e.g., to create @@ -97,22 +107,18 @@ def run_inverse( # Apply inverse snr = 3.0 lambda2 = 1.0 / snr**2 - - if isinstance(cfg.conditions, dict): - conditions = list(cfg.conditions.keys()) - else: - conditions = cfg.conditions - + conditions = _all_conditions(cfg=cfg) method = cfg.inverse_method if "evoked" in in_files: fname_ave = in_files.pop("evoked") evokeds = mne.read_evokeds(fname_ave) for condition, evoked in zip(conditions, evokeds): - pick_ori = None - cond_str = sanitize_cond_name(condition) - key = f"{cond_str}+{method}+hemi" - out_files[key] = fname_ave.copy().update(suffix=key, extension=None) + suffix = f"{sanitize_cond_name(condition)}+{method}+hemi" + out_files[condition] = fname_ave.copy().update( + suffix=suffix, + extension=".h5", + ) if "eeg" in cfg.ch_types: evoked.set_eeg_reference("average", projection=True) @@ -122,10 +128,9 @@ def run_inverse( inverse_operator=inverse_operator, lambda2=lambda2, method=method, - pick_ori=pick_ori, + pick_ori=None, ) - stc.save(out_files[key], overwrite=True) - out_files[key] = pathlib.Path(str(out_files[key]) + "-lh.stc") + stc.save(out_files[condition], ftype="h5", overwrite=True) with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session @@ -133,16 +138,13 @@ def run_inverse( msg = "Adding inverse information to report" logger.info(**gen_log_kwargs(message=msg)) for condition in conditions: - cond_str = sanitize_cond_name(condition) - key = f"{cond_str}+{method}+hemi" - if key not in out_files: - continue msg = f"Rendering inverse solution for {condition}" logger.info(**gen_log_kwargs(message=msg)) - fname_stc = out_files[key] tags = ("source-estimate", _sanitize_cond_tag(condition)) + if condition not in cfg.conditions: + tags = tags + ("contrast",) report.add_stc( - stc=fname_stc, + stc=out_files[condition], title=f"Source: {condition}", subject=cfg.fs_subject, subjects_dir=cfg.fs_subjects_dir, @@ -152,7 +154,7 @@ def run_inverse( ) assert len(in_files) == 0, in_files - return out_files + return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( @@ -165,6 +167,7 @@ def get_config( inverse_targets=config.inverse_targets, ch_types=config.ch_types, conditions=config.conditions, + contrasts=config.contrasts, loose=config.loose, depth=config.depth, inverse_method=config.inverse_method, diff --git a/mne_bids_pipeline/steps/source/_99_group_average.py b/mne_bids_pipeline/steps/source/_99_group_average.py index 3212e0249..81de0a01b 100644 --- a/mne_bids_pipeline/steps/source/_99_group_average.py +++ b/mne_bids_pipeline/steps/source/_99_group_average.py @@ -4,30 +4,39 @@ """ from types import SimpleNamespace -from typing import Optional, List - -import numpy as np import mne +import numpy as np from mne_bids import BIDSPath from ..._config_utils import ( + _bids_kwargs, + get_fs_subject, get_fs_subjects_dir, + get_sessions, get_subjects, sanitize_cond_name, - get_fs_subject, - get_sessions, - get_all_contrasts, - _bids_kwargs, ) -from ..._logging import logger, gen_log_kwargs +from ..._logging import gen_log_kwargs, logger from ..._parallel import get_parallel_backend, parallel_func -from ..._report import run_report_average_source -from ..._run import failsafe_run, save_logs +from ..._report import _all_conditions, _open_report +from ..._run import _prep_out_files, failsafe_run, save_logs -def morph_stc(cfg, subject, fs_subject, session=None): - bids_path = BIDSPath( +def _stc_path( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + condition: str, + morphed: bool, +) -> BIDSPath: + cond_str = sanitize_cond_name(condition) + suffix = [cond_str, cfg.inverse_method, "hemi"] + if morphed: + suffix.insert(2, "morph2fsaverage") + suffix = "+".join(suffix) + return BIDSPath( subject=subject, session=session, task=cfg.task, @@ -37,35 +46,47 @@ def morph_stc(cfg, subject, fs_subject, session=None): space=cfg.space, datatype=cfg.datatype, root=cfg.deriv_root, + suffix=suffix, + extension=".h5", check=False, ) - morphed_stcs = [] - - if cfg.task_is_rest: - conditions = [cfg.task.lower()] - else: - if isinstance(cfg.conditions, dict): - conditions = list(cfg.conditions.keys()) - else: - conditions = cfg.conditions - for condition in conditions: - method = cfg.inverse_method - cond_str = sanitize_cond_name(condition) - inverse_str = method - hemi_str = "hemi" # MNE will auto-append '-lh' and '-rh'. - morph_str = "morph2fsaverage" - - fname_stc = bids_path.copy().update( - suffix=f"{cond_str}+{inverse_str}+{hemi_str}" - ) - fname_stc_fsaverage = bids_path.copy().update( - suffix=f"{cond_str}+{inverse_str}+{morph_str}+{hemi_str}" +def get_input_fnames_morph_stc( + *, + cfg: SimpleNamespace, + subject: str, + fs_subject: str, + session: str | None, +) -> dict: + in_files = dict() + for condition in _all_conditions(cfg=cfg): + in_files[f"original-{condition}"] = _stc_path( + cfg=cfg, + subject=subject, + session=session, + condition=condition, + morphed=False, ) + return in_files - stc = mne.read_source_estimate(fname_stc) +@failsafe_run( + get_input_fnames=get_input_fnames_morph_stc, +) +def morph_stc( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + fs_subject: str, + session: str | None, + in_files: dict, +) -> dict: + out_files = dict() + for condition in _all_conditions(cfg=cfg): + fname_stc = in_files.pop(f"original-{condition}") + stc = mne.read_source_estimate(fname_stc) morph = mne.compute_source_morph( stc, subject_from=fs_subject, @@ -73,51 +94,98 @@ def morph_stc(cfg, subject, fs_subject, session=None): subjects_dir=cfg.fs_subjects_dir, ) stc_fsaverage = morph.apply(stc) - stc_fsaverage.save(fname_stc_fsaverage, overwrite=True) - morphed_stcs.append(stc_fsaverage) + key = f"morphed-{condition}" + out_files[key] = _stc_path( + cfg=cfg, + subject=subject, + session=session, + condition=condition, + morphed=True, + ) + stc_fsaverage.save(out_files[key], ftype="h5", overwrite=True) - del fname_stc, fname_stc_fsaverage + assert len(in_files) == 0, in_files + return _prep_out_files(out_files=out_files, exec_params=exec_params) - return morphed_stcs +def get_input_fnames_run_average( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> dict: + in_files = dict() + assert subject == "average" + for condition in _all_conditions(cfg=cfg): + for this_subject in cfg.subjects: + in_files[f"{this_subject}-{condition}"] = _stc_path( + cfg=cfg, + subject=this_subject, + session=session, + condition=condition, + morphed=True, + ) + return in_files + +@failsafe_run( + get_input_fnames=get_input_fnames_run_average, +) def run_average( *, cfg: SimpleNamespace, + exec_params: SimpleNamespace, subject: str, - session: Optional[str], - mean_morphed_stcs: List[mne.SourceEstimate], + session: str | None, + in_files: dict, ): - bids_path = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - run=None, - processing=cfg.proc, - recording=cfg.rec, - space=cfg.space, - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - - if isinstance(cfg.conditions, dict): - conditions = list(cfg.conditions.keys()) - else: - conditions = cfg.conditions - - for condition, stc in zip(conditions, mean_morphed_stcs): - method = cfg.inverse_method - cond_str = sanitize_cond_name(condition) - inverse_str = method - hemi_str = "hemi" # MNE will auto-append '-lh' and '-rh'. - morph_str = "morph2fsaverage" - - fname_stc_avg = bids_path.copy().update( - suffix=f"{cond_str}+{inverse_str}+{morph_str}+{hemi_str}" + assert subject == "average" + out_files = dict() + conditions = _all_conditions(cfg=cfg) + for condition in conditions: + stc = np.array( + [ + mne.read_source_estimate(in_files.pop(f"{this_subject}-{condition}")) + for this_subject in cfg.subjects + ] + ).mean(axis=0) + out_files[condition] = _stc_path( + cfg=cfg, + subject=subject, + session=session, + condition=condition, + morphed=True, ) - stc.save(fname_stc_avg, overwrite=True) + stc.save(out_files[condition], ftype="h5", overwrite=True) + + ####################################################################### + # + # Visualize forward solution, inverse operator, and inverse solutions. + # + with _open_report( + cfg=cfg, exec_params=exec_params, subject=subject, session=session + ) as report: + for condition in conditions: + msg = f"Rendering inverse solution for {condition}" + logger.info(**gen_log_kwargs(message=msg)) + cond_str = sanitize_cond_name(condition) + tags = ("source-estimate", cond_str) + if condition in cfg.conditions: + title = f"Average (source): {condition}" + else: # It's a contrast of two conditions. + title = f"Average (source) contrast: {condition}" + tags = tags + ("contrast",) + report.add_stc( + stc=out_files[condition], + title=title, + subject="fsaverage", + subjects_dir=cfg.fs_subjects_dir, + n_time_points=cfg.report_stc_n_time_points, + tags=tags, + replace=True, + ) + assert len(in_files) == 0, in_files + return _prep_out_files(out_files=out_files, exec_params=exec_params) def get_config( @@ -131,11 +199,11 @@ def get_config( fs_subjects_dir=get_fs_subjects_dir(config), subjects_dir=get_fs_subjects_dir(config), ch_types=config.ch_types, - subjects=config.subjects, + subjects=get_subjects(config=config), exclude_subjects=config.exclude_subjects, sessions=get_sessions(config), use_template_mri=config.use_template_mri, - all_contrasts=get_all_contrasts(config), + contrasts=config.contrasts, report_stc_n_time_points=config.report_stc_n_time_points, # TODO: needed because get_datatype gets called again... data_type=config.data_type, @@ -144,64 +212,39 @@ def get_config( return cfg -# pass 'average' subject for logging -@failsafe_run() -def run_group_average_source( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, -) -> None: - """Run group average in source space""" +def main(*, config: SimpleNamespace) -> None: + if not config.run_source_estimation: + msg = "Skipping, run_source_estimation is set to False …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return - mne.datasets.fetch_fsaverage(subjects_dir=get_fs_subjects_dir(cfg)) + mne.datasets.fetch_fsaverage(subjects_dir=get_fs_subjects_dir(config)) + cfg = get_config(config=config) + exec_params = config.exec_params + subjects = get_subjects(config) + sessions = get_sessions(config) + logs = list() with get_parallel_backend(exec_params): parallel, run_func = parallel_func(morph_stc, exec_params=exec_params) - all_morphed_stcs = parallel( + logs += parallel( run_func( cfg=cfg, + exec_params=exec_params, subject=subject, fs_subject=get_fs_subject(config=cfg, subject=subject), session=session, ) - for subject in get_subjects(cfg) - for session in get_sessions(cfg) + for subject in subjects + for session in sessions ) - mean_morphed_stcs = np.array(all_morphed_stcs).mean(axis=0) - - # XXX to fix - sessions = get_sessions(cfg) - if sessions: - session = sessions[0] - else: - session = None - + logs += [ run_average( - cfg=cfg, - session=session, - subject=subject, - mean_morphed_stcs=mean_morphed_stcs, - ) - run_report_average_source( cfg=cfg, exec_params=exec_params, - subject=subject, session=session, + subject="average", ) - - -def main(*, config: SimpleNamespace) -> None: - if not config.run_source_estimation: - msg = "Skipping, run_source_estimation is set to False …" - logger.info(**gen_log_kwargs(message=msg, emoji="skip")) - return - - log = run_group_average_source( - cfg=get_config( - config=config, - ), - exec_params=config.exec_params, - subject="average", - ) - save_logs(config=config, logs=[log]) + for session in sessions + ] + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/source/__init__.py b/mne_bids_pipeline/steps/source/__init__.py index c748f7f8b..89b757670 100644 --- a/mne_bids_pipeline/steps/source/__init__.py +++ b/mne_bids_pipeline/steps/source/__init__.py @@ -1,11 +1,13 @@ """Source-space analysis.""" -from . import _01_make_bem_surfaces -from . import _02_make_bem_solution -from . import _03_setup_source_space -from . import _04_make_forward -from . import _05_make_inverse -from . import _99_group_average +from . import ( + _01_make_bem_surfaces, + _02_make_bem_solution, + _03_setup_source_space, + _04_make_forward, + _05_make_inverse, + _99_group_average, +) _STEPS = ( _01_make_bem_surfaces, diff --git a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py index 8d0f9fdfe..650aa7395 100644 --- a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py +++ b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py @@ -1,7 +1,6 @@ -""" -ERP CORE +"""ERP CORE. -This example demonstrate how to process 5 participants from the +This example demonstrates how to process 5 participants from the [ERP CORE](https://erpinfo.org/erp-core) dataset. It shows how to obtain 7 ERP components from a total of 6 experimental tasks: @@ -23,11 +22,12 @@ event-related potential research. *NeuroImage* 225: 117465. [https://doi.org/10.1016/j.neuroimage.2020.117465](https://doi.org/10.1016/j.neuroimage.2020.117465) """ + import argparse -import mne import sys -study_name = "ERP-CORE" +import mne + bids_root = "~/mne_data/ERP_CORE" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ERP_CORE" @@ -47,6 +47,8 @@ interactive = False raw_resample_sfreq = 128 +# Suppress "Data file name in EEG.data (sub-019_task-ERN_eeg.fdt) is incorrect..." +read_raw_bids_verbose = "error" eeg_template_montage = mne.channels.make_standard_montage("standard_1005") eeg_bipolar_channels = { @@ -69,20 +71,34 @@ t_break_annot_start_after_previous_event = 3.0 t_break_annot_stop_before_next_event = 1.5 -ica_reject = dict(eeg=350e-6, eog=500e-6) -reject = "autoreject_global" +if task == "N400": # test autoreject local without ICA + spatial_filter = None + reject = "autoreject_local" + autoreject_n_interpolate = [2, 4] +elif task == "N170": # test autoreject local before ICA, and MNE-ICALabel + spatial_filter = "ica" + ica_algorithm = "picard-extended_infomax" + ica_use_icalabel = True + ica_l_freq = 1 + h_freq = 100 + ica_reject = "autoreject_local" + reject = "autoreject_global" + autoreject_n_interpolate = [12] # only for testing! +else: + spatial_filter = "ica" + ica_reject = dict(eeg=350e-6, eog=500e-6) + reject = "autoreject_global" -spatial_filter = "ica" +# These settings are only used for the cases where spatial_filter="ica" ica_max_iterations = 1000 ica_eog_threshold = 2 ica_decim = 2 # speed up ICA fitting run_source_estimation = False - on_rename_missing_events = "ignore" parallel_backend = "dask" -dask_worker_memory_limit = "2G" +dask_worker_memory_limit = "2.5G" n_jobs = 4 if task == "N400": @@ -102,7 +118,6 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) epochs_tmin = -0.2 epochs_tmax = 0.8 epochs_metadata_tmin = 0 @@ -136,7 +151,6 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) epochs_tmin = -0.6 epochs_tmax = 0.4 baseline = (-0.4, -0.2) @@ -169,7 +183,6 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) epochs_tmin = -0.8 epochs_tmax = 0.2 baseline = (None, -0.6) @@ -182,7 +195,6 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) epochs_tmin = -0.2 epochs_tmax = 0.8 baseline = (None, 0) @@ -203,7 +215,41 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) + # Analyze all EEG channels -- we only specify the channels here for the purpose of + # demonstration + analyze_channels = [ + "FP1", + "F3", + "F7", + "FC3", + "C3", + "C5", + "P3", + "P7", + "P9", + "PO7", + "PO3", + "O1", + "Oz", + "Pz", + "CPz", + "FP2", + "Fz", + "F4", + "F8", + "FC4", + "FCz", + "Cz", + "C4", + "C6", + "P4", + "P8", + "P10", + "PO8", + "PO4", + "O2", + ] + epochs_tmin = -0.2 epochs_tmax = 0.8 baseline = (None, 0) @@ -216,6 +262,41 @@ } eeg_reference = "average" + # Analyze all EEG channels -- we only specify the channels here for the purpose of + # demonstration + analyze_channels = [ + "FP1", + "F3", + "F7", + "FC3", + "C3", + "C5", + "P3", + "P7", + "P9", + "PO7", + "PO3", + "O1", + "Oz", + "Pz", + "CPz", + "FP2", + "Fz", + "F4", + "F8", + "FC4", + "FCz", + "Cz", + "C4", + "C6", + "P4", + "P8", + "P10", + "PO8", + "PO4", + "O2", + ] + ica_n_components = 30 - 1 for i in range(1, 180 + 1): orig_name = f"stimulus/{i}" @@ -270,7 +351,6 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) epochs_tmin = -0.2 epochs_tmax = 0.8 baseline = (None, 0) diff --git a/mne_bids_pipeline/tests/configs/config_MNE_phantom_KIT_data.py b/mne_bids_pipeline/tests/configs/config_MNE_phantom_KIT_data.py new file mode 100644 index 000000000..49689bd3e --- /dev/null +++ b/mne_bids_pipeline/tests/configs/config_MNE_phantom_KIT_data.py @@ -0,0 +1,27 @@ +""" +KIT phantom data. + +https://mne.tools/dev/documentation/datasets.html#kit-phantom-dataset +""" + +bids_root = "~/mne_data/MNE-phantom-KIT-data" +deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/MNE-phantom-KIT-data" +task = "phantom" +ch_types = ["meg"] + +# Preprocessing +l_freq = None +h_freq = 40.0 +regress_artifact = dict( + picks="meg", picks_artifact=["MISC 001", "MISC 002", "MISC 003"] +) + +# Epochs +epochs_tmin = -0.08 +epochs_tmax = 0.18 +epochs_decim = 10 # 2000->200 Hz +baseline = (None, 0) +conditions = ["dip01", "dip13", "dip25", "dip37", "dip49"] + +# Decoding +decode = True # should be very good performance diff --git a/mne_bids_pipeline/tests/configs/config_ds000117.py b/mne_bids_pipeline/tests/configs/config_ds000117.py index b46db99bd..2e49f1a4e 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000117.py +++ b/mne_bids_pipeline/tests/configs/config_ds000117.py @@ -1,8 +1,5 @@ -""" -Faces dataset -""" +"""Faces dataset.""" -study_name = "ds000117" bids_root = "~/mne_data/ds000117" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000117" @@ -18,6 +15,7 @@ find_flat_channels_meg = True find_noisy_channels_meg = True use_maxwell_filter = True +process_empty_room = True mf_reference_run = "02" mf_cal_fname = bids_root + "/derivatives/meg_derivatives/sss_cal.dat" diff --git a/mne_bids_pipeline/tests/configs/config_ds000246.py b/mne_bids_pipeline/tests/configs/config_ds000246.py index a32267b00..1aa58f244 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000246.py +++ b/mne_bids_pipeline/tests/configs/config_ds000246.py @@ -1,16 +1,15 @@ -""" -Brainstorm - Auditory Dataset. +"""Brainstorm - Auditory Dataset. See https://openneuro.org/datasets/ds000246/versions/1.0.0 for more information. """ -study_name = "ds000246" bids_root = "~/mne_data/ds000246" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000246" runs = ["01"] crop_runs = (0, 120) # Reduce memory usage on CI system +read_raw_bids_verbose = "error" # No BIDS -> MNE mapping found for channel ... l_freq = 0.3 h_freq = 100 epochs_decim = 4 @@ -18,6 +17,7 @@ ch_types = ["meg"] reject = dict(mag=4e-12, eog=250e-6) conditions = ["standard", "deviant", "button"] +epochs_metadata_tmin = ["standard", "deviant"] # for testing only contrasts = [("deviant", "standard")] decode = True decoding_time_generalization = True diff --git a/mne_bids_pipeline/tests/configs/config_ds000247.py b/mne_bids_pipeline/tests/configs/config_ds000247.py index 8d2b0451f..fc4f42464 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000247.py +++ b/mne_bids_pipeline/tests/configs/config_ds000247.py @@ -1,12 +1,9 @@ -""" -OMEGA Resting State Sample Data -""" -import numpy as np +"""OMEGA Resting State Sample Data.""" +import numpy as np -study_name = "ds000247" -bids_root = f"~/mne_data/{study_name}" -deriv_root = f"~/mne_data/derivatives/mne-bids-pipeline/{study_name}" +bids_root = "~/mne_data/ds000247" +deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000247" subjects = ["0002"] sessions = ["01"] diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_FLASH_BEM.py b/mne_bids_pipeline/tests/configs/config_ds000248_FLASH_BEM.py index f09fdc6d5..5d37fde67 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_FLASH_BEM.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_FLASH_BEM.py @@ -1,7 +1,5 @@ -""" -MNE Sample Data: BEM from FLASH images -""" -study_name = "ds000248" +"""MNE Sample Data: BEM from FLASH images.""" + bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_FLASH_BEM" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_T1_BEM.py b/mne_bids_pipeline/tests/configs/config_ds000248_T1_BEM.py index df315e035..0fdfdbf76 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_T1_BEM.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_T1_BEM.py @@ -1,8 +1,5 @@ -""" -MNE Sample Data: BEM from T1 images -""" +"""MNE Sample Data: BEM from T1 images.""" -study_name = "ds000248" bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_T1_BEM" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_base.py b/mne_bids_pipeline/tests/configs/config_ds000248_base.py index 795ed618b..2f4db6a10 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_base.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_base.py @@ -1,9 +1,7 @@ -""" -MNE Sample Data: M/EEG combined processing -""" +"""MNE Sample Data: M/EEG combined processing.""" + import mne -study_name = "ds000248" bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_base" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" @@ -21,13 +19,14 @@ find_flat_channels_meg = True find_noisy_channels_meg = True use_maxwell_filter = True -_raw_split_size = "60MB" # hits both task-noise and task-audiovisual -_epochs_split_size = "30MB" def noise_cov(bp): + """Estimate the noise covariance.""" # Use pre-stimulus period as noise source - bp = bp.copy().update(processing="clean", suffix="epo") + bp = bp.copy().update(suffix="epo") + if not bp.fpath.exists(): + bp.update(split="01") epo = mne.read_epochs(bp) cov = mne.compute_covariance(epo, rank="info", tmax=0) return cov @@ -48,5 +47,6 @@ def noise_cov(bp): def mri_t1_path_generator(bids_path): + """Return the path to a T1 image.""" # don't really do any modifications – just for testing! return bids_path diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_coreg_surfaces.py b/mne_bids_pipeline/tests/configs/config_ds000248_coreg_surfaces.py index 9262fdcb8..dba51f97d 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_coreg_surfaces.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_coreg_surfaces.py @@ -1,8 +1,5 @@ -""" -MNE Sample Data: Head surfaces from FreeSurfer surfaces for coregistration step -""" +"""MNE Sample Data: Head surfaces from FreeSurfer surfaces for coregistration step.""" -study_name = "ds000248" bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_coreg_surfaces" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_ica.py b/mne_bids_pipeline/tests/configs/config_ds000248_ica.py index 176a2f592..6dfb49c7b 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_ica.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_ica.py @@ -1,7 +1,5 @@ -""" -MNE Sample Data: ICA -""" -study_name = 'MNE "sample" dataset' +"""MNE Sample Data: ICA.""" + bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_ica" diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_no_mri.py b/mne_bids_pipeline/tests/configs/config_ds000248_no_mri.py index 9941d2842..08b98e9bb 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_no_mri.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_no_mri.py @@ -1,8 +1,5 @@ -""" -MNE Sample Data: Using the `fsaverage` template MRI -""" +"""MNE Sample Data: Using the `fsaverage` template MRI.""" -study_name = "ds000248" bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_no_mri" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" diff --git a/mne_bids_pipeline/tests/configs/config_ds001810.py b/mne_bids_pipeline/tests/configs/config_ds001810.py index 064e8ddb7..3771d6cd3 100644 --- a/mne_bids_pipeline/tests/configs/config_ds001810.py +++ b/mne_bids_pipeline/tests/configs/config_ds001810.py @@ -1,8 +1,5 @@ -""" -tDCS EEG -""" +"""tDCS EEG.""" -study_name = "ds001810" bids_root = "~/mne_data/ds001810" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds001810" @@ -15,6 +12,7 @@ conditions = ["61450", "61511"] contrasts = [("61450", "61511")] decode = True +decoding_n_splits = 3 # only for testing, use 5 otherwise l_freq = 0.3 diff --git a/mne_bids_pipeline/tests/configs/config_ds001971.py b/mne_bids_pipeline/tests/configs/config_ds001971.py index 2f3307c85..349dbe23e 100644 --- a/mne_bids_pipeline/tests/configs/config_ds001971.py +++ b/mne_bids_pipeline/tests/configs/config_ds001971.py @@ -3,8 +3,6 @@ See ds001971 on OpenNeuro: https://github.com/OpenNeuroDatasets/ds001971 """ - -study_name = "ds001971" bids_root = "~/mne_data/ds001971" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds001971" @@ -13,6 +11,21 @@ ch_types = ["eeg"] reject = {"eeg": 150e-6} conditions = ["AdvanceTempo", "DelayTempo"] +contrasts = [("AdvanceTempo", "DelayTempo")] subjects = ["001"] runs = ["01"] +epochs_decim = 5 # to 100 Hz + +# This is mostly for testing purposes! +decode = True +decoding_time_generalization = True +decoding_time_generalization_decim = 2 +decoding_csp = True +decoding_csp_freqs = { + "beta": [13, 20, 30], +} +decoding_csp_times = [-0.2, 0.0, 0.2, 0.4] + +# Just to test that MD5 works +memory_file_method = "hash" diff --git a/mne_bids_pipeline/tests/configs/config_ds003104.py b/mne_bids_pipeline/tests/configs/config_ds003104.py index c88d07161..d47a0a64c 100644 --- a/mne_bids_pipeline/tests/configs/config_ds003104.py +++ b/mne_bids_pipeline/tests/configs/config_ds003104.py @@ -1,6 +1,5 @@ -"""Somato -""" -study_name = "MNE-somato-data-anonymized" +"""Somato.""" + bids_root = "~/mne_data/ds003104" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds003104" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" diff --git a/mne_bids_pipeline/tests/configs/config_ds003392.py b/mne_bids_pipeline/tests/configs/config_ds003392.py index 02134768a..37c8a46c3 100644 --- a/mne_bids_pipeline/tests/configs/config_ds003392.py +++ b/mne_bids_pipeline/tests/configs/config_ds003392.py @@ -1,15 +1,14 @@ -""" -hMT+ Localizer -""" -study_name = "localizer" +"""hMT+ Localizer.""" + bids_root = "~/mne_data/ds003392" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds003392" subjects = ["01"] task = "localizer" -find_flat_channels_meg = True -find_noisy_channels_meg = True +# usually a good idea to use True, but we know no bads are detected for this dataset +find_flat_channels_meg = False +find_noisy_channels_meg = False use_maxwell_filter = True ch_types = ["meg"] @@ -20,10 +19,10 @@ # Artifact correction. spatial_filter = "ica" -ica_max_iterations = 500 +ica_algorithm = "picard-extended_infomax" +ica_max_iterations = 1000 ica_l_freq = 1.0 ica_n_components = 0.99 -ica_reject_components = "auto" # Epochs epochs_tmin = -0.2 @@ -38,6 +37,11 @@ decoding_time_generalization = True decoding_time_generalization_decim = 4 contrasts = [("incoherent", "coherent")] +decoding_csp = True +decoding_csp_times = [] +decoding_csp_freqs = { + "alpha": (8, 12), +} # Noise estimation noise_cov = "emptyroom" diff --git a/mne_bids_pipeline/tests/configs/config_ds003775.py b/mne_bids_pipeline/tests/configs/config_ds003775.py index 4dae88993..219b1e23a 100644 --- a/mne_bids_pipeline/tests/configs/config_ds003775.py +++ b/mne_bids_pipeline/tests/configs/config_ds003775.py @@ -1,8 +1,5 @@ -""" -SRM Resting-state EEG -""" +"""SRM Resting-state EEG.""" -study_name = "ds003775" bids_root = "~/mne_data/ds003775" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds003775" diff --git a/mne_bids_pipeline/tests/configs/config_ds004107.py b/mne_bids_pipeline/tests/configs/config_ds004107.py index 7a32d952c..0dd70a5ef 100644 --- a/mne_bids_pipeline/tests/configs/config_ds004107.py +++ b/mne_bids_pipeline/tests/configs/config_ds004107.py @@ -1,5 +1,4 @@ -""" -MIND DATA +"""MIND DATA. M.P. Weisend, F.M. Hanlon, R. Montaño, S.P. Ahlfors, A.C. Leuthold, D. Pantazis, J.C. Mosher, A.P. Georgopoulos, M.S. Hämäläinen, C.J. @@ -7,11 +6,11 @@ Paving the way for cross-site pooling of magnetoencephalography (MEG) data. International Congress Series, Volume 1300, Pages 615-618. """ + # This has auditory, median, indx, visual, rest, and emptyroom but let's just # process the auditory (it's the smallest after rest) -study_name = "ds004107" -bids_root = f"~/mne_data/{study_name}" -deriv_root = f"~/mne_data/derivatives/mne-bids-pipeline/{study_name}" +bids_root = "~/mne_data/ds004107" +deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds004107" subjects = ["mind002"] sessions = ["01"] conditions = ["left", "right"] # there are also tone and noise diff --git a/mne_bids_pipeline/tests/configs/config_ds004229.py b/mne_bids_pipeline/tests/configs/config_ds004229.py index b7fd70d05..878b743b0 100644 --- a/mne_bids_pipeline/tests/configs/config_ds004229.py +++ b/mne_bids_pipeline/tests/configs/config_ds004229.py @@ -1,12 +1,11 @@ -""" -Single-subject infant dataset for testing maxwell_filter with movecomp. +"""Single-subject infant dataset for testing maxwell_filter with movecomp. https://openneuro.org/datasets/ds004229 """ + import mne import numpy as np -study_name = "amnoise" bids_root = "~/mne_data/ds004229" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds004229" @@ -16,8 +15,6 @@ find_flat_channels_meg = True find_noisy_channels_meg = True use_maxwell_filter = True -mf_cal_fname = bids_root + "/derivatives/meg_derivatives/sss_cal.dat" -mf_ctc_fname = bids_root + "/derivatives/meg_derivatives/ct_sparse.fif" mf_destination = mne.transforms.translation( # rotate backward and move up z=0.055, ) @ mne.transforms.rotation(x=np.deg2rad(-15)) @@ -27,6 +24,10 @@ mf_mc_t_step_min = 0.5 # just for speed! mf_mc_t_window = 0.2 # cleaner cHPI filtering on this dataset mf_filter_chpi = False # for speed, not needed as we low-pass anyway +mf_mc_rotation_velocity_limit = 30.0 # deg/s for annotations +mf_mc_translation_velocity_limit = 20e-3 # m/s +mf_esss = 8 +mf_esss_reject = {"grad": 10000e-13, "mag": 40000e-15} ch_types = ["meg"] l_freq = None @@ -44,6 +45,10 @@ epochs_tmax = 1 epochs_decim = 6 # 1200->200 Hz baseline = (None, 0) +report_add_epochs_image_kwargs = { + "grad": {"vmin": 0, "vmax": 1e13 * reject["grad"]}, # fT/cm + "mag": {"vmin": 0, "vmax": 1e15 * reject["mag"]}, # fT +} # Conditions / events to consider when epoching conditions = ["auditory"] diff --git a/mne_bids_pipeline/tests/configs/config_eeg_matchingpennies.py b/mne_bids_pipeline/tests/configs/config_eeg_matchingpennies.py index 643e51799..5cb0b1390 100644 --- a/mne_bids_pipeline/tests/configs/config_eeg_matchingpennies.py +++ b/mne_bids_pipeline/tests/configs/config_eeg_matchingpennies.py @@ -1,8 +1,5 @@ -""" -Matchingpennies EEG experiment -""" +"""Matchingpennies EEG experiment.""" -study_name = "eeg_matchingpennies" bids_root = "~/mne_data/eeg_matchingpennies" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/eeg_matchingpennies" diff --git a/mne_bids_pipeline/tests/conftest.py b/mne_bids_pipeline/tests/conftest.py index c06e8694a..9f45230eb 100644 --- a/mne_bids_pipeline/tests/conftest.py +++ b/mne_bids_pipeline/tests/conftest.py @@ -2,6 +2,7 @@ def pytest_addoption(parser): + """Add pytest command line options.""" parser.addoption( "--download", action="store_true", @@ -10,6 +11,7 @@ def pytest_addoption(parser): def pytest_configure(config): + """Add pytest configuration settings.""" # register an additional marker config.addinivalue_line("markers", "dataset_test: mark that a test runs a dataset") warning_lines = r""" @@ -34,6 +36,43 @@ def pytest_configure(config): # seaborn calling tight layout, etc. ignore:The figure layout has changed to tight:UserWarning ignore:The \S+_cmap function was deprecated.*:DeprecationWarning + # seaborn->pandas + ignore:is_categorical_dtype is deprecated.*:FutureWarning + ignore:use_inf_as_na option is deprecated.*:FutureWarning + # Dask distributed with jsonschema 4.18 + ignore:jsonschema\.RefResolver is deprecated.*:DeprecationWarning + # seaborn->pandas + ignore:is_categorical_dtype is deprecated.*:FutureWarning + ignore:use_inf_as_na option is deprecated.*:FutureWarning + ignore:All-NaN axis encountered.*:RuntimeWarning + # sklearn class not enough samples for cv=5 + always:The least populated class in y has only.*:UserWarning + # constrained layout fails on ds003392 + # mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py:551: in run_ica + # report.add_ica( + #../python_env/lib/python3.10/site-packages/mne/report/report.py:1974: in add_ica + # self._add_ica( + #../python_env/lib/python3.10/site-packages/mne/report/report.py:1872: in _add_ica + # self._add_ica_artifact_sources( + #../python_env/lib/python3.10/site-packages/mne/report/report.py:1713: + # in _add_ica_artifact_sources + # self._add_figure( + always:constrained_layout not applied.*:UserWarning + ignore:datetime\.datetime\.utcfromtimestamp.*:DeprecationWarning + ignore:datetime\.datetime\.utcnow.*:DeprecationWarning + # pandas with no good workaround + ignore:The behavior of DataFrame concatenation with empty.*:FutureWarning + # joblib on Windows sometimes + ignore:Persisting input arguments took.*:UserWarning + # matplotlib needs to update + ignore:Conversion of an array with ndim.*:DeprecationWarning + # scipy + ignore:nperseg .* is greater.*:UserWarning + # NumPy 2.0 + ignore:__array_wrap__ must accept context.*:DeprecationWarning + ignore:__array__ implementation doesn't accept.*:DeprecationWarning + # Seaborn + ignore:.*bool was deprecated in Matplotlib.*:DeprecationWarning """ for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() diff --git a/mne_bids_pipeline/tests/datasets.py b/mne_bids_pipeline/tests/datasets.py index f93f3206d..92e365cda 100644 --- a/mne_bids_pipeline/tests/datasets.py +++ b/mne_bids_pipeline/tests/datasets.py @@ -1,48 +1,34 @@ """Definition of the testing datasets.""" -from typing import Dict, List, TypedDict +from typing import TypedDict -class DATASET_OPTIONS_T(TypedDict): - git: str - openneuro: str - osf: str - web: str - include: List[str] - exclude: List[str] +# If not supplied below, the effective defaults are listed in comments +class DATASET_OPTIONS_T(TypedDict, total=False): + """A container for sources, hash, include and excludes of a dataset.""" + git: str # "" + openneuro: str # "" + osf: str # "" + web: str # "" + mne: str # "" + include: list[str] # [] + exclude: list[str] # [] + hash: str # "" -DATASET_OPTIONS: Dict[str, DATASET_OPTIONS_T] = { + +DATASET_OPTIONS: dict[str, DATASET_OPTIONS_T] = { "ERP_CORE": { - "git": "", - "openneuro": "", - "osf": "", # original dataset: '9f5w7' + # original dataset: "osf": "9f5w7" "web": "https://osf.io/3zk6n/download?version=2", - "include": [], - "exclude": [], + "hash": "sha256:ddc94a7c9ba1922637f2770592dd51c019d341bf6bc8558e663e1979a4cb002f", # noqa: E501 }, "eeg_matchingpennies": { - # This dataset started out on osf.io as dataset https://osf.io/cj2dr - # then moved to g-node.org. As of 2023/02/28 when we download it via - # datalad it's too (~200 kB/sec!) and times out at the end: - # - # "git": "https://gin.g-node.org/sappelhoff/eeg_matchingpennies", - # "web": "", - # "include": ["sub-05"], - # - # So now we mirror this datalad-fetched git repo back on osf.io! - "git": "", - "openneuro": "", - "osf": "", # original dataset: 'cj2dr' "web": "https://osf.io/download/8rbfk?version=1", - "include": [], - "exclude": [], + "hash": "sha256:06bfbe52c50b9343b6b8d2a5de3dd33e66ad9303f7f6bfbe6868c3c7c375fafd", # noqa: E501 }, "ds003104": { # Anonymized "somato" dataset. - "git": "", "openneuro": "ds003104", - "osf": "", - "web": "", "include": ["sub-01", "derivatives/freesurfer/subjects"], "exclude": [ "derivatives/freesurfer/subjects/01/mri/aparc+aseg.mgz", @@ -51,30 +37,19 @@ class DATASET_OPTIONS_T(TypedDict): ], }, "ds000246": { - "git": "", "openneuro": "ds000246", - "osf": "", - "web": "", "include": [ "sub-0001/meg/sub-0001_task-AEF_run-01_meg.ds", "sub-0001/meg/sub-0001_task-AEF_run-01_meg.json", "sub-0001/meg/sub-0001_task-AEF_run-01_channels.tsv", ], - "exclude": [], }, "ds000247": { - "git": "", "openneuro": "ds000247", - "osf": "", - "web": "", "include": ["sub-0002/ses-01/meg"], - "exclude": [], }, "ds000248": { - "git": "", "openneuro": "ds000248", - "osf": "", - "web": "", "include": ["sub-01", "sub-emptyroom", "derivatives/freesurfer/subjects"], "exclude": [ "derivatives/freesurfer/subjects/fsaverage/mri/aparc.a2005s+aseg.mgz", # noqa: E501 @@ -88,10 +63,7 @@ class DATASET_OPTIONS_T(TypedDict): ], }, "ds000117": { - "git": "", "openneuro": "ds000117", - "osf": "", - "web": "", "include": [ "sub-01/ses-meg/meg/sub-01_ses-meg_task-facerecognition_run-01_*", # noqa: E501 "sub-01/ses-meg/meg/sub-01_ses-meg_task-facerecognition_run-02_*", # noqa: E501 @@ -102,29 +74,19 @@ class DATASET_OPTIONS_T(TypedDict): "derivatives/meg_derivatives/ct_sparse.fif", "derivatives/meg_derivatives/sss_cal.dat", ], - "exclude": [], }, "ds003775": { - "git": "", "openneuro": "ds003775", - "osf": "", - "web": "", "include": ["sub-010"], - "exclude": [], + # See https://github.com/OpenNeuroOrg/openneuro/issues/2976 + "exclude": ["sub-010/ses-t1/sub-010_ses-t1_scans.tsv"], }, "ds001810": { - "git": "", "openneuro": "ds001810", - "osf": "", - "web": "", "include": ["sub-01"], - "exclude": [], }, "ds001971": { - "git": "", "openneuro": "ds001971", - "osf": "", - "web": "", "include": [ "sub-001/eeg/sub-001_task-AudioCueWalkingStudy_run-01_events.tsv", "sub-001/eeg/sub-001_task-AudioCueWalkingStudy_run-01_eeg.set", @@ -134,38 +96,26 @@ class DATASET_OPTIONS_T(TypedDict): "sub-001/eeg/sub-001_task-AudioCueWalkingStudy_run-01_coordsystem.json", # noqa: E501 "sub-001/eeg/sub-001_task-AudioCueWalkingStudy_run-01_channels.tsv", # noqa: E501 ], - "exclude": [], }, "ds003392": { - "git": "", "openneuro": "ds003392", - "osf": "", - "web": "", "include": ["sub-01", "sub-emptyroom/ses-19111211"], - "exclude": [], }, "ds004107": { - "git": "", "openneuro": "ds004107", - "osf": "", - "web": "", "include": [ "sub-mind002/ses-01/meg/*coordsystem*", "sub-mind002/ses-01/meg/*auditory*", ], - "exclude": [], }, "ds004229": { - "git": "", "openneuro": "ds004229", - "osf": "", - "web": "", "include": [ "sub-102", "sub-emptyroom/ses-20000101", - "derivatives/meg_derivatives/ct_sparse.fif", - "derivatives/meg_derivatives/sss_cal.dat", ], - "exclude": [], + }, + "MNE-phantom-KIT-data": { + "mne": "phantom_kit", }, } diff --git a/mne_bids_pipeline/tests/sub-010_ses-t1_scans.tsv b/mne_bids_pipeline/tests/sub-010_ses-t1_scans.tsv new file mode 100644 index 000000000..54b711284 --- /dev/null +++ b/mne_bids_pipeline/tests/sub-010_ses-t1_scans.tsv @@ -0,0 +1,2 @@ +filename acq_time +eeg/sub-010_ses-t1_task-resteyesc_eeg.edf 2017-05-09T12:11:44 diff --git a/mne_bids_pipeline/tests/test_cli.py b/mne_bids_pipeline/tests/test_cli.py index 607cbdd67..5b8b33b5a 100644 --- a/mne_bids_pipeline/tests/test_cli.py +++ b/mne_bids_pipeline/tests/test_cli.py @@ -2,11 +2,14 @@ import importlib import sys + import pytest + from mne_bids_pipeline._main import main def test_config_generation(tmp_path, monkeypatch): + """Test the generation of a default config file.""" cmd = ["mne_bids_pipeline", "--create-config"] monkeypatch.setattr(sys, "argv", cmd) with pytest.raises(SystemExit, match="2"): diff --git a/mne_bids_pipeline/tests/test_documented.py b/mne_bids_pipeline/tests/test_documented.py index 2175b7af9..a6275a4c8 100644 --- a/mne_bids_pipeline/tests/test_documented.py +++ b/mne_bids_pipeline/tests/test_documented.py @@ -1,13 +1,17 @@ """Test that all config values are documented.""" + import ast -from pathlib import Path import os import re +import sys +from pathlib import Path + import yaml +from mne_bids_pipeline._config_import import _get_default_config +from mne_bids_pipeline._docs import _EXECUTION_OPTIONS, _ParseConfigSteps from mne_bids_pipeline.tests.datasets import DATASET_OPTIONS from mne_bids_pipeline.tests.test_run import TEST_SUITE -from mne_bids_pipeline._config_import import _get_default_config root_path = Path(__file__).parent.parent @@ -15,7 +19,7 @@ def test_options_documented(): """Test that all options are suitably documented.""" # use ast to parse _config.py for assignments - with open(root_path / "_config.py", "r") as fid: + with open(root_path / "_config.py") as fid: contents = fid.read() contents = ast.parse(contents) in_config = [ @@ -28,31 +32,57 @@ def test_options_documented(): config_names = set(d for d in dir(config) if not d.startswith("_")) assert in_config == config_names settings_path = root_path.parent / "docs" / "source" / "settings" + sys.path.append(str(settings_path)) + try: + from gen_settings import main + finally: + sys.path.pop() + main() assert settings_path.is_dir() - in_doc = set() + in_doc = dict() key = " - " - allowed_duplicates = set( - [ - "source_info_path_update", - ] - ) for dirpath, _, fnames in os.walk(settings_path): for fname in fnames: if not fname.endswith(".md"): continue # This is a .md file - with open(Path(dirpath) / fname, "r") as fid: + # convert to relative path + fname = os.path.join(os.path.relpath(dirpath, settings_path), fname) + assert fname not in in_doc + in_doc[fname] = set() + with open(settings_path / fname) as fid: for line in fid: if not line.startswith(key): continue # The line starts with our magic key val = line[len(key) :].strip() - if val not in allowed_duplicates: - assert val not in in_doc, "Duplicate documentation" - in_doc.add(val) + for other in in_doc: + why = f"Duplicate docs in {fname} and {other} for {val}" + assert val not in in_doc[other], why + in_doc[fname].add(val) what = "docs/source/settings doc" - assert in_doc.difference(in_config) == set(), f"Extra values in {what}" - assert in_config.difference(in_doc) == set(), f"Values missing from {what}" + in_doc_all = set() + for vals in in_doc.values(): + in_doc_all.update(vals) + assert in_doc_all.difference(in_config) == set(), f"Extra values in {what}" + assert in_config.difference(in_doc_all) == set(), f"Values missing from {what}" + + +def test_config_options_used(): + """Test that all config options are used somewhere.""" + config = _get_default_config() + config_names = set(d for d in dir(config) if not d.startswith("__")) + for key in ("_epochs_split_size", "_raw_split_size"): + config_names.add(key) + for key in _EXECUTION_OPTIONS: + config_names.remove(key) + pcs = _ParseConfigSteps(force_empty=()) + missing_from_config = sorted(set(pcs.steps) - config_names) + assert missing_from_config == [], f"Missing from config: {missing_from_config}" + missing_from_steps = sorted(config_names - set(pcs.steps)) + assert missing_from_steps == [], f"Missing from steps: {missing_from_steps}" + for key, val in pcs.steps.items(): + assert val, f"No steps for {key}" def test_datasets_in_doc(): @@ -67,7 +97,7 @@ def test_datasets_in_doc(): # So let's make sure they stay in sync. # 1. Read cache, test, etc. entries from CircleCI - with open(root_path.parent / ".circleci" / "config.yml", "r") as fid: + with open(root_path.parent / ".circleci" / "config.yml") as fid: circle_yaml_src = fid.read() circle_yaml = yaml.safe_load(circle_yaml_src) caches = [job[6:] for job in circle_yaml["jobs"] if job.startswith("cache_")] @@ -80,7 +110,7 @@ def test_datasets_in_doc(): # make sure everything is consistent there (too much work), let's at least # check that we get the correct number using `.count`. counts = dict(ERP_CORE=7, ds000248=6) - counts_noartifact = dict(ds000248=3) # 3 are actually tests, not for docs + counts_noartifact = dict(ds000248=1) # 1 is actually a test, not for docs for name in sorted(caches): get = f"Get {name}" n_found = circle_yaml_src.count(get) @@ -117,19 +147,13 @@ def test_datasets_in_doc(): # jobs: test_*: steps: persist_to_workspace pw = re.compile( f"- mne_data/derivatives/mne-bids-pipeline/{name}[^\\.]+\\*.html" - ) # noqa: E501 + ) n_found = len(pw.findall(circle_yaml_src)) assert n_found == this_count, f"{pw} ({n_found} != {this_count})" # jobs: test_*: steps: run test - cp = re.compile( - f"""\ - DS={name}.* - \\$RUN_TESTS \\${{DS}}.* - mkdir -p ~/reports/\\${{DS}} - cp -av ~/mne_data/derivatives/mne-bids-pipeline/\\${{DS}}/[^\\.]+.html""" - ) # noqa: E501 + cp = re.compile(rf" command: \$RUN_TESTS[ -rc]+{name}.*") n_found = len(cp.findall(circle_yaml_src)) - assert n_found == this_count, f"{cp} ({n_found} != {this_count})" + assert n_found == count, f"{cp} ({n_found} != {count})" # 3. Read examples from docs (being careful about tags we can't read) class SafeLoaderIgnoreUnknown(yaml.SafeLoader): @@ -140,7 +164,7 @@ def ignore_unknown(self, node): None, SafeLoaderIgnoreUnknown.ignore_unknown ) - with open(root_path.parent / "docs" / "mkdocs.yml", "r") as fid: + with open(root_path.parent / "docs" / "mkdocs.yml") as fid: examples = yaml.load(fid.read(), Loader=SafeLoaderIgnoreUnknown) examples = [n for n in examples["nav"] if list(n)[0] == "Examples"][0] examples = [ex for ex in examples["Examples"] if isinstance(ex, str)] diff --git a/mne_bids_pipeline/tests/test_functions.py b/mne_bids_pipeline/tests/test_functions.py new file mode 100644 index 000000000..f4d64adf4 --- /dev/null +++ b/mne_bids_pipeline/tests/test_functions.py @@ -0,0 +1,64 @@ +"""Test some properties of our core processing-step functions.""" + +import ast +import inspect + +import pytest + +from mne_bids_pipeline._config_utils import _get_step_modules + +# mne_bids_pipeline.init._01_init_derivatives_dir: +FLAT_MODULES = {x.__name__: x for x in sum(_get_step_modules().values(), ())} + + +@pytest.mark.parametrize("module_name", list(FLAT_MODULES)) +def test_all_functions_return(module_name): + """Test that all functions decorated with failsafe_run return a dict.""" + # Find the functions within the module that use the failsafe_run decorator + module = FLAT_MODULES[module_name] + funcs = list() + for name in dir(module): + obj = getattr(module, name) + if not callable(obj): + continue + if getattr(obj, "__module__", None) != module_name: + continue + if not hasattr(obj, "__wrapped__"): + continue + # All our failsafe_run decorated functions should look like this + assert "__mne_bids_pipeline_failsafe_wrapper__" in repr(obj.__code__) + funcs.append(obj) + # Some module names we know don't have any + if module_name.split(".")[-1] in ("_01_recon_all",): + assert len(funcs) == 0 + return + + assert len(funcs) != 0, f"No failsafe_runs functions found in {module_name}" + + # Adapted from numpydoc RT01 validation + def get_returns_not_on_nested_functions(node): + returns = [node] if isinstance(node, ast.Return) else [] + for child in ast.iter_child_nodes(node): + # Ignore nested functions and its subtrees. + if not isinstance(child, ast.FunctionDef): + child_returns = get_returns_not_on_nested_functions(child) + returns.extend(child_returns) + return returns + + for func in funcs: + what = f"{module_name}.{func.__name__}" + tree = ast.parse(inspect.getsource(func.__wrapped__)).body + if func.__closure__[-1].cell_contents is False: + continue # last closure node is require_output=False + assert tree, f"Failed to parse source code for {what}" + returns = get_returns_not_on_nested_functions(tree[0]) + return_values = [r.value for r in returns] + # Replace Constant nodes valued None for None. + for i, v in enumerate(return_values): + if isinstance(v, ast.Constant) and v.value is None: + return_values[i] = None + assert len(return_values), f"Function does not return anything: {what}" + for r in return_values: + assert ( + isinstance(r, ast.Call) and r.func.id == "_prep_out_files" + ), f"Function does _prep_out_files: {what}" diff --git a/mne_bids_pipeline/tests/test_run.py b/mne_bids_pipeline/tests/test_run.py index 593a5968e..952be5f13 100644 --- a/mne_bids_pipeline/tests/test_run.py +++ b/mne_bids_pipeline/tests/test_run.py @@ -1,13 +1,16 @@ """Download test data and run a test suite.""" -import sys + +import os import shutil +import sys +from collections.abc import Collection from pathlib import Path -from typing import Collection, Dict, Optional, TypedDict +from typing import TypedDict import pytest -from mne_bids_pipeline._main import main from mne_bids_pipeline._download import main as download_main +from mne_bids_pipeline._main import main BIDS_PIPELINE_DIR = Path(__file__).absolute().parents[1] @@ -17,24 +20,18 @@ # Once PEP655 lands in 3.11 we can use NotRequired instead of total=False +# Effective defaults are listed in comments class _TestOptionsT(TypedDict, total=False): - dataset: str - config: str - steps: Collection[str] - task: Optional[str] - env: Dict[str, str] - - -# If not supplied below, the defaults are: -# key: { -# 'dataset': key.split('_')[0], -# 'config': f'config_{key}.py', -# 'steps': ('preprocessing', 'sensor'), -# 'env': {}, -# 'task': None, -# } -# -TEST_SUITE: Dict[str, _TestOptionsT] = { + dataset: str # key.split("_")[0] + config: str # f"config_{key}.py" + steps: Collection[str] # ("preprocessing", "sensor") + task: str | None # None + env: dict[str, str] # {} + requires: Collection[str] # () + extra_config: str # "" + + +TEST_SUITE: dict[str, _TestOptionsT] = { "ds003392": {}, "ds004229": {}, "ds001971": {}, @@ -56,16 +53,32 @@ class _TestOptionsT(TypedDict, total=False): }, "ds000248_base": { "steps": ("preprocessing", "sensor", "source"), + "requires": ("freesurfer",), + "extra_config": """ +_raw_split_size = "60MB" # hits both task-noise and task-audiovisual +_epochs_split_size = "30MB" +# use n_jobs=1 here to ensure that we get coverage for metadata_query +_n_jobs = {"preprocessing/_05_make_epochs": 1} +""", + }, + "ds000248_ica": { + "extra_config": """ +_raw_split_size = "60MB" +_epochs_split_size = "30MB" +_n_jobs = {} +""" }, - "ds000248_ica": {}, "ds000248_T1_BEM": { "steps": ("source/make_bem_surfaces",), + "requires": ("freesurfer",), }, "ds000248_FLASH_BEM": { "steps": ("source/make_bem_surfaces",), + "requires": ("freesurfer",), }, "ds000248_coreg_surfaces": { "steps": ("freesurfer/coreg_surfaces",), + "requires": ("freesurfer",), }, "ds000248_no_mri": { "steps": ("preprocessing", "sensor", "source"), @@ -85,6 +98,13 @@ class _TestOptionsT(TypedDict, total=False): "dataset": "ERP_CORE", "config": "config_ERP_CORE.py", "task": "ERN", + "extra_config": """ +# use n_jobs = 1 with loky to ensure that the CSP steps get proper coverage +_n_jobs = { + "sensor/_05_decoding_csp": 1, + "sensor/_99_group_average": 1, +} +""", }, "ERP_CORE_LRP": { "dataset": "ERP_CORE", @@ -111,15 +131,22 @@ class _TestOptionsT(TypedDict, total=False): "config": "config_ERP_CORE.py", "task": "P3", }, + "MNE-phantom-KIT-data": { + "config": "config_MNE_phantom_KIT_data.py", + }, } @pytest.fixture() def dataset_test(request): + """Provide a defined context for our dataset tests.""" # There is probably a cleaner way to get this param, but this works for now capsys = request.getfixturevalue("capsys") dataset = request.getfixturevalue("dataset") test_options = TEST_SUITE[dataset] + if "freesurfer" in test_options.get("requires", ()): + if "FREESURFER_HOME" not in os.environ: + pytest.skip("FREESURFER_HOME required but not found") dataset_name = test_options.get("dataset", dataset.split("_")[0]) with capsys.disabled(): if request.config.getoption("--download", False): # download requested @@ -129,33 +156,39 @@ def dataset_test(request): @pytest.mark.dataset_test @pytest.mark.parametrize("dataset", list(TEST_SUITE)) -def test_run(dataset, monkeypatch, dataset_test, capsys): +def test_run(dataset, monkeypatch, dataset_test, capsys, tmp_path): """Test running a dataset.""" test_options = TEST_SUITE[dataset] - - # export the environment variables - monkeypatch.setenv("DATASET", dataset) - for key, value in test_options.get("env", {}).items(): - monkeypatch.setenv(key, value) - config = test_options.get("config", f"config_{dataset}.py") config_path = BIDS_PIPELINE_DIR / "tests" / "configs" / config + extra_config = TEST_SUITE[dataset].get("extra_config", "") + if extra_config: + extra_path = tmp_path / "extra_config.py" + extra_path.write_text(extra_config) + monkeypatch.setenv("_MNE_BIDS_STUDY_TESTING_EXTRA_CONFIG", str(extra_path)) # XXX Workaround for buggy date in ds000247. Remove this and the # XXX file referenced here once fixed!!! fix_path = Path(__file__).parent if dataset == "ds000247": - shutil.copy( - src=fix_path / "ds000247_scans.tsv", - dst=Path( - "~/mne_data/ds000247/sub-0002/ses-01/" "sub-0002_ses-01_scans.tsv" - ).expanduser(), + dst = ( + DATA_DIR / "ds000247" / "sub-0002" / "ses-01" / "sub-0002_ses-01_scans.tsv" ) + shutil.copy(src=fix_path / "ds000247_scans.tsv", dst=dst) # XXX Workaround for buggy participant_id in ds001971 elif dataset == "ds001971": shutil.copy( src=fix_path / "ds001971_participants.tsv", - dst=Path("~/mne_data/ds001971/participants.tsv").expanduser(), + dst=DATA_DIR / "ds001971" / "participants.tsv", + ) + elif dataset == "ds003775": + shutil.copy( + src=fix_path / "sub-010_ses-t1_scans.tsv", + dst=DATA_DIR + / "ds003775" + / "sub-010" + / "ses-t1" + / "sub-010_ses-t1_scans.tsv", ) # Run the tests. diff --git a/mne_bids_pipeline/tests/test_validation.py b/mne_bids_pipeline/tests/test_validation.py index 25d5abdaa..c76130cc0 100644 --- a/mne_bids_pipeline/tests/test_validation.py +++ b/mne_bids_pipeline/tests/test_validation.py @@ -1,4 +1,7 @@ +"""Test the pipeline configuration import validator.""" + import pytest + from mne_bids_pipeline._config_import import _import_config @@ -13,7 +16,7 @@ def test_validation(tmp_path, capsys): bad_text += f"bids_root = '{tmp_path}'\n" # no ch_types config_path.write_text(bad_text) - with pytest.raises(ValueError, match="Please specify ch_types"): + with pytest.raises(ValueError, match="Value should have at least 1 item"): _import_config(config_path=config_path) bad_text += "ch_types = ['eeg']\n" # conditions diff --git a/mne_bids_pipeline/typing.py b/mne_bids_pipeline/typing.py index 555012713..8ac9ecfe4 100644 --- a/mne_bids_pipeline/typing.py +++ b/mne_bids_pipeline/typing.py @@ -1,24 +1,58 @@ -"""Typing.""" +"""Custom data types for MNE-BIDS-Pipeline.""" import pathlib -from typing import Union, List, Dict, TypedDict +import sys +from typing import Annotated + +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict import mne +import numpy as np +from numpy.typing import ArrayLike +from pydantic import PlainValidator -PathLike = Union[str, pathlib.Path] +PathLike = str | pathlib.Path class ArbitraryContrast(TypedDict): + """Statistical contrast with arbitrary weights.""" + name: str - conditions: List[str] - weights: List[float] + conditions: list[str] + weights: list[float] class LogKwargsT(TypedDict): + """Container for logger keyword arguments.""" + msg: str - extra: Dict[str, str] + extra: dict[str, str] + + +def assert_float_array_like(val): + """Convert the input into a NumPy float array.""" + # https://docs.pydantic.dev/latest/errors/errors/#custom-errors + # Should raise ValueError or AssertionError... NumPy should do this for us + return np.array(val, dtype="float") + + +FloatArrayLike = Annotated[ + ArrayLike, + # PlainValidator will skip internal validation attempts for ArrayLike + PlainValidator(assert_float_array_like), +] + + +def assert_dig_montage(val): + """Assert that the input is a DigMontage.""" + assert isinstance(val, mne.channels.DigMontage) + return val -class ReferenceRunParams(TypedDict): - montage: mne.channels.DigMontage - dev_head_t: mne.Transform +DigMontageType = Annotated[ + mne.channels.DigMontage, + PlainValidator(assert_dig_montage), +] diff --git a/pyproject.toml b/pyproject.toml index 4e866afff..b9ceda85b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,71 +1,84 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + [project] name = "mne-bids-pipeline" # Keep in sync with README.md: description = "A full-flegded processing pipeline for your MEG and EEG data" readme = "README.md" -requires-python = ">=3.8" -license = {file = "LICENSE.txt"} +requires-python = ">=3.10" +license = { file = "LICENSE.txt" } keywords = ["science", "neuroscience", "psychology"] authors = [ - {name = "Eric Larson"}, - {name = "Alexandre Gramfort"}, - {name = "Mainak Jas"}, - {name = "Richard Höchenberger", email = "richard.hoechenberger@gmail.com"}, + { name = "Eric Larson" }, + { name = "Alexandre Gramfort" }, + { name = "Mainak Jas" }, + { name = "Richard Höchenberger", email = "richard.hoechenberger@gmail.com" }, ] classifiers = [ "Intended Audience :: Science/Research", - "Programming Language :: Python" + "Programming Language :: Python", ] dependencies = [ - "typing_extensions; python_version < '3.8'", - "importlib_metadata; python_version < '3.8'", - "psutil", # for joblib - "packaging", - "joblib >= 0.14", - "threadpoolctl", - "dask[distributed]", - "bokeh < 3", # for distributed dashboard - "jupyter-server-proxy", # to have dask and jupyter working together - "scikit-learn", - "pandas", - "seaborn", - "json_tricks", - "rich", - "python-picard", - "qtpy", - "pyvista", - "pyvistaqt", - "openpyxl", - "autoreject", - "mne[hdf5] >=1.2", - "mne-bids[full]", - "filelock", - "setuptools >=65", + "psutil", # for joblib + "packaging", + "numpy", + "scipy", + "matplotlib", + "nibabel", + "joblib >= 0.14", + "threadpoolctl", + "dask[distributed]", + "bokeh", # for distributed dashboard + "jupyter-server-proxy", # to have dask and jupyter working together + "scikit-learn", + "pandas", + "pyarrow", # from pandas + "seaborn", + "json_tricks", + "pydantic >= 2.0.0", + "annotated-types", + "rich", + "python-picard", + "qtpy", + "pyvista", + "pyvistaqt", + "openpyxl", + "autoreject", + "mne[hdf5] >=1.7", + "mne-bids[full]", + "mne-icalabel", + "onnxruntime", # for mne-icalabel + "filelock", ] dynamic = ["version"] [project.optional-dependencies] tests = [ - "pytest", - "pytest-cov", - "psutil", - "datalad", - "ruff", - "mkdocs", - "mkdocs-material >= 9.0.4", - "mkdocs-material-extensions", - "mkdocs-macros-plugin", - "mkdocs-include-markdown-plugin", - "mkdocstrings-python", - "mike", - "jinja2", - "black", # function signature formatting - "livereload", - "openneuro-py >= 2022.2.0", - "httpx >= 0.20", - "tqdm", - "Pygments", - "pyyaml", + "pytest", + "pytest-cov", + "pooch", + "psutil", + "ruff", + "jinja2", + "openneuro-py >= 2022.2.0", + "httpx >= 0.20", + "tqdm", + "Pygments", + "pyyaml", +] +docs = [ + "mkdocs", + "mkdocs-material >= 9.0.4", + "mkdocs-material-extensions", + "mkdocs-macros-plugin", + "mkdocs-include-markdown-plugin", + "mkdocs-exclude", + "mkdocstrings-python", + "mike", + "livereload", + "black", # docstring reformatting ] [project.scripts] @@ -76,19 +89,21 @@ homepage = "https://mne.tools/mne-bids-pipeline" repository = "https://github.com/mne-tools/mne-bids-pipeline" changelog = "http://mne.tools/mne-bids-pipeline/changes.html" -[build-system] -requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2", "wheel"] -build-backend = "setuptools.build_meta" - -[tool.setuptools_scm] -tag_regex = "^(?Pv)?(?P[0-9.]+)(?P.*)?$" -version_scheme = "release-branch-semver" +[tool.hatch.version] +source = "vcs" +raw-options = { version_scheme = "release-branch-semver" } -[tool.setuptools.packages.find] -exclude = ["false"] # on CircleCI this folder appears during pip install -ve. for an unknown reason - -[tool.setuptools.package-data] -"mne_bids_pipeline.steps.freesurfer.contrib" = ["version"] +[tool.hatch.build] +exclude = [ + "/.*", + "/codecov.yml", + "**/tests", + "/docs", + "/docs/source/examples/gen_examples.py", # specify explicitly because its exclusion is negated in .gitignore + "/Makefile", + "/CONTRIBUTING.md", + "ignore_words.txt", +] [tool.codespell] skip = "docs/site/*,*.html,steps/freesurfer/contrib/*" @@ -101,13 +116,15 @@ count = "" [tool.pytest.ini_options] addopts = "-ra -vv --tb=short --cov=mne_bids_pipeline --cov-report= --junit-xml=junit-results.xml --durations=10" -testpaths = [ - "mne_bids_pipeline", -] +testpaths = ["mne_bids_pipeline"] junit_family = "xunit2" -[tool.ruff] -exclude = ["**/freesurfer/contrib", "dist/" , "build/"] +[tool.ruff.lint] +select = ["A", "B006", "D", "E", "F", "I", "W", "UP"] +exclude = ["**/freesurfer/contrib", "dist/", "build/"] +ignore = [ + "D104", # Missing docstring in public package +] -[tool.black] -exclude = "(.*/freesurfer/contrib/.*)|(dist/)|(build/)" +[tool.ruff.lint.pydocstyle] +convention = "numpy"