diff --git a/.circleci/config.yml b/.circleci/config.yml index aa0f22479..2e030bc88 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,12 +2,12 @@ version: 2.1 orbs: browser-tools: circleci/browser-tools@1.2 - codecov: codecov/codecov@3.2.4 + codecov: codecov/codecov@5.3.0 # Aliases to reuse _imageconfig: &imageconfig docker: - - image: cimg/base:2022.10-22.04 + - image: cimg/python:3.12.4 # https://circleci.com/docs/workflows/#executing-workflows-for-a-git-tag @@ -47,10 +47,11 @@ jobs: - run: name: Get Python running command: | - pip install --upgrade --progress-bar off pip setuptools - pip install --upgrade --progress-bar off "autoreject @ https://api.github.com/repos/autoreject/autoreject/zipball/master" "mne[hdf5] @ git+https://github.com/mne-tools/mne-python@main" "mne-bids[full] @ https://api.github.com/repos/mne-tools/mne-bids/zipball/main" numba + pip install --upgrade --progress-bar off pip + # TODO: Restore https://api.github.com/repos/mne-tools/mne-bids/zipball/main pending https://github.com/mne-tools/mne-bids/pull/1349/files#r1885104885 + pip install --upgrade --progress-bar off "autoreject @ https://api.github.com/repos/autoreject/autoreject/zipball/master" "mne[hdf5] @ git+https://github.com/mne-tools/mne-python@main" "mne-bids[full] @ git+https://github.com/mne-tools/mne-bids@main" numba pip install -ve .[tests] - pip install PyQt6 + pip install "PyQt6!=6.6.1" "PyQt6-Qt6!=6.6.1,!=6.6.2,!=6.6.3,!=6.7.0" - run: name: Check Qt command: | @@ -67,7 +68,7 @@ jobs: - project - mne_data - minimal_cmds - - python_env + - .pyenv cache_ds000117: <<: *imageconfig @@ -76,7 +77,7 @@ jobs: at: ~/ - restore_cache: keys: - - data-cache-ds000117-2 + - data-cache-ds000117-3 - bash_env - run: name: Get ds000117 @@ -84,7 +85,7 @@ jobs: $DOWNLOAD_DATA ds000117 - codecov/upload - save_cache: - key: data-cache-ds000117-2 + key: data-cache-ds000117-3 paths: - ~/mne_data/ds000117 @@ -151,7 +152,7 @@ jobs: at: ~/ - restore_cache: keys: - - data-cache-ds000246-2 + - data-cache-ds000246-3 - bash_env - run: name: Get ds000246 @@ -159,7 +160,7 @@ jobs: $DOWNLOAD_DATA ds000246 - codecov/upload - save_cache: - key: data-cache-ds000246-2 + key: data-cache-ds000246-3 paths: - ~/mne_data/ds000246 @@ -170,7 +171,7 @@ jobs: at: ~/ - restore_cache: keys: - - data-cache-ds000247-2 + - data-cache-ds000247-3 - bash_env - run: name: Get ds000247 @@ -178,7 +179,7 @@ jobs: $DOWNLOAD_DATA ds000247 - codecov/upload - save_cache: - key: data-cache-ds000247-2 + key: data-cache-ds000247-3 paths: - ~/mne_data/ds000247 @@ -286,7 +287,6 @@ jobs: keys: - data-cache-eeg_matchingpennies-1 - bash_env - - gitconfig # email address is needed for datalad - run: name: Get eeg_matchingpennies command: | @@ -297,6 +297,44 @@ jobs: paths: - ~/mne_data/eeg_matchingpennies + cache_MNE-funloc-data: + <<: *imageconfig + steps: + - attach_workspace: + at: ~/ + - restore_cache: + keys: + - data-cache-MNE-funloc-data-5 + - bash_env + - run: + name: Get MNE-funloc-data + command: | + $DOWNLOAD_DATA MNE-funloc-data + - codecov/upload + - save_cache: + key: data-cache-MNE-funloc-data-5 + paths: + - ~/mne_data/MNE-funloc-data + + cache_MNE-phantom-KIT-data: + <<: *imageconfig + steps: + - attach_workspace: + at: ~/ + - restore_cache: + keys: + - data-cache-MNE-phantom-KIT-data-1 + - bash_env + - run: + name: Get MNE-phantom-KIT-data + command: | + $DOWNLOAD_DATA MNE-phantom-KIT-data + - codecov/upload + - save_cache: + key: data-cache-MNE-phantom-KIT-data-1 + paths: + - ~/mne_data/MNE-phantom-KIT-data + cache_ERP_CORE: <<: *imageconfig steps: @@ -324,7 +362,7 @@ jobs: - bash_env - restore_cache: keys: - - data-cache-ds000117-2 + - data-cache-ds000117-3 - run: name: test ds000117 command: $RUN_TESTS ds000117 @@ -429,7 +467,7 @@ jobs: - bash_env - restore_cache: keys: - - data-cache-ds000246-2 + - data-cache-ds000246-3 - run: name: test ds000246 no_output_timeout: 15m @@ -457,7 +495,7 @@ jobs: - bash_env - restore_cache: keys: - - data-cache-ds000247-2 + - data-cache-ds000247-3 - run: name: test ds000247 command: $RUN_TESTS ds000247 @@ -741,6 +779,7 @@ jobs: test_eeg_matchingpennies: <<: *imageconfig + resource_class: large # memory for zapline steps: - attach_workspace: at: ~/ @@ -765,6 +804,59 @@ jobs: paths: - mne_data/derivatives/mne-bids-pipeline/eeg_matchingpennies/*/*/*.html + test_MNE-funloc-data: + <<: *imageconfig + resource_class: large + steps: + - attach_workspace: + at: ~/ + - bash_env + - restore_cache: + keys: + - data-cache-MNE-funloc-data-5 + - run: + name: test MNE-funloc-data + command: $RUN_TESTS MNE-funloc-data + - codecov/upload + - store_test_results: + path: ./test-results + - store_artifacts: + path: ./test-results + destination: test-results + - store_artifacts: + path: /home/circleci/reports/MNE-funloc-data + destination: reports/MNE-funloc-data + - persist_to_workspace: + root: ~/ + paths: + - mne_data/derivatives/mne-bids-pipeline/MNE-funloc-data/*/*/*.html + + test_MNE-phantom-KIT-data: + <<: *imageconfig + steps: + - attach_workspace: + at: ~/ + - bash_env + - restore_cache: + keys: + - data-cache-MNE-phantom-KIT-data-1 + - run: + name: test MNE-phantom-KIT-data + command: $RUN_TESTS MNE-phantom-KIT-data + - codecov/upload + - store_test_results: + path: ./test-results + - store_artifacts: + path: ./test-results + destination: test-results + - store_artifacts: + path: /home/circleci/reports/MNE-phantom-KIT-data + destination: reports/MNE-phantom-KIT-data + - persist_to_workspace: + root: ~/ + paths: + - mne_data/derivatives/mne-bids-pipeline/MNE-phantom-KIT-data/*/*/*.html + test_ERP_CORE_N400: <<: *imageconfig resource_class: large @@ -986,6 +1078,10 @@ jobs: - attach_workspace: at: ~/ - bash_env + - run: + name: Install dependencies + command: | + pip install -ve .[docs] - run: name: Build documentation command: | @@ -1014,6 +1110,10 @@ jobs: at: ~/ - bash_env - gitconfig + - run: + name: Install dependencies + command: | + pip install -ve .[docs] - run: # This is a bit computationally inefficient, but it should be much # faster to "cp" directly on the machine rather than persist @@ -1191,6 +1291,24 @@ workflows: - cache_eeg_matchingpennies <<: *filter_tags + - cache_MNE-funloc-data: + requires: + - setup_env + <<: *filter_tags + - test_MNE-funloc-data: + requires: + - cache_MNE-funloc-data + <<: *filter_tags + + - cache_MNE-phantom-KIT-data: + requires: + - setup_env + <<: *filter_tags + - test_MNE-phantom-KIT-data: + requires: + - cache_MNE-phantom-KIT-data + <<: *filter_tags + - cache_ERP_CORE: requires: - setup_env @@ -1242,6 +1360,8 @@ workflows: - test_ds003392 - test_ds004229 - test_eeg_matchingpennies + - test_MNE-funloc-data + - test_MNE-phantom-KIT-data - test_ERP_CORE_N400 - test_ERP_CORE_ERN - test_ERP_CORE_LRP diff --git a/.circleci/remove_examples.sh b/.circleci/remove_examples.sh new file mode 100755 index 000000000..ee4004442 --- /dev/null +++ b/.circleci/remove_examples.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +set -eo pipefail + +VER=$1 +if [ -z "$VER" ]; then + echo "Usage: $0 " + exit 1 +fi +ROOT="$PWD/$VER/examples/" +if [ ! -d ${ROOT} ]; then + echo "Version directory does not exist or appears incorrect:" + echo + echo "$ROOT" + echo + echo "Are you on the gh-pages branch and is the ds000117 directory present?" + exit 1 +fi +if [ ! -d ${ROOT}ds000117 ]; then + echo "Directory does not exist:" + echo + echo "${ROOT}ds000117" + echo + echo "Assuming already pruned and exiting." + exit 0 +fi +echo "Pruning examples in ${ROOT} ..." + +find $ROOT -type d -name "*" | tail -n +2 | xargs rm -Rf +find $ROOT -name "*.html" -exec sed -i /^\Generated/,/^\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ $/{d} {} \; +find $ROOT -name "*.html" -exec sed -i '/^ $/{d}' {} \; + +echo "Done" diff --git a/.circleci/run_dataset_and_copy_files.sh b/.circleci/run_dataset_and_copy_files.sh index 34dcfa14f..23e395b8f 100755 --- a/.circleci/run_dataset_and_copy_files.sh +++ b/.circleci/run_dataset_and_copy_files.sh @@ -27,25 +27,32 @@ else fi SECONDS=0 +EMPH="\e[35m\e[1m" # bold magenta foreground +RESET="\e[0m" pytest mne_bids_pipeline --junit-xml=test-results/junit-results.xml -k ${DS_RUN} -echo "Runtime: ${SECONDS} seconds" +# Add emphasis and echo +echo -e "${EMPH}Clean test runtime: ${SECONDS} seconds${RESET}" +echo # rerun test (check caching)! SECONDS=0 +RERUN_LIMIT=60 if [[ "$RERUN_TEST" == "false" ]]; then - echo "Skipping rerun test" + echo -e "${EMPH}Skipping cache rerun test${RESET}" RUN_TIME=0 else pytest mne_bids_pipeline --cov-append -k $DS_RUN RUN_TIME=$SECONDS - echo "Runtime: ${RUN_TIME} seconds (should be < 20)" + echo -e "${EMPH}Cached test runtime: ${RUN_TIME} seconds (should be <= $RERUN_LIMIT)${RESET}" fi -test $RUN_TIME -lt 20 +test $RUN_TIME -le $RERUN_LIMIT if [[ "$COPY_FILES" == "false" ]]; then - echo "Not copying files" + echo -e "${EMPH}Not copying files${RESET}" exit 0 fi +echo +echo -e "${EMPH}Copying files${RESET}" mkdir -p ~/reports/${DS} # these should always exist cp -av ~/mne_data/derivatives/mne-bids-pipeline/${DS}/*/**/*.html ~/reports/${DS}/ diff --git a/.circleci/setup_bash.sh b/.circleci/setup_bash.sh index ee44b317b..e02c876c9 100755 --- a/.circleci/setup_bash.sh +++ b/.circleci/setup_bash.sh @@ -32,32 +32,23 @@ else fi # Set up image -sudo ln -s /usr/lib/x86_64-linux-gnu/libxcb-util.so.0 /usr/lib/x86_64-linux-gnu/libxcb-util.so.1 -wget -q -O- http://neuro.debian.net/lists/focal.us-tn.libre | sudo tee /etc/apt/sources.list.d/neurodebian.sources.list -sudo apt-key adv --recv-keys --keyserver hkps://keyserver.ubuntu.com 0xA5D32F012649A5A9 -echo "export RUN_TESTS=\".circleci/run_dataset_and_copy_files.sh\"" >> "$BASH_ENV" -echo "export DOWNLOAD_DATA=\"coverage run -m mne_bids_pipeline._download\"" >> "$BASH_ENV" +echo "export RUN_TESTS=\".circleci/run_dataset_and_copy_files.sh\"" | tee -a "$BASH_ENV" +echo "export DOWNLOAD_DATA=\"coverage run -m mne_bids_pipeline._download\"" | tee -a "$BASH_ENV" -# Similar CircleCI setup to mne-python (Xvfb, venv, minimal commands, env vars) +# Similar CircleCI setup to mne-python (Xvfb, minimal commands, env vars) wget -q https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/setup_xvfb.sh bash setup_xvfb.sh -sudo apt install -qq tcsh git-annex-standalone python3.10-venv python3-venv libxft2 -python3.10 -m venv ~/python_env +sudo apt install -qq tcsh libxft2 wget -q https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/get_minimal_commands.sh source get_minimal_commands.sh mkdir -p ~/mne_data -echo "set -e" >> "$BASH_ENV" -echo 'export OPENBLAS_NUM_THREADS=2' >> "$BASH_ENV" -echo 'shopt -s globstar' >> "$BASH_ENV" # Enable recursive globbing via ** -echo 'export MNE_DATA=$HOME/mne_data' >> "$BASH_ENV" -echo "export PATH=~/.local/bin/:$PATH" >> "$BASH_ENV" -echo 'export DISPLAY=:99' >> "$BASH_ENV" -echo 'export XDG_RUNTIME_DIR=/tmp/runtime-circleci' >> "$BASH_ENV" -echo 'export MPLBACKEND=Agg' >> "$BASH_ENV" -echo "source ~/python_env/bin/activate" >> "$BASH_ENV" -echo "export MNE_3D_OPTION_MULTI_SAMPLES=1" >> "$BASH_ENV" -echo "export MNE_BIDS_PIPELINE_FORCE_TERMINAL=true" >> "$BASH_ENV" -mkdir -p ~/.local/bin -if [[ ! -f ~/.local/bin/python ]]; then - ln -s ~/python_env/bin/python ~/.local/bin/python -fi +echo "set -e" | tee -a "$BASH_ENV" +echo 'export OPENBLAS_NUM_THREADS=2' | tee -a "$BASH_ENV" +echo 'shopt -s globstar' | tee -a "$BASH_ENV" # Enable recursive globbing via ** +echo 'export MNE_DATA=$HOME/mne_data' | tee -a "$BASH_ENV" +echo 'export DISPLAY=:99' | tee -a "$BASH_ENV" +echo 'export XDG_RUNTIME_DIR=/tmp/runtime-circleci' | tee -a "$BASH_ENV" +echo 'export MPLBACKEND=Agg' | tee -a "$BASH_ENV" +echo "export MNE_3D_OPTION_MULTI_SAMPLES=1" | tee -a "$BASH_ENV" +echo "export MNE_BIDS_PIPELINE_FORCE_TERMINAL=true" | tee -a "$BASH_ENV" +echo "export FORCE_COLOR=1" | tee -a "$BASH_ENV" # for rich to use color in logs diff --git a/.git_archival.txt b/.git_archival.txt new file mode 100644 index 000000000..7c5100942 --- /dev/null +++ b/.git_archival.txt @@ -0,0 +1,3 @@ +node: $Format:%H$ +node-date: $Format:%cI$ +describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..00a7b00c9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +.git_archival.txt export-subst diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..4a47c7a99 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,15 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + labels: + - "dependabot" + commit-message: + prefix: "[dependabot]" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index cd4653be5..7e90e45b0 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,3 +1,3 @@ ### Before merging … -- [ ] Changelog has been updated (`docs/source/changes.md`) +- [ ] Changelog has been updated (`docs/source/dev.md.inc`) diff --git a/.github/release.yaml b/.github/release.yaml new file mode 100644 index 000000000..0b7e54946 --- /dev/null +++ b/.github/release.yaml @@ -0,0 +1,7 @@ +--- +changelog: + exclude: + authors: + - dependabot[bot] + - pre-commit-ci[bot] + - mne-bot diff --git a/.github/workflows/automerge.yml b/.github/workflows/automerge.yml new file mode 100644 index 000000000..f102b3813 --- /dev/null +++ b/.github/workflows/automerge.yml @@ -0,0 +1,17 @@ +name: Bot auto-merge +on: pull_request_target # yamllint disable-line rule:truthy + +permissions: + contents: write + pull-requests: write + +jobs: + autobot: + runs-on: ubuntu-latest + if: (github.event.pull_request.user.login == 'dependabot[bot]' || github.event.pull_request.user.login == 'pre-commit-ci[bot]') && github.repository == 'mne-tools/mne-bids-pipeline' + steps: + - name: Enable auto-merge for bot PRs + run: gh pr merge --auto --squash "$PR_URL" + env: + PR_URL: ${{github.event.pull_request.html_url}} + GH_TOKEN: ${{secrets.MNE_BOT_TOKEN}} diff --git a/.github/workflows/autopush.yml b/.github/workflows/autopush.yml new file mode 100644 index 000000000..06f338e1c --- /dev/null +++ b/.github/workflows/autopush.yml @@ -0,0 +1,34 @@ +name: Bot auto-push +on: # yamllint disable-line rule:truthy + push: + branches: + - dependabot/** + - pre-commit-ci* + +jobs: + autobot: + permissions: + contents: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + with: + persist-credentials: true + token: ${{ secrets.MNE_BOT_TOKEN }} + ssh-user: mne-bot + - name: Push a commit for bot PRs to run CircleCI + run: | + set -xeo pipefail + git config --global user.name "mne[bot]" + git config --global user.email "50266005+mne-bot@users.noreply.github.com" + COMMIT_MESSAGE=$(git show -s --format=%s) + # Detect dependabot and pre-commit.ci commit messages + if [[ "$COMMIT_MESSAGE" == '[dependabot]'* ]] || [[ "$COMMIT_MESSAGE" == '[pre-commit.ci]'* ]] ; then + echo "Pushed commit to run CircleCI for: $COMMIT_MESSAGE" | tee -a $GITHUB_STEP_SUMMARY + git commit --allow-empty -m "mne[bot] Push commit to run CircleCI" + git push + else + echo "No need to push a commit for: $COMMIT_MESSAGE" | tee -a $GITHUB_STEP_SUMMARY + fi + env: + GH_TOKEN: ${{ secrets.MNE_BOT_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1d5a786d7..380a2c448 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -18,9 +18,9 @@ jobs: package: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: '3.10' - name: Install dependencies @@ -34,7 +34,7 @@ jobs: - name: Check env vars run: | echo "Triggered by: ${{ github.event_name }}" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v5 with: name: dist path: dist @@ -44,13 +44,12 @@ jobs: needs: package runs-on: ubuntu-latest if: github.event_name == 'release' + permissions: + id-token: write steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v6 with: name: dist path: dist - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml deleted file mode 100644 index 7f7ae75ea..000000000 --- a/.github/workflows/run-tests.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: Checks -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true - -on: [push, pull_request] - -jobs: - check-style: - name: Style - runs-on: "ubuntu-latest" - defaults: - run: - shell: bash -l {0} - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - - name: Install ruff and codespell - run: pip install ruff codespell tomli - - run: make ruff - - run: make codespell-error - - uses: psf/black@stable - check-doc: - name: Doc consistency - runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - - run: pip install --upgrade setuptools[toml] pip - - run: pip install --no-build-isolation -ve .[tests] - - run: pytest mne_bids_pipeline -m "not dataset_test" - - uses: codecov/codecov-action@v3 - if: success() - name: 'Upload coverage to CodeCov' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 000000000..0b9eb271f --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,109 @@ +name: Checks +concurrency: + group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} + cancel-in-progress: true + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +jobs: + check-doc: + name: Doc consistency and codespell + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@v5 + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + - run: pip install --upgrade pip + - run: pip install -ve .[tests] "mne-bids[full] @ git+https://github.com/mne-tools/mne-bids@main" codespell tomli --only-binary="numpy,scipy,pandas,matplotlib,pyarrow,numexpr" + - run: make codespell-error + - run: pytest mne_bids_pipeline -m "not dataset_test" + - uses: codecov/codecov-action@v5 + if: success() + name: 'Upload coverage to CodeCov' + caching: + name: 'Testing and caching ${{ matrix.dataset }} on ${{ matrix.os }} py${{ matrix.python }}' + timeout-minutes: 30 + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -e {0} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, ubuntu-24.04-arm, macos-latest, windows-latest] + dataset: [ds001971, ds003392] # one uses "hash", the other "mtime" + python: ["3.12"] + include: + - os: ubuntu-latest + dataset: ds001971 + python: "3.10" + - os: ubuntu-latest + dataset: ds003392 + python: "3.13" + env: + MNE_BIDS_PIPELINE_LEGACY_WINDOWS: "false" + PYTHONIOENCODING: 'utf8' # for Windows + steps: + - uses: actions/checkout@v5 + - uses: pyvista/setup-headless-display-action@main + with: + qt: true + pyvista: false + - uses: actions/setup-python@v6 + with: + python-version: "${{ matrix.python }}" + - run: pip install -ve .[tests] --only-binary="numpy,scipy,pandas,matplotlib,pyarrow,numexpr" + - uses: actions/cache@v4 + with: + key: ${{ matrix.dataset }} + path: ~/mne_data/${{ matrix.dataset }} + id: dataset-cache + - run: python -m mne_bids_pipeline._download ${{ matrix.dataset }} + if: steps.dataset-cache.outputs.cache-hit != 'true' + - run: | + rm -f ~/mne_data/ds003392/sub-01/meg/sub-01_acq-calibration_meg.dat + rm -f ~/mne_data/ds003392/sub-01/meg/sub-01_acq-crosstalk_meg.fif + if: matrix.dataset == 'ds003392' + name: Remove cross-talk and cal files from ds003392 + - run: pytest --cov-append -k ${{ matrix.dataset }} mne_bids_pipeline/ + name: Run ${{ matrix.dataset }} test from scratch + - run: pytest --cov-append -k ${{ matrix.dataset }} mne_bids_pipeline/ + timeout-minutes: 1 + name: Rerun ${{ matrix.dataset }} test to check all steps cached + - uses: codecov/codecov-action@v5 + if: success() || failure() + non-doc-dataset-tests: + name: 'Non-doc dataset tests' + timeout-minutes: 30 + runs-on: ubuntu-latest + defaults: + run: + shell: bash -e {0} + strategy: + fail-fast: true + steps: + - uses: actions/checkout@v5 + - uses: pyvista/setup-headless-display-action@main + with: + qt: true + pyvista: false + - uses: actions/setup-python@v6 + with: + python-version: "3.12" + - run: pip install -ve .[tests] pyvistaqt PySide6 --only-binary="numpy,scipy,pandas,matplotlib,pyarrow,numexpr,PySide6" + - uses: actions/cache@v4 + with: + key: MNE-funloc-data + path: ~/mne_data/MNE-funloc-data + id: MNE-funloc-data-cache + - run: python -m mne_bids_pipeline._download MNE-funloc-data + if: steps.MNE-funloc-data-cache.outputs.cache-hit != 'true' + - run: pytest --cov-append -k test_session_specific_mri mne_bids_pipeline/ diff --git a/.gitignore b/.gitignore index 8c9401a3d..9b98a62c0 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ build/ .hypothesis/ .coverage* junit-results.xml +.cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e9ddf043a..f2259cf41 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,29 +1,34 @@ --- files: ^(.*\.(py|yaml))$ -# We need to match the exclude list in pyproject.toml because pre-commit.ci -# passes filenames and these do not get passed through the tool.black filter -# for example -exclude: ^(\.[^/]*cache/.*|.*/freesurfer/contrib/.*)$ repos: - - repo: https://github.com/psf/black - rev: 23.10.1 - hooks: - - id: black - args: - - --safe - - --quiet - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.4 + rev: v0.14.5 hooks: - id: ruff + args: ["--fix"] + - id: ruff-format - repo: https://github.com/codespell-project/codespell - rev: v2.2.6 + rev: v2.4.1 hooks: - id: codespell additional_dependencies: - tomli - repo: https://github.com/adrienverge/yamllint.git - rev: v1.32.0 + rev: v1.37.1 hooks: - id: yamllint args: [--strict] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.18.2 + hooks: + - id: mypy + additional_dependencies: + - types-PyYAML + - numpy + - pytest + - dask[distributed] + - types-tqdm + # TODO: Fix errors for pandas someday perhaps + # (though we don't use it a lot) + # - pandas-stubs diff --git a/BUILDING.md b/BUILDING.md deleted file mode 100644 index 41f1e327d..000000000 --- a/BUILDING.md +++ /dev/null @@ -1,9 +0,0 @@ -# Building a release - -* Tag a new release with `git` if necessary. -* Create `sdist` distribution: - - ```shell - pip install -q build - python -m build # will build sdist and wheel - ``` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a403bc8f2..cea32e59d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,9 +1,9 @@ -# Overview +# Contributing to MNE-BIDS-Pipeline Contributors to MNE-BIDS-Pipeline are expected to follow our [Code of Conduct](https://github.com/mne-tools/.github/blob/main/CODE_OF_CONDUCT.md). -# Installation +## Installation First, you need to make sure you have MNE-Python installed and working on your system. See the [installation instructions](http://martinos.org/mne/stable/install_mne_python.html). @@ -11,17 +11,11 @@ Once this is done, you should be able to run this in a terminal: `$ python -c "import mne; mne.sys_info()"` -You can then install the following additional packages via `pip`. Note that +You can then install the following additional package via `pip`. Note that the URL points to the bleeding edge version of `mne_bids`: -`$ pip install datalad` `$ pip install https://github.com/mne-tools/mne-bids/zipball/main` -To get the test data, you need to install `git-annex` on your system. If you -installed MNE-Python via `conda`, you can simply call: - -`conda install -c conda-forge git-annex` - Now, get the pipeline through git: `$ git clone https://github.com/mne-tools/mne-bids-pipeline.git` @@ -32,9 +26,9 @@ If you do not know how to use git, download the pipeline as a zip file Finally, for source analysis you'll also need `FreeSurfer`, follow the instructions on [their website](https://surfer.nmr.mgh.harvard.edu/). -# Testing +## Testing -## Running the tests, and continuous integration +### Running the tests, and continuous integration The tests are run using `pytest`. You can run them by calling `pytest mne_bids_pipeline` to run @@ -48,7 +42,7 @@ For every pull request or merge into the `main` branch of the [CircleCI](https://circleci.com/gh/brainthemind/CogBrainDyn_MEG_Pipeline) will run tests as defined in `./circleci/config.yml`. -## Debugging +### Debugging To run the test in debugging mode, just pass `--pdb` to the `pytest` call as usual. This will place you in debugging mode on failure. @@ -56,9 +50,9 @@ See the [pdb help](https://docs.python.org/3/library/pdb.html#debugger-commands) for more commands. -## Config files +### Config files -Nested in the `/tests` directory is a `/configs` directory, which contains +Nested in the `tests` directory is a `configs` directory, which contains config files for specific test datasets. For example, the `config_ds001810.py` file specifies parameters only for the `ds001810` data, which should overwrite the more general parameters in the main `_config.py` file. diff --git a/Makefile b/Makefile index 8af267201..3199402d6 100644 --- a/Makefile +++ b/Makefile @@ -23,8 +23,6 @@ doc: check: which python - git-annex version - datalad --version openneuro-py --version mri_convert --version mne_bids --version @@ -33,10 +31,6 @@ check: trailing-spaces: find . -name "*.py" | xargs perl -pi -e 's/[ \t]*$$//' -ruff: - ruff . - @echo "ruff passed" - codespell: # running manually; auto-fix spelling mistakes @codespell --write-changes $(CODESPELL_DIRS) diff --git a/README.md b/README.md index a18948e22..075e1e2b8 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Structure (BIDS)](https://bids.neuroimaging.io/). * 👣 Data processing as a sequence of standard processing steps. * ⏩ Steps are cached to avoid unnecessary recomputation. * ⏏️ Data can be "ejected" from the pipeline at any stage. No lock-in! -* ☁️ Runs on your laptop, on a powerful server, or on a high-performance cluster via Dash. +* ☁️ Runs on your laptop, on a powerful server, or on a high-performance cluster via Dask. @@ -44,7 +44,7 @@ developed for this publication: > M. Jas, E. Larson, D. A. Engemann, J. Leppäkangas, S. Taulu, M. Hämäläinen, > A. Gramfort (2018). A reproducible MEG/EEG group study with the MNE software: > recommendations, quality assessments, and good practices. Frontiers in -> neuroscience, 12. https://doi.org/10.3389/fnins.2018.00530 +> neuroscience, 12. The current iteration is based on BIDS and relies on the extensions to BIDS for EEG and MEG. See the following two references: @@ -52,10 +52,10 @@ for EEG and MEG. See the following two references: > Pernet, C. R., Appelhoff, S., Gorgolewski, K. J., Flandin, G., > Phillips, C., Delorme, A., Oostenveld, R. (2019). EEG-BIDS, an extension > to the brain imaging data structure for electroencephalography. Scientific -> Data, 6, 103. https://doi.org/10.1038/s41597-019-0104-8 +> Data, 6, 103. > Niso, G., Gorgolewski, K. J., Bock, E., Brooks, T. L., Flandin, G., Gramfort, A., > Henson, R. N., Jas, M., Litvak, V., Moreau, J., Oostenveld, R., Schoffelen, J., > Tadel, F., Wexler, J., Baillet, S. (2018). MEG-BIDS, the brain imaging data > structure extended to magnetoencephalography. Scientific Data, 5, 180110. -> https://doi.org/10.1038/sdata.2018.110 +> diff --git a/docs/build-docs.sh b/docs/build-docs.sh index ccb159aae..6e8154b36 100755 --- a/docs/build-docs.sh +++ b/docs/build-docs.sh @@ -10,6 +10,10 @@ python $STEP_DIR/source/examples/gen_examples.py echo "Generating pipeline table …" python $STEP_DIR/source/features/gen_steps.py +echo "Generating config docs …" +python $STEP_DIR/source/settings/gen_settings.py + echo "Building the documentation …" cd $STEP_DIR -mkdocs build +mkdocs build --strict + diff --git a/docs/hooks.py b/docs/hooks.py index ab7192f6e..41d73d3d4 100644 --- a/docs/hooks.py +++ b/docs/hooks.py @@ -1,19 +1,31 @@ +"""Custom hooks for MkDocs-Material.""" + import logging -from typing import Dict, Any +from typing import Any from mkdocs.config.defaults import MkDocsConfig -from mkdocs.structure.pages import Page from mkdocs.structure.files import Files +from mkdocs.structure.pages import Page + +from mne_bids_pipeline._docs import _ParseConfigSteps logger = logging.getLogger("mkdocs") config_updated = False +# This hack can be cleaned up once this is resolved: +# https://github.com/mkdocstrings/mkdocstrings/issues/615#issuecomment-1971568301 +def on_pre_build(config: MkDocsConfig) -> None: + """Monkey patch mkdocstrings-python jinja template to have global vars.""" + python_handler = config["plugins"]["mkdocstrings"].get_handler("python") + python_handler.env.globals["pipeline_steps"] = _ParseConfigSteps() + + # Ideally there would be a better hook, but it's unclear if context can # be obtained any earlier def on_template_context( - context: Dict[str, Any], + context: dict[str, Any], template_name: str, config: MkDocsConfig, ) -> None: @@ -46,6 +58,7 @@ def on_page_markdown( config: MkDocsConfig, files: Files, ) -> str: + """Replace emojis.""" if page.file.name == "index" and page.title == "Home": for rd, md in _EMOJI_MAP.items(): markdown = markdown.replace(rd, md) diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 8763aa9c0..3e28c5314 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -21,12 +21,10 @@ theme: - content.code.copy # copy button palette: # Palette toggle for automated theme selection - # Currently only available to sponsors - # - # - media: "(prefers-color-scheme)" - # toggle: - # icon: material/brightness-auto - # name: Switch to light mode + - media: "(prefers-color-scheme)" + toggle: + icon: material/brightness-auto + name: Switch to system preference # Palette toggle for light mode - media: "(prefers-color-scheme: light)" @@ -34,14 +32,14 @@ theme: primary: white toggle: icon: material/brightness-7 - name: Switch to dark mode + name: Switch to light mode # Palette toggle for dark mode - media: "(prefers-color-scheme: dark)" scheme: slate toggle: icon: material/brightness-4 - name: Switch to system preference + name: Switch to dark mode custom_dir: overrides extra: @@ -69,6 +67,8 @@ extra_css: extra_javascript: - https://unpkg.com/tablesort@5.3.0/dist/tablesort.min.js - javascripts/tablesort.js +not_in_nav: | + /governance.md nav: - Home: index.md - Getting started: @@ -77,7 +77,7 @@ nav: - Preparations for source-level analyses: getting_started/freesurfer.md - Processing steps: - Overview: features/overview.md - - Detailed list of processing steps: features/steps.md + - List of processing steps: features/steps.md - Configuration options: - General settings: settings/general.md - Preprocessing: @@ -90,7 +90,7 @@ nav: - Epoching: settings/preprocessing/epochs.md - Artifact removal: - Stimulation artifact: settings/preprocessing/stim_artifact.md - - SSP & ICA: settings/preprocessing/ssp_ica.md + - SSP, ICA, and artifact regression: settings/preprocessing/ssp_ica.md - Amplitude-based artifact rejection: settings/preprocessing/artifacts.md - Sensor-level analysis: - Condition contrasts: settings/sensor/contrasts.md @@ -103,6 +103,11 @@ nav: - Source space & forward solution: settings/source/forward.md - Inverse solution: settings/source/inverse.md - Report generation: settings/reports/report_generation.md + - Caching: settings/caching.md + - Parallelization: settings/parallelization.md + - Logging: settings/logging.md + - Error handling: settings/error_handling.md + - Examples: - Examples Gallery: examples/examples.md - examples/ds003392.md @@ -116,6 +121,8 @@ nav: - examples/ds000248_no_mri.md - examples/ds003104.md - examples/eeg_matchingpennies.md + - examples/MNE-funloc-data.md + - examples/MNE-phantom-KIT-data.md - examples/ds001810.md - examples/ds000117.md - examples/ds003775.md @@ -128,8 +135,7 @@ nav: plugins: - search - macros - - tags: - tags_file: tags.md + - privacy # https://squidfunk.github.io/mkdocs-material/plugins/privacy/ - include-markdown - exclude: glob: @@ -137,6 +143,8 @@ plugins: - "*.inc" # includes - mkdocstrings: default_handler: python + enable_inventory: true + custom_templates: templates handlers: python: paths: # Where to find the packages and modules to import @@ -148,9 +156,14 @@ plugins: show_root_toc_entry: false show_root_full_path: false separate_signature: true + show_signature_annotations: true + unwrap_annotated: true + signature_crossrefs: true line_length: 88 # Black's default show_bases: false docstring_style: numpy + inventories: + - https://mne.tools/dev/objects.inv - mike: canonical_version: stable markdown_extensions: @@ -181,3 +194,13 @@ markdown_extensions: permalink: true # Add paragraph symbol to link to current headline - pymdownx.tabbed: alternate_style: true + +# https://www.mkdocs.org/user-guide/configuration/#validation +validation: + # Everything set to "warn" will cause an error when building in strict mode with + # mkdocs build --strict + omitted_files: warn + not_found: warn + absolute_links: warn + unrecognized_links: warn + anchors: warn diff --git a/docs/source/.gitignore b/docs/source/.gitignore index 77afb012b..b909cc5d3 100644 --- a/docs/source/.gitignore +++ b/docs/source/.gitignore @@ -1 +1,4 @@ features/steps.md +features/overview.md +settings/**/ +settings/*.md diff --git a/docs/source/changes.md b/docs/source/changes.md index 3c38f18da..a4c2ae204 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -1,3 +1,13 @@ +{% include-markdown "./dev.md.inc" %} + +{% include-markdown "./v1.9.md.inc" %} + +{% include-markdown "./v1.8.md.inc" %} + +{% include-markdown "./v1.7.md.inc" %} + +{% include-markdown "./v1.6.md.inc" %} + {% include-markdown "./v1.5.md.inc" %} {% include-markdown "./v1.4.md.inc" %} diff --git a/docs/source/dev.md.inc b/docs/source/dev.md.inc new file mode 100644 index 000000000..ca7e7f241 --- /dev/null +++ b/docs/source/dev.md.inc @@ -0,0 +1,54 @@ +## v1.10.0 (unreleased) + +### :new: New features & enhancements + +- It is now possible to use separate MRIs for each session within a subject, as in longitudinal studies. This is achieved by creating separate "subject" folders for each subject-session combination, with the naming convention `sub-XXX_ses-YYY`, in the freesurfer `SUBJECTS_DIR`. (#987 by @drammock) +- New config option [`allow_missing_sessions`][mne_bids_pipeline._config.allow_missing_sessions] allows to continue when not all sessions are present for all subjects. (#1000 by @drammock) +- New config option [`mf_extra_kws`][mne_bids_pipeline._config.mf_extra_kws] passes additional keyword arguments to `mne.preprocessing.maxwell_filter`. (#1038 by @drammock) +- New value `"twa"` for config option [`mf_destination`][mne_bids_pipeline._config.mf_destination], to use the time-weighted average head position across runs as the destination position. (#1043 and #1055 by @drammock) +- New config options [`mf_cal_missing`][mne_bids_pipeline._config.mf_cal_missing] and [`mf_ctc_missing`][mne_bids_pipeline._config.mf_ctc_missing] for handling missing calibration and cross-talk files (#1057 by @harrisonritz) +- New config options [`find_bad_channels_extra_kws`][mne_bids_pipeline._config.find_bad_channels_extra_kws], [`notch_extra_kws`][mne_bids_pipeline._config.notch_extra_kws], and [`bandpass_extra_kws`][mne_bids_pipeline._config.bandpass_extra_kws] to pass additional keyword arguments to `mne.preprocessing.find_bad_channels_maxwell`, `mne.filter.notch_filter`, and `mne.filter.filter_data` respectively (#1061 by @harrisonritz) +- Config option [`ssp_ecg_channel`][mne_bids_pipeline._config.ssp_ecg_channel] now allows dict values, for setting a different channel name for each subject/session (#1062 by @drammock) +- New config option [`epochs_custom_metadata`][mne_bids_pipeline._config.epochs_custom_metadata] allows for custom metadata when creating epochs. (#1088 by @harrisonritz) + +### :warning: Behavior changes + +- The pipeline will now raise an error if a loaded `SourceSpaces` object has a `._subject` attribute different from what the pipeline expects / would have used if creating the `SourceSpaces` anew. (#1056 by @drammock) + +[//]: # (- Whatever (#000 by @whoever)) + +[//]: # (### :package: Requirements) + +[//]: # (- Whatever (#000 by @whoever)) + +### :bug: Bug fixes + +- Empty room matching is now done for all sessions (previously only for the first session) for each subject. (#976 by @drammock) +- [`noise_cov_method`][mne_bids_pipeline._config.noise_cov_method] is now properly used for noise covariance estimation from raw data (#1010 by @larsoner) +- When running the pipeline with [`mf_filter_chpi`][mne_bids_pipeline._config.mf_filter_chpi] enabled (#977 by @drammock and @larsoner): + + 1. Emptyroom files that lack cHPI channels will now be processed (for line noise only) instead of raising an error. + 2. cHPI filtering is now performed before movement compensation. + +- Fix bug where the `config.proc` parameter was not used properly during forward model creation (#1014 by @larsoner) +- Fix bug where emptyroom recordings containing EEG channels would crash the pipeline during maxwell filtering (#1040 by @drammock) +- Fix bug where only having mag sensors would crash compute_rank during maxwell filtering or epoching (#1061 and #1069 by @harrisonritz) +- Improvements to template config file generation (#1074 by @drammock) +- Fix bug where `mf_int_order` wasn't passed to `maxwell_filter`. Added config option for `mf_ext_order`. (#1092 by @harrisonritz) +- Sanitize report tags that contain `"` or `'`, e.g., for certain metadata contrasts (#1097 by @harrisonritz) + +### :books: Documentation + +- Choose the theme (dark of light) automatically based on the user's operating system setting (#979 by @hoechenberger) +- Bundle all previously-external JavaScript to better preserve users' privacy (#982 by @hoechenberger) +- Document the need for offscreen rendering support when running on headless servers (#997 by @drammock) + +### :medical_symbol: Code health and infrastructure + +- Switch from using relative to using absolute imports (#969 by @hoechenberger) +- Enable strict type checking via mypy (#995, #1013, #1016 by @larsoner) +- Improve logging messages in maxwell filtering steps. (#893 by @drammock) +- Validate extra config params passed in during testing. (#1044 by @drammock) +- New testing/example dataset "funloc" added. (#1045 by @drammock) +- Bugfixes and better testing of session-specific MRIs. (#1039 and #1067 by @drammock) +- Drop legacy function `inst.pick_types` in favor of `inst.pick` (#1073 by @PierreGtch) diff --git a/docs/source/doc-config.py b/docs/source/doc-config.py deleted file mode 100644 index e3846c3a7..000000000 --- a/docs/source/doc-config.py +++ /dev/null @@ -1,2 +0,0 @@ -bids_root = "/tmp" -ch_types = ["meg"] diff --git a/docs/source/examples/gen_examples.py b/docs/source/examples/gen_examples.py index f24c0d29f..79a521401 100755 --- a/docs/source/examples/gen_examples.py +++ b/docs/source/examples/gen_examples.py @@ -1,26 +1,30 @@ #!/usr/bin/env python -from collections import defaultdict +"""Generate documentation pages for our examples gallery.""" + import contextlib import logging import shutil -from pathlib import Path import sys -from typing import Union, Iterable +from collections import defaultdict +from collections.abc import Generator, Iterable +from pathlib import Path +from typing import Any + +from tqdm import tqdm import mne_bids_pipeline -from mne_bids_pipeline._config_import import _import_config import mne_bids_pipeline.tests.datasets +from mne_bids_pipeline._config_import import _import_config +from mne_bids_pipeline.tests.datasets import DATASET_OPTIONS, DATASET_OPTIONS_T from mne_bids_pipeline.tests.test_run import TEST_SUITE -from mne_bids_pipeline.tests.datasets import DATASET_OPTIONS -from tqdm import tqdm this_dir = Path(__file__).parent root = Path(mne_bids_pipeline.__file__).parent.resolve(strict=True) logger = logging.getLogger() -def _bool_to_icon(x: Union[bool, Iterable]) -> str: +def _bool_to_icon(x: bool | Iterable[Any]) -> str: if x: return "✅" else: @@ -28,7 +32,7 @@ def _bool_to_icon(x: Union[bool, Iterable]) -> str: @contextlib.contextmanager -def _task_context(task): +def _task_context(task: str | None) -> Generator[None, None, None]: old_argv = sys.argv if task: sys.argv = [sys.argv[0], f"--task={task}"] @@ -38,7 +42,7 @@ def _task_context(task): sys.argv = old_argv -def _gen_demonstrated_funcs(example_config_path: Path) -> dict: +def _gen_demonstrated_funcs(example_config_path: Path) -> dict[str, bool]: """Generate dict of demonstrated functionality based on config.""" # Here we use a defaultdict, and for keys that might vary across configs # we should use an `funcs[key] = funcs[key] or ...` so that we effectively @@ -61,6 +65,8 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: key = "Maxwell filter" funcs[key] = funcs[key] or config.use_maxwell_filter funcs["Frequency filter"] = config.l_freq or config.h_freq + key = "Artifact regression" + funcs[key] = funcs[key] or (config.regress_artifact is not None) key = "SSP" funcs[key] = funcs[key] or (config.spatial_filter == "ssp") key = "ICA" @@ -104,6 +110,18 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: datasets_without_html.append(dataset_name) continue + # For ERP_CORE, cut down on what we show otherwise our website is huge + if "ERP_CORE" in test_name: + show = ["015", "average"] + orig_fnames = html_report_fnames + html_report_fnames = [ + f + for f in html_report_fnames + if any(f"sub-{s}" in f.parts and f"sub-{s}" in f.name for s in show) + ] + assert len(html_report_fnames), orig_fnames + del orig_fnames + fname_iter = tqdm( html_report_fnames, desc=f" {test_name}", @@ -133,7 +151,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: ) if dataset_name in all_demonstrated: logger.warning( - f"Duplicate dataset name {test_dataset_name} -> {dataset_name}, " "skipping" + f"Duplicate dataset name {test_dataset_name} -> {dataset_name}, skipping" ) continue del test_dataset_options, test_dataset_name @@ -142,7 +160,10 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: logger.warning(f"Dataset {dataset_name} has no HTML report.") continue - options = DATASET_OPTIONS[dataset_options_key].copy() # we modify locally + assert dataset_options_key in DATASET_OPTIONS, dataset_options_key + options: DATASET_OPTIONS_T = DATASET_OPTIONS[ + dataset_options_key + ].copy() # we modify locally report_str = "\n## Generated output\n\n" example_target_dir = this_dir / dataset_name @@ -198,30 +219,26 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: f"{fname.name} :fontawesome-solid-square-poll-vertical:\n\n" ) - assert sum(key in options for key in ("openneuro", "git", "web", "datalad")) == 1 + assert sum(key in options for key in ("openneuro", "web", "mne")) == 1 if "openneuro" in options: - url = f'https://openneuro.org/datasets/{options["openneuro"]}' - elif "git" in options: - url = options["git"] + url = f"https://openneuro.org/datasets/{options['openneuro']}" elif "web" in options: url = options["web"] else: - assert "datalad" in options # guaranteed above - url = "" + assert "mne" in options + url = f"https://mne.tools/dev/generated/mne.datasets.{options['mne']}.data_path.html" # noqa: E501 - source_str = ( - f"## Dataset source\n\nThis dataset was acquired from " f"[{url}]({url})\n" - ) + source_str = f"## Dataset source\n\nThis dataset was acquired from [{url}]({url})\n" if "openneuro" in options: - for key in ("include", "exclude"): - options[key] = options.get(key, []) + options.setdefault("include", []) + options.setdefault("exclude", []) download_str = ( f'\n??? example "How to download this dataset"\n' f" Run in your terminal:\n" f" ```shell\n" f" openneuro-py download \\\n" - f' --dataset={options["openneuro"]} \\\n' + f" --dataset={options['openneuro']} \\\n" ) for count, include in enumerate(options["include"], start=1): download_str += f" --include={include}" @@ -244,7 +261,9 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: # TODO: For things like ERP_CORE_ERN, decoding_csp are not populated # properly by the root config - config_path = root / "tests" / "configs" / f"config_{dataset_name}.py" + config_path = ( + root / "tests" / "configs" / f"config_{dataset_name.replace('-', '_')}.py" + ) config = config_path.read_text(encoding="utf-8-sig").strip() descr_end_idx = config[2:].find('"""') config_descr = "# " + config[: descr_end_idx + 1].replace('"""', "").strip() @@ -279,6 +298,7 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: f.write(download_str) f.write(config_str) f.write(report_str) + del dataset_name, funcs # Finally, write our examples.html file with a table of examples @@ -299,13 +319,13 @@ def _gen_demonstrated_funcs(example_config_path: Path) -> dict: with out_path.open("w", encoding="utf-8") as f: f.write(_example_header) header_written = False - for dataset_name, funcs in all_demonstrated.items(): + for this_dataset_name, these_funcs in all_demonstrated.items(): if not header_written: - f.write("Dataset | " + " | ".join(funcs.keys()) + "\n") - f.write("--------|" + "|".join([":---:"] * len(funcs)) + "\n") + f.write("Dataset | " + " | ".join(these_funcs.keys()) + "\n") + f.write("--------|" + "|".join([":---:"] * len(these_funcs)) + "\n") header_written = True f.write( - f"[{dataset_name}]({dataset_name}.md) | " - + " | ".join(_bool_to_icon(v) for v in funcs.values()) + f"[{this_dataset_name}]({this_dataset_name}.md) | " + + " | ".join(_bool_to_icon(v) for v in these_funcs.values()) + "\n" ) diff --git a/docs/source/features/gen_steps.py b/docs/source/features/gen_steps.py index fffc61ddf..b24803790 100755 --- a/docs/source/features/gen_steps.py +++ b/docs/source/features/gen_steps.py @@ -3,32 +3,125 @@ import importlib from pathlib import Path + from mne_bids_pipeline._config_utils import _get_step_modules -pre = """\ -# Detailed lis of processing steps +autogen_header = f"""\ +[//]: # (AUTO-GENERATED, TO CHANGE EDIT {"/".join(Path(__file__).parts[-4:])}) +""" + +steps_pre = f"""\ +{autogen_header} + +# List of processing steps The following table provides a concise summary of each processing step. The step names can be used to run individual steps or entire groups of steps by -passing their name(s) to `mne_bids_pipeline` via the `steps=...` argument. +passing their name(s) to `mne_bids_pipeline` via the `steps=...` argument. However, +we recommend using `mne_bids_pipeline config.py` to run the entire pipeline +instead to ensure that all steps affected by a given change are re-run. +""" # noqa: E501 + +overview_pre = f"""\ +{autogen_header} + +# MNE-BIDS-Pipeline overview + +MNE-BIDS-Pipeline processes your data in a sequential manner, i.e., one step +at a time. The next step is only run after the previous steps have been +successfully completed. There are, of course, exceptions; for example, if you +chose not to apply ICA or SSP, the spatial filtering steps will simply be omitted and +we'll directly move to the subsequent steps. See [the flowchart below](#flowchart) for +a visualization of the steps, or check out the +[list of processing steps](steps.md) for more information. + +All intermediate results are saved to disk for later inspection, and an +**extensive report** is generated. Analyses are conducted on individual (per-subject) +as well as group level. + +## Caching + +MNE-BIDS-Pipeline offers automated caching of intermediate results. This means that +running `mne_bids_pipeline config.py` once will generate all outputs, and running it +again will only re-run the steps that need rerunning based on: + +1. Changes to files on disk (e.g., updates to `bids_root` files), and +2. Changes to `config.py` + +This is particularly useful when you are developing your pipeline, as you can quickly +iterate over changes to your pipeline without having to re-run the entire pipeline +every time -- only the steps that need to be re-run will be executed. + +## Flowchart + +For more detailed information on each step, please refer to the [detailed list +of processing steps](steps.md). """ +icon_map = { + "Filesystem initialization and dataset inspection": ":open_file_folder:", + "Preprocessing": ":broom:", + "Sensor-space analysis": ":satellite:", + "Source-space analysis": ":brain:", + "FreeSurfer-related processing": ":person_surfing:", +} +out_dir = Path(__file__).parent + print("Generating steps …") step_modules = _get_step_modules() +char_start = ord("A") + +# In principle we could try to sort this out based on naming, but for now let's just +# set our hierarchy manually and update it when we move files around since that's easy +# (and rare) enough to do. +manual_order = { + "Preprocessing": ( + ("01", "02"), + ("02", "03"), + ("03", "04"), + ("04", "05"), + ("05", "06a1"), + ("06a1", "06a2"), + ("05", "06b"), + ("05", "07"), + # technically we could have the raw data flow here, but it doesn't really help + # ("05", "08a"), + # ("05", "08b"), + ("06a2", "08a"), + # Force the artifact-fitting and epoching steps on the same level, in this order + """\ + subgraph Z[" "] + direction LR + B06a1 + B07 + B06b + end + style Z fill:#0000,stroke-width:0px +""", + ("06b", "08b"), + ("07", "08a"), + ("07", "08b"), + ("08a", "09"), + ("08b", "09"), + ), +} # Construct the lines of steps.md -lines = [pre] +lines = [steps_pre] +overview_lines = [overview_pre] +used_titles = set() for di, (dir_, modules) in enumerate(step_modules.items(), 1): + # Steps if dir_ == "all": continue # this is an alias dir_module = importlib.import_module(f"mne_bids_pipeline.steps.{dir_}") + assert dir_module.__doc__ is not None dir_header = dir_module.__doc__.split("\n")[0].rstrip(".") - dir_body = dir_module.__doc__.split("\n", maxsplit=1) - if len(dir_body) > 1: - dir_body = dir_body[1].strip() - else: - dir_body = "" - lines.append(f"## {di}. {dir_header}\n") + dir_body_list = dir_module.__doc__.split("\n", maxsplit=1) + dir_body = dir_body_list[1].strip() if len(dir_body_list) > 1 else "" + icon = icon_map[dir_header] + module_header = f"{di}. {icon} {dir_header}" + lines.append(f"## {module_header}\n") if dir_body: lines.append(f"{dir_body}\n") lines.append("| Step name | Description |") @@ -37,9 +130,77 @@ dir_name, step_title = dir_, f"Run all {dir_header.lower()} steps." lines.append(f"`{dir_name}` | {step_title} |") for module in modules: + assert module.__file__ is not None + assert module.__doc__ is not None step_name = f"{dir_name}/{Path(module.__file__).name}"[:-3] step_title = module.__doc__.split("\n")[0] lines.append(f"`{step_name}` | {step_title} |") lines.append("") -with open(Path(__file__).parent / "steps.md", "w") as fid: - fid.write("\n".join(lines)) + + # Overview + overview_lines.append( + f"""\ +### {module_header} + +
+Click to expand + +```mermaid +flowchart TD""" + ) + chr_pre = chr(char_start + di - 1) # A, B, C, ... + start = None + prev_idx = None + title_map = {} + for mi, module in enumerate(modules, 1): + assert module.__doc__ is not None + assert module.__name__ is not None + step_title = module.__doc__.split("\n")[0].rstrip(".") + idx = module.__name__.split(".")[-1].split("_")[1] # 01, 05a, etc. + # Need to quote the title to deal with parens, and sanitize quotes + step_title = step_title.replace('"', "'") + assert step_title not in used_titles, f"Redundant title: {step_title}" + used_titles.add(step_title) + this_block = f'{chr_pre}{idx}["{step_title}"]' + # special case: manual order + title_map[idx] = step_title + if dir_header in manual_order: + continue + if mi == 1: + start = this_block + assert prev_idx is None + continue + if start is not None: + assert mi == 2, mi + overview_lines.append(f" {start} --> {this_block}") + start = None + else: + overview_lines.append(f" {chr_pre}{prev_idx} --> {this_block}") + prev_idx = idx + if dir_header in manual_order: + mapped = set() + for a_b in manual_order[dir_header]: + if isinstance(a_b, str): # insert directly + overview_lines.append(a_b) + continue + assert isinstance(a_b, tuple), type(a_b) + a_b_list: list[str] = list(a_b) # allow modification + del a_b + for ii, idx in enumerate(a_b_list): + assert idx in title_map, (dir_header, idx, sorted(title_map)) + if idx not in mapped: + mapped.add(idx) + a_b_list[ii] = f'{idx}["{title_map[idx]}"]' + overview_lines.append( + f" {chr_pre}{a_b_list[0]} --> {chr_pre}{a_b_list[1]}" + ) + all_steps_list: list[str] = list() + for a_b in manual_order[dir_header]: + if not isinstance(a_b, str): + all_steps_list.extend(a_b) + all_steps = set(all_steps_list) + assert mapped == all_steps, all_steps.symmetric_difference(mapped) + overview_lines.append("```\n\n
\n") + +(out_dir / "steps.md").write_text("\n".join(lines), encoding="utf8") +(out_dir / "overview.md").write_text("\n".join(overview_lines), encoding="utf8") diff --git a/docs/source/features/overview.md b/docs/source/features/overview.md deleted file mode 100644 index 9fe044038..000000000 --- a/docs/source/features/overview.md +++ /dev/null @@ -1,53 +0,0 @@ -MNE-BIDS-Pipeline processes your data in a sequential manner, i.e., one step -at a time. The next step is only run after the previous steps have been -successfully completed. There are, of course, exceptions; for example, if you -chose not to apply ICA, the respective steps will simply be omitted and we'll -directly move to the subsequent steps. The following flow chart aims to give -you a brief overview of which steps are included in the pipeline, in which -order they are run, and how we group them together. - -!!! info - All intermediate results are saved to disk for later - inspection, and an **extensive report** is generated. - -!!! info - Analyses are conducted on individual (per-subject) as well as group level. - - -## :open_file_folder: Filesystem initialization and dataset inspection -```mermaid -flowchart TD - A1[initialize the target directories] --> A2[locate empty-room recordings] -``` - -## :broom: Preprocessing -```mermaid - flowchart TD - B1[Noisy & flat channel detection] --> B2[Maxwell filter] - B2 --> B3[Frequency filter] - B3 --> B4[Epoch creation] - B4 --> B5[SSP or ICA fitting] - B5 --> B6[Artifact removal via SSP or ICA] - B6 --> B7[Amplitude-based epoch rejection] -``` - -## :satellite: Sensor-space processing -```mermaid - flowchart TD - C1[ERP / ERF calculation] --> C2[MVPA: full epochs] - C2 --> C3[MVPA: time-by-time decoding] - C3 --> C4[Time-frequency decomposition] - C4 --> C5[MVPA: CSP] - C5 --> C6[Noise covariance estimation] - C6 --> C7[Grand average] -``` - -## :brain: Source-space processing -```mermaid - flowchart TD - D1[BEM surface creation] --> D2[BEM solution] - D2 --> D3[Source space creation] - D3 --> D4[Forward model creation] - D4 --> D5[Inverse solution] - D5 --> D6[Grand average] -``` diff --git a/docs/source/getting_started/basic_usage.md b/docs/source/getting_started/basic_usage.md index c155ad680..192b799db 100644 --- a/docs/source/getting_started/basic_usage.md +++ b/docs/source/getting_started/basic_usage.md @@ -18,7 +18,7 @@ We recommend that MNE-BIDS provides a convenient way to visually inspect raw data and interactively mark problematic channels as bad by using the command ```shell - mne-bids inspect + mne_bids inspect ``` Please see the MNE-BIDS documentation for more information. diff --git a/docs/source/getting_started/freesurfer.md b/docs/source/getting_started/freesurfer.md index 68d3e3b1d..0bec62aac 100644 --- a/docs/source/getting_started/freesurfer.md +++ b/docs/source/getting_started/freesurfer.md @@ -2,7 +2,17 @@ Preparations for inverse modeling involve the installation of [FreeSurfer](https://surfer.nmr.mgh.harvard.edu/fswiki/). If you do not intend to run the source reconstruction steps of MNE-BIDS-Pipeline, - you can skip the instructions below. + you can skip the instructions on this page. + + Additionally, to visualize the source reconstruction steps and add those images to + the pipeline reports, the pipeline must be run on a machine capable of 3D rendering + (i.e., not a headless server). If you must use a headless server, you should either + prefix the pipeline command with `xvfb-run` (to use + [xvfb](https://www.x.org/releases/X11R7.6/doc/man/man1/Xvfb.1.xhtml) to emulate a + display), or alternatively you can install a different version of one of our + dependencies (the 3D library `vtk`) that does offscreen rendering (see + [these instructions](https://mne.tools/stable/install/advanced.html#installing-to-a-headless-linux-server) + on the MNE-Python website). !!! warning FreeSurfer does not natively run on Windows. We are currently working on @@ -41,7 +51,7 @@ nstructions](https://surfer.nmr.mgh.harvard.edu/fswiki/rel6downloads). ## :brain: Generate surfaces and brain parcellation MNE-BIDS-Pipeline provides a convenient way to invoke FreeSurfer. After -[adjusting your configuration file](basic_usage.md#adjust-your-configuration-file), +[adjusting your configuration file](../settings/general.md), invoke FreeSurfer via in the following way: ```shell @@ -70,3 +80,4 @@ creation. *[FLASH]: Fast low angle shot *[MRI]: Magnetic resonance imaging *[BEM]: Boundary element model +*[xvfb]: X virtual framebuffer \ No newline at end of file diff --git a/docs/source/settings/gen_settings.py b/docs/source/settings/gen_settings.py new file mode 100755 index 000000000..cd5d264c1 --- /dev/null +++ b/docs/source/settings/gen_settings.py @@ -0,0 +1,202 @@ +"""Generate settings .md files.""" + +# Any changes to the overall structure need to be reflected in mkdocs.yml nav section. + +import re +from pathlib import Path + +from tqdm import tqdm + +import mne_bids_pipeline._config + +config_path = Path(mne_bids_pipeline._config.__file__) +settings_dir = Path(__file__).parent + +# Mapping between first two lower-case words in the section name and the desired +# file or folder name +section_to_file = { # .md will be added to the files + # root file + "general settings": "general", + # folder + "preprocessing": "preprocessing", + "break detection": "breaks", + "bad channel": "autobads", + "maxwell filter": "maxfilter", + "filtering": "filter", + "resampling": "resample", + "epoching": "epochs", + "filtering &": None, # just a header + "artifact removal": None, + "stimulation artifact": "stim_artifact", + "ssp, ica,": "ssp_ica", + "amplitude-based artifact": "artifacts", + # folder + "sensor-level analysis": "sensor", + "condition contrasts": "contrasts", + "decoding /": "mvpa", + "time-frequency analysis": "time_freq", + "group-level analysis": "group_level", + # folder + "source-level analysis": "source", + "general source": "general", + "bem surface": "bem", + "source space": "forward", + "inverse solution": "inverse", + # folder + "reports": "reports", + "report generation": "report_generation", + # root file + "caching": "caching", + # root file + "parallelization": "parallelization", + # root file + "logging": "logging", + # root file + "error handling": "error_handling", +} +# TODO: Make sure these are consistent, autogenerate some based on section names, +# and/or autogenerate based on inputs/outputs of actual functions. +section_tags = { + "general settings": (), + "preprocessing": (), + "filtering &": (), + "artifact removal": (), + "break detection": ("preprocessing", "artifact-removal", "raw", "events"), + "bad channel": ("preprocessing", "raw", "bad-channels"), + "maxwell filter": ("preprocessing", "maxwell-filter", "raw"), + "filtering": ("preprocessing", "frequency-filter", "raw"), + "resampling": ("preprocessing", "resampling", "decimation", "raw", "epochs"), + "epoching": ("preprocessing", "epochs", "events", "metadata", "resting-state"), + "stimulation artifact": ("preprocessing", "artifact-removal", "raw", "epochs"), + "ssp, ica,": ("preprocessing", "artifact-removal", "raw", "epochs", "ssp", "ica"), + "amplitude-based artifact": ("preprocessing", "artifact-removal", "epochs"), + "sensor-level analysis": (), + "condition contrasts": ("epochs", "evoked", "contrast"), + "decoding /": ("epochs", "evoked", "contrast", "decoding", "mvpa"), + "time-frequency analysis": ("epochs", "evoked", "time-frequency"), + "group-level analysis": ("evoked", "group-level"), + "source-level analysis": (), + "general source": ("inverse-solution",), + "bem surface": ("inverse-solution", "bem", "freesurfer"), + "source space": ("inverse-solution", "forward-model"), + "inverse solution": ("inverse-solution",), + "reports": (), + "report generation": ("report",), + "caching": ("cache",), + "parallelization": ("paralleliation", "dask", "out-of-core"), + "logging": ("logging", "error-handling"), + "error handling": ("error-handling",), +} + +extra_headers = { + "general settings": """\ +!!! info + Many settings in this section control the pipeline behavior very early in the + pipeline. Therefore, for most of them (e.g., `bids_root`) we do not list the + steps that directly depend on the setting. The options with drop-down step + lists (e.g., `random_state`) have more localized effects. +""" +} + +option_header = """\ +::: mne_bids_pipeline._config + options: + members:""" +prefix = """\ + - """ + +# We cannot use ast for this because it doesn't preserve comments. We could use +# something like redbaron, but our code is hopefully simple enough! +assign_re = re.compile( + "^" # The line starts, then is followed by + r"(\w+): " # annotation syntax (name captured by the first group), + "(?:" # then the rest of the line can be (in a non-capturing group): + ".+ = .+" # 1. a standard assignment + "|" # 2. or + r"Literal\[" # 3. the start of a multiline type annotation like "a: Literal[" + "|" # 4. or + r"\(" # 5. the start of a multiline 3.9+ type annotation like "a: (" + ")" # Then the end of our group + "$", # and immediately the end of the line. + re.MULTILINE, +) + + +def main() -> None: + """Parse the configuration and generate the markdown documentation.""" + print(f"Parsing {config_path} to generate settings .md files.") + # max file-level depth is 2 even though we have 3 subsection levels + levels = ["", ""] + current_path: Path | None = None + current_lines: list[str] = list() + text = config_path.read_text("utf-8") + lines = text.splitlines() + lines += ["# #"] # add a dummy line to trigger the last write + in_header = False + have_params = False + for li, line in enumerate(tqdm(lines)): + line = line.rstrip() + if line.startswith("# #"): # a new (sub)section / file + this_def = line[2:] + this_level_str = this_def.split()[0] + assert this_level_str.count("#") == len(this_level_str), this_level_str + this_level: int = this_level_str.count("#") - 1 + if this_level == 2: + # flatten preprocessing/filtering/filter to preprocessing/filter + # for example + this_level = 1 + assert this_level in (0, 1), (this_level, this_def) + this_def = this_def[this_level + 2 :] + levels[this_level] = this_def + # Write current lines and reset + if have_params: # more than just the header + assert current_path is not None, levels + if current_lines[0] == "": # this happens with tags + current_lines = current_lines[1:] + current_path.write_text("\n".join(current_lines + [""]), "utf-8") + have_params = False + if this_level == 0: + this_root = settings_dir + else: + this_root = settings_dir / f"{section_to_file[levels[0].lower()]}" + this_root.mkdir(exist_ok=True) + key = " ".join(this_def.split()[:2]).lower() + if key == "": + assert li == len(lines) - 1, (li, line) + continue # our dummy line + fname = section_to_file[key] + if fname is None: + current_path = None + else: + current_path = this_root / f"{fname}.md" + in_header = True + current_lines = list() + if len(section_tags[key]): + current_lines += ["---", "tags:"] + current_lines += [f" - {tag}" for tag in section_tags[key]] + current_lines += ["---"] + if key in extra_headers: + current_lines.extend(["", extra_headers[key]]) + continue + + if in_header: + if line == "": + in_header = False + if current_lines: + current_lines.append("") + current_lines.append(option_header) + else: + assert line == "#" or line.startswith("# "), (li, line) # a comment + current_lines.append(line[2:]) + continue + + # Could be an option + match = assign_re.match(line) + if match is not None: + have_params = True + current_lines.append(f"{prefix}{match.groups()[0]}") + continue + + +if __name__ == "__main__": + main() diff --git a/docs/source/settings/general.md b/docs/source/settings/general.md deleted file mode 100644 index 2640f5f2b..000000000 --- a/docs/source/settings/general.md +++ /dev/null @@ -1,48 +0,0 @@ -::: mne_bids_pipeline._config - options: - members: - - study_name - - bids_root - - deriv_root - - subjects_dir - - interactive - - sessions - - task - - task_is_rest - - runs - - exclude_runs - - crop_runs - - acq - - proc - - rec - - space - - subjects - - exclude_subjects - - process_empty_room - - process_rest - - ch_types - - data_type - - eog_channels - - eeg_bipolar_channels - - eeg_reference - - eeg_template_montage - - drop_channels - - reader_extra_params - - read_raw_bids_verbose - - analyze_channels - - plot_psd_for_runs - - n_jobs - - parallel_backend - - dask_open_dashboard - - dask_temp_dir - - dask_worker_memory_limit - - random_state - - shortest_event - - memory_location - - memory_subdir - - memory_file_method - - memory_verbose - - config_validation - - log_level - - mne_log_level - - on_error diff --git a/docs/source/settings/preprocessing/artifacts.md b/docs/source/settings/preprocessing/artifacts.md deleted file mode 100644 index 88407cd2c..000000000 --- a/docs/source/settings/preprocessing/artifacts.md +++ /dev/null @@ -1,20 +0,0 @@ ---- -tags: - - preprocessing - - artifact-removal - - epochs ---- - -???+ info "Good Practice / Advice" - Have a look at your raw data and train yourself to detect a blink, a heart - beat and an eye movement. - You can do a quick average of blink data and check what the amplitude looks - like. - -::: mne_bids_pipeline._config - options: - members: - - reject - - reject_tmin - - reject_tmax - - autoreject_n_interpolate diff --git a/docs/source/settings/preprocessing/autobads.md b/docs/source/settings/preprocessing/autobads.md deleted file mode 100644 index a118917a1..000000000 --- a/docs/source/settings/preprocessing/autobads.md +++ /dev/null @@ -1,27 +0,0 @@ ---- -tags: - - preprocessing - - raw - - bad-channels ---- - -!!! warning - This functionality will soon be removed from the pipeline, and - will be integrated into MNE-BIDS. - -"Bad", i.e. flat and overly noisy channels, can be automatically detected -using a procedure inspired by the commercial MaxFilter by Elekta. First, -a copy of the data is low-pass filtered at 40 Hz. Then, channels with -unusually low variability are flagged as "flat", while channels with -excessively high variability are flagged as "noisy". Flat and noisy channels -are marked as "bad" and excluded from subsequent analysis. See -:func:`mne.preprocssessing.find_bad_channels_maxwell` for more information -on this procedure. The list of bad channels detected through this procedure -will be merged with the list of bad channels already present in the dataset, -if any. - -::: mne_bids_pipeline._config - options: - members: - - find_flat_channels_meg - - find_noisy_channels_meg diff --git a/docs/source/settings/preprocessing/breaks.md b/docs/source/settings/preprocessing/breaks.md deleted file mode 100644 index 01e3159eb..000000000 --- a/docs/source/settings/preprocessing/breaks.md +++ /dev/null @@ -1,15 +0,0 @@ ---- -tags: - - preprocessing - - artifact-removal - - raw - - events ---- - -::: mne_bids_pipeline._config - options: - members: - - find_breaks - - min_break_duration - - t_break_annot_start_after_previous_event - - t_break_annot_stop_before_next_event diff --git a/docs/source/settings/preprocessing/epochs.md b/docs/source/settings/preprocessing/epochs.md deleted file mode 100644 index 02dd1f71d..000000000 --- a/docs/source/settings/preprocessing/epochs.md +++ /dev/null @@ -1,26 +0,0 @@ ---- -tags: - - preprocessing - - epochs - - events - - metadata - - resting-state ---- - -::: mne_bids_pipeline._config - options: - members: - - rename_events - - on_rename_missing_events - - event_repeated - - conditions - - epochs_tmin - - epochs_tmax - - baseline - - epochs_metadata_tmin - - epochs_metadata_tmax - - epochs_metadata_keep_first - - epochs_metadata_keep_last - - epochs_metadata_query - - rest_epochs_duration - - rest_epochs_overlap diff --git a/docs/source/settings/preprocessing/filter.md b/docs/source/settings/preprocessing/filter.md deleted file mode 100644 index 9d1301412..000000000 --- a/docs/source/settings/preprocessing/filter.md +++ /dev/null @@ -1,37 +0,0 @@ ---- -tags: - - preprocessing - - frequency-filter - - raw ---- - -It is typically better to set your filtering properties on the raw data so -as to avoid what we call border (or edge) effects. - -If you use this pipeline for evoked responses, you could consider -a low-pass filter cut-off of h_freq = 40 Hz -and possibly a high-pass filter cut-off of l_freq = 1 Hz -so you would preserve only the power in the 1Hz to 40 Hz band. -Note that highpass filtering is not necessarily recommended as it can -distort waveforms of evoked components, or simply wash out any low -frequency that can may contain brain signal. It can also act as -a replacement for baseline correction in Epochs. See below. - -If you use this pipeline for time-frequency analysis, a default filtering -could be a high-pass filter cut-off of l_freq = 1 Hz -a low-pass filter cut-off of h_freq = 120 Hz -so you would preserve only the power in the 1Hz to 120 Hz band. - -If you need more fancy analysis, you are already likely past this kind -of tips! 😇 - -::: mne_bids_pipeline._config - options: - members: - - l_freq - - h_freq - - l_trans_bandwidth - - h_trans_bandwidth - - notch_freq - - notch_trans_bandwidth - - notch_widths diff --git a/docs/source/settings/preprocessing/maxfilter.md b/docs/source/settings/preprocessing/maxfilter.md deleted file mode 100644 index 3cd32d9d7..000000000 --- a/docs/source/settings/preprocessing/maxfilter.md +++ /dev/null @@ -1,29 +0,0 @@ ---- -tags: - - preprocessing - - maxwell-filter - - raw ---- - -::: mne_bids_pipeline._config - options: - members: - - use_maxwell_filter - - mf_st_duration - - mf_st_correlation - - mf_head_origin - - mf_destination - - mf_int_order - - mf_reference_run - - mf_cal_fname - - mf_ctc_fname - - mf_esss - - mf_esss_reject - - mf_mc - - mf_mc_t_step_min - - mf_mc_t_window - - mf_mc_gof_limit - - mf_mc_dist_limit - - mf_mc_rotation_velocity_limit - - mf_mc_translation_velocity_limit - - mf_filter_chpi diff --git a/docs/source/settings/preprocessing/resample.md b/docs/source/settings/preprocessing/resample.md deleted file mode 100644 index 6aa824a6c..000000000 --- a/docs/source/settings/preprocessing/resample.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -tags: - - preprocessing - - resampling - - decimation - - raw - - epochs ---- - -If you have acquired data with a very high sampling frequency (e.g. 2 kHz) -you will likely want to downsample to lighten up the size of the files you -are working with (pragmatics) -If you are interested in typical analysis (up to 120 Hz) you can typically -resample your data down to 500 Hz without preventing reliable time-frequency -exploration of your data. - -::: mne_bids_pipeline._config - options: - members: - - raw_resample_sfreq - - epochs_decim diff --git a/docs/source/settings/preprocessing/ssp_ica.md b/docs/source/settings/preprocessing/ssp_ica.md deleted file mode 100644 index b132ef4bf..000000000 --- a/docs/source/settings/preprocessing/ssp_ica.md +++ /dev/null @@ -1,32 +0,0 @@ ---- -tags: - - preprocessing - - artifact-removal - - raw - - epochs - - ssp - - ica ---- - -::: mne_bids_pipeline._config - options: - members: - - spatial_filter - - min_ecg_epochs - - min_eog_epochs - - n_proj_eog - - n_proj_ecg - - ssp_meg - - ecg_proj_from_average - - eog_proj_from_average - - ssp_reject_eog - - ssp_reject_ecg - - ssp_ecg_channel - - ica_reject - - ica_algorithm - - ica_l_freq - - ica_max_iterations - - ica_n_components - - ica_decim - - ica_ctps_ecg_threshold - - ica_eog_threshold diff --git a/docs/source/settings/preprocessing/stim_artifact.md b/docs/source/settings/preprocessing/stim_artifact.md deleted file mode 100644 index cbc142550..000000000 --- a/docs/source/settings/preprocessing/stim_artifact.md +++ /dev/null @@ -1,19 +0,0 @@ ---- -tags: - - preprocessing - - artifact-removal - - raw - - epochs ---- - -When using electric stimulation systems, e.g. for median nerve or index -stimulation, it is frequent to have a stimulation artifact. This option -allows to fix it by linear interpolation early in the pipeline on the raw -data. - -::: mne_bids_pipeline._config - options: - members: - - fix_stim_artifact - - stim_artifact_tmin - - stim_artifact_tmax diff --git a/docs/source/settings/reports/report_generation.md b/docs/source/settings/reports/report_generation.md deleted file mode 100644 index 2ccf520ed..000000000 --- a/docs/source/settings/reports/report_generation.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -tags: - - report ---- - -::: mne_bids_pipeline._config - options: - members: - - report_evoked_n_time_points - - report_stc_n_time_points diff --git a/docs/source/settings/sensor/contrasts.md b/docs/source/settings/sensor/contrasts.md deleted file mode 100644 index 576e45ee3..000000000 --- a/docs/source/settings/sensor/contrasts.md +++ /dev/null @@ -1,11 +0,0 @@ ---- -tags: - - epochs - - evoked - - contrast ---- - -::: mne_bids_pipeline._config - options: - members: - - contrasts diff --git a/docs/source/settings/sensor/group_level.md b/docs/source/settings/sensor/group_level.md deleted file mode 100644 index a330a9703..000000000 --- a/docs/source/settings/sensor/group_level.md +++ /dev/null @@ -1,10 +0,0 @@ ---- -tags: - - evoked - - group-level ---- - -::: mne_bids_pipeline._config - options: - members: - - interpolate_bads_grand_average diff --git a/docs/source/settings/sensor/mvpa.md b/docs/source/settings/sensor/mvpa.md deleted file mode 100644 index 8131cd428..000000000 --- a/docs/source/settings/sensor/mvpa.md +++ /dev/null @@ -1,25 +0,0 @@ ---- -tags: - - epochs - - evoked - - contrast - - decoding ---- - -::: mne_bids_pipeline._config - options: - members: - - decode - - decoding_epochs_tmin - - decoding_epochs_tmax - - decoding_metric - - decoding_n_splits - - decoding_time_generalization - - decoding_time_generalization_decim - - decoding_csp - - decoding_csp_times - - decoding_csp_freqs - - n_boot - - cluster_forming_t_threshold - - cluster_n_permutations - - cluster_permutation_p_threshold diff --git a/docs/source/settings/sensor/time_freq.md b/docs/source/settings/sensor/time_freq.md deleted file mode 100644 index 492296dc0..000000000 --- a/docs/source/settings/sensor/time_freq.md +++ /dev/null @@ -1,18 +0,0 @@ ---- -tags: - - epochs - - evoked - - time-frequency ---- - -::: mne_bids_pipeline._config - options: - members: - - time_frequency_conditions - - time_frequency_freq_min - - time_frequency_freq_max - - time_frequency_cycles - - time_frequency_subtract_evoked - - time_frequency_baseline - - time_frequency_baseline_mode - - time_frequency_crop diff --git a/docs/source/settings/source/bem.md b/docs/source/settings/source/bem.md deleted file mode 100644 index f55972baf..000000000 --- a/docs/source/settings/source/bem.md +++ /dev/null @@ -1,16 +0,0 @@ ---- -tags: - - inverse-solution - - bem - - freesurfer ---- - -::: mne_bids_pipeline._config - options: - members: - - use_template_mri - - adjust_coreg - - bem_mri_images - - recreate_bem - - recreate_scalp_surface - - freesurfer_verbose diff --git a/docs/source/settings/source/forward.md b/docs/source/settings/source/forward.md deleted file mode 100644 index 8ce5c87ad..000000000 --- a/docs/source/settings/source/forward.md +++ /dev/null @@ -1,14 +0,0 @@ ---- -tags: - - inverse-solution - - forward-model ---- - -::: mne_bids_pipeline._config - options: - members: - - mri_t1_path_generator - - mri_landmarks_kind - - spacing - - mindist - - source_info_path_update diff --git a/docs/source/settings/source/general.md b/docs/source/settings/source/general.md deleted file mode 100644 index 09eac741f..000000000 --- a/docs/source/settings/source/general.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -tags: - - inverse-solution ---- - -::: mne_bids_pipeline._config - options: - members: - - run_source_estimation diff --git a/docs/source/settings/source/inverse.md b/docs/source/settings/source/inverse.md deleted file mode 100644 index 4a10f1aef..000000000 --- a/docs/source/settings/source/inverse.md +++ /dev/null @@ -1,14 +0,0 @@ ---- -tags: - - inverse-solution ---- - -::: mne_bids_pipeline._config - options: - members: - - loose - - depth - - inverse_method - - noise_cov - - source_info_path_update - - inverse_targets diff --git a/docs/source/v1.5.md.inc b/docs/source/v1.5.md.inc index 5522271b1..6ef152c1e 100644 --- a/docs/source/v1.5.md.inc +++ b/docs/source/v1.5.md.inc @@ -1,4 +1,4 @@ -## v1.5.0 (unreleased) +## v1.5.0 (2023-11-30) This release contains a number of very important bug fixes that address problems related to decoding, time-frequency analysis, and inverse modeling. All users are encouraged to update. @@ -19,12 +19,13 @@ All users are encouraged to update. per-epoch basis as the last preprocessing step; this can be enabled by setting [`reject`][mne_bids_pipeline._config.reject] to `"autoreject_local"`. The behavior can further be controlled via the new setting [`autoreject_n_interpolate`][mne_bids_pipeline._config.autoreject_n_interpolate]. (#807 by @hoechenberger) -- Added support for "local" [`autoreject`](https://autoreject.github.io) to find (and repair) bad channels on a per-epoch - basis before submitting them to ICA fitting. This can be enabled by setting [`ica_reject`][mne_bids_pipeline._config.ica_reject] - to `"autoreject_local"`. (#810 by @hoechenberger) +- Added support for "local" [`autoreject`](https://autoreject.github.io) to remove bad epochs + before submitting the data to ICA fitting. This can be enabled by setting [`ica_reject`][mne_bids_pipeline._config.ica_reject] + to `"autoreject_local"`. (#810, #816 by @hoechenberger) +- The new setting [`decoding_which_epochs`][mne_bids_pipeline._config.decoding_which_epochs] controls which epochs (e.g., uncleaned, after ICA/SSP, cleaned) shall be used for decoding. (#819 by @hoechenberger) - Website documentation tables can now be sorted (e.g., to find examples that use a specific feature) (#808 by @larsoner) -[//]: # (### :warning: Behavior changes) +### :warning: Behavior changes - The default cache directory is now `_cache` within the derivatives folder when using `memory_location=True`, set [`memory_subdir="joblib"`][mne_bids_pipeline._config.memory_subdir] to get the behavior from v1.4 (#778 by @larsoner) - Before cleaning epochs via ICA, we used to reject any epochs execeeding the [`ica_reject`][mne_bids_pipeline._config.ica_reject] @@ -32,6 +33,7 @@ All users are encouraged to update. we only apply `ica_reject` to the epochs used for ICA fitting. After the experimental epochs have been cleaned with ICA (`preprocessing/apply_ica` step), any remaining large-amplitude artifacts can be removed via [`reject`][mne_bids_pipeline._config.reject], which is used in the last preprocessing step, `preprocessing/ptp_reject`. (#806 by @hoechenberger) +- MVPA / decoding used to be performed on un-cleaned epochs in the past. Now, cleaned epochs will be used by default (please also see the "Bug fixes" section below). (#796 by @hoechenberger) [//]: # (- Whatever (#000 by @whoever)) @@ -51,4 +53,6 @@ All users are encouraged to update. - Fixed bug where head position files were not written with a proper suffix and extension (#761 by @larsoner) - Fixed bug where default values for `decoding_csp_times` and `decoding_csp_freqs` were not set dynamically based on the config parameters (#779 by @larsoner) - Fixed bug where the MNE logger verbosity was not respected inside parallel jobs (#813 by @larsoner) -- A number of processing steps erroneously **always** operated on un-cleaned epochs (`sensor/decoding_full_epochs`, `sensor/decoding_time_by_time`, `sensor/decoding_csp`); or operated on un-cleaned epochs (without PTP rejection) if no ICA or SSP was requested (`sensor/ime_frequency`, `sensor/make_cov`) The bug in `sensor/make_cov` could propagate to the source level, as the covariance matrix is used for inverse modeling. (#796 by @hoechenberger) \ No newline at end of file +- A number of processing steps erroneously **always** operated on un-cleaned epochs (`sensor/decoding_full_epochs`, `sensor/decoding_time_by_time`, `sensor/decoding_csp`); or operated on un-cleaned epochs (without PTP rejection) if no ICA or SSP was requested (`sensor/ime_frequency`, `sensor/make_cov`) The bug in `sensor/make_cov` could propagate to the source level, as the covariance matrix is used for inverse modeling. (#796 by @hoechenberger) +- Bad channels may have been submitted to MVPA (full epochs decoding, time-by-time decoding, CSP-based decoding) when not using Maxwell filtering + (i.e., usually only EEG data was affected). This has now been fixed and data from bad channels is omitted from decoding. (#817 by @hoechenberger) diff --git a/docs/source/v1.6.md.inc b/docs/source/v1.6.md.inc new file mode 100644 index 000000000..b29871c11 --- /dev/null +++ b/docs/source/v1.6.md.inc @@ -0,0 +1,44 @@ +## v1.6.0 (2024-03-01) + +:new: New features & enhancements + +- Added [`regress_artifact`][mne_bids_pipeline._config.regress_artifact] to allow artifact regression (e.g., of MEG reference sensors in KIT systems) (#837 by @larsoner) +- Chosen `reject` parameters are now saved in the generated HTML reports (#839 by @larsoner) +- Added saving of clean raw data in addition to epochs (#840 by @larsoner) +- Added saving of detected blink and cardiac events used to calculate SSP projectors (#840 by @larsoner) +- Added [`noise_cov_method`][mne_bids_pipeline._config.noise_cov_method] to allow for the use of methods other than `"shrunk"` for noise covariance estimation (#854 by @larsoner) +- Added option to pass `image_kwargs` to [`mne.Report.add_epochs`] to allow adjusting e.g. `"vmin"` and `"vmax"` of the epochs image in the report via [`report_add_epochs_image_kwargs`][mne_bids_pipeline._config.report_add_epochs_image_kwargs]. This feature requires MNE-Python 1.7 or newer. (#848 by @SophieHerbst) +- Split ICA fitting and artifact detection into separate steps. This means that now, ICA is split into a total of three consecutive steps: fitting, artifact detection, and the actual data cleaning step ("applying ICA"). This makes it easier to experiment with different settings for artifact detection without needing to re-fit ICA. (#865 by @larsoner) +- The configuration used for the pipeline is now saved in a separate spreadsheet in the `.xlsx` log file (#869 by @larsoner) + +[//]: # (### :warning: Behavior changes) + +[//]: # (- Whatever (#000 by @whoever)) + +### :package: Requirements + +- MNE-BIDS-Pipeline now requires Python 3.9 or newer. (#825 by @hoechenberger) + +### :bug: Bug fixes + +- Fixed minor issues with path handling for cross-talk and calibration files (#834 by @larsoner) +- Fixed EEG `reject` use for `ch_types = ["meg", "eeg"]` in epoch cleaning (#839 by @larsoner) +- Fixed bug where implicit `mf_reference_run` could change across invocations of `mne_bids_pipeline`, breaking caching (#839 by @larsoner) +- Fixed `--no-cache` behavior having no effect (#839 by @larsoner) +- Fixed Maxwell filtering failures when [`find_noisy_channels_meg = False`][mne_bids_pipeline._config.find_noisy_channels_meg]` is used (#847 by @larsoner) +- Fixed raw, empty-room, and custom noise covariances calculation, previously they could errantly be calculated on data without ICA or SSP applied (#840 by @larsoner) +- Fixed multiple channel type handling (e.g., MEG and EEG) in decoding (#853 by @larsoner) +- Changed the default for [`ica_n_components`][mne_bids_pipeline._config.ica_n_components] from `0.8` (too conservative) to `None` to match MNE-Python's default (#853 by @larsoner) +- Prevent events table for the average subject overflowing in reports (#854 by @larsoner) +- Fixed split file behavior for Epochs when using ICA (#855 by @larsoner) +- Fixed a bug where users could not set `_components.tsv` as it would be detected as a cache miss and overwritten on next pipeline run (#865 by @larsoner) + +### :medical_symbol: Code health + +- The package build backend has been switched from `setuptools` to `hatchling`. (#825 by @hoechenberger) +- Removed dependencies on `datalad` and `git-annex` for testing (#867 by @larsoner) +- Code formatting now uses `ruff format` instead of `black` (#834, #838 by @larsoner) +- Code caching is now tested using GitHub Actions (#836 by @larsoner) +- Steps in the documentation are now automatically parsed into flowcharts (#859 by @larsoner) +- New configuration options are now automatically added to the docs (#863 by @larsoner) +- Configuration options now have relevant steps listed in the docs (#866 by @larsoner) diff --git a/docs/source/v1.7.md.inc b/docs/source/v1.7.md.inc new file mode 100644 index 000000000..3db0f1dfd --- /dev/null +++ b/docs/source/v1.7.md.inc @@ -0,0 +1,29 @@ +## v1.7.0 (2024-03-13) + +### :new: New features & enhancements + +- Improved logging message during cache invalidation: We now print the selected + [`memory_file_method`][mne_bids_pipeline._config.memory_file_method] ("hash" or "mtime"). + Previously, we'd always print "hash". (#876 by @hoechenberger) + +[//]: # (- Whatever (#000 by @whoever)) + +[//]: # (### :warning: Behavior changes) + +[//]: # (- Whatever (#000 by @whoever)) + +[//]: # (### :package: Requirements) + +[//]: # (- Whatever (#000 by @whoever)) + +### :bug: Bug fixes + +- Fixed an error when using [`analyze_channels`][mne_bids_pipeline._config.analyze_channels] with EEG data, where e.g. ERP creation didn't work. (#883 by @hoechenberger) + +[//]: # (- Whatever (#000 by @whoever)) + +### :medical_symbol: Code health + +- We enabled stricter linting to guarantee a consistently high code quality! (#872 by @hoechenberger) + +[//]: # (- Whatever (#000 by @whoever)) diff --git a/docs/source/v1.8.md.inc b/docs/source/v1.8.md.inc new file mode 100644 index 000000000..d4bb7f867 --- /dev/null +++ b/docs/source/v1.8.md.inc @@ -0,0 +1,27 @@ +## v1.8.0 (2024-03-20) + +### :new: New features & enhancements + +- Disabling CSP time-frequency mode is now supported by passing an empty list to [`decoding_csp_times`][mne_bids_pipeline._config.decoding_csp_times] (#890 by @whoever) + +[//]: # (### :warning: Behavior changes) + +[//]: # (- Whatever (#000 by @whoever)) + +### :package: Requirements + +- MNE-BIDS-Pipeline now explicitly depends on `annotated-types` (#886 by @hoechenberger) + +[//]: # (- Whatever (#000 by @whoever)) + +### :bug: Bug fixes + +- Fix handling of Maxwell-filtered data in CSP (#890 by @larsoner) +- Avoid recomputation / cache miss when the same empty-room file is matched to multiple subjects (#890 by @larsoner) + +### :medical_symbol: Code health + +- We removed the unused settings `shortest_event` and `study_name`. They were relics of early days of the pipeline + and haven't been in use for a long time. (#888, #889 by @hoechenberger and @larsoner) + +[//]: # (- Whatever (#000 by @whoever)) diff --git a/docs/source/v1.9.md.inc b/docs/source/v1.9.md.inc new file mode 100644 index 000000000..8b3e18d19 --- /dev/null +++ b/docs/source/v1.9.md.inc @@ -0,0 +1,61 @@ +## v1.9.0 + +### :new: New features & enhancements + +- Added number of subject to `sub-average` report (#902, #910 by @SophieHerbst) +- The type annotations in the default configuration file are now easier to read: We + replaced `Union[X, Y]` with `X | Y` and `Optional[X]` with `X | None`. (#908, #911 by @hoechenberger) +- Epochs metadata creation now supports variable time windows by specifying the names of events via + [`epochs_metadata_tmin`][mne_bids_pipeline._config.epochs_metadata_tmin] and + [`epochs_metadata_tmax`][mne_bids_pipeline._config.epochs_metadata_tmax]. (#873 by @hoechenberger) +- If you requested processing of non-existing subjects, we will now provide a more helpful error message. (#928 by @hoechenberger) +- We improved the logging output for automnated epochs rejection and cleaning via ICA and SSP. (#936, #937 by @hoechenberger) +- ECG and EOG signals created during ICA artifact detection are now saved to disk. (#938 by @hoechenberger) + +### :warning: Behavior changes + +- All ICA HTML reports have been consolidated in the standard subject `*_report.html` + file instead of producing separate files. (#899 by @larsoner) +- Changed default for `source_info_path_update` to `None`. In `_04_make_forward.py` + and `_05_make_inverse.py`, we retrieve the info from the file from which + the `noise_cov` is computed. (#919 by @SophieHerbst) +- The [`depth`][mne_bids_pipeline._config.depth] parameter doesn't accept `None` + anymore. Please use `0` instead. (#915 by @hoechenberger) +- When using automated bad channel detection, now indicate the generated `*_bads.tsv` files whether a channel + had previously already been marked as bad in the dataset. Resulting entries in the TSV file may now look like: + `"pre-existing (before MNE-BIDS-pipeline was run) & auto-noisy"` (previously: only `"auto-noisy"`). (#930 by @hoechenberger) +- The `ica_ctps_ecg_threshold` has been renamed to [`ica_ecg_threshold`][mne_bids_pipeline._config.ica_ecg_threshold]. (#935 by @hoechenberger) +- We changed the behavior when setting an EEG montage: + - When applying the montage, we now also check for channel aliases (e.g. `M1 -> TP9`). + - If the data contains a channel that is not present in the montage, we now abort with an exception (previously, we emitted a warning). + This is to prevent silent errors. To proceed in this situation, select a different montage, or drop the respective channels via + the [`drop_channels`][mne_bids_pipeline._config.drop_channels] configuration option. (#960 by @hoechenberger) + +### :package: Requirements + +- The minimum required version of MNE-Python is now 1.7.0. +- We dropped support for Python 3.9. You now need Python 3.10 or newer. (#908 by @hoechenberger) + +### :book: Documentation + +- We removed the `Execution` section from configuration options documentation and + replaced it with new, more explicit sections (namely, Caching, Parallelization, + Logging, and Error handling), and enhanced documentation. (#914 by @hoechenberger, #916 by @SophieHerbst) + +### :bug: Bug fixes + +- When running the pipeline with [`find_noisy_channels_meg`][mne_bids_pipeline._config.find_noisy_channels_meg] enabled, + then disabling it and running the pipeline again, the pipeline would incorrectly still use automatically detected + bad channels from the first pipeline run. Now, we ensure that the original bad channels would be used and the + related section is removed from the report in this case. (#902 by @larsoner) +- Fixed group-average decoding statistics were not updated in some cases, even if relevant configuration options had been changed. (#902 by @larsoner) +- Fixed a compatibility bug with joblib 1.4.0. (#899 by @larsoner) +- Fixed how "original" raw data is included in the report. Previously, bad channels, subject, and experimenter name would not + be displayed correctly. (#930 by @hoechenberger) +- In the report's table of contents, don't put the run numbers in quotation marks. (#933 by @hoechenberger) + +### :medical_symbol: Code health and infrastructure + +- Use GitHub's `dependabot` service to automatically keep GitHub Actions up-to-date. (#893 by @hoechenberger) +- Clean up some strings that our autoformatter failed to correctly merge. (#965 by @drammock) +- Type hints are now checked using `mypy`. (#995 by @larsoner) diff --git a/docs/source/vX.Y.md.inc b/docs/source/vX.Y.md.inc index ea88c02c5..36bf65f57 100644 --- a/docs/source/vX.Y.md.inc +++ b/docs/source/vX.Y.md.inc @@ -10,10 +10,14 @@ [//]: # (- Whatever (#000 by @whoever)) -[//]: # (### :medical_symbol: Code health) +[//]: # (### :package: Requirements) [//]: # (- Whatever (#000 by @whoever)) [//]: # (### :bug: Bug fixes) [//]: # (- Whatever (#000 by @whoever)) + +[//]: # (### :medical_symbol: Code health) + +[//]: # (- Whatever (#000 by @whoever)) diff --git a/docs/templates/python/material/attribute.html.jinja b/docs/templates/python/material/attribute.html.jinja new file mode 100644 index 000000000..caad07cbf --- /dev/null +++ b/docs/templates/python/material/attribute.html.jinja @@ -0,0 +1,101 @@ +{# Modified from https://github.com/mkdocstrings/python/blob/master/src/mkdocstrings_handlers/python/templates/material/_base/attribute.html #} +{# Updated 2024/05/20. See "START NEW CODE" for block that is new. #} + +{{ log.debug("Rendering " + attribute.path) }} + +
+{% with html_id = attribute.path %} + + {% if root %} + {% set show_full_path = config.show_root_full_path %} + {% set root_members = True %} + {% elif root_members %} + {% set show_full_path = config.show_root_members_full_path or config.show_object_full_path %} + {% set root_members = False %} + {% else %} + {% set show_full_path = config.show_object_full_path %} + {% endif %} + + {% if not root or config.show_root_heading %} + + {% filter heading(heading_level, + role="data" if attribute.parent.kind.value == "module" else "attr", + id=html_id, + class="doc doc-heading", + toc_label=attribute.name) %} + + {% block heading scoped %} + {% if config.separate_signature %} + {% if show_full_path %}{{ attribute.path }}{% else %}{{ attribute.name }}{% endif %} + {% else %} + {% filter highlight(language="python", inline=True) %} + {% if show_full_path %}{{ attribute.path }}{% else %}{{ attribute.name }}{% endif %} + {% if attribute.annotation %}: {{ attribute.annotation }}{% endif %} + {% if attribute.value %} = {{ attribute.value }}{% endif %} + {% endfilter %} + {% endif %} + {% endblock heading %} + + {% block labels scoped %} + {% with labels = attribute.labels %} + {% include "labels.html" with context %} + {% endwith %} + {% endblock labels %} + + {% endfilter %} + + {% block signature scoped %} + {% if config.separate_signature %} + {% filter highlight(language="python", inline=False) %} + {% filter format_code(config.line_length) %} + {% if show_full_path %}{{ attribute.path }}{% else %}{{ attribute.name }}{% endif %} + {% if attribute.annotation %}: {{ attribute.annotation|safe }}{% endif %} + {% if attribute.value %} = {{ attribute.value|safe }}{% endif %} + {% endfilter %} + {% endfilter %} + {% endif %} + {% endblock signature %} + + {% else %} + {% if config.show_root_toc_entry %} + {% filter heading(heading_level, + role="data" if attribute.parent.kind.value == "module" else "attr", + id=html_id, + toc_label=attribute.path if config.show_root_full_path else attribute.name, + hidden=True) %} + {% endfilter %} + {% endif %} + {% set heading_level = heading_level - 1 %} + {% endif %} + +
+ {% block contents scoped %} + {% block docstring scoped %} + {% with docstring_sections = attribute.docstring.parsed %} + {% include "docstring.html" with context %} + {% endwith %} + {% endblock docstring %} + {% endblock contents %} + + {# START NEW CODE #} + {% if pipeline_steps(attribute.name) %} + {# https://squidfunk.github.io/mkdocs-material/reference/admonitions/#collapsible-blocks #} +
+ Pipeline steps using this setting +

+ The following steps are directly affected by changes to + {{ attribute.name }}: +

+
    + {% for step in pipeline_steps(attribute.name) %} +
  • {{ step }}
  • + {% endfor %} +
+
+ {% endif %} + {# END NEW CODE #} + +
+ +{% endwith %} +
diff --git a/ignore_words.txt b/ignore_words.txt index e69de29bb..9c4f35236 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -0,0 +1,2 @@ +master +indx diff --git a/mne_bids_pipeline/__init__.py b/mne_bids_pipeline/__init__.py index 2826b97e6..2474edb8a 100644 --- a/mne_bids_pipeline/__init__.py +++ b/mne_bids_pipeline/__init__.py @@ -1,4 +1,4 @@ -from importlib.metadata import version, PackageNotFoundError +from importlib.metadata import PackageNotFoundError, version try: __version__ = version("mne_bids_pipeline") diff --git a/mne_bids_pipeline/_config.py b/mne_bids_pipeline/_config.py index 32e0c9735..9f4fe5a51 100644 --- a/mne_bids_pipeline/_config.py +++ b/mne_bids_pipeline/_config.py @@ -1,34 +1,24 @@ # Default settings for data processing and analysis. -from typing import Optional, Union, Iterable, List, Tuple, Dict, Callable, Literal +from collections.abc import Callable, Sequence +from typing import Annotated, Any, Literal +import pandas as pd +from annotated_types import Ge, Interval, Len, MinLen from mne import Covariance from mne_bids import BIDSPath from mne_bids_pipeline.typing import ( - PathLike, ArbitraryContrast, - FloatArrayLike, DigMontageType, + FloatArrayLike, + PathLike, ) +# %% +# # General settings -############################################################################### -# Config parameters -# ----------------- - -study_name: str = "" -""" -Specify the name of your study. It will be used to populate filenames for -saving the analysis results. - -???+ example "Example" - ```python - study_name = 'my-study' - ``` -""" - -bids_root: Optional[PathLike] = None +bids_root: PathLike | None = None """ Specify the BIDS root directory. Pass an empty string or ```None` to use the value specified in the `BIDS_ROOT` environment variable instead. @@ -41,7 +31,7 @@ ``` """ -deriv_root: Optional[PathLike] = None +deriv_root: PathLike | None = None """ The root of the derivatives directory in which the pipeline will store the processing results. If `None`, this will be @@ -52,7 +42,7 @@ set [`subjects_dir`][mne_bids_pipeline._config.subjects_dir] as well. """ -subjects_dir: Optional[PathLike] = None +subjects_dir: PathLike | None = None """ Path to the directory that contains the FreeSurfer reconstructions of all subjects. Specifically, this defines the `SUBJECTS_DIR` that is used by @@ -84,24 +74,35 @@ Enabling interactive mode deactivates parallel processing. """ -sessions: Union[List, Literal["all"]] = "all" +sessions: list[str] | Literal["all"] = "all" """ The sessions to process. If `'all'`, will process all sessions found in the BIDS dataset. """ +allow_missing_sessions: bool = False +""" +Whether to continue processing the dataset if some combinations of `subjects` and +`sessions` are missing. +""" + task: str = "" """ The task to process. """ -runs: Union[Iterable, Literal["all"]] = "all" +task_is_rest: bool = False +""" +Whether the task should be treated as resting-state data. +""" + +runs: Sequence[str] | Literal["all"] = "all" """ The runs to process. If `'all'`, will process all runs found in the BIDS dataset. """ -exclude_runs: Optional[Dict[str, List[str]]] = None +exclude_runs: dict[str, list[str]] | None = None """ Specify runs to exclude from analysis, for each participant individually. @@ -117,42 +118,34 @@ did not understand the instructions, etc.). """ -crop_runs: Optional[Tuple[float, float]] = None +crop_runs: tuple[float, float] | None = None """ Crop the raw data of each run to the specified time interval `[tmin, tmax]`, in seconds. The runs will be cropped before Maxwell or frequency filtering is applied. If `None`, do not crop the data. """ -acq: Optional[str] = None +acq: str | None = None """ The BIDS `acquisition` entity. """ -proc: Optional[str] = None +proc: str | None = None """ The BIDS `processing` entity. """ -rec: Optional[str] = None +rec: str | None = None """ The BIDS `recording` entity. """ -space: Optional[str] = None +space: str | None = None """ The BIDS `space` entity. """ -plot_psd_for_runs: Union[Literal["all"], Iterable[str]] = "all" -""" -For which runs to add a power spectral density (PSD) plot to the generated -report. This can take a considerable amount of time if you have many long -runs. In this case, specify the runs, or pass an empty list to disable raw PSD -plotting. -""" - -subjects: Union[Iterable[str], Literal["all"]] = "all" +subjects: Sequence[str] | Literal["all"] = "all" """ Subjects to analyze. If `'all'`, include all subjects. To only include a subset of subjects, pass a list of their identifiers. Even @@ -172,7 +165,7 @@ ``` """ -exclude_subjects: Iterable[str] = [] +exclude_subjects: Sequence[str] = [] """ Specify subjects to exclude from analysis. The MEG empty-room mock-subject is automatically excluded from regular analysis. @@ -202,7 +195,7 @@ covariance (via `noise_cov='rest'`). """ -ch_types: Iterable[Literal["meg", "mag", "grad", "eeg"]] = [] +ch_types: Annotated[Sequence[Literal["meg", "mag", "grad", "eeg"]], Len(1, 4)] = [] """ The channel types to consider. @@ -219,7 +212,7 @@ ``` """ -data_type: Optional[Literal["meg", "eeg"]] = None +data_type: Literal["meg", "eeg"] | None = None """ The BIDS data type. @@ -253,7 +246,7 @@ ``` """ -eog_channels: Optional[Iterable[str]] = None +eog_channels: Sequence[str] | None = None """ Specify EOG channels to use, or create virtual EOG channels. @@ -288,7 +281,7 @@ ``` """ -eeg_bipolar_channels: Optional[Dict[str, Tuple[str, str]]] = None +eeg_bipolar_channels: dict[str, tuple[str, str]] | None = None """ Combine two channels into a bipolar channel, whose signal is the **difference** between the two combined channels, and add it to the data. @@ -321,7 +314,7 @@ ``` """ -eeg_reference: Union[Literal["average"], str, Iterable["str"]] = "average" +eeg_reference: Literal["average"] | str | Sequence[str] = "average" """ The EEG reference to use. If `average`, will use the average reference, i.e. the average across all channels. If a string, must be the name of a single @@ -344,7 +337,7 @@ ``` """ -eeg_template_montage: Optional[Union[str, DigMontageType]] = None +eeg_template_montage: str | DigMontageType | None = None """ In situations where you wish to process EEG data and no individual digitization points (measured channel locations) are available, you can apply @@ -360,6 +353,13 @@ You can find an overview of supported template montages at https://mne.tools/stable/generated/mne.channels.make_standard_montage.html +!!! warning + If the data contains channel names that are not part of the template montage, the + pipeline run will fail with an error message. You must either pick a different + montage or remove those channels via + [`drop_channels`][mne_bids_pipeline._config.drop_channels] to continue. + + ???+ example "Example" Do not apply template montage: ```python @@ -372,11 +372,12 @@ ``` """ -drop_channels: Iterable[str] = [] +drop_channels: Sequence[str] = [] """ Names of channels to remove from the data. This can be useful, for example, if you have added a new bipolar channel via `eeg_bipolar_channels` and now wish -to remove the anode, cathode, or both. +to remove the anode, cathode, or both; or if your selected EEG template montage +doesn't contain coordinates for some channels. ???+ example "Example" Exclude channels `Fp1` and `Cz` from processing: @@ -385,9 +386,9 @@ ``` """ -analyze_channels: Union[ - Literal["all"], Literal["ch_types"], Iterable["str"] -] = "ch_types" +analyze_channels: Literal["all", "ch_types"] | Annotated[Sequence[str], MinLen(1)] = ( + "ch_types" +) """ The names of the channels to analyze during ERP/ERF and time-frequency analysis steps. For certain paradigms, e.g. EEG ERP research, it is common to constrain @@ -407,7 +408,7 @@ ``` """ -reader_extra_params: dict = {} +reader_extra_params: dict[str, Any] = {} """ Parameters to be passed to `read_raw_bids()` calls when importing raw data. @@ -418,7 +419,7 @@ ``` """ -read_raw_bids_verbose: Optional[Literal["error"]] = None +read_raw_bids_verbose: Literal["error"] | None = None """ Verbosity level to pass to `read_raw_bids(..., verbose=read_raw_bids_verbose)`. If you know your dataset will contain files that are not perfectly BIDS @@ -426,9 +427,26 @@ `'error'` to suppress warnings emitted by read_raw_bids. """ -############################################################################### -# BREAK DETECTION -# --------------- +plot_psd_for_runs: Literal["all"] | Sequence[str] = "all" +""" +For which runs to add a power spectral density (PSD) plot to the generated +report. This can take a considerable amount of time if you have many long +runs. In this case, specify the runs, or pass an empty list to disable raw PSD +plotting. +""" + +random_state: int | None = 42 +""" +You can specify the seed of the random number generator (RNG). +This setting is passed to the ICA algorithm and to the decoding function, +ensuring reproducible results. Set to `None` to avoid setting the RNG +to a defined state. +""" + +# %% +# # Preprocessing + +# ## Break detection find_breaks: bool = False """ @@ -527,10 +545,23 @@ ``` """ -############################################################################### -# MAXWELL FILTER PARAMETERS -# ------------------------- -# done in 01-import_and_maxfilter.py +# %% +# ## Bad channel detection +# +# !!! warning +# This functionality will soon be removed from the pipeline, and +# will be integrated into MNE-BIDS. +# +# "Bad", i.e. flat and overly noisy channels, can be automatically detected +# using a procedure inspired by the commercial MaxFilter by Elekta. First, +# a copy of the data is low-pass filtered at 40 Hz. Then, channels with +# unusually low variability are flagged as "flat", while channels with +# excessively high variability are flagged as "noisy". Flat and noisy channels +# are marked as "bad" and excluded from subsequent analysis. See +# :func:`mne.preprocssessing.find_bad_channels_maxwell` for more information +# on this procedure. The list of bad channels detected through this procedure +# will be merged with the list of bad channels already present in the dataset, +# if any. find_flat_channels_meg: bool = False """ @@ -543,9 +574,25 @@ Auto-detect "noisy" channels and mark them as bad. """ + +find_bad_channels_extra_kws: dict[str, Any] = {} + +""" +A dictionary of extra kwargs to pass to `mne.preprocessing.find_bad_channels_maxwell` +. If kwargs are passed here that have dedicated config settings already, an error +will be raised. +For full documentation of the bad channel detection: +https://mne.tools/stable/generated/mne.preprocessing.find_bad_channels_maxwell +""" + + +# %% +# ## Maxwell filter + use_maxwell_filter: bool = False """ -Whether or not to use Maxwell filtering to preprocess the data. +Whether or not to use [Maxwell filtering][mne.preprocessing.maxwell_filter] to +preprocess the data. !!! warning If the data were recorded with internal active compensation (MaxShield), @@ -555,7 +602,7 @@ before applying Maxwell filter. """ -mf_st_duration: Optional[float] = None +mf_st_duration: float | None = None """ There are two kinds of Maxwell filtering: SSS (signal space separation) and tSSS (temporal signal space separation) @@ -596,7 +643,7 @@ ``` """ -mf_head_origin: Union[Literal["auto"], FloatArrayLike] = "auto" +mf_head_origin: Literal["auto"] | FloatArrayLike = "auto" """ `mf_head_origin` : array-like, shape (3,) | 'auto' Origin of internal and external multipolar moment space in meters. @@ -611,7 +658,7 @@ ``` """ -mf_destination: Union[Literal["reference_run"], FloatArrayLike] = "reference_run" +mf_destination: Literal["reference_run", "twa"] | FloatArrayLike = "reference_run" """ Despite all possible care to avoid movements in the MEG, the participant will likely slowly drift down from the Dewar or slightly shift the head @@ -628,21 +675,31 @@ transform. This will result in a device-to-head transformation that is the same across all subjects. - ???+ example "A Standardized Position" - ```python - from mne.transforms import translation - mf_destination = translation(z=0.04) - ``` + ???+ example "A Standardized Position" + ```python + from mne.transforms import translation + mf_destination = translation(z=0.04) + ``` +3. Compute the time-weighted average head position across all runs in a session, + and use that as the destination coordinates for each run. This will result in a + device-to-head transformation that differs between sessions within each subject. """ mf_int_order: int = 8 """ -Internal order for the Maxwell basis. Can be set to something lower (e.g., 6 -or higher for datasets where lower or higher spatial complexity, respectively, -is expected. +Internal order for the Maxwell basis. Can increase or decrease for datasets where +neural signals with higher or lower spatial complexity are expected. +Per MNE, the default values are appropriate for most use cases. +""" + +mf_ext_order: int = 3 +""" +External order for the Maxwell basis. Can increase or decrease for datasets where +environmental artifacts with higher or lower spatial complexity are expected. +Per MNE, the default values are appropriate for most use cases. """ -mf_reference_run: Optional[str] = None +mf_reference_run: str | None = None """ Which run to take as the reference for adjusting the head position of all runs when [`mf_destination="reference_run"`][mne_bids_pipeline._config.mf_destination]. @@ -654,7 +711,7 @@ ``` """ -mf_cal_fname: Optional[str] = None +mf_cal_fname: str | None = None """ !!! warning This parameter should only be used for BIDS datasets that don't store @@ -666,9 +723,16 @@ ```python mf_cal_fname = '/path/to/your/file/calibration_cal.dat' ``` -""" # noqa : E501 +""" -mf_ctc_fname: Optional[str] = None +mf_cal_missing: Literal["ignore", "warn", "raise"] = "raise" +""" +How to handle the situation where the MEG device's fine calibration file is missing. +Possible options are to ignore the missing file (as may be appropriate for OPM data), +issue a warning, or raise an error. +""" + +mf_ctc_fname: str | None = None """ Path to the Maxwell Filter cross-talk file. If `None`, the recommended location is used. @@ -681,14 +745,21 @@ ```python mf_ctc_fname = '/path/to/your/file/crosstalk_ct.fif' ``` -""" # noqa : E501 +""" + +mf_ctc_missing: Literal["ignore", "warn", "raise"] = "raise" +""" +How to handle the situation where the MEG device's cross-talk file is missing. Possible +options are to ignore the missing file (as may be appropriate for OPM data), issue a +warning, or raise an error (appropriate for data from Electa/Neuromag/MEGIN systems). +""" mf_esss: int = 0 """ Number of extended SSS (eSSS) basis projectors to use from empty-room data. """ -mf_esss_reject: Optional[Dict[str, float]] = None +mf_esss_reject: dict[str, float] | None = None """ Rejection parameters to use when computing the extended SSS (eSSS) basis. """ @@ -703,7 +774,7 @@ Minimum time step to use during cHPI coil amplitude estimation. """ -mf_mc_t_window: Union[float, Literal["auto"]] = "auto" +mf_mc_t_window: float | Literal["auto"] = "auto" """ The window to use during cHPI coil amplitude estimation and in cHPI filtering. Can be "auto" to autodetect a reasonable value or a float (in seconds). @@ -719,78 +790,84 @@ Minimum distance (m) to accept for cHPI position fitting. """ -mf_mc_rotation_velocity_limit: Optional[float] = None +mf_mc_rotation_velocity_limit: float | None = None """ The rotation velocity limit (degrees/second) to use when annotating movement-compensated data. If `None`, no annotations will be added. """ -mf_mc_translation_velocity_limit: Optional[float] = None +mf_mc_translation_velocity_limit: float | None = None """ The translation velocity limit (meters/second) to use when annotating movement-compensated data. If `None`, no annotations will be added. """ -mf_filter_chpi: Optional[bool] = None +mf_filter_chpi: bool | None = None """ Use mne.chpi.filter_chpi after Maxwell filtering. Can be None to use the same value as [`mf_mc`][mne_bids_pipeline._config.mf_mc]. Only used when [`use_maxwell_filter=True`][mne_bids_pipeline._config.use_maxwell_filter] -""" # noqa: E501 - -############################################################################### -# STIMULATION ARTIFACT -# -------------------- -# used in 01-import_and_maxfilter.py - -fix_stim_artifact: bool = False -""" -Apply interpolation to fix stimulation artifact. - -???+ example "Example" - ```python - fix_stim_artifact = False - ``` -""" - -stim_artifact_tmin: float = 0.0 """ -Start time of the interpolation window in seconds. -???+ example "Example" - ```python - stim_artifact_tmin = 0. # on stim onset - ``` +mf_extra_kws: dict[str, Any] = {} """ - -stim_artifact_tmax: float = 0.01 +A dictionary of extra kwargs to pass to `mne.preprocessing.maxwell_filter`. If kwargs +are passed here that have dedicated config settings already, an error will be raised. +For full documentation of the Maxwell filter: +https://mne.tools/stable/generated/mne.preprocessing.maxwell_filter """ -End time of the interpolation window in seconds. -???+ example "Example" - ```python - stim_artifact_tmax = 0.01 # up to 10ms post-stimulation - ``` -""" +# ## Filtering & resampling -############################################################################### -# FREQUENCY FILTERING & RESAMPLING -# -------------------------------- -# done in 02-frequency_filter.py +# ### Filtering +# +# It is typically better to set your filtering properties on the raw data so +# as to avoid what we call border (or edge) effects. +# +# If you use this pipeline for evoked responses, you could consider +# a low-pass filter cut-off of h_freq = 40 Hz +# and possibly a high-pass filter cut-off of l_freq = 1 Hz +# so you would preserve only the power in the 1Hz to 40 Hz band. +# Note that highpass filtering is not necessarily recommended as it can +# distort waveforms of evoked components, or simply wash out any low +# frequency that can may contain brain signal. It can also act as +# a replacement for baseline correction in Epochs. See below. +# +# If you use this pipeline for time-frequency analysis, a default filtering +# could be a high-pass filter cut-off of l_freq = 1 Hz +# a low-pass filter cut-off of h_freq = 120 Hz +# so you would preserve only the power in the 1Hz to 120 Hz band. +# +# If you need more fancy analysis, you are already likely past this kind +# of tips! 😇 -l_freq: Optional[float] = None +l_freq: float | None = None """ The low-frequency cut-off in the highpass filtering step. Keep it `None` if no highpass filtering should be applied. """ -h_freq: Optional[float] = 40.0 +h_freq: float | None = 40.0 """ The high-frequency cut-off in the lowpass filtering step. Keep it `None` if no lowpass filtering should be applied. """ -notch_freq: Optional[Union[float, Iterable[float]]] = None +l_trans_bandwidth: float | Literal["auto"] = "auto" +""" +Specifies the transition bandwidth of the +highpass filter. By default it's `'auto'` and uses default MNE +parameters. +""" + +h_trans_bandwidth: float | Literal["auto"] = "auto" +""" +Specifies the transition bandwidth of the +lowpass filter. By default it's `'auto'` and uses default MNE +parameters. +""" + +notch_freq: float | Sequence[float] | None = None """ Notch filter frequency. More than one frequency can be supplied, e.g. to remove harmonics. Keep it `None` if no notch filter should be applied. @@ -809,31 +886,53 @@ ``` """ -l_trans_bandwidth: Union[float, Literal["auto"]] = "auto" +notch_trans_bandwidth: float = 1.0 +""" +Specifies the transition bandwidth of the notch filter. The default is `1.`. +""" + +notch_widths: float | Sequence[float] | None = None """ -Specifies the transition bandwidth of the -highpass filter. By default it's `'auto'` and uses default MNE -parameters. +Specifies the width of each stop band. `None` uses the MNE default. """ -h_trans_bandwidth: Union[float, Literal["auto"]] = "auto" +zapline_fline: float | None = None """ -Specifies the transition bandwidth of the -lowpass filter. By default it's `'auto'` and uses default MNE -parameters. +Specifies frequency to remove using Zapline filtering. If None, zapline will not +be used. """ -notch_trans_bandwidth: float = 1.0 +zapline_iter: bool = False """ -Specifies the transition bandwidth of the notch filter. The default is `1.`. +Specifies if the iterative version of the Zapline algorithm should be run. """ -notch_widths: Optional[Union[float, Iterable[float]]] = None +notch_extra_kws: dict[str, Any] = {} """ -Specifies the width of each stop band. `None` uses the MNE default. +A dictionary of extra kwargs to pass to `mne.filter.notch_filter`. If kwargs +are passed here that have dedicated config settings already, an error will be raised. +For full documentation of the notch filter: +https://mne.tools/stable/generated/mne.filter.notch_filter. +""" + +bandpass_extra_kws: dict[str, Any] = {} """ +A dictionary of extra kwargs to pass to `mne.filter.filter_data`. If kwargs +are passed here that have dedicated config settings already, an error will be raised. +For full documatation of the bandpass filter: +https://mne.tools/stable/generated/mne.filter.filter_data +""" + +# ### Resampling +# +# If you have acquired data with a very high sampling frequency (e.g. 2 kHz) +# you will likely want to downsample to lighten up the size of the files you +# are working with (pragmatics) +# If you are interested in typical analysis (up to 120 Hz) you can typically +# resample your data down to 500 Hz without preventing reliable time-frequency +# exploration of your data. -raw_resample_sfreq: Optional[float] = None +raw_resample_sfreq: float | None = None """ Specifies at which sampling frequency the data should be resampled. If `None`, then no resampling will be done. @@ -845,10 +944,6 @@ ``` """ -############################################################################### -# DECIMATION -# ---------- - epochs_decim: int = 1 """ Says how much to decimate data at the epochs level. @@ -867,11 +962,9 @@ """ -############################################################################### -# RENAME EXPERIMENTAL EVENTS -# -------------------------- +# ## Epoching -rename_events: dict = dict() +rename_events: dict[str, str] = dict() """ A dictionary specifying which events in the BIDS dataset to rename upon loading, and before processing begins. @@ -895,10 +988,6 @@ to only get a warning instead, or `'ignore'` to ignore it completely. """ -############################################################################### -# HANDLING OF REPEATED EVENTS -# --------------------------- - event_repeated: Literal["error", "drop", "merge"] = "error" """ How to handle repeated events. We call events "repeated" if more than one event @@ -914,25 +1003,54 @@ April 1st, 2021. """ -############################################################################### -# EPOCHING -# -------- +epochs_custom_metadata: pd.DataFrame | dict[str, Any] | None = None + +""" +Pandas `DataFrame` containing custom metadata. The custom metadata will be +horizontally joined with the metadata generated from `events.tsv`. +The number of rows in the custom metadata must match the number of rows in +the events metadata (after filtering by `conditions`). + +The metadata can also be formatted as a `dict`, with keys being the `subject`, +`session`, and/or `task`, and the values being a `DataFrame`. e.g.: +```python +epochs_custom_metadata = {'sub-01': {'ses-01': {'task-taskA': my_DataFrame}}} +epochs_custom_metadata = {'ses-01': my_DataFrame1, 'ses-02': my_DataFrame2} +``` -epochs_metadata_tmin: Optional[float] = None +If None, don't use custom metadata. """ -The beginning of the time window for metadata generation, in seconds, -relative to the time-locked event of the respective epoch. This may be less -than or larger than the epoch's first time point. If `None`, use the first -time point of the epoch. + + +epochs_metadata_tmin: float | str | list[str] | None = None +""" +The beginning of the time window used for epochs metadata generation. This setting +controls the `tmin` value passed to +[`mne.epochs.make_metadata`](https://mne.tools/stable/generated/mne.epochs.make_metadata.html). + +If a float, the time in seconds relative to the time-locked event of the respective +epoch. Negative indicate times before, positive values indicate times after the +time-locked event. + +If a string or a list of strings, the name(s) of events marking the start of time +window. + +If `None`, use the first time point of the epoch. + +???+ info + Note that `None` here behaves differently than `tmin=None` in + `mne.epochs.make_metadata`. To achieve the same behavior, pass the name(s) of the + time-locked events instead. + """ -epochs_metadata_tmax: Optional[float] = None +epochs_metadata_tmax: float | str | list[str] | None = None """ Same as `epochs_metadata_tmin`, but specifying the **end** of the time window for metadata generation. """ -epochs_metadata_keep_first: Optional[Iterable[str]] = None +epochs_metadata_keep_first: Sequence[str] | None = None """ Event groupings using hierarchical event descriptors (HEDs) for which to store the time of the **first** occurrence of any event of this group in a new column @@ -960,16 +1078,16 @@ and `first_stimulus`. """ -epochs_metadata_keep_last: Optional[Iterable[str]] = None +epochs_metadata_keep_last: Sequence[str] | None = None """ Same as `epochs_metadata_keep_first`, but for keeping the **last** occurrence of matching event types. The columns indicating the event types will be named with a `last_` instead of a `first_` prefix. """ -epochs_metadata_query: Optional[str] = None +epochs_metadata_query: str | None = None """ -A [metadata query][https://mne.tools/stable/auto_tutorials/epochs/30_epochs_metadata.html] +A [metadata query](https://mne.tools/stable/auto_tutorials/epochs/30_epochs_metadata.html) specifying which epochs to keep. If the query fails because it refers to an unknown metadata column, a warning will be emitted and all epochs will be kept. @@ -978,9 +1096,9 @@ ```python epochs_metadata_query = ['response_missing.isna()'] ``` -""" # noqa: E501 +""" -conditions: Optional[Union[Iterable[str], Dict[str, str]]] = None +conditions: Sequence[str] | dict[str, str] | None = None """ The time-locked events based on which to create evoked responses. This can either be name of the experimental condition as specified in the @@ -994,7 +1112,8 @@ This is a **required** parameter in the configuration file, unless you are processing resting-state data. If left as `None` and -[`task_is_rest`][mne_bids_pipeline._config.task_is_rest] is not `True`, we will raise an error. +[`task_is_rest`][mne_bids_pipeline._config.task_is_rest] is not `True`, we will raise an +error. ???+ example "Example" Specifying conditions as lists of strings: @@ -1011,7 +1130,7 @@ conditions = {'simple_name': 'complex/condition/with_subconditions'} conditions = {'correct': 'response/correct', 'incorrect': 'response/incorrect'} -""" # noqa : E501 +""" epochs_tmin: float = -0.2 """ @@ -1032,23 +1151,18 @@ ``` """ -task_is_rest: bool = False -""" -Whether the task should be treated as resting-state data. -""" - -rest_epochs_duration: Optional[float] = None +rest_epochs_duration: float | None = None """ Duration of epochs in seconds. """ -rest_epochs_overlap: Optional[float] = None +rest_epochs_overlap: float | None = None """ Overlap between epochs in seconds. This is used if the task is `'rest'` and when the annotations do not contain any stimulation or behavior events. """ -baseline: Optional[Tuple[Optional[float], Optional[float]]] = (None, 0) +baseline: tuple[float | None, float | None] | None = (None, 0) """ Specifies which time interval to use for baseline correction of epochs; if `None`, no baseline correction is applied. @@ -1059,74 +1173,69 @@ ``` """ -contrasts: Iterable[Union[Tuple[str, str], ArbitraryContrast]] = [] -""" -The conditions to contrast via a subtraction of ERPs / ERFs. The list elements -can either be tuples or dictionaries (or a mix of both). Each element in the -list corresponds to a single contrast. - -A tuple specifies a one-vs-one contrast, where the second condition is -subtracted from the first. +# ## Artifact removal -If a dictionary, must contain the following keys: +# ### Stimulation artifact +# +# When using electric stimulation systems, e.g. for median nerve or index +# stimulation, it is frequent to have a stimulation artifact. This option +# allows to fix it by linear interpolation early in the pipeline on the raw +# data. -- `name`: a custom name of the contrast -- `conditions`: the conditions to contrast -- `weights`: the weights associated with each condition. +fix_stim_artifact: bool = False +""" +Apply interpolation to fix stimulation artifact. -Pass an empty list to avoid calculation of any contrasts. +???+ example "Example" + ```python + fix_stim_artifact = False + ``` +""" -For the contrasts to be computed, the appropriate conditions must have been -epoched, and therefore the conditions should either match or be subsets of -`conditions` above. +stim_artifact_tmin: float = 0.0 +""" +Start time of the interpolation window in seconds. ???+ example "Example" - Contrast the "left" and the "right" conditions by calculating - `left - right` at every time point of the evoked responses: ```python - contrasts = [('left', 'right')] # Note we pass a tuple inside the list! + stim_artifact_tmin = 0. # on stim onset ``` +""" - Contrast the "left" and the "right" conditions within the "auditory" and - the "visual" modality, and "auditory" vs "visual" regardless of side: +stim_artifact_tmax: float = 0.01 +""" +End time of the interpolation window in seconds. + +???+ example "Example" ```python - contrasts = [('auditory/left', 'auditory/right'), - ('visual/left', 'visual/right'), - ('auditory', 'visual')] + stim_artifact_tmax = 0.01 # up to 10ms post-stimulation ``` +""" + +# ### SSP, ICA, and artifact regression + +regress_artifact: dict[str, Any] | None = None +""" +Keyword arguments to pass to the `mne.preprocessing.EOGRegression` model used +in `mne.preprocessing.regress_artifact`. If `None`, no time-domain regression will +be applied. Note that any channels picked in `regress_artifact["picks_artifact"]` will +have the same time-domain filters applied to them as the experimental data. + +Artifact regression is applied before SSP or ICA. + +???+ example "Example" + For example, if you have MEG reference channel data recorded in three + miscellaneous channels, you could do: - Contrast the "left" and the "right" regardless of side, and compute an - arbitrary contrast with a gradient of weights: ```python - contrasts = [ - ('auditory/left', 'auditory/right'), - { - 'name': 'gradedContrast', - 'conditions': [ - 'auditory/left', - 'auditory/right', - 'visual/left', - 'visual/right' - ], - 'weights': [-1.5, -.5, .5, 1.5] - } - ] + regress_artifact = { + "picks": "meg", + "picks_artifact": ["MISC 001", "MISC 002", "MISC 003"] + } ``` """ -############################################################################### -# ARTIFACT REMOVAL -# ---------------- -# -# You can choose between ICA and SSP to remove eye and heart artifacts. -# SSP: https://mne-tools.github.io/stable/auto_tutorials/plot_artifacts_correction_ssp.html?highlight=ssp # noqa -# ICA: https://mne-tools.github.io/stable/auto_tutorials/plot_artifacts_correction_ica.html?highlight=ica # noqa -# if you choose ICA, run steps 5a and 6a -# if you choose SSP, run steps 5b and 6b -# -# Currently you cannot use both. - -spatial_filter: Optional[Literal["ssp", "ica"]] = None +spatial_filter: Literal["ssp", "ica"] | None = None """ Whether to use a spatial filter to detect and remove artifacts. The BIDS Pipeline offers the use of signal-space projection (SSP) and independent @@ -1141,27 +1250,34 @@ EOG and ECG activity will be omitted during the signal reconstruction step in order to remove the artifacts. The ICA procedure can be configured in various ways using the configuration options you can find below. + +!!! warning "ICA requires manual intervention!" + After the automatic ICA component detection step, review each subject's + `*_report.html` report file check if the set of ICA components to be removed + is correct. Adjustments should be made to the `*_proc-ica_components.tsv` + file, which will then be used in the step that is applied during ICA. + + ICA component order can be considered arbitrary, so any time the ICA is + re-fit – i.e., if you change any parameters that affect steps prior to + ICA fitting – this file will need to be updated! """ -min_ecg_epochs: int = 5 +min_ecg_epochs: Annotated[int, Ge(1)] = 5 """ -Minimal number of ECG epochs needed to compute SSP or ICA rejection. +Minimal number of ECG epochs needed to compute SSP projectors. """ -min_eog_epochs: int = 5 +min_eog_epochs: Annotated[int, Ge(1)] = 5 """ -Minimal number of EOG epochs needed to compute SSP or ICA rejection. +Minimal number of EOG epochs needed to compute SSP projectors. """ - -# Rejection based on SSP -# ~~~~~~~~~~~~~~~~~~~~~~ -n_proj_eog: Dict[str, float] = dict(n_mag=1, n_grad=1, n_eeg=1) +n_proj_eog: dict[str, float] = dict(n_mag=1, n_grad=1, n_eeg=1) """ Number of SSP vectors to create for EOG artifacts for each channel type. """ -n_proj_ecg: Dict[str, float] = dict(n_mag=1, n_grad=1, n_eeg=1) +n_proj_ecg: dict[str, float] = dict(n_mag=1, n_grad=1, n_eeg=1) """ Number of SSP vectors to create for ECG artifacts for each channel type. """ @@ -1189,7 +1305,7 @@ `'separate'` otherwise. """ -ssp_reject_ecg: Optional[Union[Dict[str, float], Literal["autoreject_global"]]] = None +ssp_reject_ecg: dict[str, float] | Literal["autoreject_global"] | None = None """ Peak-to-peak amplitude limits of the ECG epochs to exclude from SSP fitting. This allows you to remove strong transient artifacts, which could negatively @@ -1207,7 +1323,7 @@ ``` """ -ssp_reject_eog: Optional[Union[Dict[str, float], Literal["autoreject_global"]]] = None +ssp_reject_eog: dict[str, float] | Literal["autoreject_global"] | None = None """ Peak-to-peak amplitude limits of the EOG epochs to exclude from SSP fitting. This allows you to remove strong transient artifacts, which could negatively @@ -1225,22 +1341,26 @@ ``` """ -ssp_ecg_channel: Optional[str] = None +ssp_ecg_channel: str | dict[str, str] | None = None """ Channel to use for ECG SSP. Can be useful when the autodetected ECG channel -is not reliable. +is not reliable. If `str`, the same channel will be used for all subjects. +If `dict`, possibly different channels will be used for each subject/session. +Dict values must be channel names, and dict keys must have the form `"sub-X"` (to use +the same channel for each session within a subject) or `"sub-X_ses-Y"` (to use possibly +different channels for each session of a given subject). """ -# Rejection based on ICA -# ~~~~~~~~~~~~~~~~~~~~~~ -ica_reject: Optional[Union[Dict[str, float], Literal["autoreject_local"]]] = None +ica_reject: dict[str, float] | Literal["autoreject_local"] | None = None """ Peak-to-peak amplitude limits to exclude epochs from ICA fitting. This allows you to remove strong transient artifacts from the epochs used for fitting ICA, which could -negatively affect ICA performance. +negatively affect ICA performance. The parameter values are the same as for [`reject`][mne_bids_pipeline._config.reject], -but `"autoreject_global"` is not supported. +but `"autoreject_global"` is not supported. `"autoreject_local"` here behaves +differently, too: it is only used to exclude bad epochs from ICA fitting; we do not +perform any interpolation. ???+ info We don't support `"autoreject_global"` here (as opposed to @@ -1248,7 +1368,7 @@ rejection thresholds were too strict before running ICA, i.e., too many epochs got rejected. `"autoreject_local"`, however, usually performed nicely. The `autoreject` documentation - [recommends][https://autoreject.github.io/stable/auto_examples/plot_autoreject_workflow.html] + [recommends](https://autoreject.github.io/stable/auto_examples/plot_autoreject_workflow.html) running local `autoreject` before and after ICA, which can be achieved by setting both, `ica_reject` and [`reject`][mne_bids_pipeline._config.reject], to `"autoreject_local"`. @@ -1262,7 +1382,7 @@ to **not** specify rejection thresholds for EOG and ECG channels here – otherwise, ICA won't be able to "see" these artifacts. -???+ info +???+ info This setting is applied only to the epochs that are used for **fitting** ICA. The goal is to make it easier for ICA to produce a good decomposition. After fitting, ICA is applied to the epochs to be analyzed, usually with one or more components @@ -1278,7 +1398,7 @@ ica_reject = "autoreject_global" # find global (per channel type) PTP thresholds before fitting ICA ica_reject = "autoreject_local" # find local (per channel) thresholds and repair epochs before fitting ICA ``` -""" +""" # noqa: E501 ica_algorithm: Literal[ "picard", "fastica", "extended_infomax", "picard-extended_infomax" @@ -1289,7 +1409,7 @@ algorithm (but may converge in less time). """ -ica_l_freq: Optional[float] = 1.0 +ica_l_freq: float | None = 1.0 """ The cutoff frequency of the high-pass filter to apply before running ICA. Using a relatively high cutoff like 1 Hz will remove slow drifts from the @@ -1323,7 +1443,7 @@ limit may be too low to achieve convergence. """ -ica_n_components: Optional[Union[float, int]] = 0.8 +ica_n_components: float | int | None = None """ MNE conducts ICA as a sort of a two-step procedure: First, a PCA is run on the data (trying to exclude zero-valued components in rank-deficient @@ -1342,12 +1462,13 @@ explained variance less than the value specified here will be passed to ICA. -If `None`, **all** principal components will be used. +If `None` (default), `0.999999` will be used to avoid issues when working with +rank-deficient data. This setting may drastically alter the time required to compute ICA. """ -ica_decim: Optional[int] = None +ica_decim: int | None = None """ The decimation parameter to compute ICA. If 5 it means that 1 every 5 sample is used by ICA solver. The higher the faster @@ -1355,9 +1476,10 @@ `1` or `None` to not perform any decimation. """ -ica_ctps_ecg_threshold: float = 0.1 +ica_ecg_threshold: float = 0.1 """ -The threshold parameter passed to `find_bads_ecg` method. +The cross-trial phase statistics (CTPS) threshold parameter used for detecting +ECG-related ICs. """ ica_eog_threshold: float = 3.0 @@ -1367,12 +1489,17 @@ false-alarm rate increases dramatically. """ -# Rejection based on peak-to-peak amplitude -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -reject: Optional[ - Union[Dict[str, float], Literal["autoreject_global", "autoreject_local"]] -] = None +# ### Amplitude-based artifact rejection +# +# ???+ info "Good Practice / Advice" +# Have a look at your raw data and train yourself to detect a blink, a heart +# beat and an eye movement. +# You can do a quick average of blink data and check what the amplitude looks +# like. + +reject: dict[str, float] | Literal["autoreject_global", "autoreject_local"] | None = ( + None +) """ Peak-to-peak amplitude limits to mark epochs as bad. This allows you to remove epochs with strong transient artifacts. @@ -1384,7 +1511,7 @@ If `None` (default), do not apply artifact rejection. -If a dictionary, manually specify rejection thresholds (see examples). +If a dictionary, manually specify rejection thresholds (see examples). The thresholds provided here must be at least as stringent as those in [`ica_reject`][mne_bids_pipeline._config.ica_reject] if using ICA. In case of `'autoreject_global'`, thresholds for any channel that do not meet this @@ -1396,7 +1523,8 @@ channel type. If `"autoreject_local"`, use "local" `autoreject` to detect (and potentially repair) bad -channels in each epoch. Use [`autoreject_n_interpolate`][mne_bids_pipeline._config.autoreject_n_interpolate] +channels in each epoch. +Use [`autoreject_n_interpolate`][mne_bids_pipeline._config.autoreject_n_interpolate] to control how many channels are allowed to be bad before an epoch gets dropped. ???+ example "Example" @@ -1409,7 +1537,7 @@ ``` """ -reject_tmin: Optional[float] = None +reject_tmin: float | None = None """ Start of the time window used to reject epochs. If `None`, the window will start with the first time point. Has no effect if @@ -1421,7 +1549,7 @@ ``` """ -reject_tmax: Optional[float] = None +reject_tmax: float | None = None """ End of the time window used to reject epochs. If `None`, the window will end with the last time point. Has no effect if @@ -1450,43 +1578,132 @@ be considered (i.e., will remain marked as bad and not analyzed by autoreject). """ -############################################################################### -# DECODING -# -------- +# %% +# # Sensor-level analysis -decode: bool = True -""" -Whether to perform decoding (MVPA) on the specified -[`contrasts`][mne_bids_pipeline._config.contrasts]. Classifiers will be trained -on entire epochs ("full-epochs decoding"), and separately on each time point -("time-by-time decoding"), trying to learn how to distinguish the contrasting -conditions. -""" +# ## Condition contrasts -decoding_epochs_tmin: Optional[float] = 0.0 -""" -The first time sample to use for full epochs decoding. By default it starts -at 0. If `None`,, it starts at the beginning of the epoch. +contrasts: Sequence[tuple[str, str] | ArbitraryContrast] = [] """ +The conditions to contrast via a subtraction of ERPs / ERFs. The list elements +can either be tuples or dictionaries (or a mix of both). Each element in the +list corresponds to a single contrast. -decoding_epochs_tmax: Optional[float] = None -""" -The last time sample to use for full epochs decoding. By default it is set -to None so it ends at the end of the epoch. -""" +A tuple specifies a one-vs-one contrast, where the second condition is +subtracted from the first. -decoding_metric: str = "roc_auc" -""" -The metric to use for estimating classification performance. It can be -`'roc_auc'` or `'accuracy'` – or any other metric supported by `scikit-learn`. +If a dictionary, must contain the following keys: + +- `name`: a custom name of the contrast +- `conditions`: the conditions to contrast +- `weights`: the weights associated with each condition. + +Pass an empty list to avoid calculation of any contrasts. + +For the contrasts to be computed, the appropriate conditions must have been +epoched, and therefore the conditions should either match or be subsets of +`conditions` above. + +???+ example "Example" + Contrast the "left" and the "right" conditions by calculating + `left - right` at every time point of the evoked responses: + ```python + contrasts = [('left', 'right')] # Note we pass a tuple inside the list! + ``` + + Contrast the "left" and the "right" conditions within the "auditory" and + the "visual" modality, and "auditory" vs "visual" regardless of side: + ```python + contrasts = [('auditory/left', 'auditory/right'), + ('visual/left', 'visual/right'), + ('auditory', 'visual')] + ``` + + Contrast the "left" and the "right" regardless of side, and compute an + arbitrary contrast with a gradient of weights: + ```python + contrasts = [ + ('auditory/left', 'auditory/right'), + { + 'name': 'gradedContrast', + 'conditions': [ + 'auditory/left', + 'auditory/right', + 'visual/left', + 'visual/right' + ], + 'weights': [-1.5, -.5, .5, 1.5] + } + ] + ``` +""" + +# ## Decoding / MVPA + +decode: bool = True +""" +Whether to perform decoding (MVPA) on the specified +[`contrasts`][mne_bids_pipeline._config.contrasts]. Classifiers will be trained +on entire epochs ("full-epochs decoding"), and separately on each time point +("time-by-time decoding"), trying to learn how to distinguish the contrasting +conditions. +""" + +decoding_which_epochs: Literal["uncleaned", "after_ica", "after_ssp", "cleaned"] = ( + "cleaned" +) +""" +This setting controls which epochs will be fed into the decoding algorithms. + +!!! info + Decoding is a very powerful tool that often can deal with noisy data surprisingly + well. Depending on the specific type of data, artifacts, and analysis performed, + decoding performance may even improve with less pre-processed data, as + processing steps such as ICA or SSP often remove parts of the signal, too, in + addition to noise. By default, MNE-BIDS-Pipeline uses cleaned epochs for decoding, + but you may choose to use entirely uncleaned epochs, or epochs before the final + PTP-based rejection or Autoreject step. + +!!! info + No other sensor- and source-level processing steps will be affected by this setting + and use cleaned epochs only. + +If `"uncleaned"`, use the "raw" epochs before any ICA / SSP, PTP-based, or Autoreject +cleaning (epochs with the filename `*_epo.fif`, without a `proc-` part). + +If `"after_ica"` or `"after_ssp"`, use the epochs that were cleaned via ICA or SSP, but +before a followup cleaning through PTP-based rejection or Autorejct (epochs with the +filename `*proc-ica_epo.fif` or `*proc-ssp_epo.fif`). + +If `"cleaned"`, use the epochs after ICA / SSP and the following cleaning through +PTP-based rejection or Autoreject (epochs with the filename `*proc-clean_epo.fif`). +""" + +decoding_epochs_tmin: float | None = 0.0 +""" +The first time sample to use for full epochs decoding. By default it starts +at 0. If `None`,, it starts at the beginning of the epoch. Does not affect time-by-time +decoding. +""" + +decoding_epochs_tmax: float | None = None +""" +The last time sample to use for full epochs decoding. By default it is set +to None so it ends at the end of the epoch. +""" + +decoding_metric: str = "roc_auc" +""" +The metric to use for estimating classification performance. It can be +`'roc_auc'` or `'accuracy'` – or any other metric supported by `scikit-learn`. With ROC AUC, chance level is the same regardless of class balance, that is, you don't need to be worried about **exactly** balancing class sizes. """ -decoding_n_splits: int = 5 +decoding_n_splits: Annotated[int, Ge(2)] = 5 """ -The number of folds (also called "splits") to use in the cross-validation +The number of folds (also called "splits") to use in the K-fold cross-validation scheme. """ @@ -1509,7 +1726,7 @@ Because each classifier is trained and tested on **all** time points, this procedure may take a significant amount of time. -""" # noqa: E501 +""" decoding_time_generalization_decim: int = 1 """ @@ -1520,13 +1737,85 @@ resolution in the resulting matrix. """ +decoding_csp: bool = False +""" +Whether to run decoding via Common Spatial Patterns (CSP) analysis on the +data. CSP takes as input data covariances that are estimated on different +time and frequency ranges. This allows to obtain decoding scores defined over +time and frequency. +""" + +decoding_csp_times: FloatArrayLike | None = None +""" +The edges of the time bins to use for CSP decoding. +Must contain at least two elements. By default, 5 equally-spaced bins are +created across the non-negative time range of the epochs. +All specified time points must be contained in the epochs interval. +If an empty list, do not perform **time-frequency** analysis, and only run CSP on +**frequency** data. + +???+ example "Example" + Create 3 equidistant time bins (0–0.2, 0.2–0.4, 0.4–0.6 sec): + ```python + decoding_csp_times = np.linspace(start=0, stop=0.6, num=4) + ``` + Create 2 time bins of different durations (0–0.4, 0.4–0.6 sec): + ```python + decoding_csp_times = [0, 0.4, 0.6] + ``` +""" + +decoding_csp_freqs: dict[str, FloatArrayLike] | None = None +""" +The edges of the frequency bins to use for CSP decoding. + +This parameter must be a dictionary with: +- keys specifying the unique identifier or "name" to use for the frequency + range to be treated jointly during statistical testing (such as "alpha" or + "beta"), and +- values must be list-like objects containing at least two scalar values, + specifying the edges of the respective frequency bin(s), e.g., `[8, 12]`. + +Defaults to two frequency bins, one from +[`time_frequency_freq_min`][mne_bids_pipeline._config.time_frequency_freq_min] +to the midpoint between this value and +[`time_frequency_freq_max`][mne_bids_pipeline._config.time_frequency_freq_max]; +and the other from that midpoint to `time_frequency_freq_max`. +???+ example "Example" + Create two frequency bins, one for 4–8 Hz, and another for 8–14 Hz, which + will be clustered together during statistical testing (in the + time-frequency plane): + ```python + decoding_csp_freqs = { + 'custom_range': [4, 8, 14] + } + ``` + Create the same two frequency bins, but treat them separately during + statistical testing (i.e., temporal clustering only): + ```python + decoding_csp_freqs = { + 'theta': [4, 8], + 'alpha': [8, 14] + } + ``` + Create 5 equidistant frequency bins from 4 to 14 Hz: + ```python + decoding_csp_freqs = { + 'custom_range': np.linspace( + start=4, + stop=14, + num=5+1 # We need one more to account for the endpoint! + ) + } +""" + n_boot: int = 5000 """ The number of bootstrap resamples when estimating the standard error and confidence interval of the mean decoding scores. """ -cluster_forming_t_threshold: Optional[float] = None +cluster_forming_t_threshold: float | None = None """ The t-value threshold to use for forming clusters in the cluster-based permutation test run on the the time-by-time decoding scores. @@ -1545,7 +1834,7 @@ test to determine the significance of the decoding scores across participants. """ -cluster_permutation_p_threshold: float = 0.05 +cluster_permutation_p_threshold: Annotated[float, Interval(gt=0, lt=1)] = 0.05 """ The alpha level (p-value, p threshold) to use for rejecting the null hypothesis that the clusters show no significant difference between conditions. This is @@ -1556,28 +1845,9 @@ [`cluster_forming_t_threshold`][mne_bids_pipeline._config.cluster_forming_t_threshold]. """ -############################################################################### -# GROUP AVERAGE SENSORS -# --------------------- +# ## Time-frequency analysis -interpolate_bads_grand_average: bool = True -""" -Interpolate bad sensors in each dataset before calculating the grand -average. This parameter is passed to the `mne.grand_average` function via -the keyword argument `interpolate_bads`. It requires to have channel -locations set. - -???+ example "Example" - ```python - interpolate_bads_grand_average = True - ``` -""" - -############################################################################### -# TIME-FREQUENCY -# -------------- - -time_frequency_conditions: Iterable[str] = [] +time_frequency_conditions: Sequence[str] = [] """ The conditions to compute time-frequency decomposition on. @@ -1587,7 +1857,7 @@ ``` """ -time_frequency_freq_min: Optional[float] = 8 +time_frequency_freq_min: float | None = 8 """ Minimum frequency for the time frequency analysis, in Hz. ???+ example "Example" @@ -1596,7 +1866,7 @@ ``` """ -time_frequency_freq_max: Optional[float] = 40 +time_frequency_freq_max: float | None = 40 """ Maximum frequency for the time frequency analysis, in Hz. ???+ example "Example" @@ -1605,7 +1875,7 @@ ``` """ -time_frequency_cycles: Optional[Union[float, FloatArrayLike]] = None +time_frequency_cycles: float | FloatArrayLike | None = None """ The number of cycles to use in the Morlet wavelet. This can be a single number or one per frequency, where frequencies are calculated via @@ -1624,83 +1894,7 @@ This also applies to CSP analysis. """ -############################################################################### -# TIME-FREQUENCY CSP -# ------------------ - -decoding_csp: bool = False -""" -Whether to run decoding via Common Spatial Patterns (CSP) analysis on the -data. CSP takes as input data covariances that are estimated on different -time and frequency ranges. This allows to obtain decoding scores defined over -time and frequency. -""" - -decoding_csp_times: Optional[FloatArrayLike] = None -""" -The edges of the time bins to use for CSP decoding. -Must contain at least two elements. By default, 5 equally-spaced bins are -created across the non-negative time range of the epochs. -All specified time points must be contained in the epochs interval. -If `None`, do not perform **time-frequency** analysis, and only run CSP on -**frequency** data. - -???+ example "Example" - Create 3 equidistant time bins (0–0.2, 0.2–0.4, 0.4–0.6 sec): - ```python - decoding_csp_times = np.linspace(start=0, stop=0.6, num=4) - ``` - Create 2 time bins of different durations (0–0.4, 0.4–0.6 sec): - ```python - decoding_csp_times = [0, 0.4, 0.6] - ``` -""" - -decoding_csp_freqs: Optional[Dict[str, FloatArrayLike]] = None -""" -The edges of the frequency bins to use for CSP decoding. - -This parameter must be a dictionary with: -- keys specifying the unique identifier or "name" to use for the frequency - range to be treated jointly during statistical testing (such as "alpha" or - "beta"), and -- values must be list-like objects containing at least two scalar values, - specifying the edges of the respective frequency bin(s), e.g., `[8, 12]`. - -Defaults to two frequency bins, one from -[`time_frequency_freq_min`][mne_bids_pipeline._config.time_frequency_freq_min] -to the midpoint between this value and -[`time_frequency_freq_max`][mne_bids_pipeline._config.time_frequency_freq_max]; -and the other from that midpoint to `time_frequency_freq_max`. -???+ example "Example" - Create two frequency bins, one for 4–8 Hz, and another for 8–14 Hz, which - will be clustered together during statistical testing (in the - time-frequency plane): - ```python - decoding_csp_freqs = { - 'custom_range': [4, 8, 14] - } - ``` - Create the same two frequency bins, but treat them separately during - statistical testing (i.e., temporal clustering only): - ```python - decoding_csp_freqs = { - 'theta': [4, 8], - 'alpha': [8, 14] - } - ``` - Create 5 equidistant frequency bins from 4 to 14 Hz: - ```python - decoding_csp_freqs = { - 'custom_range': np.linspace( - start=4, - stop=14, - num=5+1 # We need one more to account for the endpoint! - ) - } -""" - -time_frequency_baseline: Optional[Tuple[float, float]] = None +time_frequency_baseline: tuple[float, float] | None = None """ Baseline period to use for the time-frequency analysis. If `None`, no baseline. ???+ example "Example" @@ -1719,28 +1913,45 @@ ``` """ -time_frequency_crop: Optional[dict] = None +time_frequency_crop: dict[str, float] | None = None """ Period and frequency range to crop the time-frequency analysis to. If `None`, no cropping. ???+ example "Example" ```python - time_frequency_crop = dict(tmin=-0.3, tmax=0.5, fmin=5, fmax=20) + time_frequency_crop = dict(tmin=-0.3, tmax=0.5, fmin=5., fmax=20.) ``` """ -############################################################################### -# SOURCE ESTIMATION PARAMETERS -# ---------------------------- -# +# ## Group-level analysis + +interpolate_bads_grand_average: bool = True +""" +Interpolate bad sensors in each dataset before calculating the grand +average. This parameter is passed to the `mne.grand_average` function via +the keyword argument `interpolate_bads`. It requires to have channel +locations set. + +???+ example "Example" + ```python + interpolate_bads_grand_average = True + ``` +""" + +# %% +# # Source-level analysis + +# ## General source analysis settings run_source_estimation: bool = True """ Whether to run source estimation processing steps if not explicitly requested. """ -use_template_mri: Optional[str] = None +# ## BEM surface + +use_template_mri: str | None = None """ Whether to use a template MRI subject such as FreeSurfer's `fsaverage` subject. This may come in handy if you don't have individual MR scans of your @@ -1812,7 +2023,9 @@ Whether to print the complete output of FreeSurfer commands. Note that if `False`, no FreeSurfer output might be displayed at all!""" -mri_t1_path_generator: Optional[Callable[[BIDSPath], BIDSPath]] = None +# ## Source space & forward solution + +mri_t1_path_generator: Callable[[BIDSPath], BIDSPath] | None = None """ To perform source-level analyses, the Pipeline needs to generate a transformation matrix that translates coordinates from MEG and EEG sensor @@ -1872,7 +2085,7 @@ def get_t1_from_meeg(bids_path): ``` """ -mri_landmarks_kind: Optional[Callable[[BIDSPath], str]] = None +mri_landmarks_kind: Callable[[BIDSPath], str] | None = None """ This config option allows to look for specific landmarks in the json sidecar file of the T1 MRI file. This can be useful when we have different @@ -1889,7 +2102,7 @@ def mri_landmarks_kind(bids_path): ``` """ -spacing: Union[Literal["oct5", "oct6", "ico4", "ico5", "all"], int] = "oct6" +spacing: Literal["oct5", "oct6", "ico4", "ico5", "all"] | int = "oct6" """ The spacing to use. Can be `'ico#'` for a recursively subdivided icosahedron, `'oct#'` for a recursively subdivided octahedron, @@ -1904,24 +2117,35 @@ def mri_landmarks_kind(bids_path): Exclude points closer than this distance (mm) to the bounding surface. """ -loose: Union[float, Literal["auto"]] = 0.2 +# ## Inverse solution + +loose: Annotated[float, Interval(ge=0, le=1)] | Literal["auto"] = 0.2 """ -Value that weights the source variances of the dipole components -that are parallel (tangential) to the cortical surface. If `0`, then the -inverse solution is computed with **fixed orientation.** -If `1`, it corresponds to **free orientation.** -The default value, `'auto'`, is set to `0.2` for surface-oriented source -spaces, and to `1.0` for volumetric, discrete, or mixed source spaces, -unless `fixed is True` in which case the value 0. is used. +A value between 0 and 1 that weights the source variances of the dipole components +that are parallel (tangential) to the cortical surface. + +If `0`, then the inverse solution is computed with **fixed orientation**, i.e., +only dipole components perpendicular to the cortical surface are considered. + +If `1`, it corresponds to **free orientation**, i.e., dipole components with any +orientation are considered. + +The default value, `0.2`, is suitable for surface-oriented source spaces. + +For volume or mixed source spaces, choose `1.0`. + +!!! info + Support for modeling volume and mixed source spaces will be added in a future + version of MNE-BIDS-Pipeline. """ -depth: Optional[Union[float, dict]] = 0.8 +depth: Annotated[float, Interval(ge=0, le=1)] | dict[str, Any] = 0.8 """ -If float (default 0.8), it acts as the depth weighting exponent (`exp`) -to use (must be between 0 and 1). None is equivalent to 0, meaning no -depth weighting is performed. Can also be a `dict` containing additional -keyword arguments to pass to :func:`mne.forward.compute_depth_prior` -(see docstring for details and defaults). +If a number, it acts as the depth weighting exponent to use +(must be between `0` and`1`), with`0` meaning no depth weighting is performed. + +Can also be a dictionary containing additional keyword arguments to pass to +`mne.forward.compute_depth_prior` (see docstring for details and defaults). """ inverse_method: Literal["MNE", "dSPM", "sLORETA", "eLORETA"] = "dSPM" @@ -1930,11 +2154,11 @@ def mri_landmarks_kind(bids_path): solution. """ -noise_cov: Union[ - Tuple[Optional[float], Optional[float]], - Literal["emptyroom", "rest", "ad-hoc"], - Callable[[BIDSPath], Covariance], -] = (None, 0) +noise_cov: ( + tuple[float | None, float | None] + | Literal["emptyroom", "rest", "ad-hoc"] + | Callable[[BIDSPath], Covariance] +) = (None, 0) """ Specify how to estimate the noise covariance matrix, which is used in inverse modeling. @@ -1998,15 +2222,33 @@ def noise_cov(bids_path): ``` """ -source_info_path_update: Optional[Dict[str, str]] = dict(suffix="ave") +noise_cov_method: Literal[ + "shrunk", + "empirical", + "diagonal_fixed", + "oas", + "ledoit_wolf", + "factor_analysis", + "shrinkage", + "pca", + "auto", +] = "shrunk" +""" +The noise covariance estimation method to use. See the MNE-Python documentation +of `mne.compute_covariance` for details. """ -When computing the forward and inverse solutions, by default the pipeline -retrieves the `mne.Info` object from the cleaned evoked data. However, in -certain situations you may wish to use a different `Info`. +source_info_path_update: dict[str, str] | None = None +""" +When computing the forward and inverse solutions, it is important to +provide the `mne.Info` object from the data on which the noise covariance was +computed, to avoid problems resulting from mismatching ranks. This parameter allows you to explicitly specify from which file to retrieve the `mne.Info` object. Use this parameter to supply a dictionary to `BIDSPath.update()` during the forward and inverse processing steps. +If set to `None` (default), the info will be retrieved either from the raw +file specified in `noise_cov`, or the cleaned evoked +(if `noise_cov` is None or `ad-hoc`). ???+ example "Example" Use the `Info` object stored in the cleaned epochs: @@ -2014,9 +2256,20 @@ def noise_cov(bids_path): source_info_path_update = {'processing': 'clean', 'suffix': 'epo'} ``` + + Use the `Info` object stored in a raw file (e.g. resting state): + ```python + source_info_path_update = {'processing': 'clean', + 'suffix': 'raw', + 'task': 'rest'} + ``` + If you set `noise_cov = 'rest'` and `source_path_info = None`, + then the behavior is identical to that above + (it will automatically use the resting state data). + """ -inverse_targets: List[Literal["evoked"]] = ["evoked"] +inverse_targets: list[Literal["evoked"]] = ["evoked"] """ On which data to apply the inverse operator. Currently, the only supported @@ -2035,11 +2288,12 @@ def noise_cov(bids_path): ``` """ -############################################################################### -# Report generation -# ----------------- +# %% +# # Reports + +# ## Report generation -report_evoked_n_time_points: Optional[int] = None +report_evoked_n_time_points: int | None = None """ Specifies the number of time points to display for each evoked in the report. If `None`, it defaults to the current default in MNE-Python. @@ -2051,7 +2305,7 @@ def noise_cov(bids_path): ``` """ -report_stc_n_time_points: Optional[int] = None +report_stc_n_time_points: int | None = None """ Specifies the number of time points to display for each source estimates in the report. If `None`, it defaults to the current default in MNE-Python. @@ -2063,9 +2317,65 @@ def noise_cov(bids_path): ``` """ -############################################################################### -# Execution -# --------- +report_add_epochs_image_kwargs: dict[str, Any] | None = None +""" +Specifies the limits for the color scales of the epochs_image in the report. +If `None`, it defaults to the current default in MNE-Python. + +???+ example "Example" + Set vmin and vmax to the epochs rejection thresholds (with unit conversion): + + ```python + report_add_epochs_image_kwargs = { + "grad": {"vmin": 0, "vmax": 1e13 * reject["grad"]}, # fT/cm + "mag": {"vmin": 0, "vmax": 1e15 * reject["mag"]}, # fT + } + ``` +""" + +# %% +# # Caching +# +# Per default, the pipeline output is cached (temporarily stored), +# to avoid unnecessary reruns of previously computed steps. +# Yet, for consistency, changes in configuration parameters trigger +# automatic reruns of previous steps. +# !!! info +# To force rerunning a given step, run the pipeline with the option: `--no-cache`. + +memory_location: PathLike | bool | None = True +""" +If not None (or False), caching will be enabled and the cache files will be +stored in the given directory. The default (True) will use a +`"_cache"` subdirectory (name configurable via the +[`memory_subdir`][mne_bids_pipeline._config.memory_subdir] +variable) in the BIDS derivative root of the dataset. +""" + +memory_subdir: str = "_cache" +""" +The caching directory name to use if `memory_location` is `True`. +""" + +memory_file_method: Literal["mtime", "hash"] = "mtime" +""" +The method to use for cache invalidation (i.e., detecting changes). Using the +"modified time" reported by the filesystem (`'mtime'`, default) is very fast +but requires that the filesystem supports proper mtime reporting. Using file +hashes (`'hash'`) is slower and requires reading all input files but should +work on any filesystem. +""" + +memory_verbose: int = 0 +""" +The verbosity to use when using memory. The default (0) does not print, while +1 will print the function calls that will be cached. See the documentation for +the joblib.Memory class for more information.""" + +# %% +# # Parallelization +# +# These options control parallel processing (e.g., multiple subjects at once), n_jobs: int = 1 """ @@ -2087,7 +2397,7 @@ def noise_cov(bids_path): Ignored if `parallel_backend` is not `'dask'`. """ -dask_temp_dir: Optional[PathLike] = None +dask_temp_dir: PathLike | None = None """ The temporary directory to use by Dask. Dask places lock-files in this directory, and also uses it to "spill" RAM contents to disk if the amount of @@ -2105,19 +2415,10 @@ def noise_cov(bids_path): The maximum amount of RAM per Dask worker. """ -random_state: Optional[int] = 42 -""" -You can specify the seed of the random number generator (RNG). -This setting is passed to the ICA algorithm and to the decoding function, -ensuring reproducible results. Set to `None` to avoid setting the RNG -to a defined state. -""" - -shortest_event: int = 1 -""" -Minimum number of samples an event must last. If the -duration is less than this, an exception will be raised. -""" +# %% +# # Logging +# +# These options control how much logging output is produced. log_level: Literal["info", "error"] = "info" """ @@ -2129,6 +2430,13 @@ def noise_cov(bids_path): Set the MNE-Python logging verbosity. """ + +# %% +# # Error handling +# +# These options control how errors while processing the data or the configuration file +# are handled. + on_error: Literal["continue", "abort", "debug"] = "abort" """ Whether to abort processing as soon as an error occurs, continue with all other @@ -2139,35 +2447,6 @@ def noise_cov(bids_path): Enabling debug mode deactivates parallel processing. """ -memory_location: Optional[Union[PathLike, bool]] = True -""" -If not None (or False), caching will be enabled and the cache files will be -stored in the given directory. The default (True) will use a -`"_cache"` subdirectory (name configurable via the -[`memory_subdir`][mne_bids_pipeline._config.memory_subdir] -variable) in the BIDS derivative root of the dataset. -""" - -memory_subdir: str = "_cache" -""" -The caching directory name to use if `memory_location` is `True`. -""" - -memory_file_method: Literal["mtime", "hash"] = "mtime" -""" -The method to use for cache invalidation (i.e., detecting changes). Using the -"modified time" reported by the filesystem (`'mtime'`, default) is very fast -but requires that the filesystem supports proper mtime reporting. Using file -hashes (`'hash'`) is slower and requires reading all input files but should -work on any filesystem. -""" - -memory_verbose: int = 0 -""" -The verbosity to use when using memory. The default (0) does not print, while -1 will print the function calls that will be cached. See the documentation for -the joblib.Memory class for more information.""" - config_validation: Literal["raise", "warn", "ignore"] = "raise" """ How strictly to validate the configuration. Errors are always raised for diff --git a/mne_bids_pipeline/_config_import.py b/mne_bids_pipeline/_config_import.py index 14a55df2e..da5d65e68 100644 --- a/mne_bids_pipeline/_config_import.py +++ b/mne_bids_pipeline/_config_import.py @@ -1,36 +1,53 @@ import ast import copy -from dataclasses import field import difflib -from functools import partial -import importlib +import importlib.util import os import pathlib +import re +from dataclasses import field +from functools import partial +from inspect import signature from types import SimpleNamespace -from typing import Optional, List +from typing import Any import matplotlib -import numpy as np import mne +import numpy as np +from mne_bids import get_entity_vals +from pydantic import BaseModel, ConfigDict, ValidationError -from pydantic import ValidationError -from pydantic.dataclasses import dataclass - -from ._logging import logger, gen_log_kwargs +from ._config_utils import get_subjects_sessions +from ._logging import gen_log_kwargs, logger from .typing import PathLike +class ConfigError(ValueError): + pass + + def _import_config( *, - config_path: Optional[PathLike], - overrides: Optional[SimpleNamespace] = None, + config_path: PathLike | None, + overrides: SimpleNamespace | None = None, check: bool = True, log: bool = True, ) -> SimpleNamespace: """Import the default config and the user's config.""" # Get the default config = _get_default_config() + # Public names users generally will have in their config valid_names = [d for d in dir(config) if not d.startswith("_")] + # Names that we will reduce the SimpleConfig to before returning + # (see _update_with_user_config) + keep_names = [d for d in dir(config) if not d.startswith("__")] + [ + "config_path", + "PIPELINE_NAME", + "VERSION", + "CODE_URL", + "_raw_split_size", + "_epochs_split_size", + ] # Update with user config user_names = _update_with_user_config( @@ -40,27 +57,42 @@ def _import_config( log=log, ) - extra_exec_params_keys = () + extra_exec_params_keys: tuple[str, ...] = () extra_config = os.getenv("_MNE_BIDS_STUDY_TESTING_EXTRA_CONFIG", "") if extra_config: msg = f"With testing config: {extra_config}" logger.info(**gen_log_kwargs(message=msg, emoji="override")) - _update_config_from_path( - config=config, - config_path=extra_config, + extra_names = _update_config_from_path( + config=config, config_path=extra_config, include_private=True + ) + # Update valid_extra_names as needed if test configs in tests/test_run.py change + valid_extra_names = set( + ( + "_n_jobs", + "_raw_split_size", + "_epochs_split_size", + "subjects_dir", # test_session_specific_mri + "deriv_root", # test_session_specific_mri + "Path", # test_session_specific_mri + ) ) - extra_exec_params_keys = ("_n_jobs",) + assert set(extra_names) - valid_extra_names == set(), extra_names + extra_exec_params_keys = tuple(set(["_n_jobs"]) & set(extra_names)) + keep_names.extend(extra_exec_params_keys) # Check it if check: _check_config(config, config_path) _check_misspellings_removals( - config, valid_names=valid_names, user_names=user_names, log=log, + config_validation=config.config_validation, ) + # Finally, reduce to our actual supported params (all keep_names should be present) + config = SimpleNamespace(**{k: getattr(config, k) for k in keep_names}) + # Take some standard actions mne.set_log_level(verbose=config.mne_log_level.upper()) @@ -95,7 +127,7 @@ def _import_config( return config -def _get_default_config(): +def _get_default_config() -> SimpleNamespace: from . import _config # Don't use _config itself as it's mutable -- make a new object @@ -105,7 +137,7 @@ def _get_default_config(): ignore_keys = { name.asname or name.name for element in tree.body - if isinstance(element, (ast.Import, ast.ImportFrom)) + if isinstance(element, ast.Import | ast.ImportFrom) for name in element.names } config = SimpleNamespace( @@ -122,7 +154,8 @@ def _update_config_from_path( *, config: SimpleNamespace, config_path: PathLike, -): + include_private: bool = False, +) -> list[str]: user_names = list() config_path = pathlib.Path(config_path).expanduser().resolve(strict=True) # Import configuration from an arbitrary path without having to fiddle @@ -130,16 +163,19 @@ def _update_config_from_path( spec = importlib.util.spec_from_file_location( name="custom_config", location=config_path ) + assert spec is not None + assert spec.loader is not None custom_cfg = importlib.util.module_from_spec(spec) + assert custom_cfg is not None spec.loader.exec_module(custom_cfg) for key in dir(custom_cfg): if not key.startswith("__"): # don't validate private vars, but do add to config # (e.g., so that our hidden _raw_split_size is included) - if not key.startswith("_"): + if include_private or not key.startswith("_"): user_names.append(key) val = getattr(custom_cfg, key) - logger.debug("Overwriting: %s -> %s" % (key, val)) + logger.debug(f"Overwriting: {key} -> {val}") setattr(config, key, val) return user_names @@ -147,10 +183,10 @@ def _update_config_from_path( def _update_with_user_config( *, config: SimpleNamespace, # modified in-place - config_path: Optional[PathLike], - overrides: Optional[SimpleNamespace], + config_path: PathLike | None, + overrides: SimpleNamespace | None, log: bool = False, -) -> List[str]: +) -> list[str]: # 1. Basics and hidden vars from . import __version__ @@ -212,15 +248,14 @@ def _update_with_user_config( config.n_jobs = 1 if log and config.parallel_backend != "loky": msg = ( - 'Setting config.parallel_backend="loky" because ' - 'config.on_error="debug"' + 'Setting config.parallel_backend="loky" because config.on_error="debug"' ) logger.info(**gen_log_kwargs(message=msg, **log_kwargs)) config.parallel_backend = "loky" return user_names -def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> None: +def _check_config(config: SimpleNamespace, config_path: PathLike | None) -> None: _pydantic_validate(config=config, config_path=config_path) # Eventually all of these could be pydantic-validated, but for now we'll @@ -233,11 +268,68 @@ def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> N and len(set(config.ch_types).intersection(("meg", "grad", "mag"))) == 0 ): raise ValueError("Cannot use Maxwell filter without MEG channels.") + mf_reserved_kwargs = ( + "raw", + "calibration", + "cross_talk", + "st_duration", + "st_correlation", + "origin", + "coord_frame", + "destination", + "head_pos", + "extended_proj", + "int_order", + "ext_order", + ) + # check `mf_extra_kws` for things that shouldn't be in there + if duplicates := (set(config.mf_extra_kws) & set(mf_reserved_kwargs)): + raise ConfigError( + f"`mf_extra_kws` contains keys {', '.join(sorted(duplicates))} that are " + "handled by dedicated config keys. Please remove them from `mf_extra_kws`." + ) + # if `destination="twa"` make sure `mf_mc=True` + if ( + isinstance(config.mf_destination, str) + and config.mf_destination == "twa" + and not config.mf_mc + ): + raise ConfigError( + "cannot compute time-weighted average head position (mf_destination='twa') " + "without movement compensation. Please set `mf_mc=True` in your config." + ) + # if `dict` passed for ssp_ecg_channel, make sure its keys are valid + if config.ssp_ecg_channel and isinstance(config.ssp_ecg_channel, dict): + pattern = re.compile(r"^sub-[A-Za-z\d]+(_ses-[A-Za-z\d]+)?$") + matches = set(filter(pattern.match, config.ssp_ecg_channel)) + newline_indent = "\n " + if mismatch := (set(config.ssp_ecg_channel) - matches): + raise ConfigError( + "Malformed keys in ssp_ecg_channel dict:\n " + f"{newline_indent.join(sorted(mismatch))}" + ) + # also make sure there are values for all subjects/sessions: + missing = list() + subjects_sessions = get_subjects_sessions(config) + for sub, sessions in subjects_sessions.items(): + for ses in sessions: + if ( + config.ssp_ecg_channel.get(f"sub-{sub}") is None + and config.ssp_ecg_channel.get(f"sub-{sub}_ses-{ses}") is None + ): + missing.append( + f"sub-{sub}" if ses is None else f"sub-{sub}_ses-{ses}" + ) + if missing: + raise ConfigError( + f"Missing entries in ssp_ecg_channel:\n {newline_indent.join(missing)}" + ) reject = config.reject ica_reject = config.ica_reject if config.spatial_filter == "ica": - if config.ica_l_freq < 1: + effective_ica_l_freq = max([config.ica_l_freq or 0.0, config.l_freq or 0.0]) + if effective_ica_l_freq < 1: raise ValueError( "You requested to high-pass filter the data before ICA with " f"ica_l_freq={config.ica_l_freq} Hz. Please increase this " @@ -270,17 +362,6 @@ def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> N f'ica_reject["{ch_type}"] ({ica_reject[ch_type]})' ) - if not config.ch_types: - raise ValueError("Please specify ch_types in your configuration.") - - _VALID_TYPES = ("meg", "mag", "grad", "eeg") - if any(ch_type not in _VALID_TYPES for ch_type in config.ch_types): - raise ValueError( - "Invalid channel type passed. Please adjust `ch_types` in your " - f"configuration, got {config.ch_types} but supported types are " - f"{_VALID_TYPES}" - ) - if config.noise_cov == "emptyroom" and "eeg" in config.ch_types: raise ValueError( "You requested to process data that contains EEG channels. In " @@ -297,6 +378,16 @@ def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> N "Please set process_empty_room = True" ) + if ( + config.allow_missing_sessions + and "ignore_suffixes" not in signature(get_entity_vals).parameters + ): + raise ConfigError( + "You've requested to `allow_missing_sessions`, but this functionality " + "requires a newer version of `mne_bids` than you have available. Please " + "update MNE-BIDS (or if on the latest version, install the dev version)." + ) + bl = config.baseline if bl is not None: if (bl[0] is not None and bl[0] < config.epochs_tmin) or ( @@ -313,16 +404,7 @@ def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> N f"but you set baseline={bl}" ) - # check decoding parameters - if config.decoding_n_splits < 2: - raise ValueError("decoding_n_splits should be at least 2.") - # check cluster permutation parameters - if not 0 < config.cluster_permutation_p_threshold < 1: - raise ValueError( - "cluster_permutation_p_threshold should be in the (0, 1) interval." - ) - if config.cluster_n_permutations < 10 / config.cluster_permutation_p_threshold: raise ValueError( "cluster_n_permutations is not big enough to calculate " @@ -346,16 +428,19 @@ def _check_config(config: SimpleNamespace, config_path: Optional[PathLike]) -> N ) -def _default_factory(key, val): +def _default_factory(key: str, val: Any) -> Any: # convert a default to a default factory if needed, having an explicit # allowlist of non-empty ones allowlist = [ {"n_mag": 1, "n_grad": 1, "n_eeg": 1}, # n_proj_* {"custom": (8, 24.0, 40)}, # decoding_csp_freqs - {"suffix": "ave"}, # source_info_path_update ["evoked"], # inverse_targets [4, 8, 16], # autoreject_n_interpolate ] + + def default_factory() -> Any: + return val + for typ in (dict, list): if isinstance(val, typ): try: @@ -365,7 +450,7 @@ def _default_factory(key, val): default_factory = typ else: if typ is dict: - default_factory = partial(typ, **allowlist[idx]) + default_factory = partial(typ, **allowlist[idx]) # type: ignore else: assert typ is list default_factory = partial(typ, allowlist[idx]) @@ -375,44 +460,41 @@ def _default_factory(key, val): def _pydantic_validate( config: SimpleNamespace, - config_path: Optional[PathLike], -): + config_path: PathLike | None, +) -> None: """Create dataclass from config type hints and validate with pydantic.""" # https://docs.pydantic.dev/latest/usage/dataclasses/ from . import _config as root_config - annotations = copy.deepcopy(root_config.__annotations__) # just be safe - attrs = { - key: _default_factory(key, val) - for key, val in root_config.__dict__.items() - if key in annotations - } - # everything should be type annotated, make sure they are - asym = set(attrs).symmetric_difference(set(annotations)) - assert asym == set(), asym + # Modify annotations to add nested strict parsing + annotations = dict() + attrs = dict() + for key, annot in root_config.__annotations__.items(): + annotations[key] = annot + attrs[key] = _default_factory(key, root_config.__dict__[key]) name = "user configuration" if config_path is not None: name += f" from {config_path}" - UserConfig = type( - name, - (object,), - {"__annotations__": annotations, **attrs}, - ) - dataclass_config = dict( - arbitrary_types_allowed=False, + model_config = ConfigDict( + arbitrary_types_allowed=True, # needed in 2.6.0 to allow DigMontage for example validate_assignment=True, strict=True, # do not allow float for int for example + extra="forbid", + ) + UserConfig = type( + name, + (BaseModel,), + {"__annotations__": annotations, "model_config": model_config, **attrs}, ) - UserConfig = dataclass(config=dataclass_config)(UserConfig) # Now use pydantic to automagically validate user_vals = {key: val for key, val in config.__dict__.items() if key in annotations} try: - UserConfig(**user_vals) + UserConfig.model_validate(user_vals) # type: ignore[attr-defined] except ValidationError as err: raise ValueError(str(err)) from None -_REMOVED_NAMES = { +_REMOVED_NAMES: dict[str, dict[str, str | None]] = { "debug": dict( new_name="on_error", instead='use on_error="debug" instead', @@ -427,24 +509,26 @@ def _pydantic_validate( "N_JOBS": dict( new_name="n_jobs", ), + "ica_ctps_ecg_threshold": dict( + new_name="ica_ecg_threshold", + ), } def _check_misspellings_removals( - config: SimpleNamespace, *, - valid_names: List[str], - user_names: List[str], + valid_names: list[str], + user_names: list[str], log: bool, + config_validation: str, ) -> None: # for each name in the user names, check if it's in the valid names but # the correct one is not defined - valid_names = set(valid_names) for user_name in user_names: if user_name not in valid_names: # find the closest match closest_match = difflib.get_close_matches(user_name, valid_names, n=1) - msg = f"Found a variable named {repr(user_name)} in your custom " "config," + msg = f"Found a variable named {repr(user_name)} in your custom config," if closest_match and closest_match[0] not in user_names: this_msg = ( f"{msg} did you mean {repr(closest_match[0])}? " @@ -452,7 +536,7 @@ def _check_misspellings_removals( "the variable to reduce ambiguity and avoid this message, " "or set config.config_validation to 'warn' or 'ignore'." ) - _handle_config_error(this_msg, log, config) + _handle_config_error(this_msg, log, config_validation) if user_name in _REMOVED_NAMES: new = _REMOVED_NAMES[user_name]["new_name"] if new not in user_names: @@ -463,16 +547,16 @@ def _check_misspellings_removals( f"{msg} this variable has been removed as a valid " f"config option, {instead}." ) - _handle_config_error(this_msg, log, config) + _handle_config_error(this_msg, log, config_validation) def _handle_config_error( msg: str, log: bool, - config: SimpleNamespace, + config_validation: str, ) -> None: - if config.config_validation == "raise": + if config_validation == "raise": raise ValueError(msg) - elif config.config_validation == "warn": + elif config_validation == "warn": if log: logger.warning(**gen_log_kwargs(message=msg, emoji="🛟")) diff --git a/mne_bids_pipeline/_config_template.py b/mne_bids_pipeline/_config_template.py index 1925e020e..ac6bcce6f 100644 --- a/mne_bids_pipeline/_config_template.py +++ b/mne_bids_pipeline/_config_template.py @@ -1,8 +1,7 @@ +import ast from pathlib import Path -from typing import List - -from ._logging import logger, gen_log_kwargs +from ._logging import gen_log_kwargs, logger CONFIG_SOURCE_PATH = Path(__file__).parent / "_config.py" @@ -17,16 +16,38 @@ def create_template_config( raise FileExistsError(f"The specified path already exists: {target_path}") # Create a template by commenting out most of the lines in _config.py - config: List[str] = [] - with open(CONFIG_SOURCE_PATH, "r", encoding="utf-8") as f: - for line in f: - line = ( - line if line.startswith(("#", "\n", "import", "from")) else f"# {line}" - ) - config.append(line) - - target_path.write_text("".join(config), encoding="utf-8") - message = f"Successfully created template configuration file at: " f"{target_path}" + config: list[str] = ["# Template config file for mne_bids_pipeline.", ""] + text = CONFIG_SOURCE_PATH.read_text(encoding="utf-8") + # skip file header + to_strip = "# Default settings for data processing and analysis.\n\n" + if text.startswith(to_strip): + text = text[len(to_strip) :] + lines = text.split("\n") + # make sure we catch all imports and assignments + tree = ast.parse(text, type_comments=True) + for ix, line in enumerate(lines, start=1): # ast.parse assigns 1-indexed `lineno`! + nodes = [_node for _node in tree.body if _node.lineno <= ix <= _node.end_lineno] # type:ignore[operator] + if not nodes: + # blank lines and comments aren't parsed by `ast.parse`: + assert line == "" or line.startswith("#"), line + else: + assert len(nodes) == 1, nodes + node = nodes[0] + # config value assignments should become commented out: + if isinstance(node, ast.AnnAssign): + line = f"# {line}" + # imports get written as-is (not commented out): + elif isinstance(node, ast.Import | ast.ImportFrom): + pass + # everything else should be (multiline) string literals: + else: + assert isinstance(node, ast.Expr), node + assert isinstance(node.value, ast.Constant), node.value + assert isinstance(node.value.value, str), node.value.value + config.append(line) + + target_path.write_text("\n".join(config), encoding="utf-8") + message = f"Successfully created template configuration file at: {target_path}" logger.info(**gen_log_kwargs(message=message, emoji="✅")) message = "Please edit the file before running the pipeline." diff --git a/mne_bids_pipeline/_config_utils.py b/mne_bids_pipeline/_config_utils.py index 35ed07512..cbad13fca 100644 --- a/mne_bids_pipeline/_config_utils.py +++ b/mne_bids_pipeline/_config_utils.py @@ -3,21 +3,23 @@ import copy import functools import pathlib -from typing import List, Optional, Union, Iterable, Tuple, Dict, TypeVar, Literal, Any -from types import SimpleNamespace, ModuleType +from collections.abc import Iterable, Sized +from inspect import signature +from types import ModuleType, SimpleNamespace +from typing import Any, Literal, TypeVar -import numpy as np import mne import mne_bids +import numpy as np from mne_bids import BIDSPath -from ._logging import logger, gen_log_kwargs +from ._logging import gen_log_kwargs, logger from .typing import ArbitraryContrast try: - _keys_arbitrary_contrast = set(ArbitraryContrast.__required_keys__) + _set_keys_arbitrary_contrast = set(ArbitraryContrast.__required_keys__) except Exception: - _keys_arbitrary_contrast = set(ArbitraryContrast.__annotations__.keys()) + _set_keys_arbitrary_contrast = set(ArbitraryContrast.__annotations__.keys()) def get_fs_subjects_dir(config: SimpleNamespace) -> pathlib.Path: @@ -26,65 +28,80 @@ def get_fs_subjects_dir(config: SimpleNamespace) -> pathlib.Path: # avoid an error message when a user doesn't intend to run the source # analysis steps anyway. raise ValueError( - 'When specifying a "deriv_root", you must also supply a ' '"subjects_dir".' + 'When specifying a "deriv_root", you must also supply a "subjects_dir".' ) if not config.subjects_dir: + assert isinstance(config.bids_root, pathlib.Path) return config.bids_root / "derivatives" / "freesurfer" / "subjects" else: return pathlib.Path(config.subjects_dir).expanduser().resolve() -def get_fs_subject(config: SimpleNamespace, subject: str) -> str: +def get_fs_subject( + config: SimpleNamespace, subject: str, session: str | None = None +) -> str: subjects_dir = get_fs_subjects_dir(config) if config.use_template_mri is not None: + assert isinstance(config.use_template_mri, str), type(config.use_template_mri) return config.use_template_mri - if (pathlib.Path(subjects_dir) / subject).exists(): + if session is not None: + return f"sub-{subject}_ses-{session}" + elif (pathlib.Path(subjects_dir) / subject).exists(): return subject else: return f"sub-{subject}" -@functools.lru_cache(maxsize=None) -def _get_entity_vals_cached(*args, **kwargs) -> List[str]: - return mne_bids.get_entity_vals(*args, **kwargs) +def _has_session_specific_anat( + subject: str, session: str | None, subjects_dir: pathlib.Path +) -> bool: + return (subjects_dir / f"sub-{subject}_ses-{session}").exists() + + +@functools.cache +def _get_entity_vals_cached( + *args: list[Any], + **kwargs: dict[str, Any], +) -> tuple[str, ...]: + return tuple(str(x) for x in mne_bids.get_entity_vals(*args, **kwargs)) def get_datatype(config: SimpleNamespace) -> Literal["meg", "eeg"]: # Content of ch_types should be sanitized already, so we don't need any # extra sanity checks here. - if config.data_type is not None: - return config.data_type - elif config.data_type is None and config.ch_types == ["eeg"]: - return "eeg" - elif config.data_type is None and any( - [t in ["meg", "mag", "grad"] for t in config.ch_types] - ): + if config.data_type == "meg": return "meg" - else: - raise RuntimeError( - "This probably shouldn't happen, got " - f"config.data_type={repr(config.data_type)} and " - f"config.ch_types={repr(config.ch_types)} " - "but could not determine the datatype. Please contact " - "the MNE-BIDS-pipeline developers. Thank you." - ) + if config.data_type == "eeg": + return "eeg" + if config.data_type is None: + if config.ch_types == ["eeg"]: + return "eeg" + if any(t in ["meg", "mag", "grad"] for t in config.ch_types): + return "meg" + raise RuntimeError( + "This probably shouldn't happen, got " + f"config.data_type={repr(config.data_type)} and " + f"config.ch_types={repr(config.ch_types)} " + "but could not determine the datatype. Please contact " + "the MNE-BIDS-pipeline developers. Thank you." + ) -@functools.lru_cache(maxsize=None) -def _get_datatypes_cached(root): - return mne_bids.get_datatypes(root=root) +@functools.cache +def _get_datatypes_cached(root: pathlib.Path) -> tuple[str, ...]: + return tuple(mne_bids.get_datatypes(root=root)) -def _get_ignore_datatypes(config: SimpleNamespace) -> Tuple[str]: - _all_datatypes: List[str] = _get_datatypes_cached(root=config.bids_root) +def _get_ignore_datatypes(config: SimpleNamespace) -> tuple[str, ...]: + _all_datatypes = _get_datatypes_cached(root=config.bids_root) _ignore_datatypes = set(_all_datatypes) - set([get_datatype(config)]) return tuple(sorted(_ignore_datatypes)) -def get_subjects(config: SimpleNamespace) -> List[str]: +def get_subjects(config: SimpleNamespace) -> list[str]: _valid_subjects = _get_entity_vals_cached( root=config.bids_root, entity_key="subject", @@ -94,15 +111,32 @@ def get_subjects(config: SimpleNamespace) -> List[str]: s = _valid_subjects else: s = config.subjects + missing_subjects = set(s) - set(_valid_subjects) + if missing_subjects: + raise FileNotFoundError( + "The following requested subjects were not found in the dataset: " + f"{', '.join(missing_subjects)}" + ) - subjects = set(s) - set(config.exclude_subjects) - # Drop empty-room subject. - subjects = subjects - set(["emptyroom"]) + # Preserve order and remove excluded subjects + subjects = [ + subject + for subject in s + if subject not in config.exclude_subjects and subject != "emptyroom" + ] - return sorted(subjects) + return subjects -def get_sessions(config: SimpleNamespace) -> Union[List[None], List[str]]: +def get_sessions(config: SimpleNamespace) -> tuple[None] | tuple[str, ...]: + sessions = _get_sessions(config) + if not sessions: + return (None,) + else: + return sessions + + +def _get_sessions(config: SimpleNamespace) -> tuple[str, ...]: sessions = copy.deepcopy(config.sessions) _all_sessions = _get_entity_vals_cached( root=config.bids_root, @@ -112,16 +146,87 @@ def get_sessions(config: SimpleNamespace) -> Union[List[None], List[str]]: if sessions == "all": sessions = _all_sessions - if not sessions: - return [None] + return tuple(str(x) for x in sessions) + + +def get_subjects_sessions( + config: SimpleNamespace, +) -> dict[str, tuple[None] | tuple[str, ...]]: + subjects = get_subjects(config) + cfg_sessions = _get_sessions(config) + # easy case first: datasets that don't have (named) sessions + if not cfg_sessions: + return {subj: (None,) for subj in subjects} + + # find which tasks to ignore when deciding if a subj has data for a session + ignore_datatypes = _get_ignore_datatypes(config) + if config.task == "": + ignore_tasks = None else: - return sessions + all_tasks = _get_entity_vals_cached( + root=config.bids_root, + entity_key="task", + ignore_datatypes=ignore_datatypes, + ) + ignore_tasks = tuple(set(all_tasks) - set([config.task])) + + # loop over subjs and check for available sessions + subj_sessions: dict[str, tuple[None] | tuple[str, ...]] = dict() + kwargs = ( + dict(ignore_suffixes=("scans", "coordsystem")) + if "ignore_suffixes" in signature(mne_bids.get_entity_vals).parameters + else dict() + ) + for subject in subjects: + subj_folder = config.bids_root / f"sub-{subject}" + valid_sessions_subj = _get_entity_vals_cached( + subj_folder, + entity_key="session", + ignore_tasks=ignore_tasks, + ignore_acquisitions=("calibration", "crosstalk"), + ignore_datatypes=ignore_datatypes, + **kwargs, + ) + keep_sessions: tuple[str, ...] + # if valid_sessions_subj is empty, it might be because the dataset just doesn't + # have `session` subfolders, or it might be that none of the sessions in config + # are available for this subject. + if not valid_sessions_subj: + if any([x.name.startswith("ses") for x in subj_folder.iterdir()]): + keep_sessions = () # has `ses-*` folders, just not the ones we want + else: + keep_sessions = cfg_sessions # doesn't have `ses-*` folders + else: + missing_sessions = sorted(set(cfg_sessions) - set(valid_sessions_subj)) + if missing_sessions and not config.allow_missing_sessions: + raise RuntimeError( + f"Subject {subject} is missing session{_pl(missing_sessions)} " + f"{missing_sessions}, and `config.allow_missing_sessions` is False" + ) + keep_sessions = tuple(sorted(set(cfg_sessions) & set(valid_sessions_subj))) + if len(keep_sessions): + subj_sessions[subject] = keep_sessions + return subj_sessions + + +def get_subjects_given_session( + config: SimpleNamespace, session: str | None +) -> tuple[str, ...]: + """Get the subjects who actually have data for a given session.""" + sub_ses = get_subjects_sessions(config) + subjects = ( + tuple(sub for sub, ses in sub_ses.items() if session in ses) + if config.allow_missing_sessions + else config.subjects + ) + assert not isinstance(subjects, str), subjects # make sure it's not "all" + return subjects def get_runs_all_subjects( config: SimpleNamespace, -) -> Dict[str, Union[List[None], List[str]]]: - """Gives the mapping between subjects and their runs. +) -> dict[str, tuple[None] | tuple[str, ...]]: + """Give the mapping between subjects and their runs. Returns ------- @@ -130,26 +235,24 @@ def get_runs_all_subjects( (and not for each subject present in the bids_path). """ # Use caching under the hood for speed - return copy.deepcopy( - _get_runs_all_subjects_cached( - bids_root=config.bids_root, - data_type=config.data_type, - ch_types=tuple(config.ch_types), - subjects=tuple(config.subjects) if config.subjects != "all" else "all", - exclude_subjects=tuple(config.exclude_subjects), - exclude_runs=tuple(config.exclude_runs) if config.exclude_runs else None, - ) + return _get_runs_all_subjects_cached( + bids_root=config.bids_root, + data_type=config.data_type, + ch_types=tuple(config.ch_types), + subjects=tuple(config.subjects) if config.subjects != "all" else "all", + exclude_subjects=tuple(config.exclude_subjects), + exclude_runs=tuple(config.exclude_runs) if config.exclude_runs else None, ) -@functools.lru_cache(maxsize=None) +@functools.cache def _get_runs_all_subjects_cached( - **config_dict: Dict[str, Any], -) -> Dict[str, Union[List[None], List[str]]]: + **config_dict: dict[str, Any], +) -> dict[str, tuple[None] | tuple[str, ...]]: config = SimpleNamespace(**config_dict) # Sometimes we check list equivalence for ch_types, so convert it back config.ch_types = list(config.ch_types) - subj_runs = dict() + subj_runs: dict[str, tuple[None] | tuple[str, ...]] = dict() for subject in get_subjects(config): # Only traverse through the current subject's directory valid_runs_subj = _get_entity_vals_cached( @@ -160,22 +263,32 @@ def _get_runs_all_subjects_cached( # If we don't have any `run` entities, just set it to None, as we # commonly do when creating a BIDSPath. - if not valid_runs_subj: - valid_runs_subj = [None] - - if subject in (config.exclude_runs or {}): - valid_runs_subj = [ - r for r in valid_runs_subj if r not in config.exclude_runs[subject] - ] - subj_runs[subject] = valid_runs_subj + if valid_runs_subj: + if subject in (config.exclude_runs or {}): + valid_runs_subj = tuple( + r for r in valid_runs_subj if r not in config.exclude_runs[subject] + ) + subj_runs[subject] = valid_runs_subj + else: + subj_runs[subject] = (None,) return subj_runs -def get_intersect_run(config: SimpleNamespace) -> List[str]: - """Returns the intersection of all the runs of all subjects.""" +def get_intersect_run(config: SimpleNamespace) -> list[str | None]: + """Return the intersection of all the runs of all subjects.""" subj_runs = get_runs_all_subjects(config) - return list(set.intersection(*map(set, subj_runs.values()))) + # Do not use something like: + # list(set.intersection(*map(set, subj_runs.values()))) + # as it will not preserve order. Instead just be explicit and preserve order. + # We could use "sorted", but it's probably better to use the order provided by + # the user (if they want to put `runs=["02", "01"]` etc. it's better to use "02") + all_runs: list[str | None] = list() + for runs in subj_runs.values(): + for run in runs: + if run not in all_runs: + all_runs.append(run) + return all_runs def get_runs( @@ -183,8 +296,8 @@ def get_runs( config: SimpleNamespace, subject: str, verbose: bool = False, -) -> Union[List[str], List[None]]: - """Returns a list of runs in the BIDS input data. +) -> list[str] | list[None]: + """Return a list of runs in the BIDS input data. Parameters ---------- @@ -230,50 +343,63 @@ def get_runs( inclusion = set(runs).issubset(set(valid_runs)) if not inclusion: raise ValueError( - f"Invalid run. It can be a subset of {valid_runs} but " f"got {runs}" + f"Invalid run. It can be a subset of {valid_runs} but got {runs}" ) - return runs + runs_out = list(runs) + if runs_out != [None]: + runs_out = list(str(x) for x in runs_out) + return runs_out def get_runs_tasks( *, config: SimpleNamespace, subject: str, - session: Optional[str], - which: Tuple[str] = ("runs", "noise", "rest"), -) -> List[Tuple[str]]: + session: str | None, + which: tuple[str, ...] = ("runs", "noise", "rest"), +) -> tuple[tuple[str | None, str | None], ...]: """Get (run, task) tuples for all runs plus (maybe) rest.""" from ._import_data import _get_noise_path, _get_rest_path assert isinstance(which, tuple) assert all(isinstance(inc, str) for inc in which) assert all(inc in ("runs", "noise", "rest") for inc in which) - runs = list() - tasks = list() + runs: list[str | None] = list() + tasks: list[str | None] = list() if "runs" in which: runs.extend(get_runs(config=config, subject=subject)) tasks.extend([get_task(config=config)] * len(runs)) - kwargs = dict( - cfg=config, - subject=subject, - session=session, - kind="orig", - add_bads=False, - ) - if "rest" in which and _get_rest_path(**kwargs): - runs.append(None) - tasks.append("rest") + if "rest" in which: + rest_path = _get_rest_path( + cfg=config, + subject=subject, + session=session, + kind="orig", + add_bads=False, + ) + if rest_path: + runs.append(None) + tasks.append("rest") if "noise" in which: mf_reference_run = get_mf_reference_run(config=config) - if _get_noise_path(mf_reference_run=mf_reference_run, **kwargs): + noise_path = _get_noise_path( + mf_reference_run=mf_reference_run, + cfg=config, + subject=subject, + session=session, + kind="orig", + add_bads=False, + ) + if noise_path: runs.append(None) tasks.append("noise") return tuple(zip(runs, tasks)) -def get_mf_reference_run(config: SimpleNamespace) -> str: +def get_mf_reference_run(config: SimpleNamespace) -> str | None: # Retrieve to run identifier (number, name) of the reference run if config.mf_reference_run is not None: + assert isinstance(config.mf_reference_run, str), type(config.mf_reference_run) return config.mf_reference_run # Use the first run inter_runs = get_intersect_run(config) @@ -286,19 +412,19 @@ def get_mf_reference_run(config: SimpleNamespace) -> str: f"dataset only contains the following runs: {inter_runs}" ) raise ValueError(msg) - if inter_runs: - return inter_runs[0] - else: + if not inter_runs: raise ValueError( f"The intersection of runs by subjects is empty. " f"Check the list of runs: " f"{get_runs_all_subjects(config)}" ) + return inter_runs[0] -def get_task(config: SimpleNamespace) -> Optional[str]: +def get_task(config: SimpleNamespace) -> str | None: task = config.task if task: + assert isinstance(task, str), type(task) return task _valid_tasks = _get_entity_vals_cached( root=config.bids_root, @@ -311,7 +437,17 @@ def get_task(config: SimpleNamespace) -> Optional[str]: return _valid_tasks[0] -def get_channels_to_analyze(info: mne.Info, config: SimpleNamespace) -> List[str]: +def get_ecg_channel(config: SimpleNamespace, subject: str, session: str | None) -> str: + if isinstance(config.ssp_ecg_channel, str): + return config.ssp_ecg_channel + for key in (f"sub-{subject}", f"sub-{subject}_ses-{session}"): + if val := config.ssp_ecg_channel.get(key): + assert isinstance(val, str) # mypy + return val + return "" # mypy + + +def get_channels_to_analyze(info: mne.Info, config: SimpleNamespace) -> list[str]: # Return names of the channels of the channel types we wish to analyze. # We also include channels marked as "bad" here. # `exclude=[]`: keep "bad" channels, too. @@ -350,94 +486,114 @@ def sanitize_cond_name(cond: str) -> str: def get_mf_cal_fname( - *, config: SimpleNamespace, subject: str, session: str -) -> pathlib.Path: + *, config: SimpleNamespace, subject: str, session: str | None +) -> pathlib.Path | None: + msg = "Could not find Maxwell Filter calibration file {where}." if config.mf_cal_fname is None: - mf_cal_fpath = BIDSPath( + bids_path = BIDSPath( subject=subject, session=session, suffix="meg", datatype="meg", root=config.bids_root, - ).meg_calibration_fpath + ) + bids_match = bids_path.match() + mf_cal_fpath = None + if len(bids_match) > 0: + mf_cal_fpath = bids_match[0].meg_calibration_fpath if mf_cal_fpath is None: - raise ValueError("Could not find Maxwell Filter Calibration file.") + msg = msg.format(where=f"from BIDSPath {bids_path}") + if config.mf_cal_missing == "raise": + raise ValueError(msg) + elif config.mf_cal_missing == "warn": + msg = f"WARNING: {msg} Set to None." + logger.info(**gen_log_kwargs(message=msg)) else: mf_cal_fpath = pathlib.Path(config.mf_cal_fname).expanduser().absolute() if not mf_cal_fpath.exists(): - raise ValueError( - f"Could not find Maxwell Filter Calibration " - f"file at {str(mf_cal_fpath)}." - ) - + msg = msg.format(where=f"at {str(config.mf_cal_fname)}") + if config.mf_cal_missing == "raise": + raise ValueError(msg) + else: + mf_cal_fpath = None + if config.mf_cal_missing == "warn": + msg = f"WARNING: {msg} Set to None." + logger.info(**gen_log_kwargs(message=msg)) + + assert isinstance(mf_cal_fpath, pathlib.Path | None), type(mf_cal_fpath) return mf_cal_fpath def get_mf_ctc_fname( - *, config: SimpleNamespace, subject: str, session: str -) -> pathlib.Path: + *, config: SimpleNamespace, subject: str, session: str | None +) -> pathlib.Path | None: + msg = "Could not find Maxwell Filter cross-talk file {where}." if config.mf_ctc_fname is None: - mf_ctc_fpath = BIDSPath( + bids_path = BIDSPath( subject=subject, session=session, suffix="meg", datatype="meg", root=config.bids_root, - ).meg_crosstalk_fpath + ) + bids_match = bids_path.match() + mf_ctc_fpath = None + if len(bids_match) > 0: + mf_ctc_fpath = bids_match[0].meg_crosstalk_fpath if mf_ctc_fpath is None: - raise ValueError("Could not find Maxwell Filter cross-talk " "file.") + msg = msg.format(where=f"from BIDSPath {bids_path}") + if config.mf_ctc_missing == "raise": + raise ValueError(msg) + elif config.mf_ctc_missing == "warn": + msg = f"WARNING: {msg} Set to None." + logger.info(**gen_log_kwargs(message=msg)) + else: mf_ctc_fpath = pathlib.Path(config.mf_ctc_fname).expanduser().absolute() if not mf_ctc_fpath.exists(): - raise ValueError( - f"Could not find Maxwell Filter cross-talk " - f"file at {str(mf_ctc_fpath)}." - ) - + msg = msg.format(where=f"at {str(config.mf_ctc_fname)}") + if config.mf_ctc_missing == "raise": + raise ValueError(msg) + else: + mf_ctc_fpath = None + if config.mf_ctc_missing == "warn": + msg = f"WARNING: {msg} Set to None." + logger.info(**gen_log_kwargs(message=msg)) + + assert isinstance(mf_ctc_fpath, pathlib.Path | None), type(mf_ctc_fpath) return mf_ctc_fpath RawEpochsEvokedT = TypeVar( - "RawEpochsEvokedT", bound=Union[mne.io.BaseRaw, mne.BaseEpochs, mne.Evoked] + "RawEpochsEvokedT", bound=mne.io.BaseRaw | mne.BaseEpochs | mne.Evoked ) def _restrict_analyze_channels( inst: RawEpochsEvokedT, cfg: SimpleNamespace ) -> RawEpochsEvokedT: - if cfg.analyze_channels: - analyze_channels = cfg.analyze_channels - if cfg.analyze_channels == "ch_types": - analyze_channels = cfg.ch_types - inst.apply_proj() - # We special-case the average reference here to work around a situation - # where e.g. `analyze_channels` might contain only a single channel: - # `concatenate_epochs` below will then fail when trying to create / - # apply the projection. We can avoid this by removing an existing - # average reference projection here, and applying the average reference - # directly – without going through a projector. - elif "eeg" in cfg.ch_types and cfg.eeg_reference == "average": - inst.set_eeg_reference("average") - else: - inst.apply_proj() - inst.pick(analyze_channels) - return inst - - -def _get_scalp_in_files(cfg: SimpleNamespace) -> Dict[str, pathlib.Path]: - subject_path = pathlib.Path(cfg.subjects_dir) / cfg.fs_subject - seghead = subject_path / "surf" / "lh.seghead" - in_files = dict() - if seghead.is_file(): - in_files["seghead"] = seghead + analyze_channels = cfg.analyze_channels + if cfg.analyze_channels == "ch_types": + analyze_channels = cfg.ch_types + inst.apply_proj() + # We special-case the average reference here to work around a situation + # where e.g. `analyze_channels` might contain only a single channel: + # `concatenate_epochs` below will then fail when trying to create / + # apply the projection. We can avoid this by removing an existing + # average reference projection here, and applying the average reference + # directly – without going through a projector. + elif "eeg" in cfg.ch_types and cfg.eeg_reference == "average": + inst.set_eeg_reference("average") else: - in_files["t1"] = subject_path / "mri" / "T1.mgz" - return in_files + inst.apply_proj() + inst.pick(analyze_channels) + return inst -def _get_bem_conductivity(cfg: SimpleNamespace) -> Tuple[Tuple[float], str]: +def _get_bem_conductivity(cfg: SimpleNamespace) -> tuple[tuple[float, ...] | None, str]: + conductivity: tuple[float, ...] | None = None # should never be used if cfg.fs_subject in ("fsaverage", cfg.use_template_mri): - conductivity = None # should never be used + pass tag = "5120-5120-5120" elif "eeg" in cfg.ch_types: conductivity = (0.3, 0.006, 0.3) @@ -453,7 +609,7 @@ def _meg_in_ch_types(ch_types: str) -> bool: def get_noise_cov_bids_path( - cfg: SimpleNamespace, subject: str, session: Optional[str] + cfg: SimpleNamespace, subject: str, session: str | None ) -> BIDSPath: """Retrieve the path to the noise covariance file. @@ -477,7 +633,7 @@ def get_noise_cov_bids_path( task=cfg.task, acquisition=cfg.acq, run=None, - processing=cfg.proc, + processing="clean", recording=cfg.rec, space=cfg.space, suffix="cov", @@ -518,7 +674,7 @@ def get_all_contrasts(config: SimpleNamespace) -> Iterable[ArbitraryContrast]: return normalized_contrasts -def get_decoding_contrasts(config: SimpleNamespace) -> Iterable[Tuple[str, str]]: +def get_decoding_contrasts(config: SimpleNamespace) -> Iterable[tuple[str, str]]: _validate_contrasts(config.contrasts) normalized_contrasts = [] for contrast in config.contrasts: @@ -538,24 +694,39 @@ def get_decoding_contrasts(config: SimpleNamespace) -> Iterable[Tuple[str, str]] return normalized_contrasts +# Map _config.decoding_which_epochs to a BIDS proc- entity +_EPOCHS_DESCRIPTION_TO_PROC_MAP = { + "uncleaned": None, + "after_ica": "ica", + "after_ssp": "ssp", + "cleaned": "clean", +} + + +def _get_decoding_proc(config: SimpleNamespace) -> str | None: + return _EPOCHS_DESCRIPTION_TO_PROC_MAP[config.decoding_which_epochs] + + def get_eeg_reference( config: SimpleNamespace, -) -> Union[Literal["average"], Iterable[str]]: +) -> Literal["average"] | Iterable[str]: if config.eeg_reference == "average": - return config.eeg_reference + return "average" elif isinstance(config.eeg_reference, str): return [config.eeg_reference] else: + assert isinstance(config.eeg_reference, Iterable) + assert all(isinstance(x, str) for x in config.eeg_reference) return config.eeg_reference -def _validate_contrasts(contrasts: SimpleNamespace) -> None: +def _validate_contrasts(contrasts: list[tuple[str, str] | dict[str, Any]]) -> None: for contrast in contrasts: if isinstance(contrast, tuple): if len(contrast) != 2: raise ValueError("Contrasts' tuples MUST be two conditions") elif isinstance(contrast, dict): - if not _keys_arbitrary_contrast.issubset(set(contrast.keys())): + if not _set_keys_arbitrary_contrast.issubset(set(contrast.keys())): raise ValueError(f"Missing key(s) in contrast {contrast}") if len(contrast["conditions"]) != len(contrast["weights"]): raise ValueError( @@ -566,12 +737,8 @@ def _validate_contrasts(contrasts: SimpleNamespace) -> None: raise ValueError("Contrasts must be tuples or well-formed dicts") -def _get_step_modules() -> Dict[str, Tuple[ModuleType]]: - from .steps import init - from .steps import preprocessing - from .steps import sensor - from .steps import source - from .steps import freesurfer +def _get_step_modules() -> dict[str, tuple[ModuleType, ...]]: + from .steps import freesurfer, init, preprocessing, sensor, source INIT_STEPS = init._STEPS PREPROCESSING_STEPS = preprocessing._STEPS @@ -599,7 +766,7 @@ def _get_step_modules() -> Dict[str, Tuple[ModuleType]]: return STEP_MODULES -def _bids_kwargs(*, config: SimpleNamespace) -> dict: +def _bids_kwargs(*, config: SimpleNamespace) -> dict[str, str | None]: """Get the standard BIDS config entries.""" return dict( proc=config.proc, @@ -614,11 +781,32 @@ def _bids_kwargs(*, config: SimpleNamespace) -> dict: def _do_mf_autobad(*, cfg: SimpleNamespace) -> bool: - return cfg.find_noisy_channels_meg or cfg.find_flat_channels_meg + return bool(cfg.find_noisy_channels_meg or cfg.find_flat_channels_meg) # Adapted from MNE-Python -def _pl(x, *, non_pl="", pl="s"): +def _pl(x: int | np.generic | Sized, *, non_pl: str = "", pl: str = "s") -> str: """Determine if plural should be used.""" - len_x = x if isinstance(x, (int, np.generic)) else len(x) + len_x = x if isinstance(x, int | np.generic) else len(x) return non_pl if len_x == 1 else pl + + +def _proj_path( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> BIDSPath: + return BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + datatype=cfg.datatype, + root=cfg.deriv_root, + extension=".fif", + suffix="proj", + check=False, + ) diff --git a/mne_bids_pipeline/_decoding.py b/mne_bids_pipeline/_decoding.py index 2b6be3cfc..5f4912cdb 100644 --- a/mne_bids_pipeline/_decoding.py +++ b/mne_bids_pipeline/_decoding.py @@ -1,37 +1,56 @@ +from typing import Any + +import mne import numpy as np -from sklearn.linear_model import LogisticRegression from joblib import parallel_backend - from mne.utils import _validate_type +from sklearn.base import BaseEstimator +from sklearn.decomposition import PCA +from sklearn.linear_model import LogisticRegression +from ._logging import gen_log_kwargs, logger +from .typing import FloatArrayT -class LogReg(LogisticRegression): + +class LogReg(LogisticRegression): # type: ignore[misc] """Hack to avoid a warning with n_jobs != 1 when using dask.""" - def fit(self, *args, **kwargs): + def fit(self, *args, **kwargs): # type: ignore with parallel_backend("loky"): return super().fit(*args, **kwargs) def _handle_csp_args( - decoding_csp_times, - decoding_csp_freqs, - decoding_metric, + decoding_csp_times: list[float] | tuple[float, ...] | FloatArrayT | None, + decoding_csp_freqs: dict[str, Any] | None, + decoding_metric: str, *, - epochs_tmin, - epochs_tmax, - time_frequency_freq_min, - time_frequency_freq_max, -): + epochs_tmin: float, + epochs_tmax: float, + time_frequency_freq_min: float, + time_frequency_freq_max: float, +) -> tuple[dict[str, list[tuple[float, float]]], FloatArrayT]: _validate_type( decoding_csp_times, (None, list, tuple, np.ndarray), "decoding_csp_times" ) if decoding_csp_times is None: - decoding_csp_times = np.linspace(max(0, epochs_tmin), epochs_tmax, num=6) - if len(decoding_csp_times) < 2: - raise ValueError("decoding_csp_times should contain at least 2 values.") + decoding_csp_times = np.linspace( + max(0, epochs_tmin), epochs_tmax, num=6, dtype=float + ) + else: + decoding_csp_times = np.array(decoding_csp_times, float) + assert isinstance(decoding_csp_times, np.ndarray) + if decoding_csp_times.ndim != 1 or len(decoding_csp_times) == 1: + raise ValueError( + "decoding_csp_times should be 1 dimensional and contain at least 2 values " + "to define time intervals, or be empty to disable time-frequency mode, got " + f"shape {decoding_csp_times.shape}" + ) if not np.array_equal(decoding_csp_times, np.sort(decoding_csp_times)): ValueError("decoding_csp_times should be sorted.") + time_bins = np.c_[decoding_csp_times[:-1], decoding_csp_times[1:]] + assert time_bins.ndim == 2 and time_bins.shape[1] == 2, time_bins.shape + if decoding_metric != "roc_auc": raise ValueError( f'CSP decoding currently only supports the "roc_auc" ' @@ -70,4 +89,26 @@ def _handle_csp_args( freq_bins = list(zip(edges[:-1], edges[1:])) freq_name_to_bins_map[freq_range_name] = freq_bins - return freq_name_to_bins_map + return freq_name_to_bins_map, time_bins + + +def _decoding_preproc_steps( + subject: str, + session: str | None, + epochs: mne.Epochs, + pca: bool = True, +) -> list[BaseEstimator]: + scaler = mne.decoding.Scaler(epochs.info) + steps = [scaler] + if pca: + ranks = mne.compute_rank(inst=epochs, rank="info") + rank = sum(ranks.values()) + msg = f"Reducing data dimension via PCA; new rank: {rank} (from {ranks})." + logger.info(**gen_log_kwargs(message=msg)) + steps.append( + mne.decoding.UnsupervisedSpatialFilter( + PCA(rank, whiten=True), + average=False, + ) + ) + return steps diff --git a/mne_bids_pipeline/_docs.py b/mne_bids_pipeline/_docs.py new file mode 100644 index 000000000..65d18d7e6 --- /dev/null +++ b/mne_bids_pipeline/_docs.py @@ -0,0 +1,295 @@ +import ast +import inspect +import re +from collections import defaultdict +from pathlib import Path +from types import FunctionType +from typing import Any + +from tqdm import tqdm + +from . import _config_utils, _import_data + +_CONFIG_RE = re.compile(r"config\.([a-zA-Z_]+)") + +_NO_CONFIG = { + "freesurfer/_01_recon_all", +} +_IGNORE_OPTIONS = { + "PIPELINE_NAME", + "VERSION", + "CODE_URL", +} +# We don't need to parse the config itself, just the steps +_MANUAL_KWS = { + "source/_04_make_forward:get_config:t1_bids_path": ("mri_t1_path_generator",), + "source/_04_make_forward:get_config:landmarks_kind": ("mri_landmarks_kind",), + "preprocessing/_01_data_quality:get_config:extra_kwargs": ( + "mf_cal_fname", + "mf_ctc_fname", + "mf_head_origin", + "find_flat_channels_meg", + "find_noisy_channels_meg", + ), +} +# Some don't show up so force them to be empty +_EXECUTION_OPTIONS = ( + # Eventually we could deduplicate these with the execution.md list + "n_jobs", + "parallel_backend", + "dask_open_dashboard", + "dask_temp_dir", + "dask_worker_memory_limit", + "log_level", + "mne_log_level", + "on_error", + "memory_location", + "memory_file_method", + "memory_subdir", + "memory_verbose", + "config_validation", + "interactive", +) +_FORCE_EMPTY = _EXECUTION_OPTIONS + ( + # Plus some BIDS one we don't detect because _bids_kwargs etc. above, + # which we could cross-check against the general.md list. A notable + # exception is random_state, since this does have more localized effects. + # These are used a lot at the very beginning, so adding them will lead + # to long lists. Instead, let's just mention at the top of General that + # messing with basic BIDS params will affect almost every step. + "bids_root", + "deriv_root", + "subjects_dir", + "sessions", + "acq", + "proc", + "rec", + "space", + "task", + "runs", + "exclude_runs", + "subjects", + "crop_runs", + "process_empty_room", + "process_rest", + "eeg_bipolar_channels", + "eeg_reference", + "eeg_template_montage", + "drop_channels", + "reader_extra_params", + "read_raw_bids_verbose", + "plot_psd_for_runs", + "shortest_event", + "find_breaks", + "min_break_duration", + "t_break_annot_start_after_previous_event", + "t_break_annot_stop_before_next_event", + "rename_events", + "on_rename_missing_events", + "mf_reference_run", # TODO: Make clearer that this changes a lot + "fix_stim_artifact", + "stim_artifact_tmin", + "stim_artifact_tmax", + # And some that we force to be empty because they affect too many things + # and what they affect is an incomplete list anyway + "exclude_subjects", + "ch_types", + "task_is_rest", + "data_type", + "allow_missing_sessions", +) +# Eventually we could parse AST to get these, but this is simple enough +_EXTRA_FUNCS = { + "_bids_kwargs": ("get_task",), + "_import_data_kwargs": ("get_mf_reference_run",), + "get_runs": ("get_runs_all_subjects",), + "get_sessions": ("_get_sessions",), +} + + +class _ParseConfigSteps: + def __init__(self, force_empty: tuple[str, ...] | None = None) -> None: + """Build a mapping from config options to tuples of steps that use each option. + + The mapping is stored in `self.steps`. + """ + self._force_empty = _FORCE_EMPTY if force_empty is None else force_empty + steps: dict[str, Any] = defaultdict(list) + + def _add_step_option(step: str, option: str) -> None: + if step not in steps[option]: + steps[option].append(step) + + # Add a few helper functions + for func_extra in ( + _config_utils.get_eeg_reference, + _config_utils.get_all_contrasts, + _config_utils.get_decoding_contrasts, + _config_utils.get_fs_subject, + _config_utils.get_fs_subjects_dir, + _config_utils.get_mf_cal_fname, + _config_utils.get_mf_ctc_fname, + _config_utils.get_subjects_sessions, + ): + this_list: list[str] = [] + assert isinstance(func_extra, FunctionType) + for attr in ast.walk(ast.parse(inspect.getsource(func_extra))): + if not isinstance(attr, ast.Attribute): + continue + if not (isinstance(attr.value, ast.Name) and attr.value.id == "config"): + continue + if attr.attr not in this_list: + this_list.append(attr.attr) + _MANUAL_KWS[func_extra.__name__] = tuple(this_list) + + for module in tqdm( + sum(_config_utils._get_step_modules().values(), tuple()), + desc="Generating option->step mapping", + ): + step = "/".join(module.__name__.split(".")[-2:]) + found = False # found at least one? + # Walk the module file for "get_config*" functions (can be multiple!) + assert module.__file__ is not None + for func in ast.walk(ast.parse(Path(module.__file__).read_text("utf-8"))): + if not isinstance(func, ast.FunctionDef): + continue + where = f"{step}:{func.name}" + # Also look at config.* args in main(), e.g. config.recreate_bem + # and config.recreate_scalp_surface + if func.name == "main": + for call in ast.walk(func): + if not isinstance(call, ast.Call): + continue + for keyword in call.keywords: + if not isinstance(keyword.value, ast.Attribute): + continue + assert isinstance(keyword.value.value, ast.Name) + if keyword.value.value.id != "config": + continue + if keyword.value.attr in ("exec_params",): + continue + _add_step_option(step, keyword.value.attr) + for arg in call.args: + if not isinstance(arg, ast.Name): + continue + if arg.id != "config": + continue + assert isinstance(call.func, ast.Name) + key = call.func.id + # e.g., get_subjects_sessions(config) + if key in _MANUAL_KWS: + for option in _MANUAL_KWS[key]: + _add_step_option(step, option) + break + + # Also look for root-level conditionals like use_maxwell_filter + # or spatial_filter + for cond in ast.iter_child_nodes(func): + # is a conditional + if not isinstance(cond, ast.If): + continue + # has a return statement + if not any(isinstance(c, ast.Return) for c in ast.walk(cond)): + continue + # look at all attributes in the conditional + for attr in ast.walk(cond.test): + if not isinstance(attr, ast.Attribute): + continue + assert isinstance(attr.value, ast.Name) + if attr.value.id != "config": + continue + _add_step_option(step, attr.attr) + # Now look at get_config* functions + if not func.name.startswith("get_config"): + continue + found = True + for call in ast.walk(func): + if not isinstance(call, ast.Call): + continue + assert isinstance(call.func, ast.Name) + if call.func.id != "SimpleNamespace": + continue + break + else: + raise RuntimeError(f"Could not find SimpleNamespace in {func}") + assert call.args == [] + for keyword in call.keywords: + if isinstance(keyword.value, ast.Call): + assert isinstance(keyword.value.func, ast.Name) + key = keyword.value.func.id + if key in _MANUAL_KWS: + for option in _MANUAL_KWS[key]: + _add_step_option(step, option) + continue + if keyword.value.func.id == "_sanitize_callable": + assert len(keyword.value.args) == 1 + assert isinstance(keyword.value.args[0], ast.Attribute) + assert isinstance(keyword.value.args[0].value, ast.Name) + assert keyword.value.args[0].value.id == "config" + _add_step_option(step, keyword.value.args[0].attr) + continue + if key not in ( + "_bids_kwargs", + "_import_data_kwargs", + "get_runs", + "get_subjects", + "get_sessions", + ): + raise RuntimeError( + f"{where} cannot handle call {keyword.value.func.id=} " + f"for {key}" + ) + # Get the source and regex for config values + if key == "_import_data_kwargs": + funcs = [getattr(_import_data, key)] + else: + funcs = [getattr(_config_utils, key)] + for func_name in _EXTRA_FUNCS.get(key, ()): + funcs.append(getattr(_config_utils, func_name)) + for fi, func in enumerate(funcs): + assert isinstance(func, FunctionType), func + source = inspect.getsource(func) + assert "config: SimpleNamespace" in source, key + if fi == 0: + for func_name in _EXTRA_FUNCS.get(key, ()): + assert f"{func_name}(" in source, (key, func_name) + attrs = _CONFIG_RE.findall(source) + if key != "get_sessions": # pure wrapper + assert len(attrs), ( + f"No config.* found in source of {key}" + ) + for attr in attrs: + _add_step_option(step, attr) + continue + if isinstance(keyword.value, ast.Name): + key = f"{where}:{keyword.value.id}" + if key in _MANUAL_KWS: + for option in _MANUAL_KWS[f"{where}:{keyword.value.id}"]: + _add_step_option(step, option) + continue + raise RuntimeError(f"{where} cannot handle Name {key=}") + if isinstance(keyword.value, ast.IfExp): # conditional + if keyword.arg == "processing": # inline conditional for proc + continue + if not isinstance(keyword.value, ast.Attribute): + raise RuntimeError( + f"{where} cannot handle type {keyword.value=}" + ) + option = keyword.value.attr + if option in _IGNORE_OPTIONS: + continue + assert isinstance(keyword.value.value, ast.Name) + assert keyword.value.value.id == "config", f"{where} {keyword.value.value.id}" # noqa: E501 # fmt: skip + _add_step_option(step, option) + if step in _NO_CONFIG: + assert not found, f"Found unexpected get_config* in {step}" + else: + assert found, f"Could not find get_config* in {step}" + for key in self._force_empty: + steps[key] = list() + for key, val in steps.items(): + assert len(val) == len(set(val)), f"{key} {val}" + self.steps: dict[str, tuple[str, ...]] = {k: tuple(v) for k, v in steps.items()} + + def __call__(self, option: str) -> tuple[str, ...]: + return self.steps[option] diff --git a/mne_bids_pipeline/_download.py b/mne_bids_pipeline/_download.py index 33e565207..873a19422 100644 --- a/mne_bids_pipeline/_download.py +++ b/mne_bids_pipeline/_download.py @@ -1,38 +1,32 @@ """Download test data.""" + import argparse from pathlib import Path +from warnings import filterwarnings import mne +from ._config_import import _import_config +from ._config_utils import get_fs_subjects_dir from .tests.datasets import DATASET_OPTIONS DEFAULT_DATA_DIR = Path("~/mne_data").expanduser() -def _download_via_datalad(*, ds_name: str, ds_path: Path): - import datalad.api as dl - - print('datalad installing "{}"'.format(ds_name)) - options = DATASET_OPTIONS[ds_name] - git_url = options["git"] - assert "exclude" not in options - assert "hash" not in options - dataset = dl.install(path=ds_path, source=git_url) - - # XXX: git-annex bug: - # https://github.com/datalad/datalad/issues/3583 - # if datalad fails, use "get" twice, or set `n_jobs=1` - if ds_name == "ds003104": - n_jobs = 16 - else: - n_jobs = 1 - - for to_get in DATASET_OPTIONS[ds_name].get("include", []): - print('datalad get data "{}" for "{}"'.format(to_get, ds_name)) - dataset.get(to_get, jobs=n_jobs) +# TODO this can be removed when https://github.com/fatiando/pooch/pull/458 is merged and +# we pin to a version of pooch that includes that commit +filterwarnings( + action="ignore", + message=( + "Python 3.14 will, by default, filter extracted tar archives and reject files " + "or modify their metadata. Use the filter argument to control this behavior." + ), + category=DeprecationWarning, + module="tarfile", +) -def _download_via_openneuro(*, ds_name: str, ds_path: Path): +def _download_via_openneuro(*, ds_name: str, ds_path: Path) -> None: import openneuro options = DATASET_OPTIONS[ds_name] @@ -47,8 +41,8 @@ def _download_via_openneuro(*, ds_name: str, ds_path: Path): ) -def _download_from_web(*, ds_name: str, ds_path: Path): - """Retrieve Zip archives from a web URL.""" +def _download_from_web(*, ds_name: str, ds_path: Path) -> None: + """Retrieve `.zip` or `.tar.gz` archives from a web URL.""" import pooch options = DATASET_OPTIONS[ds_name] @@ -65,40 +59,80 @@ def _download_from_web(*, ds_name: str, ds_path: Path): ds_path.mkdir(parents=True, exist_ok=True) path = ds_path.parent.resolve(strict=True) - fname = f"{ds_name}.zip" + ext = "tar.gz" if options.get("processor") == "untar" else "zip" + processor = pooch.Untar if options.get("processor") == "untar" else pooch.Unzip + fname = f"{ds_name}.{ext}" pooch.retrieve( url=url, path=path, fname=fname, - processor=pooch.Unzip(extract_dir="."), # relative to path + processor=processor(extract_dir="."), # relative to path progressbar=True, known_hash=known_hash, ) - (path / f"{ds_name}.zip").unlink() + (path / f"{ds_name}.{ext}").unlink() + + +def _download_via_mne(*, ds_name: str, ds_path: Path) -> None: + assert ds_path.stem == ds_name, ds_path + getattr(mne.datasets, DATASET_OPTIONS[ds_name]["mne"]).data_path( + ds_path.parent, + verbose=True, + ) -def _download(*, ds_name: str, ds_path: Path): +def _download(*, ds_name: str, ds_path: Path) -> None: options = DATASET_OPTIONS[ds_name] openneuro_name = options.get("openneuro", "") - git_url = options.get("git", "") - osf_node = options.get("osf", "") web_url = options.get("web", "") - assert sum(bool(x) for x in (openneuro_name, git_url, osf_node, web_url)) == 1 + mne_mod = options.get("mne", "") + assert sum(bool(x) for x in (openneuro_name, web_url, mne_mod)) == 1 if openneuro_name: download_func = _download_via_openneuro - elif git_url: - download_func = _download_via_datalad - elif osf_node: - raise RuntimeError("OSF downloads are currently not supported.") + elif mne_mod: + download_func = _download_via_mne else: assert web_url download_func = _download_from_web download_func(ds_name=ds_name, ds_path=ds_path) - -def main(dataset): + # and fsaverage if needed + extra = DATASET_OPTIONS[ds_name].get("config_path_extra", "") + config_path = ( + Path(__file__).parent + / "tests" + / "configs" + / f"config_{ds_name.replace('-', '_')}{extra}.py" + ) + if config_path.is_file(): + has_subjects_dir = any( + "derivatives/freesurfer/subjects" in key + for key in options.get("include", []) + ) + if has_subjects_dir or options.get("fsaverage"): + cfg = _import_config(config_path=config_path) + subjects_dir = get_fs_subjects_dir(config=cfg) + n_try = 5 + for ii in range(1, n_try + 1): # osf.io fails sometimes + write_extra = f" (attempt #{ii})" if ii > 1 else "" + print(f"Checking fsaverage in {subjects_dir} ...{write_extra}") + try: + mne.datasets.fetch_fsaverage( + subjects_dir=subjects_dir, + verbose=True, + ) + except Exception: # pragma: no cover + if ii == n_try: + raise + else: + print("Failed and will retry, got:\n{exc}") + else: + break + + +def main(dataset: str | None) -> None: """Download the testing data.""" # Save everything 'MNE_DATA' dir ... defaults to ~/mne_data mne_data_dir = mne.get_config(key="MNE_DATA", default=False) diff --git a/mne_bids_pipeline/_import_data.py b/mne_bids_pipeline/_import_data.py index ca52c59e1..a41eeb749 100644 --- a/mne_bids_pipeline/_import_data.py +++ b/mne_bids_pipeline/_import_data.py @@ -1,41 +1,43 @@ +from collections.abc import Iterable from types import SimpleNamespace -from typing import Dict, Optional, Iterable, Union, List, Literal +from typing import Any, Literal import mne -from mne_bids import BIDSPath, read_raw_bids, get_bids_path_from_fname import numpy as np import pandas as pd +from mne_bids import BIDSPath, get_bids_path_from_fname, read_raw_bids from ._config_utils import ( - get_mf_reference_run, - get_runs, - get_datatype, - get_task, _bids_kwargs, _do_mf_autobad, _pl, + get_datatype, + get_mf_reference_run, + get_runs, + get_task, ) -from ._io import _read_json, _empty_room_match_path +from ._io import _read_json from ._logging import gen_log_kwargs, logger from ._run import _update_for_splits -from .typing import PathLike +from .typing import InFilesT, PathLike, RunKindT, RunTypeT def make_epochs( *, task: str, subject: str, - session: Optional[str], + session: str | None, raw: mne.io.BaseRaw, - event_id: Optional[Union[Dict[str, int], Literal["auto"]]], - conditions: Union[Iterable[str], Dict[str, str]], + event_id: dict[str, int] | Literal["auto"] | None, + conditions: Iterable[str] | dict[str, str], tmin: float, tmax: float, - metadata_tmin: Optional[float], - metadata_tmax: Optional[float], - metadata_keep_first: Optional[Iterable[str]], - metadata_keep_last: Optional[Iterable[str]], - metadata_query: Optional[str], + custom_metadata: pd.DataFrame | dict[str, Any] | None, + metadata_tmin: float | None, + metadata_tmax: float | None, + metadata_keep_first: Iterable[str] | None, + metadata_keep_last: Iterable[str] | None, + metadata_query: str | None, event_repeated: Literal["error", "drop", "merge"], epochs_decim: int, task_is_rest: bool, @@ -105,6 +107,58 @@ def make_epochs( sfreq=raw.info["sfreq"], ) + # If custom_metadata is provided, merge it with the generated metadata + if custom_metadata is not None: + if isinstance( + custom_metadata, dict + ): # parse custom_metadata['sub-x']['ses-y']['task-z'] + custom_dict = custom_metadata + for _ in range(3): # loop to allow for mis-ordered keys + if ( + isinstance(custom_dict, dict) + and "subj-" + subject in custom_dict + ): + custom_dict = custom_dict["subj-" + subject] + if ( + isinstance(custom_dict, dict) + and session is not None + and "ses-" + session in custom_dict + ): + custom_dict = custom_dict["ses-" + session] + if isinstance(custom_dict, dict) and "task-" + task in custom_dict: + custom_dict = custom_dict["task-" + task] + if isinstance(custom_dict, pd.DataFrame): + custom_df = custom_dict + break + if not isinstance(custom_dict, pd.DataFrame): + msg = ( + f"Custom metadata not found for subject {subject} / " + f"session {session} / task {task}.\n" + ) + raise ValueError(msg) + elif isinstance(custom_metadata, pd.DataFrame): # parse DataFrame + custom_df = custom_metadata + else: + msg = ( + f"Custom metadata not found for subject {subject} / " + f"session {session} / task {task}.\n" + ) + raise ValueError(msg) + + # Check if the custom metadata DataFrame has the same number of rows + if len(metadata) != len(custom_df): + msg = ( + f"Event metadata has {len(metadata)} rows, but custom " + f"metadata has {len(custom_df)} rows. Cannot safely join." + ) + raise ValueError(msg) + + # Merge the event and custom DataFrames + metadata = metadata.join(custom_df, how="right") + # Logging # Logging + msg = "Including custom metadata in epochs." + logger.info(**gen_log_kwargs(message=msg)) + # Epoch the data # Do not reject based on peak-to-peak or flatness thresholds at this stage epochs = mne.Epochs( @@ -136,8 +190,8 @@ def make_epochs( try: idx_keep = epochs.metadata.eval(metadata_query, engine="python") except pandas.core.computation.ops.UndefinedVariableError: - msg = f"Metadata query failed to select any columns: " f"{metadata_query}" - logger.warn(**gen_log_kwargs(message=msg)) + msg = f"Metadata query failed to select any columns: {metadata_query}" + logger.warning(**gen_log_kwargs(message=msg)) return epochs idx_drop = epochs.metadata.index[~idx_keep] @@ -147,12 +201,12 @@ def make_epochs( return epochs -def annotations_to_events(*, raw_paths: List[PathLike]) -> Dict[str, int]: +def annotations_to_events(*, raw_paths: list[PathLike]) -> dict[str, int]: """Generate a unique event name -> event code mapping. The mapping can that can be used across all passed raws. """ - event_names: List[str] = [] + event_names: list[str] = [] for raw_fname in raw_paths: raw = mne.io.read_raw_fif(raw_fname) _, event_id = mne.events_from_annotations(raw=raw) @@ -172,8 +226,8 @@ def _rename_events_func( cfg: SimpleNamespace, raw: mne.io.BaseRaw, subject: str, - session: Optional[str], - run: Optional[str], + session: str | None, + run: str | None, ) -> None: """Rename events (actually, annotations descriptions) in ``raw``. @@ -191,7 +245,7 @@ def _rename_events_func( msg = ( f"You requested to rename the following events, but " f"they are not present in the BIDS input data:\n" - f'{", ".join(sorted(list(events_not_in_raw)))}' + f"{', '.join(sorted(list(events_not_in_raw)))}" ) if cfg.on_rename_missing_events == "warn": logger.warning(**gen_log_kwargs(message=msg)) @@ -204,16 +258,15 @@ def _rename_events_func( # Do the actual event renaming. msg = "Renaming events …" logger.info(**gen_log_kwargs(message=msg)) - descriptions = list(raw.annotations.description) + descriptions_list = list(raw.annotations.description) for old_event_name, new_event_name in cfg.rename_events.items(): msg = f"… {old_event_name} -> {new_event_name}" logger.info(**gen_log_kwargs(message=msg)) - for idx, description in enumerate(descriptions.copy()): + for idx, description in enumerate(descriptions_list): if description == old_event_name: - descriptions[idx] = new_event_name + descriptions_list[idx] = new_event_name - descriptions = np.asarray(descriptions, dtype=str) - raw.annotations.description = descriptions + raw.annotations.description = np.array(descriptions_list, dtype=str) def _load_data(cfg: SimpleNamespace, bids_path: BIDSPath) -> mne.io.BaseRaw: @@ -255,14 +308,14 @@ def _drop_channels_func( cfg: SimpleNamespace, raw: mne.io.BaseRaw, subject: str, - session: Optional[str], + session: str | None, ) -> None: """Drop channels from the data. Modifies ``raw`` in-place. """ if cfg.drop_channels: - msg = f'Dropping channels: {", ".join(cfg.drop_channels)}' + msg = f"Dropping channels: {', '.join(cfg.drop_channels)}" logger.info(**gen_log_kwargs(message=msg)) raw.drop_channels(cfg.drop_channels, on_missing="warn") @@ -271,8 +324,8 @@ def _create_bipolar_channels( cfg: SimpleNamespace, raw: mne.io.BaseRaw, subject: str, - session: Optional[str], - run: Optional[str], + session: str | None, + run: str | None, ) -> None: """Create a channel from a bipolar referencing scheme.. @@ -317,22 +370,18 @@ def _set_eeg_montage( cfg: SimpleNamespace, raw: mne.io.BaseRaw, subject: str, - session: Optional[str], - run: Optional[str], + session: str | None, + run: str | None, ) -> None: """Set an EEG template montage if requested. Modifies ``raw`` in-place. """ montage = cfg.eeg_template_montage - is_mne_montage = isinstance(montage, mne.channels.montage.DigMontage) - montage_name = "custom_montage" if is_mne_montage else montage if cfg.datatype == "eeg" and montage: - msg = f"Setting EEG channel locations to template montage: " f"{montage}." + msg = f"Setting EEG channel locations to template montage: {montage}." logger.info(**gen_log_kwargs(message=msg)) - if not is_mne_montage: - montage = mne.channels.make_standard_montage(montage_name) - raw.set_montage(montage, match_case=False, on_missing="warn") + raw.set_montage(montage, match_case=False, match_alias=True) def _fix_stim_artifact_func(cfg: SimpleNamespace, raw: mne.io.BaseRaw) -> None: @@ -355,8 +404,8 @@ def import_experimental_data( *, cfg: SimpleNamespace, bids_path_in: BIDSPath, - bids_path_bads_in: Optional[BIDSPath], - data_is_rest: Optional[bool], + bids_path_bads_in: BIDSPath | None, + data_is_rest: bool | None, ) -> mne.io.BaseRaw: """Run the data import. @@ -402,6 +451,7 @@ def import_experimental_data( _fix_stim_artifact_func(cfg=cfg, raw=raw) if bids_path_bads_in is not None: + run = "rest" if data_is_rest else run # improve logging bads = _read_bads_tsv(cfg=cfg, bids_path_bads=bids_path_bads_in) msg = f"Marking {len(bads)} channel{_pl(bads)} as bad." logger.info(**gen_log_kwargs(message=msg)) @@ -415,9 +465,9 @@ def import_er_data( *, cfg: SimpleNamespace, bids_path_er_in: BIDSPath, - bids_path_ref_in: Optional[BIDSPath], - bids_path_er_bads_in: Optional[BIDSPath], - bids_path_ref_bads_in: Optional[BIDSPath], + bids_path_ref_in: BIDSPath | None, + bids_path_er_bads_in: BIDSPath | None, + bids_path_ref_bads_in: BIDSPath | None, prepare_maxwell_filter: bool, ) -> mne.io.BaseRaw: """Import empty-room data. @@ -434,6 +484,8 @@ def import_er_data( The BIDS path to the empty room bad channels file. bids_path_ref_bads_in The BIDS path to the reference data bad channels file. + prepare_maxwell_filter + Whether to prepare the empty-room data for Maxwell filtering. Returns ------- @@ -449,7 +501,6 @@ def import_er_data( cfg=cfg, bids_path_bads=bids_path_er_bads_in, ) - raw_er.pick("meg", exclude=[]) # Don't deal with ref for now (initial data quality / auto bad step) if bids_path_ref_in is None: @@ -468,7 +519,8 @@ def import_er_data( ) raw_ref.info["bads"] = bads raw_ref.info._check_consistency() - raw_ref.pick_types(meg=True, exclude=[]) + raw_ref.pick("meg") + raw_er.pick("meg") if prepare_maxwell_filter: # We need to include any automatically found bad channels, if relevant. @@ -489,19 +541,16 @@ def import_er_data( def _find_breaks_func( *, - cfg, + cfg: SimpleNamespace, raw: mne.io.BaseRaw, subject: str, - session: Optional[str], - run: Optional[str], + session: str | None, + run: str | None, ) -> None: if not cfg.find_breaks: return - msg = ( - f"Finding breaks with a minimum duration of " - f"{cfg.min_break_duration} seconds." - ) + msg = f"Finding breaks with a minimum duration of {cfg.min_break_duration} seconds." logger.info(**gen_log_kwargs(message=msg)) break_annots = mne.preprocessing.annotate_break( @@ -513,7 +562,7 @@ def _find_breaks_func( msg = ( f"Found and annotated " - f'{len(break_annots) if break_annots else "no"} break periods.' + f"{len(break_annots) if break_annots else 'no'} break periods." ) logger.info(**gen_log_kwargs(message=msg)) @@ -524,10 +573,10 @@ def _get_bids_path_in( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - kind: Literal["orig", "sss"] = "orig", + session: str | None, + run: str | None, + task: str | None, + kind: RunKindT = "orig", ) -> BIDSPath: # b/c can be used before this is updated path_kwargs = dict( @@ -541,13 +590,13 @@ def _get_bids_path_in( datatype=get_datatype(config=cfg), check=False, ) - if kind == "sss": + if kind != "orig": + assert kind in ("sss", "filt"), kind path_kwargs["root"] = cfg.deriv_root path_kwargs["suffix"] = "raw" path_kwargs["extension"] = ".fif" - path_kwargs["processing"] = "sss" + path_kwargs["processing"] = kind else: - assert kind == "orig", kind path_kwargs["root"] = cfg.bids_root path_kwargs["suffix"] = None path_kwargs["extension"] = None @@ -560,14 +609,14 @@ def _get_run_path( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - kind: Literal["orig", "sss"], - add_bads: Optional[bool] = None, + session: str | None, + run: str | None, + task: str | None, + kind: RunKindT, + add_bads: bool | None = None, allow_missing: bool = False, - key: Optional[str] = None, -) -> dict: + key: str | None = None, +) -> InFilesT: bids_path_in = _get_bids_path_in( cfg=cfg, subject=subject, @@ -583,6 +632,8 @@ def _get_run_path( add_bads=add_bads, kind=kind, allow_missing=allow_missing, + subject=subject, + session=session, ) @@ -590,10 +641,10 @@ def _get_rest_path( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - kind: Literal["orig", "sss"], - add_bads: Optional[bool] = None, -) -> dict: + session: str | None, + kind: RunKindT, + add_bads: bool | None = None, +) -> InFilesT: if not (cfg.process_rest and not cfg.task_is_rest): return dict() return _get_run_path( @@ -612,14 +663,15 @@ def _get_noise_path( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - kind: Literal["orig", "sss"], - mf_reference_run: Optional[str], - add_bads: Optional[bool] = None, -) -> dict: + session: str | None, + kind: RunKindT, + mf_reference_run: str | None, + add_bads: bool | None = None, +) -> InFilesT: if not (cfg.process_empty_room and get_datatype(config=cfg) == "meg"): return dict() - if kind == "sss": + if kind != "orig": + assert kind in ("sss", "filt") raw_fname = _get_bids_path_in( cfg=cfg, subject=subject, @@ -648,6 +700,8 @@ def _get_noise_path( add_bads=add_bads, kind=kind, allow_missing=True, + subject=subject, + session=session, ) @@ -655,36 +709,52 @@ def _get_run_rest_noise_path( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - kind: Literal["orig", "sss"], - mf_reference_run: Optional[str], - add_bads: Optional[bool] = None, -) -> dict: - kwargs = dict( - cfg=cfg, - subject=subject, - session=session, - kind=kind, - add_bads=add_bads, - ) + session: str | None, + run: str | None, + task: str | None, + kind: RunKindT, + mf_reference_run: str | None, + add_bads: bool | None = None, +) -> InFilesT: if run is None and task in ("noise", "rest"): if task == "noise": - return _get_noise_path(mf_reference_run=mf_reference_run, **kwargs) + path = _get_noise_path( + mf_reference_run=mf_reference_run, + cfg=cfg, + subject=subject, + session=session, + kind=kind, + add_bads=add_bads, + ) else: assert task == "rest" - return _get_rest_path(**kwargs) + path = _get_rest_path( + cfg=cfg, + subject=subject, + session=session, + kind=kind, + add_bads=add_bads, + ) else: - return _get_run_path(run=run, task=task, **kwargs) + path = _get_run_path( + run=run, + task=task, + cfg=cfg, + subject=subject, + session=session, + kind=kind, + add_bads=add_bads, + ) + return path def _get_mf_reference_run_path( + *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - add_bads: bool, -) -> dict: + session: str | None, + add_bads: bool | None = None, +) -> InFilesT: return _get_run_path( cfg=cfg, subject=subject, @@ -697,15 +767,23 @@ def _get_mf_reference_run_path( ) +def _empty_room_match_path(run_path: BIDSPath, cfg: SimpleNamespace) -> BIDSPath: + return run_path.copy().update( + extension=".json", suffix="emptyroommatch", root=cfg.deriv_root + ) + + def _path_dict( *, cfg: SimpleNamespace, bids_path_in: BIDSPath, - add_bads: Optional[bool] = None, - kind: Literal["orig", "sss"], + add_bads: bool | None = None, + kind: RunKindT, allow_missing: bool, - key: Optional[str] = None, -) -> dict: + key: str | None = None, + subject: str, + session: str | None, +) -> InFilesT: if add_bads is None: add_bads = kind == "orig" and _do_mf_autobad(cfg=cfg) in_files = dict() @@ -715,35 +793,30 @@ def _path_dict( if allow_missing and not in_files[key].fpath.exists(): return dict() if add_bads: - bads_tsv_fname = _bads_path(cfg=cfg, bids_path_in=bids_path_in) + bads_tsv_fname = _bads_path( + cfg=cfg, + bids_path_in=bids_path_in, + subject=subject, + session=session, + ) if bads_tsv_fname.fpath.is_file() or not allow_missing: in_files[f"{key}-bads"] = bads_tsv_fname return in_files -def _auto_scores_path( - *, - cfg: SimpleNamespace, - bids_path_in: BIDSPath, -) -> BIDSPath: - return bids_path_in.copy().update( - suffix="scores", - extension=".json", - root=cfg.deriv_root, - split=None, - check=False, - ) - - def _bads_path( *, cfg: SimpleNamespace, bids_path_in: BIDSPath, + subject: str, + session: str | None, ) -> BIDSPath: return bids_path_in.copy().update( suffix="bads", extension=".tsv", root=cfg.deriv_root, + subject=subject, + session=session, split=None, check=False, ) @@ -753,12 +826,15 @@ def _read_bads_tsv( *, cfg: SimpleNamespace, bids_path_bads: BIDSPath, -) -> List[str]: +) -> list[str]: bads_tsv = pd.read_csv(bids_path_bads.fpath, sep="\t", header=0) - return bads_tsv[bads_tsv.columns[0]].tolist() + out = bads_tsv[bads_tsv.columns[0]].tolist() + assert isinstance(out, list) + assert all(isinstance(o, str) for o in out) + return out -def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict: +def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict[str, Any]: """Get config params needed for any raw data loading.""" return dict( # import_experimental_data / general @@ -772,6 +848,7 @@ def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict: # automatic add_bads find_noisy_channels_meg=config.find_noisy_channels_meg, find_flat_channels_meg=config.find_flat_channels_meg, + find_bad_channels_extra_kws=config.find_bad_channels_extra_kws, # 1. _load_data reader_extra_params=config.reader_extra_params, crop_runs=config.crop_runs, @@ -802,3 +879,17 @@ def _import_data_kwargs(*, config: SimpleNamespace, subject: str) -> dict: runs=get_runs(config=config, subject=subject), # XXX needs to accept session! **_bids_kwargs(config=config), ) + + +def _read_raw_msg( + bids_path_in: BIDSPath, + run: str | None, + task: str | None, +) -> tuple[str, RunTypeT]: + run_type: RunTypeT = "experimental" + if run is None and task in ("noise", "rest"): + if task == "noise": + run_type = "empty-room" + else: + run_type = "resting-state" + return f"Reading {run_type} recording: {bids_path_in.basename}", run_type diff --git a/mne_bids_pipeline/_io.py b/mne_bids_pipeline/_io.py index 0b7485f76..f84f83521 100644 --- a/mne_bids_pipeline/_io.py +++ b/mne_bids_pipeline/_io.py @@ -1,24 +1,17 @@ """I/O helpers.""" -from types import SimpleNamespace +from typing import Any import json_tricks -from mne_bids import BIDSPath from .typing import PathLike -def _write_json(fname: PathLike, data: dict) -> None: +def _write_json(fname: PathLike, data: dict[str, Any] | None) -> None: with open(fname, "w", encoding="utf-8") as f: json_tricks.dump(data, fp=f, allow_nan=True, sort_keys=False) -def _read_json(fname: PathLike) -> dict: - with open(fname, "r", encoding="utf-8") as f: +def _read_json(fname: PathLike) -> Any: + with open(fname, encoding="utf-8") as f: return json_tricks.load(f) - - -def _empty_room_match_path(run_path: BIDSPath, cfg: SimpleNamespace) -> BIDSPath: - return run_path.copy().update( - extension=".json", suffix="emptyroommatch", root=cfg.deriv_root - ) diff --git a/mne_bids_pipeline/_logging.py b/mne_bids_pipeline/_logging.py index 6bcb21d73..9cb812cc7 100644 --- a/mne_bids_pipeline/_logging.py +++ b/mne_bids_pipeline/_logging.py @@ -1,9 +1,9 @@ """Logging.""" + import datetime import inspect import logging import os -from typing import Optional, Union import rich.console import rich.theme @@ -12,23 +12,26 @@ class _MBPLogger: - def __init__(self): + def __init__(self) -> None: self._level = logging.INFO + self.__console: rich.console.Console | None = None # Do lazy instantiation of _console so that pytest's output capture # mechanics don't get messed up @property - def _console(self): - try: + def _console(self) -> rich.console.Console: + if isinstance(self.__console, rich.console.Console): return self.__console - except AttributeError: - pass # need to instantiate it, continue - - force_terminal = os.getenv("MNE_BIDS_PIPELINE_FORCE_TERMINAL", None) - if force_terminal is not None: - force_terminal = force_terminal.lower() in ("true", "1") - kwargs = dict(soft_wrap=True, force_terminal=force_terminal) - kwargs["theme"] = rich.theme.Theme( + + force_terminal: bool | None = None + force_terminal_env = os.getenv("MNE_BIDS_PIPELINE_FORCE_TERMINAL", None) + if force_terminal_env is not None: + force_terminal = force_terminal_env.lower() in ("true", "1") + legacy_windows = None + legacy_windows_env = os.getenv("MNE_BIDS_PIPELINE_LEGACY_WINDOWS", None) + if legacy_windows_env is not None: + legacy_windows = legacy_windows_env.lower() in ("true", "1") + theme = rich.theme.Theme( dict( default="white", # Rule @@ -43,53 +46,65 @@ def _console(self): error="red", ) ) - self.__console = rich.console.Console(**kwargs) + self.__console = rich.console.Console( + soft_wrap=True, + force_terminal=force_terminal, + legacy_windows=legacy_windows, + theme=theme, + ) return self.__console - def title(self, title): + def title(self, title: str) -> None: # Align left with ASCTIME offset title = f"[title]┌────────┬ {title}[/]" self._console.rule(title=title, characters="─", style="title", align="left") - def end(self, msg=""): + def end(self, msg: str = "") -> None: self._console.print(f"[title]└────────┴ {msg}[/]") @property - def level(self): + def level(self) -> int: return self._level @level.setter - def level(self, level): + def level(self, level: int) -> None: level = int(level) self._level = level - def debug(self, msg: str, *, extra: Optional[LogKwargsT] = None) -> None: + def debug( + self, msg: str, *, extra: LogKwargsT | dict[str, str] | None = None + ) -> None: self._log_message(kind="debug", msg=msg, **(extra or {})) - def info(self, msg: str, *, extra: Optional[LogKwargsT] = None) -> None: + def info( + self, msg: str, *, extra: LogKwargsT | dict[str, str] | None = None + ) -> None: self._log_message(kind="info", msg=msg, **(extra or {})) - def warning(self, msg: str, *, extra: Optional[LogKwargsT] = None) -> None: + def warning( + self, msg: str, *, extra: LogKwargsT | dict[str, str] | None = None + ) -> None: self._log_message(kind="warning", msg=msg, **(extra or {})) - def error(self, msg: str, *, extra: Optional[LogKwargsT] = None) -> None: + def error( + self, msg: str, *, extra: LogKwargsT | dict[str, str] | None = None + ) -> None: self._log_message(kind="error", msg=msg, **(extra or {})) def _log_message( self, kind: str, msg: str, - subject: Optional[Union[str, int]] = None, - session: Optional[Union[str, int]] = None, - run: Optional[Union[str, int]] = None, + subject: str | None = None, + session: str | None = None, + run: str | None = None, emoji: str = "", - ): + ) -> None: this_level = getattr(logging, kind.upper()) if this_level < self.level: return # Construct str - essr = [x for x in [emoji, subject, session, run] if x] - essr = " ".join(essr) + essr = " ".join(x for x in [emoji, subject, session, run] if x) if essr: essr += " " asctime = datetime.datetime.now().strftime("│%H:%M:%S│") @@ -103,13 +118,14 @@ def _log_message( def gen_log_kwargs( message: str, *, - subject: Optional[Union[str, int]] = None, - session: Optional[Union[str, int]] = None, - run: Optional[Union[str, int]] = None, - task: Optional[str] = None, + subject: str | int | None = None, + session: str | int | None = None, + run: str | int | None = None, + task: str | None = None, emoji: str = "⏳️", ) -> LogKwargsT: # Try to figure these out + assert isinstance(message, str), type(message) stack = inspect.stack() up_locals = stack[1].frame.f_locals if subject is None: @@ -152,7 +168,7 @@ def gen_log_kwargs( return kwargs -def _linkfile(uri): +def _linkfile(uri: str) -> str: return f"[link=file://{uri}]{uri}[/link]" diff --git a/mne_bids_pipeline/_main.py b/mne_bids_pipeline/_main.py index 9489a2cca..2d7e6e331 100755 --- a/mne_bids_pipeline/_main.py +++ b/mne_bids_pipeline/_main.py @@ -1,21 +1,20 @@ import argparse import pathlib -from textwrap import dedent import time -from typing import List +from textwrap import dedent from types import ModuleType, SimpleNamespace import numpy as np -from ._config_utils import _get_step_modules from ._config_import import _import_config from ._config_template import create_template_config -from ._logging import logger, gen_log_kwargs +from ._config_utils import _get_step_modules +from ._logging import gen_log_kwargs, logger from ._parallel import get_parallel_backend from ._run import _short_step_path -def main(): +def main() -> None: from . import __version__ parser = argparse.ArgumentParser() @@ -37,7 +36,7 @@ def main(): metavar="FILE", help="Create a template configuration file with the specified name. " "If specified, all other parameters will be ignored.", - ), + ) parser.add_argument( "--steps", dest="steps", @@ -70,7 +69,7 @@ def main(): If unspecified, this will be derivatives/mne-bids-pipeline inside the BIDS root.""" ), - ), + ) parser.add_argument( "--subject", dest="subject", default=None, help="The subject to process." ) @@ -95,7 +94,11 @@ def main(): help="Enable interactive mode.", ) parser.add_argument( - "--debug", dest="debug", action="store_true", help="Enable debugging on error." + "--debug", + "--pdb", + dest="debug", + action="store_true", + help="Enable debugging on error.", ) parser.add_argument( "--no-cache", @@ -112,7 +115,7 @@ def main(): config = options.config config_switch = options.config_switch - bad = False + bad: str | bool = False if config is None: if config_switch is None: bad = "neither was provided" @@ -142,7 +145,6 @@ def main(): steps = (steps,) on_error = "debug" if debug else None - cache = "1" if cache else "0" processing_stages = [] processing_steps = [] @@ -182,7 +184,7 @@ def main(): if not cache: overrides.memory_location = False - step_modules: List[ModuleType] = [] + step_modules: list[ModuleType] = [] STEP_MODULES = _get_step_modules() for stage, step in zip(processing_stages, processing_steps): if stage not in STEP_MODULES.keys(): @@ -197,6 +199,7 @@ def main(): else: # User specified 'stage/step' for step_module in STEP_MODULES[stage]: + assert step_module.__file__ is not None step_name = pathlib.Path(step_module.__file__).name if step in step_name: step_modules.append(step_module) @@ -227,6 +230,7 @@ def main(): for step_module in step_modules: start = time.time() + assert step_module.__file__ is not None step = _short_step_path(pathlib.Path(step_module.__file__)) logger.title(title=f"{step}") step_module.main(config=config_imported) @@ -236,9 +240,9 @@ def main(): minutes, seconds = divmod(remainder, 60) minutes = int(minutes) seconds = int(np.ceil(seconds)) # always take full seconds - elapsed = f"{seconds}s" + elapsed_str = f"{seconds}s" if minutes: - elapsed = f"{minutes}m {elapsed}" + elapsed_str = f"{minutes}m {elapsed_str}" if hours: - elapsed = f"{hours}h {elapsed}" - logger.end(f"done ({elapsed})") + elapsed_str = f"{hours}h {elapsed_str}" + logger.end(f"done ({elapsed_str})") diff --git a/mne_bids_pipeline/_parallel.py b/mne_bids_pipeline/_parallel.py index e79ae5151..92370e4c0 100644 --- a/mne_bids_pipeline/_parallel.py +++ b/mne_bids_pipeline/_parallel.py @@ -1,12 +1,14 @@ """Parallelization.""" -from typing import Literal, Callable +from collections.abc import Callable from types import SimpleNamespace +from typing import Any, Literal import joblib -from mne.utils import use_log_level, logger as mne_logger +from mne.utils import logger as mne_logger +from mne.utils import use_log_level -from ._logging import logger, gen_log_kwargs, _is_testing +from ._logging import _is_testing, gen_log_kwargs, logger def get_n_jobs(*, exec_params: SimpleNamespace, log_override: bool = False) -> int: @@ -25,7 +27,7 @@ def get_n_jobs(*, exec_params: SimpleNamespace, log_override: bool = False) -> i if log_override and n_jobs != orig_n_jobs: msg = f"Overriding n_jobs: {orig_n_jobs}→{n_jobs}" logger.info(**gen_log_kwargs(message=msg, emoji="override")) - return n_jobs + return int(n_jobs) dask_client = None @@ -65,16 +67,15 @@ def setup_dask_client(*, exec_params: SimpleNamespace) -> None: "distributed.worker.memory.spill": False, } ) - client = Client( # noqa: F841 + client = Client( # type: ignore[no-untyped-call] memory_limit=exec_params.dask_worker_memory_limit, n_workers=n_workers, threads_per_worker=1, name="mne-bids-pipeline", ) - client.auto_restart = False # don't restart killed workers dashboard_url = client.dashboard_link - msg = "Dask client dashboard: " f"[link={dashboard_url}]{dashboard_url}[/link]" + msg = f"Dask client dashboard: [link={dashboard_url}]{dashboard_url}[/link]" logger.info(**gen_log_kwargs(message=msg, emoji="🌎")) if exec_params.dask_open_dashboard: @@ -90,11 +91,12 @@ def get_parallel_backend_name( *, exec_params: SimpleNamespace, ) -> Literal["dask", "loky"]: + backend: Literal["dask", "loky"] = "loky" if ( exec_params.parallel_backend == "loky" or get_n_jobs(exec_params=exec_params) == 1 ): - backend = "loky" + pass elif exec_params.parallel_backend == "dask": # Disable interactive plotting backend import matplotlib @@ -127,7 +129,11 @@ def get_parallel_backend(exec_params: SimpleNamespace) -> joblib.parallel_backen return joblib.parallel_backend(backend, **kwargs) -def parallel_func(func: Callable, *, exec_params: SimpleNamespace): +def parallel_func( + func: Callable[..., Any], + *, + exec_params: SimpleNamespace, +) -> tuple[Callable[..., Any], Callable[..., Any]]: if ( get_parallel_backend_name(exec_params=exec_params) == "loky" and get_n_jobs(exec_params=exec_params) == 1 @@ -139,7 +145,7 @@ def parallel_func(func: Callable, *, exec_params: SimpleNamespace): parallel = Parallel() - def run_verbose(*args, verbose=mne_logger.level, **kwargs): + def run_verbose(*args, verbose=mne_logger.level, **kwargs): # type: ignore with use_log_level(verbose=verbose): return func(*args, **kwargs) diff --git a/mne_bids_pipeline/_reject.py b/mne_bids_pipeline/_reject.py index 5b3729dc2..b4cf49719 100644 --- a/mne_bids_pipeline/_reject.py +++ b/mne_bids_pipeline/_reject.py @@ -1,21 +1,22 @@ """Rejection.""" -from typing import Optional, Union, Iterable, Dict, Literal +from collections.abc import Iterable +from typing import Literal import mne -from ._logging import logger, gen_log_kwargs +from ._logging import gen_log_kwargs, logger def _get_reject( *, subject: str, - session: Optional[str], - reject: Union[Dict[str, float], Literal["autoreject_global"]], + session: str | None, + reject: dict[str, float] | Literal["autoreject_global"], ch_types: Iterable[Literal["meg", "mag", "grad", "eeg"]], param: str, - epochs: Optional[mne.BaseEpochs] = None, -) -> Dict[str, float]: + epochs: mne.BaseEpochs | None = None, +) -> dict[str, float]: if reject is None: return dict() @@ -35,20 +36,20 @@ def _get_reject( msg = "Generating rejection thresholds using autoreject …" logger.info(**gen_log_kwargs(message=msg)) - reject = autoreject.get_rejection_threshold( + reject_out: dict[str, float] = autoreject.get_rejection_threshold( epochs=epochs, ch_types=ch_types_autoreject, verbose=False, ) - return reject + return reject_out # Only keep thresholds for channel types of interest reject = reject.copy() - if ch_types == ["eeg"]: - ch_types_to_remove = ("mag", "grad") - else: - ch_types_to_remove = ("eeg",) - + ch_types_to_remove: list[str] = list() + if "meg" not in ch_types: + ch_types_to_remove.extend(("mag", "grad")) + if "eeg" not in ch_types: + ch_types_to_remove.append("eeg") for ch_type in ch_types_to_remove: try: del reject[ch_type] diff --git a/mne_bids_pipeline/_report.py b/mne_bids_pipeline/_report.py index bf42a27a2..d63d1b816 100644 --- a/mne_bids_pipeline/_report.py +++ b/mne_bids_pipeline/_report.py @@ -1,24 +1,31 @@ import contextlib +import traceback +from collections.abc import Generator from functools import lru_cache from io import StringIO -from typing import Optional, List, Literal +from textwrap import indent from types import SimpleNamespace +from typing import Any, Literal -from filelock import FileLock +import matplotlib.axes +import matplotlib.figure +import matplotlib.image import matplotlib.transforms +import mne import numpy as np import pandas as pd -from scipy.io import loadmat - -import mne +from filelock import FileLock from mne.io import BaseRaw +from mne.report.report import _df_bootstrap_table from mne.utils import _pl from mne_bids import BIDSPath from mne_bids.stats import count_events +from scipy.io import loadmat from ._config_utils import get_all_contrasts from ._decoding import _handle_csp_args -from ._logging import logger, gen_log_kwargs, _linkfile +from ._logging import _linkfile, gen_log_kwargs, logger +from .typing import FloatArrayT @contextlib.contextmanager @@ -27,30 +34,35 @@ def _open_report( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str] = None, - task: Optional[str] = None, -): - fname_report = BIDSPath( - subject=subject, - session=session, - # Report is across all runs, but for logging purposes it's helpful - # to pass the run and task for gen_log_kwargs - run=None, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - extension=".h5", - datatype=cfg.datatype, - root=cfg.deriv_root, - suffix="report", - check=False, - ).fpath + session: str | None, + run: str | None = None, + task: str | None = None, + fname_report: BIDSPath | None = None, + name: str = "report", +) -> Generator[mne.Report, None, None]: + if fname_report is None: + fname_report = BIDSPath( + subject=subject, + session=session, + # Report is across all runs, but for logging purposes it's helpful + # to pass the run and task for gen_log_kwargs + run=None, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + extension=".h5", + datatype=cfg.datatype, + root=cfg.deriv_root, + suffix="report", + check=False, + ) + fname_report = fname_report.fpath + assert fname_report.suffix == ".h5", fname_report.suffix # prevent parallel file access with FileLock(f"{fname_report}.lock"), _agg_backend(): if not fname_report.is_file(): - msg = "Initializing report HDF5 file" + msg = f"Initializing {name} HDF5 file" logger.info(**gen_log_kwargs(message=msg)) report = _gen_empty_report( cfg=cfg, @@ -62,26 +74,26 @@ def _open_report( report = mne.open_report(fname_report) except Exception as exc: raise exc.__class__( - f"Could not open report HDF5 file:\n{fname_report}\n" - f"Got error:\n{exc}\nPerhaps you need to delete it?" + f"Could not open {name} HDF5 file:\n{fname_report}, " + "Perhaps you need to delete it? Got error:\n\n" + f"{indent(traceback.format_exc(), ' ')}" ) from None try: yield report finally: try: - msg = "Adding config and sys info to report" - logger.info(**gen_log_kwargs(message=msg)) _finalize( report=report, exec_params=exec_params, subject=subject, session=session, run=run, + task=task, ) except Exception as exc: logger.warning(f"Failed: {exc}") fname_report_html = fname_report.with_suffix(".html") - msg = f"Saving report: {_linkfile(fname_report_html)}" + msg = f"Saving {name}: {_linkfile(fname_report_html)}" logger.info(**gen_log_kwargs(message=msg)) report.save(fname_report, overwrite=True) report.save(fname_report_html, overwrite=True, open_browser=False) @@ -123,11 +135,11 @@ def _open_report( def _plot_full_epochs_decoding_scores( - contrast_names: List[str], - scores: List[np.ndarray], + contrast_names: list[str], + scores: list[FloatArrayT], metric: str, kind: Literal["single-subject", "grand-average"] = "single-subject", -): +) -> tuple[matplotlib.figure.Figure, str, pd.DataFrame]: """Plot cross-validation results from full-epochs decoding.""" import matplotlib.pyplot as plt # nested import to help joblib import seaborn as sns @@ -177,7 +189,7 @@ def _plot_full_epochs_decoding_scores( ) # And now add the mean CV score on top. - def _plot_mean_cv_score(x, **kwargs): + def _plot_mean_cv_score(x: FloatArrayT, **kwargs: dict[str, Any]) -> None: plt.plot(x.mean(), **kwargs) g.map( @@ -202,17 +214,17 @@ def _plot_mean_cv_score(x, **kwargs): g.set_xlabels("") fig = g.fig - return fig, caption + return fig, caption, data def _plot_time_by_time_decoding_scores( *, - times: np.ndarray, - cross_val_scores: np.ndarray, + times: FloatArrayT, + cross_val_scores: FloatArrayT, metric: str, time_generalization: bool, decim: int, -): +) -> matplotlib.figure.Figure: """Plot cross-validation results from time-by-time decoding.""" import matplotlib.pyplot as plt # nested import to help joblib @@ -251,7 +263,13 @@ def _plot_time_by_time_decoding_scores( return fig -def _label_time_by_time(ax, *, decim, xlabel=None, ylabel=None): +def _label_time_by_time( + ax: matplotlib.axes.Axes, + *, + decim: int, + xlabel: str | None = None, + ylabel: str | None = None, +) -> None: extra = "" if decim > 1: extra = f" (decim={decim})" @@ -261,7 +279,9 @@ def _label_time_by_time(ax, *, decim, xlabel=None, ylabel=None): ax.set_ylabel(f"{ylabel}{extra}") -def _plot_time_by_time_decoding_scores_gavg(*, cfg, decoding_data): +def _plot_time_by_time_decoding_scores_gavg( + *, cfg: SimpleNamespace, decoding_data: dict[str, Any] +) -> matplotlib.figure.Figure: """Plot the grand-averaged decoding scores.""" import matplotlib.pyplot as plt # nested import to help joblib @@ -300,9 +320,7 @@ def _plot_time_by_time_decoding_scores_gavg(*, cfg, decoding_data): # Only add the label once if n_significant_clusters_plotted == 0: - label = ( - f"$p$ < {cfg.cluster_permutation_p_threshold} " f"(cluster pemutation)" - ) + label = f"$p$ < {cfg.cluster_permutation_p_threshold} (cluster pemutation)" else: label = None @@ -336,7 +354,7 @@ def _plot_time_by_time_decoding_scores_gavg(*, cfg, decoding_data): ax.text( 0.05, 0.05, - s=f'$N$={decoding_data["N"].squeeze()}', + s=f"$N$={decoding_data['N'].squeeze()}", fontsize="x-large", horizontalalignment="left", verticalalignment="bottom", @@ -352,7 +370,9 @@ def _plot_time_by_time_decoding_scores_gavg(*, cfg, decoding_data): return fig -def plot_time_by_time_decoding_t_values(decoding_data): +def plot_time_by_time_decoding_t_values( + decoding_data: dict[str, Any], +) -> matplotlib.figure.Figure: """Plot the t-values used to form clusters for the permutation test.""" import matplotlib.pyplot as plt # nested import to help joblib @@ -369,7 +389,7 @@ def plot_time_by_time_decoding_t_values(decoding_data): ax.text( 0.05, 0.05, - s=f'$N$={decoding_data["N"].squeeze()}', + s=f"$N$={decoding_data['N'].squeeze()}", fontsize="x-large", horizontalalignment="left", verticalalignment="bottom", @@ -395,8 +415,10 @@ def plot_time_by_time_decoding_t_values(decoding_data): def _plot_decoding_time_generalization( - decoding_data, metric: str, kind: Literal["single-subject", "grand-average"] -): + decoding_data: dict[str, Any], + metric: str, + kind: Literal["single-subject", "grand-average"], +) -> matplotlib.figure.Figure: """Plot time generalization matrix.""" import matplotlib.pyplot as plt # nested import to help joblib @@ -446,7 +468,7 @@ def _plot_decoding_time_generalization( def _gen_empty_report( - *, cfg: SimpleNamespace, subject: str, session: Optional[str] + *, cfg: SimpleNamespace, subject: str, session: str | None ) -> mne.Report: title = f"sub-{subject}" if session is not None: @@ -458,12 +480,16 @@ def _gen_empty_report( return report -def _contrasts_to_names(contrasts: List[List[str]]) -> List[str]: +def _contrasts_to_names(contrasts: list[list[str]]) -> list[str]: return [f"{c[0]} vs.\n{c[1]}" for c in contrasts] def add_event_counts( - *, cfg, subject: Optional[str], session: Optional[str], report: mne.Report + *, + cfg: SimpleNamespace, + subject: str | None, + session: str | None, + report: mne.Report, ) -> None: try: df_events = count_events(BIDSPath(root=cfg.bids_root, session=session)) @@ -474,30 +500,13 @@ def add_event_counts( logger.info(**gen_log_kwargs(message="Adding event counts to report …")) if df_events is not None: - css_classes = ("table", "table-striped", "table-borderless", "table-hover") + df_events.reset_index(drop=False, inplace=True, col_level=1) report.add_html( - f'
\n' - f"{df_events.to_html(classes=css_classes, border=0)}\n" - f"
", + _df_bootstrap_table(df=df_events, data_id="events"), title="Event counts", tags=("events",), replace=True, ) - css = ( - ".event-counts {\n" - " display: -webkit-box;\n" - " display: -ms-flexbox;\n" - " display: -webkit-flex;\n" - " display: flex;\n" - " justify-content: center;\n" - " text-align: center;\n" - "}\n\n" - "th, td {\n" - " text-align: center;\n" - "}\n" - ) - if css not in report.include: - report.add_custom_css(css=css) def _finalize( @@ -505,14 +514,19 @@ def _finalize( report: mne.Report, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], + session: str | None, + run: str | None, + task: str | None, ) -> None: """Add system information and the pipeline configuration to the report.""" # ensure they are always appended titles = ["Configuration file", "System information"] for title in titles: report.remove(title=title, remove_all=True) + # Print this exactly once + if _cached_sys_info.cache_info()[-1] == 0: # never run + msg = "Adding config and sys info to report" + logger.info(**gen_log_kwargs(message=msg)) # No longer need replace=True in these report.add_code( code=exec_params.config_path, @@ -537,40 +551,40 @@ def _finalize( # We make a lot of calls to this function and it takes > 1 sec generally # to run, so run it just once (it shouldn't meaningfully change anyway) @lru_cache(maxsize=1) -def _cached_sys_info(): +def _cached_sys_info() -> str: with StringIO() as f: mne.sys_info(f) return f.getvalue() -def _all_conditions(*, cfg): +def _all_conditions(*, cfg: SimpleNamespace) -> list[str]: if isinstance(cfg.conditions, dict): conditions = list(cfg.conditions.keys()) else: - conditions = cfg.conditions.copy() + conditions = list(cfg.conditions) all_contrasts = get_all_contrasts(cfg) conditions.extend([contrast["name"] for contrast in all_contrasts]) return conditions -def _sanitize_cond_tag(cond): - return str(cond).lower().replace(" ", "-") +def _sanitize_cond_tag(cond: str) -> str: + return str(cond).lower().replace("'", "").replace('"', "").replace(" ", "-") def _imshow_tf( - vals, - ax, + vals: FloatArrayT, + ax: matplotlib.axes.Axes, *, - tmin, - tmax, - fmin, - fmax, - vmin, - vmax, - cmap="RdBu_r", - mask=None, - cmap_masked=None, -): + tmin: FloatArrayT, + tmax: FloatArrayT, + fmin: FloatArrayT, + fmax: FloatArrayT, + vmin: float, + vmax: float, + cmap: str = "RdBu_r", + mask: FloatArrayT | None = None, + cmap_masked: Any | None = None, +) -> matplotlib.image.AxesImage: """Plot CSP TF decoding scores.""" # XXX Add support for more metrics assert len(vals) == len(tmin) == len(tmax) == len(fmin) == len(fmax) @@ -597,19 +611,19 @@ def add_csp_grand_average( *, cfg: SimpleNamespace, subject: str, - session: str, + session: str | None, report: mne.Report, cond_1: str, cond_2: str, fname_csp_freq_results: BIDSPath, - fname_csp_cluster_results: pd.DataFrame, -): + fname_csp_cluster_results: pd.DataFrame | None, +) -> None: """Add CSP decoding results to the grand average report.""" import matplotlib.pyplot as plt # nested import to help joblib # First, plot decoding scores across frequency bins (entire epochs). - section = "Decoding: CSP" - freq_name_to_bins_map = _handle_csp_args( + section = f"Decoding: CSP, N = {len(cfg.subjects)}" + freq_name_to_bins_map, _ = _handle_csp_args( cfg.decoding_csp_times, cfg.decoding_csp_freqs, cfg.decoding_metric, @@ -622,7 +636,7 @@ def add_csp_grand_average( freq_bin_starts = list() freq_bin_widths = list() decoding_scores = list() - error_bars = list() + error_bars_list = list() csp_freq_results = pd.read_excel(fname_csp_freq_results, sheet_name="CSP Frequency") for freq_range_name, freq_bins in freq_name_to_bins_map.items(): results = csp_freq_results.loc[ @@ -638,10 +652,10 @@ def add_csp_grand_average( cis_upper = results["mean_ci_upper"][bi] error_bars_lower = decoding_scores[-1] - cis_lower error_bars_upper = cis_upper - decoding_scores[-1] - error_bars.append(np.stack([error_bars_lower, error_bars_upper])) - assert len(error_bars[-1]) == 2 # lower, upper + error_bars_list.append(np.stack([error_bars_lower, error_bars_upper])) + assert len(error_bars_list[-1]) == 2 # lower, upper del cis_lower, cis_upper, error_bars_lower, error_bars_upper - error_bars = np.array(error_bars, float).T + error_bars = np.array(error_bars_list, float).T if cfg.decoding_metric == "roc_auc": metric = "ROC AUC" @@ -692,6 +706,8 @@ def add_csp_grand_average( ) # Now, plot decoding scores across time-frequency bins. + if fname_csp_cluster_results is None: + return csp_cluster_results = loadmat(fname_csp_cluster_results) fig, ax = plt.subplots( nrows=1, ncols=2, sharex=True, sharey=True, constrained_layout=True @@ -831,7 +847,7 @@ def add_csp_grand_average( @contextlib.contextmanager -def _agg_backend(): +def _agg_backend() -> Generator[None, None, None]: import matplotlib backend = matplotlib.get_backend() @@ -851,13 +867,13 @@ def _add_raw( cfg: SimpleNamespace, report: mne.report.Report, bids_path_in: BIDSPath, + raw: BaseRaw, title: str, - tags: tuple = (), - raw: Optional[BaseRaw] = None, - extra_html: Optional[str] = None, -): + tags: tuple[str, ...] = (), + extra_html: str | None = None, +) -> None: if bids_path_in.run is not None: - title += f", run {repr(bids_path_in.run)}" + title += f", run {bids_path_in.run}" elif bids_path_in.task in ("noise", "rest"): title += f", {bids_path_in.task}" plot_raw_psd = ( @@ -868,7 +884,7 @@ def _add_raw( tags = ("raw", f"run-{bids_path_in.run}") + tags with mne.use_log_level("error"): report.add_raw( - raw=raw or bids_path_in, + raw=raw, title=title, butterfly=5, psd=plot_raw_psd, @@ -891,8 +907,8 @@ def _render_bem( cfg: SimpleNamespace, report: mne.report.Report, subject: str, - session: Optional[str], -): + session: str | None, +) -> None: logger.info(**gen_log_kwargs(message="Rendering MRI slices with BEM contours.")) report.add_bem( subject=cfg.fs_subject, diff --git a/mne_bids_pipeline/_run.py b/mne_bids_pipeline/_run.py index c76126ea2..bc270b5bf 100644 --- a/mne_bids_pipeline/_run.py +++ b/mne_bids_pipeline/_run.py @@ -7,28 +7,32 @@ import pathlib import pdb import sys -import traceback import time -from typing import Callable, Optional, Dict, List, Literal, Union +import traceback +from collections.abc import Callable, Iterable from types import SimpleNamespace +from typing import Any, Literal -from filelock import FileLock -from joblib import Memory import json_tricks import pandas as pd +from filelock import FileLock +from joblib import Memory from mne_bids import BIDSPath from ._config_utils import get_task -from ._logging import logger, gen_log_kwargs, _is_testing +from ._logging import _is_testing, gen_log_kwargs, logger +from .typing import InFilesT, OutFilesT def failsafe_run( - get_input_fnames: Optional[Callable] = None, - get_output_fnames: Optional[Callable] = None, -) -> Callable: - def failsafe_run_decorator(func): + *, + get_input_fnames: Callable[..., Any] | None = None, + get_output_fnames: Callable[..., Any] | None = None, + require_output: bool = True, +) -> Callable[..., Any]: + def failsafe_run_decorator(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) # Preserve "identity" of original function - def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): + def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): # type: ignore __mne_bids_pipeline_step__ = pathlib.Path(inspect.getfile(func)) # noqa exec_params = kwargs["exec_params"] on_error = exec_params.on_error @@ -36,15 +40,13 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): exec_params=exec_params, get_input_fnames=get_input_fnames, get_output_fnames=get_output_fnames, + require_output=require_output, + func_name=f"{__mne_bids_pipeline_step__}::{func.__name__}", ) - kwargs_copy = copy.deepcopy(kwargs) t0 = time.time() - kwargs_copy["cfg"] = json_tricks.dumps( - kwargs_copy["cfg"], sort_keys=False, indent=4 - ) log_info = pd.concat( [ - pd.Series(kwargs_copy, dtype=object), + pd.Series(kwargs, dtype=object), pd.Series(index=["time", "success", "error_message"], dtype=object), ] ) @@ -57,29 +59,27 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): log_info["error_message"] = "" except Exception as e: # Only keep what gen_log_kwargs() can handle - kwargs_copy = { - k: v - for k, v in kwargs_copy.items() - if k in ("subject", "session", "task", "run") + kwargs_log = { + k: kwargs[k] + for k in ("subject", "session", "task", "run") + if k in kwargs } - message = ( - f"A critical error occurred. " f"The error message was: {str(e)}" - ) + message = f"A critical error occurred. The error message was: {str(e)}" log_info["success"] = False log_info["error_message"] = str(e) # Find the limit / step where the error occurred step_dir = pathlib.Path(__file__).parent / "steps" - tb = traceback.extract_tb(e.__traceback__) - for fi, frame in enumerate(inspect.stack()): + tb_list = list(traceback.extract_tb(e.__traceback__)) + for fi, frame in enumerate(tb_list): is_step = pathlib.Path(frame.filename).parent.parent == step_dir del frame if is_step: # omit everything before the "step" dir, which will # generally be stuff from this file and joblib - tb = tb[-fi:] + tb_list = tb_list[fi:] break - tb = "".join(traceback.format_list(tb)) + tb = "".join(traceback.format_list(tb_list)) if on_error == "abort": message += f"\n\nAborting pipeline run. The traceback is:\n\n{tb}" @@ -87,22 +87,22 @@ def __mne_bids_pipeline_failsafe_wrapper__(*args, **kwargs): if _is_testing(): raise logger.error( - **gen_log_kwargs(message=message, **kwargs_copy, emoji="❌") + **gen_log_kwargs(message=message, **kwargs_log, emoji="❌") ) sys.exit(1) elif on_error == "debug": message += "\n\nStarting post-mortem debugger." logger.error( - **gen_log_kwargs(message=message, **kwargs_copy, emoji="🐛") + **gen_log_kwargs(message=message, **kwargs_log, emoji="🐛") ) - extype, value, tb = sys.exc_info() + _, _, tb_ = sys.exc_info() print(tb) - pdb.post_mortem(tb) + pdb.post_mortem(tb_) sys.exit(1) else: message += "\n\nContinuing pipeline run." logger.error( - **gen_log_kwargs(message=message, **kwargs_copy, emoji="🔂") + **gen_log_kwargs(message=message, **kwargs_log, emoji="🔂") ) log_info["time"] = round(time.time() - t0, ndigits=1) return log_info @@ -120,7 +120,15 @@ def hash_file_path(path: pathlib.Path) -> str: class ConditionalStepMemory: - def __init__(self, *, exec_params, get_input_fnames, get_output_fnames): + def __init__( + self, + *, + exec_params: SimpleNamespace, + get_input_fnames: Callable[..., Any] | None, + get_output_fnames: Callable[..., Any] | None, + require_output: bool, + func_name: str, + ) -> None: memory_location = exec_params.memory_location if memory_location is True: use_location = exec_params.deriv_root / exec_params.memory_subdir @@ -138,9 +146,11 @@ def __init__(self, *, exec_params, get_input_fnames, get_output_fnames): self.get_input_fnames = get_input_fnames self.get_output_fnames = get_output_fnames self.memory_file_method = exec_params.memory_file_method + self.require_output = require_output + self.func_name = func_name - def cache(self, func): - def wrapper(*args, **kwargs): + def cache(self, func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> None: in_files = out_files = None force_run = kwargs.pop("force_run", False) these_kwargs = kwargs.copy() @@ -188,6 +198,7 @@ def wrapper(*args, **kwargs): hashes.append(hash_(k, sidecar)) kwargs["cfg"] = copy.deepcopy(kwargs["cfg"]) + assert isinstance(kwargs["cfg"], SimpleNamespace), type(kwargs["cfg"]) kwargs["cfg"].hashes = hashes del in_files # will be modified by func call @@ -195,7 +206,8 @@ def wrapper(*args, **kwargs): # call (https://github.com/joblib/joblib/issues/1342), but our hash # should be plenty fast so let's not bother for now. memorized_func = self.memory.cache(func, ignore=self.ignore) - msg = emoji = None + msg: str | None = None + emoji: str | None = None short_circuit = False # Used for logging automatically subject = kwargs.get("subject", None) # noqa @@ -225,13 +237,17 @@ def wrapper(*args, **kwargs): for key, (fname, this_hash) in out_files_hashes.items(): fname = pathlib.Path(fname) if not fname.exists(): - msg = "Output file missing, will recompute …" + msg = f"Output file missing: {fname}, will recompute …" emoji = "🧩" bad_out_files = True break got_hash = hash_(key, fname, kind="out")[1] if this_hash != got_hash: - msg = "Output file hash mismatch, will recompute …" + msg = ( + f"Output file {self.memory_file_method} mismatch for " + f"{fname} ({this_hash} != {got_hash}), will " + "recompute …" + ) emoji = "🚫" bad_out_files = True break @@ -256,15 +272,31 @@ def wrapper(*args, **kwargs): del out_files if msg is not None: + assert emoji is not None logger.info(**gen_log_kwargs(message=msg, emoji=emoji)) if short_circuit: return # https://joblib.readthedocs.io/en/latest/memory.html#joblib.memory.MemorizedFunc.call # noqa: E501 if force_run or unknown_inputs or bad_out_files: - memorized_func.call(*args, **kwargs) + # Joblib 1.4.0 only returns the output, but 1.3.2 returns both. + # Fortunately we can use tuple-ness to tell the difference (we always + # return None or a dict) + out_files = memorized_func.call(*args, **kwargs) + if isinstance(out_files, tuple): + out_files = out_files[0] else: - memorized_func(*args, **kwargs) + out_files = memorized_func(*args, **kwargs) + if self.require_output: + assert isinstance(out_files, dict) and len(out_files), ( + f"Internal error: step must return non-empty out_files dict, got " + f"{type(out_files).__name__} for:\n{self.func_name}" + ) + else: + assert out_files is None, ( + f"Internal error: step must return None, got {type(out_files)} " + f"for:\n{self.func_name}" + ) return wrapper @@ -272,7 +304,7 @@ def clear(self) -> None: self.memory.clear() -def save_logs(*, config: SimpleNamespace, logs) -> None: # TODO add type +def save_logs(*, config: SimpleNamespace, logs: Iterable[pd.Series]) -> None: fname = config.deriv_root / f"task-{get_task(config)}_log.xlsx" # Get the script from which the function is called for logging @@ -280,15 +312,7 @@ def save_logs(*, config: SimpleNamespace, logs) -> None: # TODO add type sheet_name = sheet_name[-30:] # shorten due to limit of excel format df = pd.DataFrame(logs) - - columns = df.columns - if "cfg" in columns: - columns = list(columns) - idx = columns.index("cfg") - del columns[idx] - columns.insert(-3, "cfg") # put it before time, success & err cols - - df = df[columns] + del logs with FileLock(fname.with_suffix(fname.suffix + ".lock")): append = fname.exists() @@ -298,13 +322,33 @@ def save_logs(*, config: SimpleNamespace, logs) -> None: # TODO add type mode="a" if append else "w", if_sheet_exists="replace" if append else None, ) + assert isinstance(config, SimpleNamespace), type(config) + cf_dict = dict() + for key, val in config.__dict__.items(): + # We need to be careful about functions, json_tricks does not work with them + if inspect.isfunction(val): + new_val = "" + if func_file := inspect.getfile(val): + new_val += f"{func_file}:" + if getattr(val, "__qualname__", None): + new_val += val.__qualname__ + val = "custom callable" if not new_val else new_val + val = json_tricks.dumps(val, indent=4, sort_keys=False) + # 32767 char limit per cell (could split over lines but if something is + # this long, you'll probably get the gist from the first 32k chars) + if len(val) > 32767: + val = val[:32765] + " …" + cf_dict[key] = val + cf_df = pd.DataFrame([cf_dict], dtype=object) with writer: + # Config first then the data + cf_df.to_excel(writer, sheet_name="config", index=False) df.to_excel(writer, sheet_name=sheet_name, index=False) def _update_for_splits( - files_dict: Union[Dict[str, BIDSPath], BIDSPath], - key: Optional[str], + files_dict: InFilesT | BIDSPath, + key: str | None, *, single: bool = False, allow_missing: bool = False, @@ -312,6 +356,7 @@ def _update_for_splits( if not isinstance(files_dict, dict): # fake it assert key is None files_dict, key = dict(x=files_dict), "x" + assert isinstance(key, str), type(key) bids_path = files_dict[key] if bids_path.fpath.exists(): return bids_path # no modifications needed @@ -337,7 +382,7 @@ def _update_for_splits( return bids_path -def _sanitize_callable(val): +def _sanitize_callable(val: Any) -> Any: # Callables are not nicely pickleable, so let's pass a string instead if callable(val): return "custom" @@ -346,23 +391,27 @@ def _sanitize_callable(val): def _get_step_path( - stack: Optional[List[inspect.FrameInfo]] = None, + stack: list[inspect.FrameInfo] | None = None, ) -> pathlib.Path: if stack is None: stack = inspect.stack() - paths = list() + paths: list[str] = list() for frame in stack: fname = pathlib.Path(frame.filename) + paths.append(frame.filename) if "steps" in fname.parts: return fname else: # pragma: no cover try: - return frame.frame.f_locals["__mne_bids_pipeline_step__"] + out = frame.frame.f_locals["__mne_bids_pipeline_step__"] except KeyError: pass + else: + assert isinstance(out, pathlib.Path) + return out else: # pragma: no cover - paths = "\n".join(paths) - raise RuntimeError(f"Could not find step path in call stack:\n{paths}") + paths_str = "\n".join(paths) + raise RuntimeError(f"Could not find step path in call stack:\n{paths_str}") def _short_step_path(step_path: pathlib.Path) -> str: @@ -372,12 +421,29 @@ def _short_step_path(step_path: pathlib.Path) -> str: def _prep_out_files( *, exec_params: SimpleNamespace, - out_files: Dict[str, BIDSPath], -): + out_files: InFilesT, + check_relative: pathlib.Path | None = None, + bids_only: bool = True, +) -> OutFilesT: + if check_relative is None: + check_relative = exec_params.deriv_root for key, fname in out_files.items(): + # Sanity check that we only ever write to the derivatives directory + if bids_only: + assert isinstance(fname, BIDSPath), (type(fname), fname) + # raw and epochs can split on write, and .save should check for us now, so + # we only need to check *other* types (these should never split) + if isinstance(fname, BIDSPath) and fname.suffix not in ("raw", "epo"): + assert fname.split is None, fname + fname = pathlib.Path(fname) + if not fname.is_relative_to(check_relative): + raise RuntimeError( + f"Output BIDSPath not relative to expected root {check_relative}:" + f"\n{fname}" + ) out_files[key] = _path_to_str_hash( key, - pathlib.Path(fname), + fname, method=exec_params.memory_file_method, kind="out", ) @@ -386,17 +452,18 @@ def _prep_out_files( def _path_to_str_hash( k: str, - v: Union[BIDSPath, pathlib.Path], + v: BIDSPath | pathlib.Path, *, method: Literal["mtime", "hash"], kind: str = "in", -): +) -> tuple[str, str | float]: if isinstance(v, BIDSPath): v = v.fpath assert isinstance(v, pathlib.Path), f'Bad type {type(v)}: {kind}_files["{k}"] = {v}' assert v.exists(), f'missing {kind}_files["{k}"] = {v}' + this_hash: str | float = "" if method == "mtime": - this_hash = v.lstat().st_mtime + this_hash = v.stat().st_mtime else: assert method == "hash" # guaranteed this_hash = hash_file_path(v) diff --git a/mne_bids_pipeline/_viz.py b/mne_bids_pipeline/_viz.py index 8e49af509..4aadf0742 100644 --- a/mne_bids_pipeline/_viz.py +++ b/mne_bids_pipeline/_viz.py @@ -1,10 +1,13 @@ -from typing import List +from typing import Any + import numpy as np import pandas as pd from matplotlib.figure import Figure -def plot_auto_scores(auto_scores, *, ch_types) -> List[Figure]: +def plot_auto_scores( + auto_scores: dict[str, Any], *, ch_types: list[str] +) -> list[Figure]: # Plot scores of automated bad channel detection. import matplotlib.pyplot as plt import seaborn as sns @@ -15,7 +18,7 @@ def plot_auto_scores(auto_scores, *, ch_types) -> List[Figure]: ch_types_[idx] = "grad" ch_types_.insert(idx + 1, "mag") - figs: List[Figure] = [] + figs: list[Figure] = [] for ch_type in ch_types_: # Only select the data for mag or grad channels. ch_subset = auto_scores["ch_types"] == ch_type diff --git a/mne_bids_pipeline/steps/freesurfer/_01_recon_all.py b/mne_bids_pipeline/steps/freesurfer/_01_recon_all.py index ee803c800..c7c0496b0 100755 --- a/mne_bids_pipeline/steps/freesurfer/_01_recon_all.py +++ b/mne_bids_pipeline/steps/freesurfer/_01_recon_all.py @@ -3,32 +3,50 @@ This will run FreeSurfer's ``recon-all --all`` if necessary. """ + import os import shutil import sys from pathlib import Path +from types import SimpleNamespace +from typing import Any from mne.utils import run_subprocess -from ..._config_utils import get_fs_subjects_dir, get_subjects -from ..._logging import logger, gen_log_kwargs -from ..._parallel import parallel_func, get_parallel_backend +from mne_bids_pipeline._config_utils import ( + _has_session_specific_anat, + get_fs_subjects_dir, + get_sessions, + get_subjects, +) +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func fs_bids_app = Path(__file__).parent / "contrib" / "run.py" -def run_recon(root_dir, subject, fs_bids_app, subjects_dir) -> None: +def run_recon( + root_dir: Path, + subject: str, + fs_bids_app: Any, + subjects_dir: Path, + session: str | None = None, +) -> None: subj_dir = subjects_dir / f"sub-{subject}" + sub_ses = f"Subject {subject}" + if session is not None: + subj_dir = subj_dir.with_name(f"{subj_dir.name}_ses-{session}") + sub_ses = f"{sub_ses} session {session}" if subj_dir.exists(): msg = ( - f"Subject {subject} is already present. Please delete the " + f"Recon for {sub_ses} is already present. Please delete the " f"directory if you want to recompute." ) logger.info(**gen_log_kwargs(message=msg)) return msg = ( - "Running recon-all on subject {subject}. This will take " + "Running recon-all on {sub_ses}. This will take " "a LONG time – it's a good idea to let it run over night." ) logger.info(**gen_log_kwargs(message=msg)) @@ -55,11 +73,13 @@ def run_recon(root_dir, subject, fs_bids_app, subjects_dir) -> None: f"--license_file={license_file}", f"--participant_label={subject}", ] + if session is not None: + cmd += [f"--session_label={session}"] logger.debug("Running: " + " ".join(cmd)) run_subprocess(cmd, env=env, verbose=logger.level) -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run freesurfer recon-all command on BIDS dataset. The script allows to run the freesurfer recon-all @@ -76,21 +96,33 @@ def main(*, config) -> None: You must have freesurfer available on your system. - Run via the MNE BIDS Pipeline's `run.py`: + Run via the MNE BIDS Pipeline's CLI: - python run.py --steps=freesurfer --config=your_pipeline_config.py + mne_bids_pipeline --steps=freesurfer --config=your_pipeline_config.py """ # noqa subjects = get_subjects(config) + sessions = get_sessions(config) root_dir = config.bids_root subjects_dir = Path(get_fs_subjects_dir(config)) subjects_dir.mkdir(parents=True, exist_ok=True) + # check for session-specific MRIs within subject, and handle accordingly + subj_sess = list() + for _subj in subjects: + for _sess in sessions: + session = ( + _sess + if _has_session_specific_anat(_subj, _sess, subjects_dir) + else None + ) + subj_sess.append((_subj, session)) + with get_parallel_backend(config.exec_params): parallel, run_func = parallel_func(run_recon, exec_params=config.exec_params) parallel( - run_func(root_dir, subject, fs_bids_app, subjects_dir) - for subject in subjects + run_func(root_dir, subject, fs_bids_app, subjects_dir, session) + for subject, session in subj_sess ) # Handle fsaverage diff --git a/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py b/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py index 560448713..18d6e7f7b 100644 --- a/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py +++ b/mne_bids_pipeline/steps/freesurfer/_02_coreg_surfaces.py @@ -4,35 +4,50 @@ Use FreeSurfer's ``mkheadsurf`` and related utilities to make head surfaces suitable for coregistration. """ + from pathlib import Path from types import SimpleNamespace import mne.bem -from ..._config_utils import ( - get_fs_subjects_dir, +from mne_bids_pipeline._config_utils import ( get_fs_subject, + get_fs_subjects_dir, + get_sessions, get_subjects, - _get_scalp_in_files, ) -from ..._logging import logger, gen_log_kwargs -from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run, _prep_out_files +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._run import _prep_out_files, failsafe_run +from mne_bids_pipeline.typing import InFilesPathT, OutFilesT fs_bids_app = Path(__file__).parent / "contrib" / "run.py" +def _get_scalp_in_files(cfg: SimpleNamespace) -> InFilesPathT: + subject_path = Path(cfg.fs_subjects_dir) / cfg.fs_subject + seghead = subject_path / "surf" / "lh.seghead" + in_files = dict() + if seghead.is_file(): + in_files["seghead"] = seghead + else: + in_files["t1"] = subject_path / "mri" / "T1.mgz" + return in_files + + def get_input_fnames_coreg_surfaces( *, cfg: SimpleNamespace, subject: str, -) -> dict: +) -> InFilesPathT: return _get_scalp_in_files(cfg) -def get_output_fnames_coreg_surfaces(*, cfg: SimpleNamespace, subject: str) -> dict: +def get_output_fnames_coreg_surfaces( + *, cfg: SimpleNamespace, subject: str +) -> InFilesPathT: out_files = dict() - subject_path = Path(cfg.subjects_dir) / cfg.fs_subject + subject_path = Path(cfg.fs_subjects_dir) / cfg.fs_subject out_files["seghead"] = subject_path / "surf" / "lh.seghead" for key in ("dense", "medium", "sparse"): out_files[f"head-{key}"] = ( @@ -49,34 +64,44 @@ def make_coreg_surfaces( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - in_files: dict, -) -> dict: + in_files: InFilesPathT, +) -> OutFilesT: """Create head surfaces for use with MNE-Python coregistration tools.""" msg = "Creating scalp surfaces for coregistration" logger.info(**gen_log_kwargs(message=msg)) in_files.pop("t1" if "t1" in in_files else "seghead") mne.bem.make_scalp_surfaces( subject=cfg.fs_subject, - subjects_dir=cfg.subjects_dir, + subjects_dir=cfg.fs_subjects_dir, force=True, overwrite=True, ) out_files = get_output_fnames_coreg_surfaces(cfg=cfg, subject=subject) - return _prep_out_files(exec_params=exec_params, out_files=out_files) + return _prep_out_files( + exec_params=exec_params, + out_files=out_files, + check_relative=cfg.fs_subjects_dir, + bids_only=False, + ) -def get_config(*, config, subject) -> SimpleNamespace: +def get_config( + *, + config: SimpleNamespace, + subject: str, + session: str | None = None, +) -> SimpleNamespace: cfg = SimpleNamespace( - subject=subject, - fs_subject=get_fs_subject(config, subject), - subjects_dir=get_fs_subjects_dir(config), + fs_subject=get_fs_subject(config, subject, session=session), + fs_subjects_dir=get_fs_subjects_dir(config), ) return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: # Ensure we're also processing fsaverage if present subjects = get_subjects(config) + sessions = get_sessions(config) if (Path(get_fs_subjects_dir(config)) / "fsaverage").exists(): subjects.append("fsaverage") @@ -90,10 +115,12 @@ def main(*, config) -> None: cfg=get_config( config=config, subject=subject, + session=session, ), exec_params=config.exec_params, force_run=config.recreate_scalp_surface, subject=subject, ) for subject in subjects + for session in sessions ) diff --git a/mne_bids_pipeline/steps/freesurfer/__init__.py b/mne_bids_pipeline/steps/freesurfer/__init__.py index 84e37008a..7f4d9d088 100644 --- a/mne_bids_pipeline/steps/freesurfer/__init__.py +++ b/mne_bids_pipeline/steps/freesurfer/__init__.py @@ -3,7 +3,6 @@ Surface reconstruction via FreeSurfer. These steps are not run by default. """ -from . import _01_recon_all -from . import _02_coreg_surfaces +from . import _01_recon_all, _02_coreg_surfaces _STEPS = (_01_recon_all, _02_coreg_surfaces) diff --git a/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py b/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py index a964e6d59..df84c6eea 100644 --- a/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py +++ b/mne_bids_pipeline/steps/init/_01_init_derivatives_dir.py @@ -3,24 +3,22 @@ Initialize the derivatives directory. """ -from typing import Optional from types import SimpleNamespace from mne_bids.config import BIDS_VERSION from mne_bids.utils import _write_json -from ..._config_utils import get_subjects, get_sessions, _bids_kwargs -from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run +from mne_bids_pipeline._config_utils import _bids_kwargs, get_subjects_sessions +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._run import _prep_out_files, failsafe_run +from mne_bids_pipeline.typing import OutFilesT -def init_dataset(cfg) -> None: +@failsafe_run() +def init_dataset(cfg: SimpleNamespace, exec_params: SimpleNamespace) -> OutFilesT: """Prepare the pipeline directory in /derivatives.""" - fname_json = cfg.deriv_root / "dataset_description.json" - if fname_json.is_file(): - msg = "Output directories already exist …" - logger.info(**gen_log_kwargs(message=msg, emoji="✅")) - return + out_files = dict() + out_files["json"] = cfg.deriv_root / "dataset_description.json" logger.info(**gen_log_kwargs(message="Initializing output directories.")) cfg.deriv_root.mkdir(exist_ok=True, parents=True) @@ -38,16 +36,18 @@ def init_dataset(cfg) -> None: "URL": "n/a", } - _write_json(fname_json, ds_json, overwrite=True) + _write_json(out_files["json"], ds_json, overwrite=True) + return _prep_out_files( + exec_params=exec_params, out_files=out_files, bids_only=False + ) -@failsafe_run() def init_subject_dirs( *, cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> None: """Create processing data output directories for individual participants.""" out_dir = cfg.deriv_root / f"sub-{subject}" @@ -60,7 +60,7 @@ def init_subject_dirs( def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( PIPELINE_NAME=config.PIPELINE_NAME, @@ -71,13 +71,13 @@ def get_config( return cfg -def main(*, config): +def main(*, config: SimpleNamespace) -> None: """Initialize the output directories.""" - init_dataset(cfg=get_config(config=config)) + init_dataset(cfg=get_config(config=config), exec_params=config.exec_params) # Don't bother with parallelization here as I/O operations are generally # not well parallelized (and this should be very fast anyway) - for subject in get_subjects(config): - for session in get_sessions(config): + for subject, sessions in get_subjects_sessions(config).items(): + for session in sessions: init_subject_dirs( cfg=get_config( config=config, diff --git a/mne_bids_pipeline/steps/init/_02_find_empty_room.py b/mne_bids_pipeline/steps/init/_02_find_empty_room.py index d9334a9cf..6a742df5c 100644 --- a/mne_bids_pipeline/steps/init/_02_find_empty_room.py +++ b/mne_bids_pipeline/steps/init/_02_find_empty_room.py @@ -1,26 +1,31 @@ """Find empty-room data matches.""" from types import SimpleNamespace -from typing import Dict, Optional from mne_bids import BIDSPath -from ..._config_utils import ( - get_datatype, - get_sessions, - get_subjects, - get_mf_reference_run, +from mne_bids_pipeline._config_utils import ( _bids_kwargs, _pl, + get_datatype, + get_mf_reference_run, + get_subjects_sessions, +) +from mne_bids_pipeline._import_data import _empty_room_match_path +from mne_bids_pipeline._io import _write_json +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, ) -from ..._io import _empty_room_match_path, _write_json -from ..._logging import gen_log_kwargs, logger -from ..._run import _update_for_splits, failsafe_run, save_logs, _prep_out_files +from mne_bids_pipeline.typing import InFilesT, OutFilesT def get_input_fnames_find_empty_room( - *, subject: str, session: Optional[str], run: Optional[str], cfg: SimpleNamespace -) -> Dict[str, BIDSPath]: + *, subject: str, session: str | None, run: str | None, cfg: SimpleNamespace +) -> InFilesT: """Get paths of files required by find_empty_room function.""" bids_path_in = BIDSPath( subject=subject, @@ -35,7 +40,7 @@ def get_input_fnames_find_empty_room( root=cfg.bids_root, check=False, ) - in_files: Dict[str, BIDSPath] = dict() + in_files: InFilesT = dict() in_files[f"raw_run-{run}"] = bids_path_in _update_for_splits(in_files, f"raw_run-{run}", single=True) if hasattr(bids_path_in, "find_matching_sidecar"): @@ -62,10 +67,10 @@ def find_empty_room( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - in_files: Dict[str, BIDSPath], -) -> Dict[str, BIDSPath]: + session: str | None, + run: str | None, + in_files: InFilesT, +) -> OutFilesT: raw_path = in_files.pop(f"raw_run-{run}") in_files.pop("sidecar", None) try: @@ -101,7 +106,7 @@ def find_empty_room( def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: cfg = SimpleNamespace( **_bids_kwargs(config=config), @@ -109,7 +114,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run find_empty_room.""" if not config.process_empty_room: msg = "Skipping, process_empty_room is set to False …" @@ -122,17 +127,16 @@ def main(*, config) -> None: # This will be I/O bound if the sidecar is not complete, so let's not run # in parallel. logs = list() - for subject in get_subjects(config): - run = get_mf_reference_run(config=config) - logs.append( - find_empty_room( - cfg=get_config( - config=config, - ), - exec_params=config.exec_params, - subject=subject, - session=get_sessions(config)[0], - run=run, + for subject, sessions in get_subjects_sessions(config).items(): + for session in sessions: + run = get_mf_reference_run(config=config) + logs.append( + find_empty_room( + cfg=get_config(config=config), + exec_params=config.exec_params, + subject=subject, + session=session, + run=run, + ) ) - ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/init/__init__.py b/mne_bids_pipeline/steps/init/__init__.py index 72a80cf13..6435ffdfe 100644 --- a/mne_bids_pipeline/steps/init/__init__.py +++ b/mne_bids_pipeline/steps/init/__init__.py @@ -1,7 +1,6 @@ """Filesystem initialization and dataset inspection.""" -from . import _01_init_derivatives_dir -from . import _02_find_empty_room +from . import _01_init_derivatives_dir, _02_find_empty_room _STEPS = ( _01_init_derivatives_dir, diff --git a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py index 655280e52..06dff95c4 100644 --- a/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py +++ b/mne_bids_pipeline/steps/preprocessing/_01_data_quality.py @@ -1,64 +1,73 @@ """Assess data quality and find bad (and flat) channels.""" from types import SimpleNamespace -from typing import Optional - -import pandas as pd import mne -from mne_bids import BIDSPath +import pandas as pd -from ..._config_utils import ( +from mne_bids_pipeline._config_utils import ( + _do_mf_autobad, + _pl, get_mf_cal_fname, get_mf_ctc_fname, - get_subjects, - get_sessions, get_runs_tasks, - _do_mf_autobad, - _pl, + get_subjects_sessions, ) -from ..._import_data import ( - _get_run_rest_noise_path, - _get_mf_reference_run_path, - import_experimental_data, - import_er_data, +from mne_bids_pipeline._import_data import ( _bads_path, - _auto_scores_path, + _get_mf_reference_run_path, + _get_run_rest_noise_path, _import_data_kwargs, + _read_raw_msg, + import_er_data, + import_experimental_data, ) -from ..._io import _write_json -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, _add_raw -from ..._run import failsafe_run, save_logs, _prep_out_files -from ..._viz import plot_auto_scores +from mne_bids_pipeline._io import _write_json +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _add_raw, _open_report +from mne_bids_pipeline._run import _prep_out_files, failsafe_run, save_logs +from mne_bids_pipeline._viz import plot_auto_scores +from mne_bids_pipeline.typing import FloatArrayT, InFilesT, OutFilesT def get_input_fnames_data_quality( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], -) -> dict: + session: str | None, + run: str | None, + task: str | None, +) -> InFilesT: """Get paths of files required by assess_data_quality function.""" - kwargs = dict( - cfg=cfg, - subject=subject, - session=session, - add_bads=False, - ) - in_files = _get_run_rest_noise_path( + in_files: InFilesT = _get_run_rest_noise_path( run=run, task=task, kind="orig", mf_reference_run=cfg.mf_reference_run, - **kwargs, + cfg=cfg, + subject=subject, + session=session, + add_bads=False, ) # When doing autobad for the noise run, we also need the reference run if _do_mf_autobad(cfg=cfg) and run is None and task == "noise": - in_files.update(_get_mf_reference_run_path(**kwargs)) + in_files.update( + _get_mf_reference_run_path( + cfg=cfg, + subject=subject, + session=session, + add_bads=False, + ) + ) + + # set calibration and crosstalk files (if provided) + if _do_mf_autobad(cfg=cfg): + if cfg.mf_cal_fname is not None: + in_files["mf_cal_fname"] = cfg.mf_cal_fname + if cfg.mf_ctc_fname is not None: + in_files["mf_ctc_fname"] = cfg.mf_ctc_fname + return in_files @@ -70,38 +79,128 @@ def assess_data_quality( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + run: str | None, + task: str | None, + in_files: InFilesT, +) -> OutFilesT: """Assess data quality and find and mark bad channels.""" import matplotlib.pyplot as plt out_files = dict() key = f"raw_task-{task}_run-{run}" bids_path_in = in_files.pop(key) + if key == "raw_task-noise_run-None": + bids_path_ref_in = in_files.pop("raw_ref_run", None) + else: + bids_path_ref_in = None + msg, _ = _read_raw_msg(bids_path_in=bids_path_in, run=run, task=task) + logger.info(**gen_log_kwargs(message=msg)) + + if run is None and task == "noise": + raw = import_er_data( + cfg=cfg, + bids_path_er_in=bids_path_in, + bids_path_er_bads_in=None, + bids_path_ref_in=bids_path_ref_in, + bids_path_ref_bads_in=None, + prepare_maxwell_filter=True, + ) + else: + data_is_rest = run is None and task == "rest" + raw = import_experimental_data( + bids_path_in=bids_path_in, + bids_path_bads_in=None, + cfg=cfg, + data_is_rest=data_is_rest, + ) + preexisting_bads = sorted(raw.info["bads"]) + + auto_scores: dict[str, FloatArrayT] | None = None + auto_noisy_chs: list[str] | None = None + auto_flat_chs: list[str] | None = None if _do_mf_autobad(cfg=cfg): - if key == "raw_task-noise_run-None": - bids_path_ref_in = in_files.pop("raw_ref_run") - else: - bids_path_ref_in = None - auto_scores = _find_bads_maxwell( + # use calibration and crosstalk files (if provided) + cfg.mf_cal_fname = in_files.pop("mf_cal_fname", None) + cfg.mf_ctc_fname = in_files.pop("mf_ctc_fname", None) + + ( + auto_noisy_chs, + auto_flat_chs, + auto_scores, + ) = _find_bads_maxwell( cfg=cfg, exec_params=exec_params, - bids_path_in=bids_path_in, - bids_path_ref_in=bids_path_ref_in, + raw=raw, subject=subject, session=session, run=run, task=task, - out_files=out_files, ) - else: - auto_scores = None + bads = sorted(set(raw.info["bads"] + auto_noisy_chs + auto_flat_chs)) + msg = f"Found {len(bads)} bad channel{_pl(bads)}." + raw.info["bads"] = bads + del bads + logger.info(**gen_log_kwargs(message=msg)) del key + # Always output the scores and bads TSV + out_files["auto_scores"] = bids_path_in.copy().update( + suffix="scores", + extension=".json", + root=cfg.deriv_root, + split=None, + check=False, + session=session, + subject=subject, + ) + _write_json(out_files["auto_scores"], auto_scores) + + # Write the bad channels to disk. + out_files["bads_tsv"] = _bads_path( + cfg=cfg, + bids_path_in=bids_path_in, + subject=subject, + session=session, + ) + bads_for_tsv = [] + reasons = [] + + if auto_flat_chs: + for ch in auto_flat_chs: + reason = ( + "pre-existing (before MNE-BIDS-pipeline was run) & auto-flat" + if ch in preexisting_bads + else "auto-flat" + ) + bads_for_tsv.append(ch) + reasons.append(reason) + + if auto_noisy_chs: + for ch in auto_noisy_chs: + reason = ( + "pre-existing (before MNE-BIDS-pipeline was run) & auto-noisy" + if ch in preexisting_bads + else "auto-noisy" + ) + bads_for_tsv.append(ch) + reasons.append(reason) + + if preexisting_bads: + for ch in preexisting_bads: + if ch in bads_for_tsv: + continue + bads_for_tsv.append(ch) + reasons.append("pre-existing (before MNE-BIDS-pipeline was run)") + + tsv_data = pd.DataFrame(dict(name=bads_for_tsv, reason=reasons)) + tsv_data = tsv_data.sort_values(by="name") + tsv_data.to_csv(out_files["bads_tsv"], sep="\t", index=False) + # Report + # Restore bads to their original state so they will show up in the report + raw.info["bads"] = preexisting_bads + with _open_report( cfg=cfg, exec_params=exec_params, @@ -118,9 +217,11 @@ def assess_data_quality( cfg=cfg, report=report, bids_path_in=bids_path_in, + raw=raw, title=f"Raw ({kind})", tags=("data-quality",), ) + title = f"Bad channel detection: {run}" if cfg.find_noisy_channels_meg: assert auto_scores is not None msg = "Adding noisy channel detection to report" @@ -132,12 +233,14 @@ def assess_data_quality( fig=figs, caption=captions, section="Data quality", - title=f"Bad channel detection: {run}", + title=title, tags=tags, replace=True, ) for fig in figs: plt.close(fig) + else: + report.remove(title=title) assert len(in_files) == 0, in_files.keys() return _prep_out_files(exec_params=exec_params, out_files=out_files) @@ -147,45 +250,25 @@ def _find_bads_maxwell( *, cfg: SimpleNamespace, exec_params: SimpleNamespace, - bids_path_in: BIDSPath, - bids_path_ref_in: Optional[BIDSPath], + raw: mne.io.BaseRaw, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - out_files: dict, -): - if cfg.find_flat_channels_meg and not cfg.find_noisy_channels_meg: - msg = "Finding flat channels." - elif cfg.find_noisy_channels_meg and not cfg.find_flat_channels_meg: - msg = "Finding noisy channels using Maxwell filtering." + session: str | None, + run: str | None, + task: str | None, +) -> tuple[list[str], list[str], dict[str, FloatArrayT]]: + if cfg.find_flat_channels_meg: + if cfg.find_noisy_channels_meg: + msg = "Finding flat channels and noisy channels using Maxwell filtering." + else: + msg = "Finding flat channels." else: - msg = "Finding flat channels and noisy channels using " "Maxwell filtering." + assert cfg.find_noisy_channels_meg + msg = "Finding noisy channels using Maxwell filtering." logger.info(**gen_log_kwargs(message=msg)) - if run is None and task == "noise": - raw = import_er_data( - cfg=cfg, - bids_path_er_in=bids_path_in, - bids_path_er_bads_in=None, - bids_path_ref_in=bids_path_ref_in, - bids_path_ref_bads_in=None, - prepare_maxwell_filter=True, - ) - else: - data_is_rest = run is None and task == "rest" - raw = import_experimental_data( - bids_path_in=bids_path_in, - bids_path_bads_in=None, - cfg=cfg, - data_is_rest=data_is_rest, - ) - # Filter the data manually before passing it to find_bad_channels_maxwell() # This reduces memory usage, as we can control the number of jobs used # during filtering. - preexisting_bads = raw.info["bads"].copy() - bads = preexisting_bads.copy() raw_filt = raw.copy().filter(l_freq=None, h_freq=40, n_jobs=1) ( auto_noisy_chs, @@ -199,76 +282,34 @@ def _find_bads_maxwell( coord_frame="head", return_scores=True, h_freq=None, # we filtered manually above + **cfg.find_bad_channels_extra_kws, ) del raw_filt if cfg.find_flat_channels_meg: if auto_flat_chs: msg = ( - f"Found {len(auto_flat_chs)} flat channels: " - f'{", ".join(auto_flat_chs)}' + f"Found {len(auto_flat_chs)} flat channels: {', '.join(auto_flat_chs)}" ) else: msg = "Found no flat channels." logger.info(**gen_log_kwargs(message=msg)) - bads.extend(auto_flat_chs) + else: + auto_flat_chs = [] if cfg.find_noisy_channels_meg: if auto_noisy_chs: msg = ( f"Found {len(auto_noisy_chs)} noisy " f"channel{_pl(auto_noisy_chs)}: " - f'{", ".join(auto_noisy_chs)}' + f"{', '.join(auto_noisy_chs)}" ) else: msg = "Found no noisy channels." logger.info(**gen_log_kwargs(message=msg)) - bads.extend(auto_noisy_chs) - - bads = sorted(set(bads)) - msg = f"Found {len(bads)} channel{_pl(bads)} as bad." - raw.info["bads"] = bads - del bads - logger.info(**gen_log_kwargs(message=msg)) - - if cfg.find_noisy_channels_meg: - out_files["auto_scores"] = _auto_scores_path( - cfg=cfg, - bids_path_in=bids_path_in, - ) - if not out_files["auto_scores"].fpath.parent.exists(): - out_files["auto_scores"].fpath.parent.mkdir(parents=True) - _write_json(out_files["auto_scores"], auto_scores) - - # Write the bad channels to disk. - out_files["bads_tsv"] = _bads_path( - cfg=cfg, - bids_path_in=bids_path_in, - ) - bads_for_tsv = [] - reasons = [] - - if cfg.find_flat_channels_meg: - bads_for_tsv.extend(auto_flat_chs) - reasons.extend(["auto-flat"] * len(auto_flat_chs)) - preexisting_bads = set(preexisting_bads) - set(auto_flat_chs) - - if cfg.find_noisy_channels_meg: - bads_for_tsv.extend(auto_noisy_chs) - reasons.extend(["auto-noisy"] * len(auto_noisy_chs)) - preexisting_bads = set(preexisting_bads) - set(auto_noisy_chs) - - preexisting_bads = list(preexisting_bads) - if preexisting_bads: - bads_for_tsv.extend(preexisting_bads) - reasons.extend( - ["pre-existing (before MNE-BIDS-pipeline was run)"] * len(preexisting_bads) - ) - - tsv_data = pd.DataFrame(dict(name=bads_for_tsv, reason=reasons)) - tsv_data = tsv_data.sort_values(by="name") - tsv_data.to_csv(out_files["bads_tsv"], sep="\t", index=False) + else: + auto_noisy_chs = [] # Interaction if exec_params.interactive and cfg.find_noisy_channels_meg: @@ -277,17 +318,18 @@ def _find_bads_maxwell( plot_auto_scores(auto_scores, ch_types=cfg.ch_types) plt.show() - return auto_scores + return auto_noisy_chs, auto_flat_chs, auto_scores def get_config( *, config: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> SimpleNamespace: extra_kwargs = dict() if config.find_noisy_channels_meg or config.find_flat_channels_meg: + # If these change, need to update hooks.py in doc build extra_kwargs["mf_cal_fname"] = get_mf_cal_fname( config=config, subject=subject, @@ -304,6 +346,7 @@ def get_config( # detection # find_flat_channels_meg=config.find_flat_channels_meg, # find_noisy_channels_meg=config.find_noisy_channels_meg, + # find_bad_channels_extra_kws=config.find_bad_channels_extra_kws, **_import_data_kwargs(config=config, subject=subject), **extra_kwargs, ) @@ -325,8 +368,8 @@ def main(*, config: SimpleNamespace) -> None: run=run, task=task, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions for run, task in get_runs_tasks( config=config, subject=subject, diff --git a/mne_bids_pipeline/steps/preprocessing/_02_head_pos.py b/mne_bids_pipeline/steps/preprocessing/_02_head_pos.py index a75cd7339..6715e6f0d 100644 --- a/mne_bids_pipeline/steps/preprocessing/_02_head_pos.py +++ b/mne_bids_pipeline/steps/preprocessing/_02_head_pos.py @@ -1,34 +1,33 @@ """Estimate head positions.""" -from typing import Optional from types import SimpleNamespace import mne +from mne_bids import BIDSPath, find_matching_paths -from ..._config_utils import ( - get_subjects, - get_sessions, - get_runs_tasks, -) -from ..._import_data import ( - import_experimental_data, +from mne_bids_pipeline._config_utils import get_runs_tasks, get_subjects_sessions +from mne_bids_pipeline._import_data import ( + _get_bids_path_in, _get_run_rest_noise_path, _import_data_kwargs, + _path_dict, + import_experimental_data, ) -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report -from ..._run import failsafe_run, save_logs, _prep_out_files +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _open_report +from mne_bids_pipeline._run import _prep_out_files, failsafe_run, save_logs +from mne_bids_pipeline.typing import InFilesT, OutFilesT def get_input_fnames_head_pos( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], -) -> dict: + session: str | None, + run: str | None, + task: str | None, +) -> InFilesT: """Get paths of files required by run_head_pos function.""" return _get_run_rest_noise_path( cfg=cfg, @@ -49,11 +48,11 @@ def run_head_pos( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + run: str | None, + task: str | None, + in_files: InFilesT, +) -> OutFilesT: import matplotlib.pyplot as plt in_key = f"raw_task-{task}_run-{run}" @@ -144,11 +143,125 @@ def run_head_pos( return _prep_out_files(exec_params=exec_params, out_files=out_files) +def get_input_fnames_twa_head_pos( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + task: str | None, +) -> dict[str, BIDSPath | list[BIDSPath]]: + """Get paths of files required by compute_twa_head_pos function.""" + in_files: dict[str, BIDSPath] = dict() + # can't use `_get_run_path()` here because we don't loop over runs/tasks. + # But any run will do, as long as the file exists: + runs_tasks = get_runs_tasks( + config=cfg, subject=subject, session=session, which=("runs",) + ) + run = next(filter(lambda run_task: run_task[1] == task, runs_tasks))[0] + bids_path_in = _get_bids_path_in( + cfg=cfg, + subject=subject, + session=session, + run=run, + task=task, + kind="orig", + ) + in_files[f"raw_task-{task}"] = _path_dict( + cfg=cfg, + subject=subject, + session=session, + bids_path_in=bids_path_in, + add_bads=False, + allow_missing=False, + kind="orig", + )[f"raw_task-{task}_run-{run}"] + # ideally we'd do the path-finding for `all_runs_raw_bidspaths` and + # `all_runs_headpos_bidspaths` here, but we can't because MBP is strict about only + # returning paths, not lists of paths :( + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_twa_head_pos, +) +def compute_twa_head_pos( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | list[str] | None, + task: str | None, + in_files: InFilesT, +) -> OutFilesT: + """Compute time-weighted average head position.""" + # logging + want_mc = cfg.mf_mc + dest_is_twa = isinstance(cfg.mf_destination, str) and cfg.mf_destination == "twa" + msg = "Skipping computation of time-weighted average head position" + if not want_mc: + msg += " (no movement compensation requested)" + kwargs = dict(emoji="skip") + elif not dest_is_twa: + msg += ' (mf_destination is not "twa")' + kwargs = dict(emoji="skip") + else: + msg = "Computing time-weighted average head position" + kwargs = dict() + logger.info(**gen_log_kwargs(message=msg, **kwargs)) + # maybe bail early + if not want_mc and not dest_is_twa: + return _prep_out_files(exec_params=exec_params, out_files=dict()) + + # path to (subject+session)-level `destination.fif` in derivatives folder + bids_path_in = in_files.pop(f"raw_task-{task}") + dest_path = bids_path_in.copy().update( + check=False, + description="twa", + extension=".fif", + root=cfg.deriv_root, + run=None, + suffix="destination", + ) + # need raw files from all runs + all_runs_raw_bidspaths = find_matching_paths( + root=cfg.bids_root, + subjects=subject, + sessions=session, + tasks=task, + suffixes="meg", + ignore_json=True, + ignore_nosub=True, + check=True, + ) + raw_fnames = [bp.fpath for bp in all_runs_raw_bidspaths] + raws = [ + mne.io.read_raw_fif(fname, allow_maxshield=True, verbose="ERROR", preload=False) + for fname in raw_fnames + ] + # also need headpos files from all runs + all_runs_headpos_bidspaths = find_matching_paths( + root=cfg.deriv_root, + subjects=subject, + sessions=session, + tasks=task, + suffixes="headpos", + extensions=".txt", + check=False, + ) + head_poses = [mne.chpi.read_head_pos(bp.fpath) for bp in all_runs_headpos_bidspaths] + # compute time-weighted average head position and save it to disk + destination = mne.preprocessing.compute_average_dev_head_t(raws, head_poses) + mne.write_trans(fname=dest_path.fpath, trans=destination, overwrite=True) + # output + out_files = dict(destination_head_pos=dest_path) + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + def get_config( *, config: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> SimpleNamespace: cfg = SimpleNamespace( mf_mc_t_step_min=config.mf_mc_t_step_min, @@ -178,8 +291,8 @@ def main(*, config: SimpleNamespace) -> None: run=run, task=task, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions for run, task in get_runs_tasks( config=config, subject=subject, @@ -187,5 +300,21 @@ def main(*, config: SimpleNamespace) -> None: which=("runs", "rest"), ) ) + # compute time-weighted average head position + # within subject+session+task, across runs + parallel, run_func = parallel_func( + compute_twa_head_pos, exec_params=config.exec_params + ) + more_logs = parallel( + run_func( + cfg=config, + exec_params=config.exec_params, + subject=subject, + session=session, + task=config.task or None, # default task is "" + ) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions + ) - save_logs(config=config, logs=logs) + save_logs(config=config, logs=logs + more_logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_03_maxfilter.py b/mne_bids_pipeline/steps/preprocessing/_03_maxfilter.py index 099336c5c..cfb7a759f 100644 --- a/mne_bids_pipeline/steps/preprocessing/_03_maxfilter.py +++ b/mne_bids_pipeline/steps/preprocessing/_03_maxfilter.py @@ -14,35 +14,39 @@ The function loads machine-specific calibration files. """ -from copy import deepcopy import gc -from typing import Optional +from copy import deepcopy from types import SimpleNamespace -import numpy as np import mne +import numpy as np from mne_bids import read_raw_bids -from ..._config_utils import ( +from mne_bids_pipeline._config_utils import ( + _pl, get_mf_cal_fname, get_mf_ctc_fname, - get_subjects, - get_sessions, get_runs_tasks, - _pl, + get_subjects_sessions, ) -from ..._import_data import ( - import_experimental_data, - import_er_data, +from mne_bids_pipeline._import_data import ( + _get_mf_reference_run_path, _get_run_path, _get_run_rest_noise_path, - _get_mf_reference_run_path, _import_data_kwargs, + import_er_data, + import_experimental_data, +) +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _add_raw, _open_report +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, ) -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, _add_raw -from ..._run import failsafe_run, save_logs, _update_for_splits, _prep_out_files +from mne_bids_pipeline.typing import InFilesT, OutFilesT # %% eSSS @@ -50,21 +54,24 @@ def get_input_fnames_esss( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: - kwargs = dict( - cfg=cfg, - subject=subject, - session=session, - ) + session: str | None, +) -> InFilesT: in_files = _get_run_rest_noise_path( run=None, task="noise", kind="orig", mf_reference_run=cfg.mf_reference_run, - **kwargs, + cfg=cfg, + subject=subject, + session=session, + ) + in_files.update( + _get_mf_reference_run_path( + cfg=cfg, + subject=subject, + session=session, + ) ) - in_files.update(_get_mf_reference_run_path(add_bads=True, **kwargs)) return in_files @@ -76,9 +83,9 @@ def compute_esss_proj( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesT, +) -> OutFilesT: import matplotlib.pyplot as plt run, task = None, "noise" @@ -182,22 +189,19 @@ def get_input_fnames_maxwell_filter( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], -) -> dict: + session: str | None, + run: str | None, + task: str | None, +) -> InFilesT: """Get paths of files required by maxwell_filter function.""" - kwargs = dict( - cfg=cfg, - subject=subject, - session=session, - ) in_files = _get_run_rest_noise_path( run=run, task=task, kind="orig", mf_reference_run=cfg.mf_reference_run, - **kwargs, + cfg=cfg, + subject=subject, + session=session, ) in_key = f"raw_task-{task}_run-{run}" assert in_key in in_files @@ -212,9 +216,11 @@ def get_input_fnames_maxwell_filter( task=pos_task, add_bads=False, kind="orig", - **kwargs, + cfg=cfg, + subject=subject, + session=session, )[f"raw_task-{pos_task}_run-{pos_run}"] - in_files[f"{in_key}-pos"] = path.update( + in_files[f"{in_key}-pos"] = path.copy().update( suffix="headpos", extension=".txt", root=cfg.deriv_root, @@ -222,7 +228,16 @@ def get_input_fnames_maxwell_filter( task=pos_task, run=pos_run, ) - + if isinstance(cfg.mf_destination, str) and cfg.mf_destination == "twa": + in_files[f"{in_key}-twa"] = path.update( + description="twa", + suffix="destination", + extension=".fif", + root=cfg.deriv_root, + check=False, + task=pos_task, + run=None, + ) if cfg.mf_esss: in_files["esss_basis"] = ( in_files[in_key] @@ -241,7 +256,14 @@ def get_input_fnames_maxwell_filter( ) # reference run (used for `destination` and also bad channels for noise) - in_files.update(_get_mf_reference_run_path(add_bads=True, **kwargs)) + # use add_bads=None here to mean "add if autobad is turned on" + in_files.update( + _get_mf_reference_run_path( + cfg=cfg, + subject=subject, + session=session, + ) + ) is_rest_noise = run is None and task in ("noise", "rest") if is_rest_noise: @@ -259,9 +281,12 @@ def get_input_fnames_maxwell_filter( ) _update_for_splits(in_files, key, single=True) - # standard files - in_files["mf_cal_fname"] = cfg.mf_cal_fname - in_files["mf_ctc_fname"] = cfg.mf_ctc_fname + # set calibration and crosstalk files (if provided) + if cfg.mf_cal_fname is not None: + in_files["mf_cal_fname"] = cfg.mf_cal_fname + if cfg.mf_ctc_fname is not None: + in_files["mf_ctc_fname"] = cfg.mf_ctc_fname + return in_files @@ -273,11 +298,11 @@ def run_maxwell_filter( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - run: Optional[str], - task: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + run: str | None, + task: str | None, + in_files: InFilesT, +) -> OutFilesT: if cfg.proc and "sss" in cfg.proc and cfg.use_maxwell_filter: raise ValueError( f"You cannot set use_maxwell_filter to True " @@ -286,15 +311,16 @@ def run_maxwell_filter( ) if isinstance(cfg.mf_destination, str): destination = cfg.mf_destination - assert destination == "reference_run" + assert destination in ("reference_run", "twa") else: - destination = np.array(cfg.mf_destination, float) - assert destination.shape == (4, 4) - destination = mne.transforms.Transform("meg", "head", destination) + destination_array = np.array(cfg.mf_destination, float) + assert destination_array.shape == (4, 4) + destination = mne.transforms.Transform("meg", "head", destination_array) filter_chpi = cfg.mf_mc if cfg.mf_filter_chpi is None else cfg.mf_filter_chpi is_rest_noise = run is None and task in ("noise", "rest") if is_rest_noise: + assert task is not None nice_names = dict(rest="resting-state", noise="empty-room") recording_type = nice_names[task] else: @@ -310,6 +336,7 @@ def run_maxwell_filter( extension=".fif", root=cfg.deriv_root, check=False, + split=None, ) bids_path_out = bids_path_in.copy().update(**bids_path_out_kwargs) @@ -325,14 +352,17 @@ def run_maxwell_filter( verbose=cfg.read_raw_bids_verbose, ) bids_path_ref_bads_in = in_files.pop("raw_ref_run-bads", None) + # triage string-valued destinations if isinstance(destination, str): - assert destination == "reference_run" - destination = raw.info["dev_head_t"] + if destination == "reference_run": + destination = raw.info["dev_head_t"] + elif destination == "twa": + destination = mne.read_trans(in_files.pop(f"{in_key}-twa")) del raw assert isinstance(destination, mne.transforms.Transform), destination # Maxwell-filter experimental data. - apply_msg = "Applying " + apply_msg = "Preparing to apply " extra = list() if cfg.mf_st_duration: apply_msg += f"tSSS ({cfg.mf_st_duration} sec, corr={cfg.mf_st_correlation})" @@ -353,8 +383,8 @@ def run_maxwell_filter( apply_msg += " to" mf_kws = dict( - calibration=in_files.pop("mf_cal_fname"), - cross_talk=in_files.pop("mf_ctc_fname"), + calibration=in_files.pop("mf_cal_fname", None), + cross_talk=in_files.pop("mf_ctc_fname", None), st_duration=cfg.mf_st_duration, st_correlation=cfg.mf_st_correlation, origin=cfg.mf_head_origin, @@ -362,7 +392,12 @@ def run_maxwell_filter( destination=destination, head_pos=head_pos, extended_proj=extended_proj, + int_order=cfg.mf_int_order, + ext_order=cfg.mf_ext_order, ) + # If the mf_kws keys above change, we need to modify our list + # of illegal keys in _config_import.py + mf_kws |= cfg.mf_extra_kws logger.info(**gen_log_kwargs(message=f"{apply_msg} {recording_type} data")) if not (run is None and task == "noise"): @@ -411,6 +446,21 @@ def run_maxwell_filter( ) logger.warning(**gen_log_kwargs(message=msg)) + if filter_chpi: + logger.info(**gen_log_kwargs(message="Filtering cHPI")) + mne.chpi.filter_chpi( + raw, + t_window=cfg.mf_mc_t_window, + allow_line_only=(task == "noise"), + ) + + msg = ( + "Maxwell Filtering" + f" (internal order: {mf_kws['int_order']}," + f" external order: {mf_kws['ext_order']})" + ) + logger.info(**gen_log_kwargs(message=msg)) + raw_sss = mne.preprocessing.maxwell_filter(raw, **mf_kws) del raw gc.collect() @@ -426,8 +476,15 @@ def run_maxwell_filter( bids_path_ref_sss = in_files.pop("raw_ref_run_sss") raw_exp = mne.io.read_raw_fif(bids_path_ref_sss) - rank_exp = mne.compute_rank(raw_exp, rank="info")["meg"] - rank_noise = mne.compute_rank(raw_sss, rank="info")["meg"] + if "grad" in raw_exp: + if "mag" in raw_exp: + type_sel = "meg" + else: + type_sel = "grad" + else: + type_sel = "mag" + rank_exp = mne.compute_rank(raw_exp, rank="info")[type_sel] + rank_noise = mne.compute_rank(raw_sss, rank="info")[type_sel] del raw_exp if task == "rest": @@ -450,13 +507,8 @@ def run_maxwell_filter( ) raise RuntimeError(msg) - if filter_chpi: - logger.info(**gen_log_kwargs(message="Filtering cHPI")) - mne.chpi.filter_chpi( - raw_sss, - t_window=cfg.mf_mc_t_window, - ) - + movement_annot: mne.Annotations | None = None + extra_html: str | None = None if cfg.mf_mc and ( cfg.mf_mc_rotation_velocity_limit is not None or cfg.mf_mc_translation_velocity_limit is not None @@ -468,7 +520,7 @@ def run_maxwell_filter( translation_velocity_limit=cfg.mf_mc_translation_velocity_limit, ) perc_time = 100 / raw_sss.times[-1] - extra_html = list() + extra_html_list = list() for kind, unit in (("translation", "m"), ("rotation", "°")): limit = getattr(cfg, f"mf_mc_{kind}_velocity_limit") if limit is None: @@ -484,14 +536,12 @@ def run_maxwell_filter( f"limit for {tot_time:0.1f} s ({perc:0.1f}%)" ) logger_meth(**gen_log_kwargs(message=msg)) - extra_html.append(f"
  • {msg}
  • ") + extra_html_list.append(f"
  • {msg}
  • ") extra_html = ( "

    The raw data were annotated with the following movement-related bad " - f"segment annotations:

    " + f"segment annotations:

    " ) raw_sss.set_annotations(raw_sss.annotations + movement_annot) - else: - movement_annot = extra_html = None out_files["sss_raw"] = bids_path_out msg = f"Writing {out_files['sss_raw'].fpath.relative_to(cfg.deriv_root)}" @@ -537,7 +587,7 @@ def get_config_esss( *, config: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> SimpleNamespace: cfg = SimpleNamespace( mf_esss=config.mf_esss, @@ -551,7 +601,7 @@ def get_config_maxwell_filter( *, config: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, ) -> SimpleNamespace: cfg = SimpleNamespace( mf_cal_fname=get_mf_cal_fname( @@ -571,10 +621,12 @@ def get_config_maxwell_filter( mf_filter_chpi=config.mf_filter_chpi, mf_destination=config.mf_destination, mf_int_order=config.mf_int_order, + mf_ext_order=config.mf_ext_order, mf_mc_t_window=config.mf_mc_t_window, mf_mc_rotation_velocity_limit=config.mf_mc_rotation_velocity_limit, mf_mc_translation_velocity_limit=config.mf_mc_translation_velocity_limit, mf_esss=config.mf_esss, + mf_extra_kws=config.mf_extra_kws, **_import_data_kwargs(config=config, subject=subject), ) return cfg @@ -605,8 +657,8 @@ def main(*, config: SimpleNamespace) -> None: subject=subject, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions ) # Second: maxwell_filter @@ -629,8 +681,8 @@ def main(*, config: SimpleNamespace) -> None: run=run, task=task, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions for run, task in get_runs_tasks( config=config, subject=subject, diff --git a/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py b/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py index b60543121..1395e071b 100644 --- a/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py +++ b/mne_bids_pipeline/steps/preprocessing/_04_frequency_filter.py @@ -14,39 +14,46 @@ If config.interactive = True plots raw data and power spectral density. """ # noqa: E501 -import numpy as np +from collections.abc import Iterable from types import SimpleNamespace -from typing import Optional, Union, Literal, Iterable +from typing import Any, Literal import mne +import numpy as np +from meegkit import dss +from mne.io.pick import _picks_to_idx +from mne.preprocessing import EOGRegression -from ..._config_utils import ( - get_sessions, - get_runs_tasks, - get_subjects, -) -from ..._import_data import ( - import_experimental_data, - import_er_data, +from mne_bids_pipeline._config_utils import get_runs_tasks, get_subjects_sessions +from mne_bids_pipeline._import_data import ( _get_run_rest_noise_path, _import_data_kwargs, + _read_raw_msg, + import_er_data, + import_experimental_data, +) +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _add_raw, _open_report +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, ) -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, _add_raw -from ..._run import failsafe_run, save_logs, _update_for_splits, _prep_out_files +from mne_bids_pipeline.typing import InFilesT, IntArrayT, OutFilesT, RunKindT, RunTypeT def get_input_fnames_frequency_filter( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, run: str, - task: Optional[str], -) -> dict: + task: str | None, +) -> InFilesT: """Get paths of files required by filter_data function.""" - kind = "sss" if cfg.use_maxwell_filter else "orig" + kind: RunKindT = "sss" if cfg.use_maxwell_filter else "orig" return _get_run_rest_noise_path( cfg=cfg, subject=subject, @@ -58,26 +65,53 @@ def get_input_fnames_frequency_filter( ) +def zapline( + raw: mne.io.BaseRaw, + subject: str, + session: str | None, + run: str, + task: str | None, + fline: float | None, + iter_: bool, +) -> None: + """Use Zapline to remove line frequencies.""" + if fline is None: + return + + msg = f"Zapline filtering data at with {fline=} Hz." + logger.info(**gen_log_kwargs(message=msg)) + sfreq = raw.info["sfreq"] + picks = mne.pick_types(raw.info, meg=True, eeg=True) + data = raw.get_data(picks).T # transpose to (n_samples, n_channels) + func = dss.dss_line_iter if iter_ else dss.dss_line + out, _ = func(data, fline, sfreq) + raw._data[picks] = out.T + + def notch_filter( raw: mne.io.BaseRaw, subject: str, - session: Optional[str], + session: str | None, run: str, - task: Optional[str], - freqs: Optional[Union[float, Iterable[float]]], - trans_bandwidth: Union[float, Literal["auto"]], - notch_widths: Optional[Union[float, Iterable[float]]], - run_type: Literal["experimental", "empty-room", "resting-state"], + task: str | None, + freqs: float | Iterable[float] | None, + trans_bandwidth: float | Literal["auto"], + notch_widths: float | Iterable[float] | None, + run_type: RunTypeT, + picks: IntArrayT | None, + notch_extra_kws: dict[str, Any], ) -> None: """Filter data channels (MEG and EEG).""" - if freqs is None: + if freqs is None and (notch_extra_kws.get("method") != "spectrum_fit"): msg = f"Not applying notch filter to {run_type} data." + elif notch_extra_kws.get("method") == "spectrum_fit": + msg = f"Applying notch filter to {run_type} data with spectrum fitting." else: msg = f"Notch filtering {run_type} data at {freqs} Hz." logger.info(**gen_log_kwargs(message=msg)) - if freqs is None: + if (freqs is None) and (notch_extra_kws.get("method") != "spectrum_fit"): return raw.notch_filter( @@ -85,28 +119,32 @@ def notch_filter( trans_bandwidth=trans_bandwidth, notch_widths=notch_widths, n_jobs=1, + picks=picks, + **notch_extra_kws, ) def bandpass_filter( raw: mne.io.BaseRaw, subject: str, - session: Optional[str], + session: str | None, run: str, - task: Optional[str], - l_freq: Optional[float], - h_freq: Optional[float], - l_trans_bandwidth: Union[float, Literal["auto"]], - h_trans_bandwidth: Union[float, Literal["auto"]], - run_type: Literal["experimental", "empty-room", "resting-state"], + task: str | None, + l_freq: float | None, + h_freq: float | None, + l_trans_bandwidth: float | Literal["auto"], + h_trans_bandwidth: float | Literal["auto"], + run_type: RunTypeT, + picks: IntArrayT | None, + bandpass_extra_kws: dict[str, Any], ) -> None: """Filter data channels (MEG and EEG).""" if l_freq is not None and h_freq is None: - msg = f"High-pass filtering {run_type} data; lower bound: " f"{l_freq} Hz" + msg = f"High-pass filtering {run_type} data; lower bound: {l_freq} Hz" elif l_freq is None and h_freq is not None: - msg = f"Low-pass filtering {run_type} data; upper bound: " f"{h_freq} Hz" + msg = f"Low-pass filtering {run_type} data; upper bound: {h_freq} Hz" elif l_freq is not None and h_freq is not None: - msg = f"Band-pass filtering {run_type} data; range: " f"{l_freq} – {h_freq} Hz" + msg = f"Band-pass filtering {run_type} data; range: {l_freq} – {h_freq} Hz" else: msg = f"Not applying frequency filtering to {run_type} data." @@ -121,17 +159,19 @@ def bandpass_filter( l_trans_bandwidth=l_trans_bandwidth, h_trans_bandwidth=h_trans_bandwidth, n_jobs=1, + picks=picks, + **bandpass_extra_kws, ) def resample( raw: mne.io.BaseRaw, subject: str, - session: Optional[str], + session: str | None, run: str, - task: Optional[str], + task: str | None, sfreq: float, - run_type: Literal["experimental", "empty-room", "resting-state"], + run_type: RunTypeT, ) -> None: if not sfreq: return @@ -149,31 +189,25 @@ def filter_data( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, run: str, - task: Optional[str], - in_files: dict, -) -> dict: + task: str | None, + in_files: InFilesT, +) -> OutFilesT: """Filter data from a single subject.""" out_files = dict() in_key = f"raw_task-{task}_run-{run}" bids_path_in = in_files.pop(in_key) bids_path_bads_in = in_files.pop(f"{in_key}-bads", None) - - if run is None and task in ("noise", "rest"): - run_type = dict(rest="resting-state", noise="empty-room")[task] - else: - run_type = "experimental" - + msg, run_type = _read_raw_msg(bids_path_in=bids_path_in, run=run, task=task) + logger.info(**gen_log_kwargs(message=msg)) if cfg.use_maxwell_filter: - msg = f"Reading {run_type} recording: " f"{bids_path_in.basename}" - logger.info(**gen_log_kwargs(message=msg)) raw = mne.io.read_raw_fif(bids_path_in) elif run is None and task == "noise": raw = import_er_data( cfg=cfg, bids_path_er_in=bids_path_in, - bids_path_ref_in=in_files.pop("raw_ref_run"), + bids_path_ref_in=in_files.pop("raw_ref_run", None), bids_path_er_bads_in=bids_path_bads_in, # take bads from this run (0) bids_path_ref_bads_in=in_files.pop("raw_ref_run-bads", None), @@ -190,15 +224,39 @@ def filter_data( out_files[in_key] = bids_path_in.copy().update( root=cfg.deriv_root, + subject=subject, # save under subject's directory so all files are there + session=session, processing="filt", extension=".fif", suffix="raw", split=None, task=task, run=run, + check=False, ) + if cfg.regress_artifact is None: + picks = None + else: + # Need to figure out the correct picks to use + model = EOGRegression(**cfg.regress_artifact) + picks_regress = _picks_to_idx( + raw.info, model.picks, none="data", exclude=model.exclude + ) + picks_artifact = _picks_to_idx(raw.info, model.picks_artifact) + picks_data = _picks_to_idx(raw.info, "data", exclude=()) # raw.filter default + picks = np.unique(np.r_[picks_regress, picks_artifact, picks_data]) + raw.load_data() + zapline( + raw=raw, + subject=subject, + session=session, + run=run, + task=task, + fline=cfg.zapline_fline, + iter_=cfg.zapline_iter, + ) notch_filter( raw=raw, subject=subject, @@ -209,6 +267,8 @@ def filter_data( trans_bandwidth=cfg.notch_trans_bandwidth, notch_widths=cfg.notch_widths, run_type=run_type, + picks=picks, + notch_extra_kws=cfg.notch_extra_kws, ) bandpass_filter( raw=raw, @@ -221,6 +281,8 @@ def filter_data( h_trans_bandwidth=cfg.h_trans_bandwidth, l_trans_bandwidth=cfg.l_trans_bandwidth, run_type=run_type, + picks=picks, + bandpass_extra_kws=cfg.bandpass_extra_kws, ) resample( raw=raw, @@ -232,6 +294,9 @@ def filter_data( run_type=run_type, ) + # For example, might need to create + # derivatives/mne-bids-pipeline/sub-emptyroom/ses-20230412/meg + out_files[in_key].fpath.parent.mkdir(exist_ok=True, parents=True) raw.save( out_files[in_key], overwrite=True, @@ -277,11 +342,16 @@ def get_config( l_freq=config.l_freq, h_freq=config.h_freq, notch_freq=config.notch_freq, + zapline_fline=config.zapline_fline, + zapline_iter=config.zapline_iter, l_trans_bandwidth=config.l_trans_bandwidth, h_trans_bandwidth=config.h_trans_bandwidth, notch_trans_bandwidth=config.notch_trans_bandwidth, notch_widths=config.notch_widths, raw_resample_sfreq=config.raw_resample_sfreq, + regress_artifact=config.regress_artifact, + notch_extra_kws=config.notch_extra_kws, + bandpass_extra_kws=config.bandpass_extra_kws, **_import_data_kwargs(config=config, subject=subject), ) return cfg @@ -304,8 +374,8 @@ def main(*, config: SimpleNamespace) -> None: run=run, task=task, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions for run, task in get_runs_tasks( config=config, subject=subject, diff --git a/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py b/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py new file mode 100644 index 000000000..f8a9c0128 --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_05_regress_artifact.py @@ -0,0 +1,172 @@ +"""Temporal regression for artifact removal.""" + +from types import SimpleNamespace + +import mne +from mne.io.pick import _picks_to_idx +from mne.preprocessing import EOGRegression + +from mne_bids_pipeline._config_utils import get_runs_tasks, get_subjects_sessions +from mne_bids_pipeline._import_data import ( + _get_run_rest_noise_path, + _import_data_kwargs, + _read_raw_msg, +) +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _add_raw, _open_report +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, +) +from mne_bids_pipeline.typing import InFilesT, OutFilesT + + +def get_input_fnames_regress_artifact( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, +) -> InFilesT: + """Get paths of files required by regress_artifact function.""" + out = _get_run_rest_noise_path( + cfg=cfg, + subject=subject, + session=session, + run=run, + task=task, + kind="filt", + mf_reference_run=cfg.mf_reference_run, + ) + assert len(out) + return out + + +@failsafe_run( + get_input_fnames=get_input_fnames_regress_artifact, +) +def run_regress_artifact( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, + in_files: InFilesT, +) -> OutFilesT: + model = EOGRegression(proj=False, **cfg.regress_artifact) + out_files = dict() + in_key = f"raw_task-{task}_run-{run}" + bids_path_in = in_files.pop(in_key) + out_files[in_key] = bids_path_in.copy().update(processing="regress") + msg, _ = _read_raw_msg(bids_path_in=bids_path_in, run=run, task=task) + logger.info(**gen_log_kwargs(message=msg)) + raw = mne.io.read_raw_fif(bids_path_in).load_data() + projs = raw.info["projs"] + raw.del_proj() + model.fit(raw) + all_types = raw.get_channel_types() + picks = _picks_to_idx(raw.info, model.picks, none="data", exclude=model.exclude) + ch_types = set(all_types[pick] for pick in picks) + del picks + out_files["regress"] = bids_path_in.copy().update( + processing=None, + split=None, + suffix="regress", + extension=".h5", + ) + model.apply(raw, copy=False) + if projs: + raw.add_proj(projs) + raw.save(out_files[in_key], overwrite=True, split_size=cfg._raw_split_size) + _update_for_splits(out_files, in_key) + model.save(out_files["regress"], overwrite=True) + assert len(in_files) == 0, in_files.keys() + + # Report + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) as report: + msg = "Adding regressed raw data to report" + logger.info(**gen_log_kwargs(message=msg)) + figs, captions = list(), list() + for kind in ("mag", "grad", "eeg"): + if kind not in ch_types: + continue + figs.append(model.plot(ch_type=kind)) + captions.append(f"Run {run}: {kind}") + if figs: + report.add_figure( + fig=figs, + caption=captions, + title="Regression weights", + tags=("raw", f"run-{run}", "regression"), + replace=True, + ) + _add_raw( + cfg=cfg, + report=report, + bids_path_in=out_files[in_key], + title="Raw (regression)", + tags=("regression",), + raw=raw, + ) + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, +) -> SimpleNamespace: + cfg = SimpleNamespace( + regress_artifact=config.regress_artifact, + **_import_data_kwargs(config=config, subject=subject), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Run artifact regression.""" + if config.regress_artifact is None: + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + parallel, run_func = parallel_func( + run_regress_artifact, exec_params=config.exec_params + ) + + logs = parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions + for run, task in get_runs_tasks( + config=config, + subject=subject, + session=session, + ) + ) + + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py new file mode 100644 index 000000000..18135f66e --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py @@ -0,0 +1,387 @@ +"""Fit ICA. + +This fits Independent Component Analysis (ICA) on high-pass filtered raw data, +temporarily creating task-related epochs. The epochs created here are used for +the purpose of fitting ICA only, and will not enter any other processing steps. + +Before performing ICA, we reject epochs based on peak-to-peak amplitude above +the 'ica_reject' limits to remove high-amplitude non-biological artifacts +(e.g., voltage or flux spikes). +""" + +from types import SimpleNamespace + +import autoreject +import mne +import numpy as np +from mne.preprocessing import ICA +from mne_bids import BIDSPath + +from mne_bids_pipeline._config_utils import ( + _bids_kwargs, + get_eeg_reference, + get_runs, + get_subjects_sessions, +) +from mne_bids_pipeline._import_data import annotations_to_events, make_epochs +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._reject import _get_reject +from mne_bids_pipeline._report import _open_report +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, +) +from mne_bids_pipeline.typing import InFilesT, OutFilesT + + +def get_input_fnames_run_ica( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> InFilesT: + bids_basename = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False, + extension=".fif", + ) + in_files = dict() + for run in cfg.runs: + key = f"raw_run-{run}" + in_files[key] = bids_basename.copy().update( + run=run, processing=cfg.processing, suffix="raw" + ) + _update_for_splits(in_files, key, single=True) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_run_ica, +) +def run_ica( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: InFilesT, +) -> OutFilesT: + """Run ICA.""" + import matplotlib.pyplot as plt + + raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] + out_files = dict() + bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) + out_files["ica"] = bids_basename.copy().update(processing="icafit", suffix="ica") + out_files["epochs"] = ( + out_files["ica"].copy().update(suffix="epo", processing="icafit") + ) + del bids_basename + + # Generate a list of raw data paths (i.e., paths of individual runs) + # we want to create epochs from. + + # Generate a unique event name -> event code mapping that can be used + # across all runs. + event_name_to_code_map = annotations_to_events(raw_paths=raw_fnames) + + epochs = None + for idx, (run, raw_fname) in enumerate(zip(cfg.runs, raw_fnames)): + msg = f"Processing raw data from {raw_fname.basename}" + logger.info(**gen_log_kwargs(message=msg)) + raw = mne.io.read_raw_fif(raw_fname, preload=True) + + # Produce high-pass filtered version of the data for ICA. + # Sanity check – make sure we're using the correct data! + if cfg.raw_resample_sfreq is not None: + assert np.allclose(raw.info["sfreq"], cfg.raw_resample_sfreq) + if cfg.l_freq is not None: + assert np.allclose(raw.info["highpass"], cfg.l_freq) + + if idx == 0: + if cfg.ica_l_freq is None: + msg = ( + f"Not applying high-pass filter (data is already filtered, " + f"cutoff: {raw.info['highpass']} Hz)." + ) + logger.info(**gen_log_kwargs(message=msg)) + else: + msg = f"Applying high-pass filter with {cfg.ica_l_freq} Hz cutoff …" + logger.info(**gen_log_kwargs(message=msg)) + raw.filter(l_freq=cfg.ica_l_freq, h_freq=None, n_jobs=1) + + # Only keep the subset of the mapping that applies to the current run + event_id = event_name_to_code_map.copy() + for event_name in event_id.copy().keys(): + if event_name not in raw.annotations.description: + del event_id[event_name] + + if idx == 0: + msg = "Creating task-related epochs …" + logger.info(**gen_log_kwargs(message=msg)) + these_epochs = make_epochs( + subject=subject, + session=session, + task=cfg.task, + conditions=cfg.conditions, + raw=raw, + event_id=event_id, + tmin=cfg.epochs_tmin, + tmax=cfg.epochs_tmax, + custom_metadata=cfg.epochs_custom_metadata, + metadata_tmin=cfg.epochs_metadata_tmin, + metadata_tmax=cfg.epochs_metadata_tmax, + metadata_keep_first=cfg.epochs_metadata_keep_first, + metadata_keep_last=cfg.epochs_metadata_keep_last, + metadata_query=cfg.epochs_metadata_query, + event_repeated=cfg.event_repeated, + epochs_decim=cfg.epochs_decim, + task_is_rest=cfg.task_is_rest, + rest_epochs_duration=cfg.rest_epochs_duration, + rest_epochs_overlap=cfg.rest_epochs_overlap, + ) + + these_epochs.load_data() # Remove reference to raw + del raw # free memory + + if epochs is None: + epochs = these_epochs + else: + epochs = mne.concatenate_epochs([epochs, these_epochs], on_mismatch="warn") + + del these_epochs + del run + assert epochs is not None + + # Set an EEG reference + if "eeg" in cfg.ch_types: + projection = True if cfg.eeg_reference == "average" else False + epochs.set_eeg_reference(cfg.eeg_reference, projection=projection) + + ar_reject_log = ar_n_interpolate_ = None + if cfg.ica_reject == "autoreject_local": + msg = ( + "Using autoreject to find bad epochs for ICA " + "(no interpolation will be performend)" + ) + logger.info(**gen_log_kwargs(message=msg)) + ar = autoreject.AutoReject( + n_interpolate=cfg.autoreject_n_interpolate, + random_state=cfg.random_state, + n_jobs=exec_params.n_jobs, + verbose=False, + ) + ar.fit(epochs) + ar_reject_log = ar.get_reject_log(epochs) + epochs = epochs[~ar_reject_log.bad_epochs] + + n_epochs_before_reject = len(epochs) + n_epochs_rejected = ar_reject_log.bad_epochs.sum() + n_epochs_after_reject = n_epochs_before_reject - n_epochs_rejected + + ar_n_interpolate_ = ar.n_interpolate_ + msg = ( + f"autoreject marked {n_epochs_rejected} epochs as bad " + f"(cross-validated n_interpolate limit: {ar_n_interpolate_})" + ) + logger.info(**gen_log_kwargs(message=msg)) + del ar + else: + # Reject epochs based on peak-to-peak rejection thresholds + ica_reject = _get_reject( + subject=subject, + session=session, + reject=cfg.ica_reject, + ch_types=cfg.ch_types, + param="ica_reject", + ) + n_epochs_before_reject = len(epochs) + epochs.drop_bad(reject=ica_reject) + n_epochs_after_reject = len(epochs) + n_epochs_rejected = n_epochs_before_reject - n_epochs_after_reject + + msg = ( + f"Removed {n_epochs_rejected} of {n_epochs_before_reject} epochs via PTP " + f"rejection thresholds: {ica_reject}" + ) + logger.info(**gen_log_kwargs(message=msg)) + ar = None + + if 0 < n_epochs_after_reject < 0.5 * n_epochs_before_reject: + msg = ( + "More than 50% of all epochs rejected. Please check the " + "rejection thresholds." + ) + logger.warning(**gen_log_kwargs(message=msg)) + elif n_epochs_after_reject == 0: + rejection_type = ( + cfg.ica_reject if cfg.ica_reject == "autoreject_local" else "PTP-based" + ) + raise RuntimeError( + f"No epochs remaining after {rejection_type} rejection. Cannot continue." + ) + + msg = f"Saving {n_epochs_after_reject} ICA epochs to disk." + logger.info(**gen_log_kwargs(message=msg)) + epochs.save( + out_files["epochs"], + overwrite=True, + split_naming="bids", + split_size=cfg._epochs_split_size, + ) + _update_for_splits(out_files, "epochs") + + msg = f"Calculating ICA solution using method: {cfg.ica_algorithm}." + logger.info(**gen_log_kwargs(message=msg)) + + algorithm = cfg.ica_algorithm + fit_params = None + + if algorithm == "picard": + fit_params = dict(fastica_it=5) + elif algorithm == "picard-extended_infomax": + algorithm = "picard" + fit_params = dict(ortho=False, extended=True) + elif algorithm == "extended_infomax": + algorithm = "infomax" + fit_params = dict(extended=True) + + ica = ICA( + method=algorithm, + random_state=cfg.random_state, + n_components=cfg.ica_n_components, + fit_params=fit_params, + max_iter=cfg.ica_max_iterations, + ) + ica.fit(epochs, decim=cfg.ica_decim) + explained_var = ( + ica.pca_explained_variance_[: ica.n_components_].sum() + / ica.pca_explained_variance_.sum() + ) + msg = ( + f"Fit {ica.n_components_} components (explaining " + f"{round(explained_var * 100, 1)}% of the variance) in " + f"{ica.n_iter_} iterations." + ) + logger.info(**gen_log_kwargs(message=msg)) + msg = "Saving ICA solution to disk." + logger.info(**gen_log_kwargs(message=msg)) + ica.save(out_files["ica"], overwrite=True) + + # Add to report + tags = ("ica", "epochs") + title = "ICA: epochs for fitting" + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + task=cfg.task, + ) as report: + report.add_epochs( + epochs=epochs, + title=title, + drop_log_ignore=(), + replace=True, + tags=tags, + ) + if cfg.ica_reject == "autoreject_local": + assert ar_reject_log is not None + caption = ( + f"Autoreject was run to produce cleaner epochs before fitting ICA. " + f"{ar_reject_log.bad_epochs.sum()} epochs were rejected because more " + f"than {ar_n_interpolate_} channels were bad (cross-validated " + f"n_interpolate limit; excluding globally bad and non-data channels, " + f"shown in white). Note that none of the blue segments were actually " + f"interpolated before submitting the data to ICA. This is following " + f"the recommended approach for ICA described in the the Autoreject " + f"documentation." + ) + fig = ar_reject_log.plot( + orientation="horizontal", aspect="auto", show=False + ) + report.add_figure( + fig=fig, + title="Autoreject cleaning", + section=title, + caption=caption, + tags=tags + ("autoreject",), + replace=True, + ) + plt.close(fig) + del caption + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, + session: str | None = None, +) -> SimpleNamespace: + cfg = SimpleNamespace( + conditions=config.conditions, + runs=get_runs(config=config, subject=subject), + task_is_rest=config.task_is_rest, + ica_l_freq=config.ica_l_freq, + ica_algorithm=config.ica_algorithm, + ica_n_components=config.ica_n_components, + ica_max_iterations=config.ica_max_iterations, + ica_decim=config.ica_decim, + ica_reject=config.ica_reject, + autoreject_n_interpolate=config.autoreject_n_interpolate, + random_state=config.random_state, + ch_types=config.ch_types, + l_freq=config.l_freq, + epochs_decim=config.epochs_decim, + raw_resample_sfreq=config.raw_resample_sfreq, + event_repeated=config.event_repeated, + epochs_tmin=config.epochs_tmin, + epochs_tmax=config.epochs_tmax, + epochs_custom_metadata=config.epochs_custom_metadata, + epochs_metadata_tmin=config.epochs_metadata_tmin, + epochs_metadata_tmax=config.epochs_metadata_tmax, + epochs_metadata_keep_first=config.epochs_metadata_keep_first, + epochs_metadata_keep_last=config.epochs_metadata_keep_last, + epochs_metadata_query=config.epochs_metadata_query, + eeg_reference=get_eeg_reference(config), + eog_channels=config.eog_channels, + rest_epochs_duration=config.rest_epochs_duration, + rest_epochs_overlap=config.rest_epochs_overlap, + processing="filt" if config.regress_artifact is None else "regress", + _epochs_split_size=config._epochs_split_size, + **_bids_kwargs(config=config), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Run ICA.""" + if config.spatial_filter != "ica": + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + parallel, run_func = parallel_func(run_ica, exec_params=config.exec_params) + logs = parallel( + run_func( + cfg=get_config(config=config, subject=subject), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions + ) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py new file mode 100644 index 000000000..c2211476c --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py @@ -0,0 +1,400 @@ +"""Find ICA artifacts. + +This step automatically finds ECG- and EOG-related ICs in your data, and sets them +as bad components. + +To actually remove designated ICA components from your data, you will have to +run the apply_ica step. +""" + +from types import SimpleNamespace +from typing import Literal + +import mne +import numpy as np +import pandas as pd +from mne.preprocessing import create_ecg_epochs, create_eog_epochs +from mne_bids import BIDSPath + +from mne_bids_pipeline._config_utils import ( + _bids_kwargs, + get_eeg_reference, + get_runs, + get_subjects_sessions, +) +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _open_report +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, +) +from mne_bids_pipeline.typing import FloatArrayT, InFilesT, OutFilesT + + +def detect_bad_components( + *, + cfg: SimpleNamespace, + which: Literal["eog", "ecg"], + epochs: mne.BaseEpochs | None, + ica: mne.preprocessing.ICA, + ch_names: list[str] | None, + subject: str, + session: str | None, +) -> tuple[list[int], FloatArrayT]: + artifact = which.upper() + if epochs is None: + msg = ( + f"No {artifact} events could be found. " + f"Not running {artifact} artifact detection." + ) + logger.info(**gen_log_kwargs(message=msg)) + return [], np.zeros(0) + msg = f"Performing automated {artifact} artifact detection …" + logger.info(**gen_log_kwargs(message=msg)) + + if which == "eog": + inds, scores = ica.find_bads_eog( + epochs, + threshold=cfg.ica_eog_threshold, + ch_name=ch_names, + ) + else: + inds, scores = ica.find_bads_ecg( + epochs, + method="ctps", + threshold=cfg.ica_ecg_threshold, + ch_name=ch_names, + ) + + if not inds: + adjust_setting = f"ica_{which}_threshold" + warn = ( + f"No {artifact}-related ICs detected, this is highly " + f"suspicious. A manual check is suggested. You may wish to " + f'lower "{adjust_setting}".' + ) + logger.warning(**gen_log_kwargs(message=warn)) + else: + msg = ( + f"Detected {len(inds)} {artifact}-related ICs in " + f"{len(epochs)} {artifact} epochs: {', '.join([str(i) for i in inds])}" + ) + logger.info(**gen_log_kwargs(message=msg)) + + return inds, scores + + +def get_input_fnames_find_ica_artifacts( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> InFilesT: + bids_basename = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False, + extension=".fif", + ) + in_files = dict() + in_files["epochs"] = bids_basename.copy().update(processing="icafit", suffix="epo") + _update_for_splits(in_files, "epochs", single=True) + for run in cfg.runs: + key = f"raw_run-{run}" + in_files[key] = bids_basename.copy().update( + run=run, processing=cfg.processing, suffix="raw" + ) + _update_for_splits(in_files, key, single=True) + in_files["ica"] = bids_basename.copy().update(processing="icafit", suffix="ica") + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_find_ica_artifacts, +) +def find_ica_artifacts( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: InFilesT, +) -> OutFilesT: + """Run ICA.""" + raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] + bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) + out_files = dict() + out_files["ica"] = bids_basename.copy().update(processing="ica", suffix="ica") + out_files["ecg"] = bids_basename.copy().update(processing="ica+ecg", suffix="ave") + out_files["eog"] = bids_basename.copy().update(processing="ica+eog", suffix="ave") + + # DO NOT add this to out_files["ica"] because we expect it to be modified by users. + # If the modify it and it's in out_files, caching will detect the hash change and + # consider *this step* a cache miss, and it will run again, overwriting the user's + # changes. Instead, we want the ica.apply step to rerun (which it will if the + # file changes). + out_files_components = bids_basename.copy().update( + processing="ica", suffix="components", extension=".tsv" + ) + del bids_basename + msg = "Loading ICA solution" + logger.info(**gen_log_kwargs(message=msg)) + ica = mne.preprocessing.read_ica(in_files.pop("ica")) + + # Epochs used for ICA fitting + epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) + + # ECG component detection + epochs_ecg = None + ecg_ics: list[int] = [] + ecg_scores: FloatArrayT = np.zeros(0) + for ri, raw_fname in enumerate(raw_fnames): + # Have the channels needed to make ECG epochs + raw = mne.io.read_raw(raw_fname, preload=False) + # ECG epochs + if not ( + "ecg" in raw.get_channel_types() + or "meg" in cfg.ch_types + or "mag" in cfg.ch_types + ): + msg = ( + "No ECG or magnetometer channels are present, cannot " + "automate artifact detection for ECG." + ) + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating ECG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + + # We want to extract a total of 5 min of data for ECG epochs generation + # (across all runs) + total_ecg_dur = 5 * 60 + ecg_dur_per_run = total_ecg_dur / len(raw_fnames) + t_mid = (raw.times[-1] + raw.times[0]) / 2 + raw = raw.crop( + tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), + tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), + ).load_data() + + these_ecg_epochs = create_ecg_epochs( + raw, + baseline=(None, -0.2), + tmin=-0.5, + tmax=0.5, + ) + del raw # Free memory + if len(these_ecg_epochs): + if epochs.reject is not None: + these_ecg_epochs.drop_bad(reject=epochs.reject) + if len(these_ecg_epochs): + if epochs_ecg is None: + epochs_ecg = these_ecg_epochs + else: + epochs_ecg = mne.concatenate_epochs( + [epochs_ecg, these_ecg_epochs], on_mismatch="warn" + ) + del these_ecg_epochs + else: # did not break so had usable channels + ecg_ics, ecg_scores = detect_bad_components( + cfg=cfg, + which="ecg", + epochs=epochs_ecg, + ica=ica, + ch_names=None, # we currently don't allow for custom channels + subject=subject, + session=session, + ) + + # EOG component detection + epochs_eog = None + eog_ics: list[int] = [] + eog_scores: FloatArrayT = np.zeros(0) + for ri, raw_fname in enumerate(raw_fnames): + raw = mne.io.read_raw_fif(raw_fname, preload=True) + if cfg.eog_channels: + ch_names = cfg.eog_channels + assert all([ch_name in raw.ch_names for ch_name in ch_names]) + else: + eog_picks = mne.pick_types(raw.info, meg=False, eog=True) + ch_names = [raw.ch_names[pick] for pick in eog_picks] + if not ch_names: + msg = "No EOG channel is present, cannot automate IC detection for EOG." + logger.info(**gen_log_kwargs(message=msg)) + break + elif ri == 0: + msg = "Creating EOG epochs …" + logger.info(**gen_log_kwargs(message=msg)) + these_eog_epochs = create_eog_epochs( + raw, + ch_name=ch_names, + baseline=(None, -0.2), + ) + if len(these_eog_epochs): + if epochs.reject is not None: + these_eog_epochs.drop_bad(reject=epochs.reject) + if len(these_eog_epochs): + if epochs_eog is None: + epochs_eog = these_eog_epochs + else: + epochs_eog = mne.concatenate_epochs( + [epochs_eog, these_eog_epochs], on_mismatch="warn" + ) + else: # did not break + eog_ics, eog_scores = detect_bad_components( + cfg=cfg, + which="eog", + epochs=epochs_eog, + ica=ica, + ch_names=cfg.eog_channels, + subject=subject, + session=session, + ) + + # Save updated ICA to disk. + # We also store the automatically identified ECG- and EOG-related ICs. + msg = "Saving ICA solution and detected artifacts to disk." + logger.info(**gen_log_kwargs(message=msg)) + ica.exclude = sorted(set(ecg_ics + eog_ics)) + ica.save(out_files["ica"], overwrite=True) + + # Create TSV. + tsv_data = pd.DataFrame( + dict( + component=list(range(ica.n_components_)), + type=["ica"] * ica.n_components_, + description=["Independent Component"] * ica.n_components_, + status=["good"] * ica.n_components_, + status_description=["n/a"] * ica.n_components_, + ) + ) + + for component in ecg_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[row_idx, "status_description"] = "Auto-detected ECG artifact" + + for component in eog_ics: + row_idx = tsv_data["component"] == component + tsv_data.loc[row_idx, "status"] = "bad" + tsv_data.loc[row_idx, "status_description"] = "Auto-detected EOG artifact" + + tsv_data.to_csv(out_files_components, sep="\t", index=False) + + # Lastly, add info about the epochs used for the ICA fit, and plot all ICs + # for manual inspection. + + ecg_evoked = None if epochs_ecg is None else epochs_ecg.average() + eog_evoked = None if epochs_eog is None else epochs_eog.average() + + # Save ECG and EOG evokeds to disk. + for artifact_name, artifact_evoked in zip(("ecg", "eog"), (ecg_evoked, eog_evoked)): + if artifact_evoked: + msg = f"Saving {artifact_name.upper()} artifact: {out_files[artifact_name]}" + logger.info(**gen_log_kwargs(message=msg)) + artifact_evoked.save(out_files[artifact_name], overwrite=True) + else: + # Don't track the non-existent output file + del out_files[artifact_name] + + del artifact_name, artifact_evoked + + title = "ICA: components" + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + task=cfg.task, + ) as report: + logger.info(**gen_log_kwargs(message=f'Adding "{title}" to report.')) + report.add_ica( + ica=ica, + title=title, + inst=epochs, + ecg_evoked=ecg_evoked, + eog_evoked=eog_evoked, + ecg_scores=ecg_scores if len(ecg_scores) else None, + eog_scores=eog_scores if len(eog_scores) else None, + replace=True, + n_jobs=1, # avoid automatic parallelization + tags=("ica",), # the default but be explicit + ) + + msg = 'Carefully review the extracted ICs and mark components "bad" in:' + logger.info(**gen_log_kwargs(message=msg, emoji="🛑")) + logger.info(**gen_log_kwargs(message=str(out_files_components), emoji="🛑")) + + assert len(in_files) == 0, in_files.keys() + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, + session: str | None = None, +) -> SimpleNamespace: + cfg = SimpleNamespace( + conditions=config.conditions, + runs=get_runs(config=config, subject=subject), + task_is_rest=config.task_is_rest, + ica_l_freq=config.ica_l_freq, + ica_reject=config.ica_reject, + ica_eog_threshold=config.ica_eog_threshold, + ica_ecg_threshold=config.ica_ecg_threshold, + autoreject_n_interpolate=config.autoreject_n_interpolate, + random_state=config.random_state, + ch_types=config.ch_types, + l_freq=config.l_freq, + epochs_decim=config.epochs_decim, + raw_resample_sfreq=config.raw_resample_sfreq, + event_repeated=config.event_repeated, + epochs_tmin=config.epochs_tmin, + epochs_tmax=config.epochs_tmax, + epochs_metadata_tmin=config.epochs_metadata_tmin, + epochs_metadata_tmax=config.epochs_metadata_tmax, + epochs_metadata_keep_first=config.epochs_metadata_keep_first, + epochs_metadata_keep_last=config.epochs_metadata_keep_last, + epochs_metadata_query=config.epochs_metadata_query, + eeg_reference=get_eeg_reference(config), + eog_channels=config.eog_channels, + rest_epochs_duration=config.rest_epochs_duration, + rest_epochs_overlap=config.rest_epochs_overlap, + processing="filt" if config.regress_artifact is None else "regress", + **_bids_kwargs(config=config), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Run ICA.""" + if config.spatial_filter != "ica": + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + parallel, run_func = parallel_func( + find_ica_artifacts, exec_params=config.exec_params + ) + logs = parallel( + run_func( + cfg=get_config(config=config, subject=subject), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions + ) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py b/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py deleted file mode 100644 index efd8bec84..000000000 --- a/mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py +++ /dev/null @@ -1,633 +0,0 @@ -"""Run Independent Component Analysis (ICA) for artifact correction. - -This fits ICA on epoched data filtered with 1 Hz highpass, -for this purpose only using fastICA. Separate ICAs are fitted and stored for -MEG and EEG data. - -Before performing ICA, we reject epochs based on peak-to-peak amplitude above -the 'ica_reject' to filter massive non-biological artifacts. - -To actually remove designated ICA components from your data, you will have to -run 05a-apply_ica.py. -""" - -from typing import List, Optional, Iterable, Tuple, Literal -from types import SimpleNamespace - -import pandas as pd -import numpy as np -import autoreject - -import mne -from mne.report import Report -from mne.preprocessing import ICA, create_ecg_epochs, create_eog_epochs -from mne_bids import BIDSPath - -from ..._config_utils import ( - get_runs, - get_sessions, - get_subjects, - get_eeg_reference, - _bids_kwargs, -) -from ..._import_data import make_epochs, annotations_to_events -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._reject import _get_reject -from ..._report import _agg_backend -from ..._run import failsafe_run, _update_for_splits, save_logs, _prep_out_files - - -def filter_for_ica( - *, - cfg, - raw: mne.io.BaseRaw, - subject: str, - session: Optional[str], - run: Optional[str] = None, -) -> None: - """Apply a high-pass filter if needed.""" - if cfg.ica_l_freq is None: - msg = ( - f"Not applying high-pass filter (data is already filtered, " - f'cutoff: {raw.info["highpass"]} Hz).' - ) - logger.info(**gen_log_kwargs(message=msg)) - else: - msg = f"Applying high-pass filter with {cfg.ica_l_freq} Hz cutoff …" - logger.info(**gen_log_kwargs(message=msg)) - raw.filter(l_freq=cfg.ica_l_freq, h_freq=None, n_jobs=1) - - -def fit_ica( - *, - cfg, - epochs: mne.BaseEpochs, - subject: str, - session: Optional[str], -) -> mne.preprocessing.ICA: - algorithm = cfg.ica_algorithm - fit_params = None - - if algorithm == "picard": - fit_params = dict(fastica_it=5) - elif algorithm == "picard-extended_infomax": - algorithm = "picard" - fit_params = dict(ortho=False, extended=True) - elif algorithm == "extended_infomax": - algorithm = "infomax" - fit_params = dict(extended=True) - - ica = ICA( - method=algorithm, - random_state=cfg.random_state, - n_components=cfg.ica_n_components, - fit_params=fit_params, - max_iter=cfg.ica_max_iterations, - ) - - ica.fit(epochs, decim=cfg.ica_decim) - - explained_var = ( - ica.pca_explained_variance_[: ica.n_components_].sum() - / ica.pca_explained_variance_.sum() - ) - msg = ( - f"Fit {ica.n_components_} components (explaining " - f"{round(explained_var * 100, 1)}% of the variance) in " - f"{ica.n_iter_} iterations." - ) - logger.info(**gen_log_kwargs(message=msg)) - return ica - - -def make_ecg_epochs( - *, - cfg, - raw_path: BIDSPath, - subject: str, - session: Optional[str], - run: Optional[str] = None, - n_runs: int, -) -> Optional[mne.BaseEpochs]: - # ECG either needs an ecg channel, or avg of the mags (i.e. MEG data) - raw = mne.io.read_raw(raw_path, preload=False) - - if ( - "ecg" in raw.get_channel_types() - or "meg" in cfg.ch_types - or "mag" in cfg.ch_types - ): - msg = "Creating ECG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - - # We want to extract a total of 5 min of data for ECG epochs generation - # (across all runs) - total_ecg_dur = 5 * 60 - ecg_dur_per_run = total_ecg_dur / n_runs - t_mid = (raw.times[-1] + raw.times[0]) / 2 - raw = raw.crop( - tmin=max(t_mid - 1 / 2 * ecg_dur_per_run, 0), - tmax=min(t_mid + 1 / 2 * ecg_dur_per_run, raw.times[-1]), - ).load_data() - - ecg_epochs = create_ecg_epochs(raw, baseline=(None, -0.2), tmin=-0.5, tmax=0.5) - del raw # Free memory - - if len(ecg_epochs) == 0: - msg = "No ECG events could be found. Not running ECG artifact " "detection." - logger.info(**gen_log_kwargs(message=msg)) - ecg_epochs = None - else: - msg = ( - "No ECG or magnetometer channels are present. Cannot " - "automate artifact detection for ECG" - ) - logger.info(**gen_log_kwargs(message=msg)) - ecg_epochs = None - - return ecg_epochs - - -def make_eog_epochs( - *, - raw: mne.io.BaseRaw, - eog_channels: Optional[Iterable[str]], - subject: str, - session: Optional[str], - run: Optional[str] = None, -) -> Optional[mne.Epochs]: - """Create EOG epochs. No rejection thresholds will be applied.""" - if eog_channels: - ch_names = eog_channels - assert all([ch_name in raw.ch_names for ch_name in ch_names]) - else: - ch_idx = mne.pick_types(raw.info, meg=False, eog=True) - ch_names = [raw.ch_names[i] for i in ch_idx] - del ch_idx - - if ch_names: - msg = "Creating EOG epochs …" - logger.info(**gen_log_kwargs(message=msg)) - - eog_epochs = create_eog_epochs(raw, ch_name=ch_names, baseline=(None, -0.2)) - - if len(eog_epochs) == 0: - msg = "No EOG events could be found. Not running EOG artifact " "detection." - logger.warning(**gen_log_kwargs(message=msg)) - eog_epochs = None - else: - msg = "No EOG channel is present. Cannot automate IC detection " "for EOG" - logger.info(**gen_log_kwargs(message=msg)) - eog_epochs = None - - return eog_epochs - - -def detect_bad_components( - *, - cfg, - which: Literal["eog", "ecg"], - epochs: mne.BaseEpochs, - ica: mne.preprocessing.ICA, - ch_names: Optional[List[str]], - subject: str, - session: Optional[str], -) -> Tuple[List[int], np.ndarray]: - artifact = which.upper() - msg = f"Performing automated {artifact} artifact detection …" - logger.info(**gen_log_kwargs(message=msg)) - - if which == "eog": - inds, scores = ica.find_bads_eog( - epochs, - threshold=cfg.ica_eog_threshold, - ch_name=ch_names, - ) - else: - inds, scores = ica.find_bads_ecg( - epochs, - method="ctps", - threshold=cfg.ica_ctps_ecg_threshold, - ch_name=ch_names, - ) - - if not inds: - adjust_setting = ( - "ica_eog_threshold" if which == "eog" else "ica_ctps_ecg_threshold" - ) - warn = ( - f"No {artifact}-related ICs detected, this is highly " - f"suspicious. A manual check is suggested. You may wish to " - f'lower "{adjust_setting}".' - ) - logger.warning(**gen_log_kwargs(message=warn)) - else: - msg = ( - f"Detected {len(inds)} {artifact}-related ICs in " - f"{len(epochs)} {artifact} epochs." - ) - logger.info(**gen_log_kwargs(message=msg)) - - return inds, scores - - -def get_input_fnames_run_ica( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], -) -> dict: - bids_basename = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - in_files = dict() - for run in cfg.runs: - key = f"raw_run-{run}" - in_files[key] = bids_basename.copy().update( - run=run, processing="filt", suffix="raw" - ) - _update_for_splits(in_files, key, single=True) - return in_files - - -@failsafe_run( - get_input_fnames=get_input_fnames_run_ica, -) -def run_ica( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], - in_files: dict, -) -> dict: - """Run ICA.""" - raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] - bids_basename = raw_fnames[0].copy().update(processing=None, split=None, run=None) - out_files = dict() - out_files["ica"] = bids_basename.copy().update(suffix="ica", extension=".fif") - out_files["components"] = bids_basename.copy().update( - processing="ica", suffix="components", extension=".tsv" - ) - out_files["report"] = bids_basename.copy().update( - processing="ica+components", suffix="report", extension=".html" - ) - del bids_basename - - # Generate a list of raw data paths (i.e., paths of individual runs) - # we want to create epochs from. - - # Generate a unique event name -> event code mapping that can be used - # across all runs. - event_name_to_code_map = annotations_to_events(raw_paths=raw_fnames) - - # Now, generate epochs from each individual run - eog_epochs_all_runs = None - ecg_epochs_all_runs = None - - for idx, (run, raw_fname) in enumerate(zip(cfg.runs, raw_fnames)): - msg = f"Loading filtered raw data from {raw_fname.basename}" - logger.info(**gen_log_kwargs(message=msg)) - - # ECG epochs - ecg_epochs = make_ecg_epochs( - cfg=cfg, - raw_path=raw_fname, - subject=subject, - session=session, - run=run, - n_runs=len(cfg.runs), - ) - if ecg_epochs is not None: - if idx == 0: - ecg_epochs_all_runs = ecg_epochs - else: - ecg_epochs_all_runs = mne.concatenate_epochs( - [ecg_epochs_all_runs, ecg_epochs], on_mismatch="warn" - ) - - del ecg_epochs - - # EOG epochs - raw = mne.io.read_raw_fif(raw_fname, preload=True) - eog_epochs = make_eog_epochs( - raw=raw, - eog_channels=cfg.eog_channels, - subject=subject, - session=session, - run=run, - ) - if eog_epochs is not None: - if idx == 0: - eog_epochs_all_runs = eog_epochs - else: - eog_epochs_all_runs = mne.concatenate_epochs( - [eog_epochs_all_runs, eog_epochs], on_mismatch="warn" - ) - - del eog_epochs - - # Produce high-pass filtered version of the data for ICA. - # Sanity check – make sure we're using the correct data! - if cfg.raw_resample_sfreq is not None: - assert np.allclose(raw.info["sfreq"], cfg.raw_resample_sfreq) - if cfg.l_freq is not None: - assert np.allclose(raw.info["highpass"], cfg.l_freq) - - filter_for_ica(cfg=cfg, raw=raw, subject=subject, session=session, run=run) - - # Only keep the subset of the mapping that applies to the current run - event_id = event_name_to_code_map.copy() - for event_name in event_id.copy().keys(): - if event_name not in raw.annotations.description: - del event_id[event_name] - - msg = "Creating task-related epochs …" - logger.info(**gen_log_kwargs(message=msg)) - epochs = make_epochs( - subject=subject, - session=session, - task=cfg.task, - conditions=cfg.conditions, - raw=raw, - event_id=event_id, - tmin=cfg.epochs_tmin, - tmax=cfg.epochs_tmax, - metadata_tmin=cfg.epochs_metadata_tmin, - metadata_tmax=cfg.epochs_metadata_tmax, - metadata_keep_first=cfg.epochs_metadata_keep_first, - metadata_keep_last=cfg.epochs_metadata_keep_last, - metadata_query=cfg.epochs_metadata_query, - event_repeated=cfg.event_repeated, - epochs_decim=cfg.epochs_decim, - task_is_rest=cfg.task_is_rest, - rest_epochs_duration=cfg.rest_epochs_duration, - rest_epochs_overlap=cfg.rest_epochs_overlap, - ) - - epochs.load_data() # Remove reference to raw - del raw # free memory - - if idx == 0: - epochs_all_runs = epochs - else: - epochs_all_runs = mne.concatenate_epochs( - [epochs_all_runs, epochs], on_mismatch="warn" - ) - - del epochs - - # Clean up namespace - epochs = epochs_all_runs - epochs_ecg = ecg_epochs_all_runs - epochs_eog = eog_epochs_all_runs - - del epochs_all_runs, eog_epochs_all_runs, ecg_epochs_all_runs, run - - # Set an EEG reference - if "eeg" in cfg.ch_types: - projection = True if cfg.eeg_reference == "average" else False - epochs.set_eeg_reference(cfg.eeg_reference, projection=projection) - - if cfg.ica_reject == "autoreject_local": - msg = "Using autoreject to find and repair bad epochs before fitting ICA" - logger.info(**gen_log_kwargs(message=msg)) - - ar = autoreject.AutoReject( - n_interpolate=cfg.autoreject_n_interpolate, - random_state=cfg.random_state, - n_jobs=exec_params.n_jobs, - verbose=False, - ) - ar.fit(epochs) - epochs, reject_log = ar.transform(epochs, return_log=True) - msg = ( - f"autoreject marked {reject_log.bad_epochs.sum()} epochs as bad " - f"(cross-validated n_interpolate limit: {ar.n_interpolate_})" - ) - logger.info(**gen_log_kwargs(message=msg)) - else: - # Reject epochs based on peak-to-peak rejection thresholds - ica_reject = _get_reject( - subject=subject, - session=session, - reject=cfg.ica_reject, - ch_types=cfg.ch_types, - param="ica_reject", - ) - - msg = f"Using PTP rejection thresholds: {ica_reject}" - logger.info(**gen_log_kwargs(message=msg)) - - epochs.drop_bad(reject=ica_reject) - if epochs_eog is not None: - epochs_eog.drop_bad(reject=ica_reject) - if epochs_ecg is not None: - epochs_ecg.drop_bad(reject=ica_reject) - - # Now actually perform ICA. - msg = f"Calculating ICA solution using method: {cfg.ica_algorithm}." - logger.info(**gen_log_kwargs(message=msg)) - ica = fit_ica(cfg=cfg, epochs=epochs, subject=subject, session=session) - - # Start a report - title = f"ICA – sub-{subject}" - if session is not None: - title += f", ses-{session}" - if cfg.task is not None: - title += f", task-{cfg.task}" - - # ECG and EOG component detection - if epochs_ecg: - ecg_ics, ecg_scores = detect_bad_components( - cfg=cfg, - which="ecg", - epochs=epochs_ecg, - ica=ica, - ch_names=None, # we currently don't allow for custom channels - subject=subject, - session=session, - ) - else: - ecg_ics = ecg_scores = [] - - if epochs_eog: - eog_ics, eog_scores = detect_bad_components( - cfg=cfg, - which="eog", - epochs=epochs_eog, - ica=ica, - ch_names=cfg.eog_channels, - subject=subject, - session=session, - ) - else: - eog_ics = eog_scores = [] - - # Save ICA to disk. - # We also store the automatically identified ECG- and EOG-related ICs. - msg = "Saving ICA solution and detected artifacts to disk." - logger.info(**gen_log_kwargs(message=msg)) - ica.exclude = sorted(set(ecg_ics + eog_ics)) - ica.save(out_files["ica"], overwrite=True) - _update_for_splits(out_files, "ica") - - # Create TSV. - tsv_data = pd.DataFrame( - dict( - component=list(range(ica.n_components_)), - type=["ica"] * ica.n_components_, - description=["Independent Component"] * ica.n_components_, - status=["good"] * ica.n_components_, - status_description=["n/a"] * ica.n_components_, - ) - ) - - for component in ecg_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected ECG artifact" - - for component in eog_ics: - row_idx = tsv_data["component"] == component - tsv_data.loc[row_idx, "status"] = "bad" - tsv_data.loc[row_idx, "status_description"] = "Auto-detected EOG artifact" - - tsv_data.to_csv(out_files["components"], sep="\t", index=False) - - # Lastly, add info about the epochs used for the ICA fit, and plot all ICs - # for manual inspection. - msg = "Adding diagnostic plots for all ICA components to the HTML report …" - logger.info(**gen_log_kwargs(message=msg)) - - report = Report(info_fname=epochs, title=title, verbose=False) - ecg_evoked = None if epochs_ecg is None else epochs_ecg.average() - eog_evoked = None if epochs_eog is None else epochs_eog.average() - ecg_scores = None if len(ecg_scores) == 0 else ecg_scores - eog_scores = None if len(eog_scores) == 0 else eog_scores - - with _agg_backend(): - if cfg.ica_reject == "autoreject_local": - caption = ( - f"Autoreject was run to produce cleaner epochs before fitting ICA. " - f"{reject_log.bad_epochs.sum()} epochs were rejected because more than " - f"{ar.n_interpolate_} channels were bad (cross-validated n_interpolate " - f"limit; excluding globally bad and non-data channels, shown in white)." - ) - report.add_figure( - fig=reject_log.plot( - orientation="horizontal", aspect="auto", show=False - ), - title="Epochs: Autoreject cleaning", - caption=caption, - tags=("ica", "epochs", "autoreject"), - replace=True, - ) - del caption - - report.add_epochs( - epochs=epochs, - title="Epochs used for ICA fitting", - drop_log_ignore=(), - replace=True, - ) - report.add_ica( - ica=ica, - title="ICA cleaning", - inst=epochs, - ecg_evoked=ecg_evoked, - eog_evoked=eog_evoked, - ecg_scores=ecg_scores, - eog_scores=eog_scores, - replace=True, - n_jobs=1, # avoid automatic parallelization - ) - - msg = ( - f"ICA completed. Please carefully review the extracted ICs in the " - f"report {out_files['report'].basename}, and mark all components " - f"you wish to reject as 'bad' in " - f"{out_files['components'].basename}" - ) - logger.info(**gen_log_kwargs(message=msg)) - - report.save( - out_files["report"], - overwrite=True, - open_browser=exec_params.interactive, - ) - - assert len(in_files) == 0, in_files.keys() - return _prep_out_files(exec_params=exec_params, out_files=out_files) - - -def get_config( - *, - config: SimpleNamespace, - subject: str, - session: Optional[str] = None, -) -> SimpleNamespace: - cfg = SimpleNamespace( - conditions=config.conditions, - runs=get_runs(config=config, subject=subject), - task_is_rest=config.task_is_rest, - ica_l_freq=config.ica_l_freq, - ica_algorithm=config.ica_algorithm, - ica_n_components=config.ica_n_components, - ica_max_iterations=config.ica_max_iterations, - ica_decim=config.ica_decim, - ica_reject=config.ica_reject, - ica_eog_threshold=config.ica_eog_threshold, - ica_ctps_ecg_threshold=config.ica_ctps_ecg_threshold, - autoreject_n_interpolate=config.autoreject_n_interpolate, - random_state=config.random_state, - ch_types=config.ch_types, - l_freq=config.l_freq, - epochs_decim=config.epochs_decim, - raw_resample_sfreq=config.raw_resample_sfreq, - event_repeated=config.event_repeated, - epochs_tmin=config.epochs_tmin, - epochs_tmax=config.epochs_tmax, - epochs_metadata_tmin=config.epochs_metadata_tmin, - epochs_metadata_tmax=config.epochs_metadata_tmax, - epochs_metadata_keep_first=config.epochs_metadata_keep_first, - epochs_metadata_keep_last=config.epochs_metadata_keep_last, - epochs_metadata_query=config.epochs_metadata_query, - eeg_reference=get_eeg_reference(config), - eog_channels=config.eog_channels, - rest_epochs_duration=config.rest_epochs_duration, - rest_epochs_overlap=config.rest_epochs_overlap, - **_bids_kwargs(config=config), - ) - return cfg - - -def main(*, config: SimpleNamespace) -> None: - """Run ICA.""" - if config.spatial_filter != "ica": - msg = "Skipping …" - logger.info(**gen_log_kwargs(message=msg, emoji="skip")) - return - - with get_parallel_backend(config.exec_params): - parallel, run_func = parallel_func(run_ica, exec_params=config.exec_params) - logs = parallel( - run_func( - cfg=get_config(config=config, subject=subject), - exec_params=config.exec_params, - subject=subject, - session=session, - ) - for subject in get_subjects(config) - for session in get_sessions(config) - ) - save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py b/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py index eeb22cf36..aa8d97b5a 100644 --- a/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py +++ b/mne_bids_pipeline/steps/preprocessing/_06b_run_ssp.py @@ -1,36 +1,50 @@ -"""Run Signal Subspace Projections (SSP) for artifact correction. +"""Compute SSP. +Signal subspace projections (SSP) vectors are computed from EOG and ECG signals. These are often also referred to as PCA vectors. """ -from typing import Optional from types import SimpleNamespace import mne -from mne.preprocessing import create_eog_epochs, create_ecg_epochs -from mne import compute_proj_evoked, compute_proj_epochs +from mne import compute_proj_epochs, compute_proj_evoked +from mne.preprocessing import find_ecg_events, find_eog_events from mne_bids import BIDSPath -from ..._config_utils import ( - get_runs, - get_sessions, - get_subjects, +from mne_bids_pipeline._config_import import ConfigError +from mne_bids_pipeline._config_utils import ( _bids_kwargs, _pl, + _proj_path, + get_ecg_channel, + get_runs, + get_subjects_sessions, +) +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._reject import _get_reject +from mne_bids_pipeline._report import _open_report +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, ) -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._reject import _get_reject -from ..._report import _open_report -from ..._run import failsafe_run, _update_for_splits, save_logs, _prep_out_files +from mne_bids_pipeline.typing import InFilesT, IntArrayT, OutFilesT + + +def _find_ecg_events(raw: mne.io.Raw, ch_name: str | None) -> IntArrayT: + """Wrap find_ecg_events to use the same defaults as create_ecg_events.""" + out: IntArrayT = find_ecg_events(raw, ch_name=ch_name, l_freq=8, h_freq=16)[0] + return out def get_input_fnames_run_ssp( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesT: bids_basename = BIDSPath( subject=subject, session=session, @@ -47,7 +61,7 @@ def get_input_fnames_run_ssp( for run in cfg.runs: key = f"raw_run-{run}" in_files[key] = bids_basename.copy().update( - run=run, processing="filt", suffix="raw" + run=run, processing=cfg.processing, suffix="raw" ) _update_for_splits(in_files, key, single=True) return in_files @@ -61,28 +75,21 @@ def run_ssp( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesT, +) -> OutFilesT: import matplotlib.pyplot as plt - # compute SSP on first run of raw + # compute SSP on all runs of raw raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] - # when saving proj, use run=None - out_files = dict() - out_files["proj"] = ( - raw_fnames[0] - .copy() - .update(run=None, suffix="proj", split=None, processing=None, check=False) - ) - + out_files = dict(proj=_proj_path(cfg=cfg, subject=subject, session=session)) msg = ( f"Input{_pl(raw_fnames)} ({len(raw_fnames)}): " - f'{raw_fnames[0].basename}{_pl(raw_fnames, pl=" ...")}' + f"{raw_fnames[0].basename}{_pl(raw_fnames, pl=' ...')}" ) logger.info(**gen_log_kwargs(message=msg)) - msg = f'Output: {out_files["proj"].basename}' + msg = f"Output: {out_files['proj'].basename}" logger.info(**gen_log_kwargs(message=msg)) raw = mne.concatenate_raws( @@ -90,38 +97,64 @@ def run_ssp( ) del raw_fnames - projs = dict() + projs: dict[str, list[mne.Projection]] = dict() proj_kinds = ("ecg", "eog") rate_names = dict(ecg="heart", eog="blink") - epochs_fun = dict(ecg=create_ecg_epochs, eog=create_eog_epochs) + events_fun = dict(ecg=_find_ecg_events, eog=find_eog_events) minimums = dict(ecg=cfg.min_ecg_epochs, eog=cfg.min_eog_epochs) rejects = dict(ecg=cfg.ssp_reject_ecg, eog=cfg.ssp_reject_eog) avg = dict(ecg=cfg.ecg_proj_from_average, eog=cfg.eog_proj_from_average) n_projs = dict(ecg=cfg.n_proj_ecg, eog=cfg.n_proj_eog) - ch_name = dict(ecg=None, eog=None) + ch_name: dict[str, str | list[str] | None] = dict(ecg=None, eog=None) if cfg.eog_channels: ch_name["eog"] = cfg.eog_channels + assert ch_name["eog"] is not None assert all(ch_name in raw.ch_names for ch_name in ch_name["eog"]) if cfg.ssp_ecg_channel: - ch_name["ecg"] = cfg.ssp_ecg_channel - assert ch_name["ecg"] in raw.ch_names, ch_name["ecg"] + ch_name["ecg"] = get_ecg_channel(config=cfg, subject=subject, session=session) + if ch_name["ecg"] not in raw.ch_names: + raise ConfigError( + f"SSP ECG channel '{ch_name['ecg']}' not found in data for " + f"subject {subject}, session {session}" + ) if cfg.ssp_meg == "auto": cfg.ssp_meg = "combined" if cfg.use_maxwell_filter else "separate" for kind in proj_kinds: projs[kind] = [] if not any(n_projs[kind].values()): continue - proj_epochs = epochs_fun[kind]( - raw, - ch_name=ch_name[kind], - decim=cfg.epochs_decim, - ) - n_orig = len(proj_epochs.selection) + events = events_fun[kind](raw=raw, ch_name=ch_name[kind]) + n_orig = len(events) rate = n_orig / raw.times[-1] * 60 bpm_msg = f"{rate:5.1f} bpm" msg = f"Detected {rate_names[kind]} rate: {bpm_msg}" logger.info(**gen_log_kwargs(message=msg)) - # Enough to start + # Enough to create epochs + if len(events) < minimums[kind]: + msg = ( + f"No {kind.upper()} projectors computed: got " + f"{len(events)} original events < {minimums[kind]} {bpm_msg}" + ) + logger.warning(**gen_log_kwargs(message=msg)) + continue + out_files[f"events_{kind}"] = ( + out_files["proj"] + .copy() + .update(suffix=f"{kind}-eve", split=None, check=False, extension=".txt") + ) + mne.write_events(out_files[f"events_{kind}"], events, overwrite=True) + proj_epochs = mne.Epochs( + raw, + events=events, + event_id=events[0, 2], + tmin=-0.5, + tmax=0.5, + proj=False, + baseline=(None, None), + reject_by_annotation=True, + preload=True, + decim=cfg.epochs_decim, + ) if len(proj_epochs) >= minimums[kind]: reject_ = _get_reject( subject=subject, @@ -134,7 +167,6 @@ def run_ssp( proj_epochs.drop_bad(reject=reject_) # Still enough after rejection if len(proj_epochs) >= minimums[kind]: - proj_epochs.apply_baseline((None, None)) use = proj_epochs.average() if avg[kind] else proj_epochs fun = compute_proj_evoked if avg[kind] else compute_proj_epochs desc_prefix = ( @@ -162,6 +194,7 @@ def run_ssp( mne.write_proj(out_files["proj"], sum(projs.values(), []), overwrite=True) assert len(in_files) == 0, in_files.keys() + del projs # Report with _open_report( @@ -175,13 +208,15 @@ def run_ssp( msg = f"Adding {kind.upper()} SSP to report." logger.info(**gen_log_kwargs(message=msg)) proj_epochs = mne.read_epochs(out_files[f"epochs_{kind}"]) - projs = mne.read_proj(out_files["proj"]) - projs = [p for p in projs if kind.upper() in p["desc"]] - assert len(projs), len(projs) # should exist if the epochs do - picks_trace = None + these_projs: list[mne.Projection] = mne.read_proj(out_files["proj"]) + these_projs = [p for p in these_projs if kind.upper() in p["desc"]] + assert len(these_projs), len(these_projs) # should exist if the epochs do + picks_trace: str | list[str] | None = None if kind == "ecg": if cfg.ssp_ecg_channel: - picks_trace = [cfg.ssp_ecg_channel] + picks_trace = [ + get_ecg_channel(config=cfg, subject=subject, session=session) + ] elif "ecg" in proj_epochs: picks_trace = "ecg" else: @@ -191,7 +226,7 @@ def run_ssp( elif "eog" in proj_epochs: picks_trace = "eog" fig = mne.viz.plot_projs_joint( - projs, proj_epochs.average(picks="all"), picks_trace=picks_trace + these_projs, proj_epochs.average(picks="all"), picks_trace=picks_trace ) caption = ( f"Computed using {len(proj_epochs)} epochs " @@ -229,6 +264,7 @@ def get_config( epochs_decim=config.epochs_decim, use_maxwell_filter=config.use_maxwell_filter, runs=get_runs(config=config, subject=subject), + processing="filt" if config.regress_artifact is None else "regress", **_bids_kwargs(config=config), ) return cfg @@ -253,7 +289,7 @@ def main(*, config: SimpleNamespace) -> None: subject=subject, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_05_make_epochs.py b/mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py similarity index 83% rename from mne_bids_pipeline/steps/preprocessing/_05_make_epochs.py rename to mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py index d4deb4078..0a08c8aa0 100644 --- a/mne_bids_pipeline/steps/preprocessing/_05_make_epochs.py +++ b/mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py @@ -7,38 +7,39 @@ To save space, the epoch data can be decimated. """ +import inspect from types import SimpleNamespace -from typing import Optional +from typing import Any import mne from mne_bids import BIDSPath -from ..._config_utils import ( - get_runs, - get_subjects, - get_eeg_reference, - get_sessions, +from mne_bids_pipeline._config_utils import ( _bids_kwargs, + get_eeg_reference, + get_runs, + get_subjects_sessions, ) -from ..._import_data import make_epochs, annotations_to_events -from ..._logging import gen_log_kwargs, logger -from ..._report import _open_report -from ..._run import ( +from mne_bids_pipeline._import_data import annotations_to_events, make_epochs +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _open_report +from mne_bids_pipeline._run import ( + _prep_out_files, + _sanitize_callable, + _update_for_splits, failsafe_run, save_logs, - _update_for_splits, - _sanitize_callable, - _prep_out_files, ) -from ..._parallel import parallel_func, get_parallel_backend +from mne_bids_pipeline.typing import InFilesT, IntArrayT, OutFilesT def get_input_fnames_epochs( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesT: """Get paths of files required by filter_data function.""" # Construct the basenames of the files we wish to load, and of the empty- # room recording we wish to save. @@ -54,7 +55,7 @@ def get_input_fnames_epochs( extension=".fif", datatype=cfg.datatype, root=cfg.deriv_root, - processing="filt", + processing=cfg.processing, ).update(suffix="raw", check=False) # Generate a list of raw data paths (i.e., paths of individual runs) @@ -78,9 +79,9 @@ def run_epochs( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesT, +) -> OutFilesT: """Extract epochs for one subject.""" raw_fnames = [in_files.pop(f"raw_run-{run}") for run in cfg.runs] bids_path_in = raw_fnames[0].copy().update(processing=None, run=None, split=None) @@ -124,6 +125,7 @@ def run_epochs( conditions=cfg.conditions, tmin=cfg.epochs_tmin, tmax=cfg.epochs_tmax, + custom_metadata=cfg.epochs_custom_metadata, metadata_tmin=cfg.epochs_metadata_tmin, metadata_tmax=cfg.epochs_metadata_tmax, metadata_keep_first=cfg.epochs_metadata_keep_first, @@ -149,7 +151,14 @@ def run_epochs( if cfg.use_maxwell_filter: # Keep track of the info corresponding to the run with the smallest # data rank. - new_rank = mne.compute_rank(epochs, rank="info")["meg"] + if "grad" in epochs: + if "mag" in epochs: + type_sel = "meg" + else: + type_sel = "grad" + else: + type_sel = "mag" + new_rank = mne.compute_rank(epochs, rank="info")[type_sel] if (smallest_rank is None) or (new_rank < smallest_rank): smallest_rank = new_rank smallest_rank_info = epochs.info.copy() @@ -162,7 +171,7 @@ def run_epochs( if cfg.use_maxwell_filter and cfg.noise_cov == "rest": raw_rest_filt = mne.io.read_raw(in_files.pop("raw_rest")) - rank_rest = mne.compute_rank(raw_rest_filt, rank="info")["meg"] + rank_rest = mne.compute_rank(raw_rest_filt, rank="info")[type_sel] if rank_rest < smallest_rank: msg = ( f"The MEG rank of the resting state data ({rank_rest}) is " @@ -187,10 +196,8 @@ def run_epochs( assert epochs.info["ch_names"] == smallest_rank_info["ch_names"] with epochs.info._unlock(): epochs.info["proc_history"] = smallest_rank_info["proc_history"] - rank_epochs_new = mne.compute_rank(epochs, rank="info")["meg"] - msg = ( - f'The MEG rank of the "{cfg.task}" epochs is now: ' f"{rank_epochs_new}" - ) + rank_epochs_new = mne.compute_rank(epochs, rank="info")[type_sel] + msg = f'The MEG rank of the "{cfg.task}" epochs is now: {rank_epochs_new}' logger.warning(**gen_log_kwargs(message=msg)) # Set an EEG reference @@ -206,15 +213,17 @@ def run_epochs( ) logger.info(**gen_log_kwargs(message=msg)) msg = ( - f"Selected {len(epochs)} epochs via metadata query: " - f"{cfg.epochs_metadata_query}" + f"Selected {len(epochs)} epochs via metadata query: {cfg.epochs_metadata_query}" ) logger.info(**gen_log_kwargs(message=msg)) msg = f"Writing {len(epochs)} epochs to disk." logger.info(**gen_log_kwargs(message=msg)) out_files = dict() out_files["epochs"] = bids_path_in.copy().update( - suffix="epo", processing=None, check=False + suffix="epo", + processing=None, + check=False, + split=None, ) epochs.save( out_files["epochs"], @@ -246,16 +255,14 @@ def run_epochs( msg = "Adding uncleaned epochs to report." logger.info(**gen_log_kwargs(message=msg)) # Add PSD plots for 30s of data or all epochs if we have less available - if len(epochs) * (epochs.tmax - epochs.tmin) < 30: - psd = True - else: - psd = 30 + psd = True if len(epochs) * (epochs.tmax - epochs.tmin) < 30 else 30.0 report.add_epochs( epochs=epochs, title="Epochs: before cleaning", psd=psd, drop_log_ignore=(), replace=True, + **_add_epochs_image_kwargs(cfg), ) # Interactive @@ -266,8 +273,18 @@ def run_epochs( return _prep_out_files(exec_params=exec_params, out_files=out_files) +def _add_epochs_image_kwargs(cfg: SimpleNamespace) -> dict[str, dict[str, Any]]: + arg_spec = inspect.getfullargspec(mne.Report.add_epochs) + kwargs = dict() + if cfg.report_add_epochs_image_kwargs and "image_kwargs" in arg_spec.kwonlyargs: + kwargs["image_kwargs"] = cfg.report_add_epochs_image_kwargs + return kwargs + + # TODO: ideally we wouldn't need this anymore and could refactor the code above -def _get_events(cfg, subject, session): +def _get_events( + cfg: SimpleNamespace, subject: str, session: str | None +) -> tuple[IntArrayT, dict[str, int], float, int]: raws_filt = [] raw_fname = BIDSPath( subject=subject, @@ -276,7 +293,7 @@ def _get_events(cfg, subject, session): acquisition=cfg.acq, recording=cfg.rec, space=cfg.space, - processing="filt", + processing=cfg.processing, suffix="raw", extension=".fif", datatype=cfg.datatype, @@ -299,7 +316,7 @@ def _get_events(cfg, subject, session): def get_config( *, - config, + config: SimpleNamespace, subject: str, ) -> SimpleNamespace: cfg = SimpleNamespace( @@ -308,6 +325,7 @@ def get_config( conditions=config.conditions, epochs_tmin=config.epochs_tmin, epochs_tmax=config.epochs_tmax, + epochs_custom_metadata=config.epochs_custom_metadata, epochs_metadata_tmin=config.epochs_metadata_tmin, epochs_metadata_tmax=config.epochs_metadata_tmax, epochs_metadata_keep_first=config.epochs_metadata_keep_first, @@ -315,6 +333,7 @@ def get_config( epochs_metadata_query=config.epochs_metadata_query, event_repeated=config.event_repeated, epochs_decim=config.epochs_decim, + report_add_epochs_image_kwargs=config.report_add_epochs_image_kwargs, ch_types=config.ch_types, noise_cov=_sanitize_callable(config.noise_cov), eeg_reference=get_eeg_reference(config), @@ -322,12 +341,13 @@ def get_config( rest_epochs_overlap=config.rest_epochs_overlap, _epochs_split_size=config._epochs_split_size, runs=get_runs(config=config, subject=subject), + processing="filt" if config.regress_artifact is None else "regress", **_bids_kwargs(config=config), ) return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run epochs.""" with get_parallel_backend(config.exec_params): parallel, run_func = parallel_func(run_epochs, exec_params=config.exec_params) @@ -341,7 +361,7 @@ def main(*, config) -> None: subject=subject, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py b/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py deleted file mode 100644 index 4b906a106..000000000 --- a/mne_bids_pipeline/steps/preprocessing/_07a_apply_ica.py +++ /dev/null @@ -1,207 +0,0 @@ -"""Apply ICA and obtain the cleaned epochs. - -Blinks and ECG artifacts are automatically detected and the corresponding ICA -components are removed from the data. -This relies on the ICAs computed in 04-run_ica.py - -!! If you manually add components to remove (config.rejcomps_man), -make sure you did not re-run the ICA in the meantime. Otherwise (especially if -the random state was not set, or you used a different machine, the component -order might differ). - -""" - -from types import SimpleNamespace -from typing import Optional - -import pandas as pd -import mne -from mne.preprocessing import read_ica -from mne.report import Report - -from mne_bids import BIDSPath - -from ..._config_utils import ( - get_subjects, - get_sessions, - _bids_kwargs, -) -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, _agg_backend -from ..._run import failsafe_run, _update_for_splits, save_logs, _prep_out_files - - -def get_input_fnames_apply_ica( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], -) -> dict: - bids_basename = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - datatype=cfg.datatype, - root=cfg.deriv_root, - check=False, - ) - in_files = dict() - in_files["ica"] = bids_basename.copy().update(suffix="ica", extension=".fif") - in_files["components"] = bids_basename.copy().update( - processing="ica", suffix="components", extension=".tsv" - ) - in_files["epochs"] = bids_basename.copy().update(suffix="epo", extension=".fif") - _update_for_splits(in_files, "epochs", single=True) - return in_files - - -@failsafe_run( - get_input_fnames=get_input_fnames_apply_ica, -) -def apply_ica( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], - in_files: dict, -) -> dict: - bids_basename = in_files["ica"].copy().update(processing=None) - out_files = dict() - out_files["epochs"] = in_files["epochs"].copy().update(processing="ica") - out_files["report"] = bids_basename.copy().update( - processing="ica", suffix="report", extension=".html" - ) - - title = f"ICA artifact removal – sub-{subject}" - if session is not None: - title += f", ses-{session}" - if cfg.task is not None: - title += f", task-{cfg.task}" - - # Load ICA. - msg = f"Reading ICA: {in_files['ica']}" - logger.debug(**gen_log_kwargs(message=msg)) - ica = read_ica(fname=in_files.pop("ica")) - - # Select ICs to remove. - tsv_data = pd.read_csv(in_files.pop("components"), sep="\t") - ica.exclude = tsv_data.loc[tsv_data["status"] == "bad", "component"].to_list() - - # Load epochs. - msg = f'Input: {in_files["epochs"].basename}' - logger.info(**gen_log_kwargs(message=msg)) - msg = f'Output: {out_files["epochs"].basename}' - logger.info(**gen_log_kwargs(message=msg)) - - epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) - - # Now actually reject the components. - msg = f'Rejecting ICs: {", ".join([str(ic) for ic in ica.exclude])}' - logger.info(**gen_log_kwargs(message=msg)) - epochs_cleaned = ica.apply(epochs.copy()) # Copy b/c works in-place! - - msg = "Saving reconstructed epochs after ICA." - logger.info(**gen_log_kwargs(message=msg)) - epochs_cleaned.save( - out_files["epochs"], - overwrite=True, - split_naming="bids", - split_size=cfg._epochs_split_size, - ) - _update_for_splits(out_files, "epochs") - - # Compare ERP/ERF before and after ICA artifact rejection. The evoked - # response is calculated across ALL epochs, just like ICA was run on - # all epochs, regardless of their respective experimental condition. - # - # We apply baseline correction here to (hopefully!) make the effects of - # ICA easier to see. Otherwise, individual channels might just have - # arbitrary DC shifts, and we wouldn't be able to easily decipher what's - # going on! - report = Report(out_files["report"], title=title, verbose=False) - picks = ica.exclude if ica.exclude else None - with _agg_backend(): - report.add_ica( - ica=ica, - title="Effects of ICA cleaning", - inst=epochs.copy().apply_baseline(cfg.baseline), - picks=picks, - replace=True, - n_jobs=1, # avoid automatic parallelization - ) - report.save( - out_files["report"], - overwrite=True, - open_browser=exec_params.interactive, - ) - - assert len(in_files) == 0, in_files.keys() - - # Report - kwargs = dict() - if ica.exclude: - msg = "Adding ICA to report." - else: - msg = "Skipping ICA addition to report, no components marked as bad." - kwargs["emoji"] = "skip" - logger.info(**gen_log_kwargs(message=msg, **kwargs)) - if ica.exclude: - with _open_report( - cfg=cfg, exec_params=exec_params, subject=subject, session=session - ) as report: - report.add_ica( - ica=ica, - title="ICA", - inst=epochs, - picks=ica.exclude, - # TODO upstream - # captions=f'Evoked response (across all epochs) ' - # f'before and after ICA ' - # f'({len(ica.exclude)} ICs removed)' - replace=True, - ) - - return _prep_out_files(exec_params=exec_params, out_files=out_files) - - -def get_config( - *, - config: SimpleNamespace, -) -> SimpleNamespace: - cfg = SimpleNamespace( - baseline=config.baseline, - ica_reject=config.ica_reject, - ch_types=config.ch_types, - _epochs_split_size=config._epochs_split_size, - **_bids_kwargs(config=config), - ) - return cfg - - -def main(*, config: SimpleNamespace) -> None: - """Apply ICA.""" - if not config.spatial_filter == "ica": - msg = "Skipping …" - logger.info(**gen_log_kwargs(message=msg, emoji="skip")) - return - - with get_parallel_backend(config.exec_params): - parallel, run_func = parallel_func(apply_ica, exec_params=config.exec_params) - logs = parallel( - run_func( - cfg=get_config( - config=config, - ), - exec_params=config.exec_params, - subject=subject, - session=session, - ) - for subject in get_subjects(config) - for session in get_sessions(config) - ) - save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py b/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py deleted file mode 100644 index 65fc27b70..000000000 --- a/mne_bids_pipeline/steps/preprocessing/_07b_apply_ssp.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Apply SSP projections and obtain the cleaned epochs. - -Blinks and ECG artifacts are automatically detected and the corresponding SSP -projections components are removed from the data. - -""" - -from types import SimpleNamespace -from typing import Optional - -import mne -from mne_bids import BIDSPath - -from ..._config_utils import ( - get_sessions, - get_subjects, - _bids_kwargs, -) -from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, _update_for_splits, save_logs, _prep_out_files -from ..._parallel import parallel_func, get_parallel_backend - - -def get_input_fnames_apply_ssp( - *, - cfg: SimpleNamespace, - subject: str, - session: Optional[str], -) -> dict: - bids_basename = BIDSPath( - subject=subject, - session=session, - task=cfg.task, - acquisition=cfg.acq, - recording=cfg.rec, - space=cfg.space, - datatype=cfg.datatype, - root=cfg.deriv_root, - extension=".fif", - check=False, - ) - in_files = dict() - in_files["epochs"] = bids_basename.copy().update(suffix="epo", check=False) - _update_for_splits(in_files, "epochs", single=True) - in_files["proj"] = bids_basename.copy().update(suffix="proj", check=False) - return in_files - - -@failsafe_run( - get_input_fnames=get_input_fnames_apply_ssp, -) -def apply_ssp( - *, - cfg: SimpleNamespace, - exec_params: SimpleNamespace, - subject: str, - session: Optional[str], - in_files: dict, -) -> dict: - # load epochs to reject ICA components - # compute SSP on first run of raw - out_files = dict() - out_files["epochs"] = ( - in_files["epochs"].copy().update(processing="ssp", split=None, check=False) - ) - msg = f"Input epochs: {in_files['epochs'].basename}" - logger.info(**gen_log_kwargs(message=msg)) - msg = f'Input SSP: {in_files["proj"].basename}' - logger.info(**gen_log_kwargs(message=msg)) - msg = f"Output: {out_files['epochs'].basename}" - logger.info(**gen_log_kwargs(message=msg)) - epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) - projs = mne.read_proj(in_files.pop("proj")) - epochs_cleaned = epochs.copy().add_proj(projs).apply_proj() - epochs_cleaned.save( - out_files["epochs"], - overwrite=True, - split_naming="bids", - split_size=cfg._epochs_split_size, - ) - _update_for_splits(out_files, "epochs") - assert len(in_files) == 0, in_files.keys() - return _prep_out_files(exec_params=exec_params, out_files=out_files) - - -def get_config( - *, - config: SimpleNamespace, -) -> SimpleNamespace: - cfg = SimpleNamespace( - _epochs_split_size=config._epochs_split_size, - **_bids_kwargs(config=config), - ) - return cfg - - -def main(*, config: SimpleNamespace) -> None: - """Apply ssp.""" - if not config.spatial_filter == "ssp": - msg = "Skipping …" - logger.info(**gen_log_kwargs(message=msg, emoji="skip")) - return - - with get_parallel_backend(config.exec_params): - parallel, run_func = parallel_func(apply_ssp, exec_params=config.exec_params) - logs = parallel( - run_func( - cfg=get_config( - config=config, - ), - exec_params=config.exec_params, - subject=subject, - session=session, - ) - for subject in get_subjects(config) - for session in get_sessions(config) - ) - save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py b/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py new file mode 100644 index 000000000..be3b5ac44 --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py @@ -0,0 +1,304 @@ +"""Apply ICA. + +!! If you manually add components to remove, make sure you did not re-run the ICA in +the meantime. Otherwise (especially if the random state was not set, or you used a +different machine) the component order might differ. +""" + +from types import SimpleNamespace + +import mne +import pandas as pd +from mne.preprocessing import read_ica +from mne_bids import BIDSPath + +from mne_bids_pipeline._config_utils import get_runs_tasks, get_subjects_sessions +from mne_bids_pipeline._import_data import _get_run_rest_noise_path, _import_data_kwargs +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _add_raw, _open_report +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, +) +from mne_bids_pipeline.typing import InFilesT, OutFilesT + + +def _ica_paths( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> InFilesT: + bids_basename = BIDSPath( + subject=subject, + session=session, + task=cfg.task, + acquisition=cfg.acq, + recording=cfg.rec, + space=cfg.space, + datatype=cfg.datatype, + root=cfg.deriv_root, + check=False, + ) + in_files = dict() + in_files["ica"] = bids_basename.copy().update( + processing="ica", + suffix="ica", + extension=".fif", + ) + in_files["components"] = bids_basename.copy().update( + processing="ica", suffix="components", extension=".tsv" + ) + return in_files + + +def _read_ica_and_exclude( + in_files: InFilesT, +) -> mne.preprocessing.ICA: + ica = read_ica(fname=in_files.pop("ica")) + tsv_data = pd.read_csv(in_files.pop("components"), sep="\t") + ica.exclude = tsv_data.loc[tsv_data["status"] == "bad", "component"].to_list() + return ica + + +def get_input_fnames_apply_ica_epochs( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> InFilesT: + in_files = _ica_paths(cfg=cfg, subject=subject, session=session) + in_files["epochs"] = ( + in_files["ica"] + .copy() + .update( + suffix="epo", + extension=".fif", + processing=None, + ) + ) + _update_for_splits(in_files, "epochs", single=True) + return in_files + + +def get_input_fnames_apply_ica_raw( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, +) -> InFilesT: + in_files = _get_run_rest_noise_path( + cfg=cfg, + subject=subject, + session=session, + run=run, + task=task, + kind="filt", + mf_reference_run=cfg.mf_reference_run, + ) + assert len(in_files) + in_files.update(_ica_paths(cfg=cfg, subject=subject, session=session)) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_apply_ica_epochs, +) +def apply_ica_epochs( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: InFilesT, +) -> OutFilesT: + out_files = dict() + out_files["epochs"] = in_files["epochs"].copy().update(processing="ica", split=None) + + title = f"ICA artifact removal – sub-{subject}" + if session is not None: + title += f", ses-{session}" + if cfg.task is not None: + title += f", task-{cfg.task}" + + # Load ICA. + msg = f"Reading ICA: {in_files['ica']}" + logger.debug(**gen_log_kwargs(message=msg)) + ica = _read_ica_and_exclude(in_files) + + # Load epochs. + msg = f"Input: {in_files['epochs'].basename}" + logger.info(**gen_log_kwargs(message=msg)) + msg = f"Output: {out_files['epochs'].basename}" + logger.info(**gen_log_kwargs(message=msg)) + + epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) + + # Now actually reject the components. + msg = ( + f"Rejecting ICs with the following indices: " + f"{', '.join([str(i) for i in ica.exclude])}" + ) + logger.info(**gen_log_kwargs(message=msg)) + epochs_cleaned = ica.apply(epochs.copy()) # Copy b/c works in-place! + + msg = f"Saving {len(epochs)} reconstructed epochs after ICA." + logger.info(**gen_log_kwargs(message=msg)) + epochs_cleaned.save( + out_files["epochs"], + overwrite=True, + split_naming="bids", + split_size=cfg._epochs_split_size, + ) + _update_for_splits(out_files, "epochs") + assert len(in_files) == 0, in_files.keys() + + # Report + kwargs = dict() + if ica.exclude: + msg = "Adding ICA to report." + else: + msg = "Skipping ICA addition to report, no components marked as bad." + kwargs["emoji"] = "skip" + logger.info(**gen_log_kwargs(message=msg, **kwargs)) + if ica.exclude: + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + ) as report: + report.add_ica( + ica=ica, + title="ICA: removals", + inst=epochs, + picks=ica.exclude, + # TODO upstream + # captions=f'Evoked response (across all epochs) ' + # f'before and after ICA ' + # f'({len(ica.exclude)} ICs removed)' + replace=True, + n_jobs=1, # avoid automatic parallelization + ) + + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +@failsafe_run( + get_input_fnames=get_input_fnames_apply_ica_raw, +) +def apply_ica_raw( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, + in_files: InFilesT, +) -> OutFilesT: + ica = _read_ica_and_exclude(in_files) + in_key = list(in_files)[0] + assert in_key.startswith("raw"), in_key + raw_fname = in_files.pop(in_key) + assert len(in_files) == 0, in_files + out_files = dict() + out_files[in_key] = raw_fname.copy().update(processing="clean", split=None) + msg = f"Writing {out_files[in_key].basename} …" + logger.info(**gen_log_kwargs(message=msg)) + raw = mne.io.read_raw_fif(raw_fname, preload=True) + ica.apply(raw) + raw.save(out_files[in_key], overwrite=True, split_size=cfg._raw_split_size) + _update_for_splits(out_files, in_key) + # Report + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) as report: + msg = "Adding cleaned raw data to report" + logger.info(**gen_log_kwargs(message=msg)) + _add_raw( + cfg=cfg, + report=report, + bids_path_in=out_files[in_key], + title="Raw (clean)", + tags=("clean",), + raw=raw, + ) + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, +) -> SimpleNamespace: + cfg = SimpleNamespace( + baseline=config.baseline, + ica_reject=config.ica_reject, + processing="filt" if config.regress_artifact is None else "regress", + _epochs_split_size=config._epochs_split_size, + **_import_data_kwargs(config=config, subject=subject), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Apply ICA.""" + if not config.spatial_filter == "ica": + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + # Epochs + parallel, run_func = parallel_func( + apply_ica_epochs, exec_params=config.exec_params + ) + logs = parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions + ) + # Raw + parallel, run_func = parallel_func( + apply_ica_raw, exec_params=config.exec_params + ) + logs += parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions + for run, task in get_runs_tasks( + config=config, + subject=subject, + session=session, + ) + ) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py b/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py new file mode 100644 index 000000000..6e8d2c26e --- /dev/null +++ b/mne_bids_pipeline/steps/preprocessing/_08b_apply_ssp.py @@ -0,0 +1,213 @@ +"""Apply SSP. + +Blinks and ECG artifacts are automatically detected and the corresponding SSP +projections components are removed from the data. +""" + +from types import SimpleNamespace + +import mne + +from mne_bids_pipeline._config_utils import ( + _proj_path, + get_runs_tasks, + get_subjects_sessions, +) +from mne_bids_pipeline._import_data import _get_run_rest_noise_path, _import_data_kwargs +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _add_raw, _open_report +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, +) +from mne_bids_pipeline.typing import InFilesT, OutFilesT + + +def get_input_fnames_apply_ssp_epochs( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, +) -> InFilesT: + in_files = dict() + in_files["proj"] = _proj_path(cfg=cfg, subject=subject, session=session) + in_files["epochs"] = in_files["proj"].copy().update(suffix="epo", check=False) + _update_for_splits(in_files, "epochs", single=True) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_apply_ssp_epochs, +) +def apply_ssp_epochs( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + in_files: InFilesT, +) -> OutFilesT: + out_files = dict() + out_files["epochs"] = ( + in_files["epochs"].copy().update(processing="ssp", split=None, check=False) + ) + msg = f"Input epochs: {in_files['epochs'].basename}" + logger.info(**gen_log_kwargs(message=msg)) + msg = f"Input SSP: {in_files['proj'].basename}" + logger.info(**gen_log_kwargs(message=msg)) + msg = f"Output: {out_files['epochs'].basename}" + logger.info(**gen_log_kwargs(message=msg)) + epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) + projs = mne.read_proj(in_files.pop("proj")) + epochs_cleaned = epochs.copy().add_proj(projs).apply_proj() + + msg = f"Saving {len(epochs_cleaned)} reconstructed epochs after SSP." + logger.info(**gen_log_kwargs(message=msg)) + + epochs_cleaned.save( + out_files["epochs"], + overwrite=True, + split_naming="bids", + split_size=cfg._epochs_split_size, + ) + _update_for_splits(out_files, "epochs") + assert len(in_files) == 0, in_files.keys() + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_input_fnames_apply_ssp_raw( + *, + cfg: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, +) -> InFilesT: + in_files = _get_run_rest_noise_path( + cfg=cfg, + subject=subject, + session=session, + run=run, + task=task, + kind="filt", + mf_reference_run=cfg.mf_reference_run, + ) + assert len(in_files) + in_files["proj"] = _proj_path(cfg=cfg, subject=subject, session=session) + return in_files + + +@failsafe_run( + get_input_fnames=get_input_fnames_apply_ssp_raw, +) +def apply_ssp_raw( + *, + cfg: SimpleNamespace, + exec_params: SimpleNamespace, + subject: str, + session: str | None, + run: str, + task: str | None, + in_files: InFilesT, +) -> OutFilesT: + projs = mne.read_proj(in_files.pop("proj")) + in_key = list(in_files.keys())[0] + assert in_key.startswith("raw"), in_key + raw_fname = in_files.pop(in_key) + assert len(in_files) == 0, in_files.keys() + raw = mne.io.read_raw_fif(raw_fname) + raw.add_proj(projs) + out_files = dict() + out_files[in_key] = raw_fname.copy().update(processing="clean", split=None) + msg = f"Writing {out_files[in_key].basename} …" + logger.info(**gen_log_kwargs(message=msg)) + raw.save(out_files[in_key], overwrite=True, split_size=cfg._raw_split_size) + _update_for_splits(out_files, in_key) + # Report + with _open_report( + cfg=cfg, + exec_params=exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) as report: + msg = "Adding cleaned raw data to report" + logger.info(**gen_log_kwargs(message=msg)) + _add_raw( + cfg=cfg, + report=report, + bids_path_in=out_files[in_key], + title="Raw (clean)", + tags=("clean",), + raw=raw, + ) + return _prep_out_files(exec_params=exec_params, out_files=out_files) + + +def get_config( + *, + config: SimpleNamespace, + subject: str, +) -> SimpleNamespace: + cfg = SimpleNamespace( + processing="filt" if config.regress_artifact is None else "regress", + _epochs_split_size=config._epochs_split_size, + **_import_data_kwargs(config=config, subject=subject), + ) + return cfg + + +def main(*, config: SimpleNamespace) -> None: + """Apply ssp.""" + if not config.spatial_filter == "ssp": + msg = "Skipping …" + logger.info(**gen_log_kwargs(message=msg, emoji="skip")) + return + + with get_parallel_backend(config.exec_params): + # Epochs + parallel, run_func = parallel_func( + apply_ssp_epochs, exec_params=config.exec_params + ) + logs = parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + ) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions + ) + # Raw + parallel, run_func = parallel_func( + apply_ssp_raw, exec_params=config.exec_params + ) + logs += parallel( + run_func( + cfg=get_config( + config=config, + subject=subject, + ), + exec_params=config.exec_params, + subject=subject, + session=session, + run=run, + task=task, + ) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions + for run, task in get_runs_tasks( + config=config, + subject=subject, + session=session, + ) + ) + save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py b/mne_bids_pipeline/steps/preprocessing/_09_ptp_reject.py similarity index 74% rename from mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py rename to mne_bids_pipeline/steps/preprocessing/_09_ptp_reject.py index 3c02ad91c..2e5390f62 100644 --- a/mne_bids_pipeline/steps/preprocessing/_08_ptp_reject.py +++ b/mne_bids_pipeline/steps/preprocessing/_09_ptp_reject.py @@ -1,6 +1,6 @@ -"""Remove epochs based on peak-to-peak (PTP) amplitudes. +"""Remove epochs based on PTP amplitudes. -Epochs containing peak-to-peak above the thresholds defined +Epochs containing peak-to-peak (PTP) above the thresholds defined in the 'reject' parameter are removed from the data. This step will drop epochs containing non-biological artifacts @@ -9,32 +9,34 @@ """ from types import SimpleNamespace -from typing import Optional -import numpy as np import autoreject - import mne +import numpy as np from mne_bids import BIDSPath -from ..._config_utils import ( - get_sessions, - get_subjects, - _bids_kwargs, +from mne_bids_pipeline._config_utils import _bids_kwargs, get_subjects_sessions +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._reject import _get_reject +from mne_bids_pipeline._report import _open_report +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, ) -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._reject import _get_reject -from ..._report import _open_report -from ..._run import failsafe_run, _update_for_splits, save_logs, _prep_out_files +from mne_bids_pipeline.typing import InFilesT, OutFilesT + +from ._07_make_epochs import _add_epochs_image_kwargs def get_input_fnames_drop_ptp( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesT: bids_path = BIDSPath( subject=subject, session=session, @@ -63,9 +65,11 @@ def drop_ptp( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesT, +) -> OutFilesT: + import matplotlib.pyplot as plt + out_files = dict() out_files["epochs"] = ( in_files["epochs"] @@ -75,16 +79,19 @@ def drop_ptp( split=None, ) ) - msg = f'Input: {in_files["epochs"].basename}' + msg = f"Input: {in_files['epochs'].basename}" logger.info(**gen_log_kwargs(message=msg)) - msg = f'Output: {out_files["epochs"].basename}' + msg = f"Output: {out_files['epochs'].basename}" logger.info(**gen_log_kwargs(message=msg)) # Get rejection parameters and drop bad epochs epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) if cfg.reject == "autoreject_local": - msg = "Using autoreject to find and repair bad epochs" + msg = ( + "Using autoreject to find and repair bad epochs (interpolating bad " + "segments)" + ) logger.info(**gen_log_kwargs(message=msg)) ar = autoreject.AutoReject( @@ -140,14 +147,18 @@ def drop_ptp( logger.info(**gen_log_kwargs(message=msg)) reject[ch_type] = threshold - msg = f"Using PTP rejection thresholds: {reject}" - logger.info(**gen_log_kwargs(message=msg)) - n_epochs_before_reject = len(epochs) epochs.reject_tmin = cfg.reject_tmin epochs.reject_tmax = cfg.reject_tmax epochs.drop_bad(reject=reject) n_epochs_after_reject = len(epochs) + n_epochs_rejected = n_epochs_before_reject - n_epochs_after_reject + + msg = ( + f"Removed {n_epochs_rejected} of {n_epochs_before_reject} epochs via PTP " + f"rejection thresholds: {reject}" + ) + logger.info(**gen_log_kwargs(message=msg)) if 0 < n_epochs_after_reject < 0.5 * n_epochs_before_reject: msg = ( @@ -165,7 +176,7 @@ def drop_ptp( f"No epochs remaining after {rejection_type} rejection. Cannot continue." ) - msg = "Saving cleaned, baseline-corrected epochs …" + msg = f"Saving {n_epochs_after_reject} cleaned, baseline-corrected epochs …" epochs.apply_baseline(cfg.baseline) epochs.save( @@ -181,10 +192,10 @@ def drop_ptp( msg = "Adding cleaned epochs to report." logger.info(**gen_log_kwargs(message=msg)) # Add PSD plots for 30s of data or all epochs if we have less available - if len(epochs) * (epochs.tmax - epochs.tmin) < 30: - psd = True - else: - psd = 30 + psd = True if len(epochs) * (epochs.tmax - epochs.tmin) < 30 else 30.0 + tags = ("epochs", "clean") + kind = cfg.reject if isinstance(cfg.reject, str) else "Rejection" + title = "Epochs: after cleaning" with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session ) as report: @@ -195,23 +206,34 @@ def drop_ptp( f"{ar.n_interpolate_} channels were bad (cross-validated n_interpolate " f"limit; excluding globally bad and non-data channels, shown in white)." ) + fig = reject_log.plot(orientation="horizontal", aspect="auto", show=False) report.add_figure( - fig=reject_log.plot( - orientation="horizontal", aspect="auto", show=False - ), - title="Epochs: Autoreject cleaning", + fig=fig, + title=f"{kind} cleaning", caption=caption, - tags=("epochs", "autoreject"), + section=title, + tags=tags, replace=True, ) + plt.close(fig) del caption + else: + report.add_html( + html=f"{reject}", + title=f"{kind} thresholds", + section=title, + replace=True, + tags=tags, + ) report.add_epochs( epochs=epochs, - title="Epochs: after cleaning", + title=title, psd=psd, drop_log_ignore=(), + tags=tags, replace=True, + **_add_epochs_image_kwargs(cfg=cfg), ) return _prep_out_files(exec_params=exec_params, out_files=out_files) @@ -231,6 +253,7 @@ def get_config( random_state=config.random_state, ch_types=config.ch_types, _epochs_split_size=config._epochs_split_size, + report_add_epochs_image_kwargs=config.report_add_epochs_image_kwargs, **_bids_kwargs(config=config), ) return cfg @@ -250,7 +273,7 @@ def main(*, config: SimpleNamespace) -> None: subject=subject, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/preprocessing/__init__.py b/mne_bids_pipeline/steps/preprocessing/__init__.py index 95637ecab..f9072617c 100644 --- a/mne_bids_pipeline/steps/preprocessing/__init__.py +++ b/mne_bids_pipeline/steps/preprocessing/__init__.py @@ -1,25 +1,31 @@ """Preprocessing.""" -from . import _01_data_quality -from . import _02_head_pos -from . import _03_maxfilter -from . import _04_frequency_filter -from . import _05_make_epochs -from . import _06a_run_ica -from . import _06b_run_ssp -from . import _07a_apply_ica -from . import _07b_apply_ssp -from . import _08_ptp_reject +from . import ( + _01_data_quality, + _02_head_pos, + _03_maxfilter, + _04_frequency_filter, + _05_regress_artifact, + _06a1_fit_ica, + _06a2_find_ica_artifacts, + _06b_run_ssp, + _07_make_epochs, + _08a_apply_ica, + _08b_apply_ssp, + _09_ptp_reject, +) _STEPS = ( _01_data_quality, _02_head_pos, _03_maxfilter, _04_frequency_filter, - _05_make_epochs, - _06a_run_ica, + _05_regress_artifact, + _06a1_fit_ica, + _06a2_find_ica_artifacts, _06b_run_ssp, - _07a_apply_ica, - _07b_apply_ssp, - _08_ptp_reject, + _07_make_epochs, + _08a_apply_ica, + _08b_apply_ssp, + _09_ptp_reject, ) diff --git a/mne_bids_pipeline/steps/sensor/_01_make_evoked.py b/mne_bids_pipeline/steps/sensor/_01_make_evoked.py index 2ec0ea714..eb4e8c71e 100644 --- a/mne_bids_pipeline/steps/sensor/_01_make_evoked.py +++ b/mne_bids_pipeline/steps/sensor/_01_make_evoked.py @@ -1,37 +1,37 @@ """Extract evoked data for each condition.""" from types import SimpleNamespace -from typing import Optional import mne from mne_bids import BIDSPath -from ..._config_utils import ( - get_sessions, - get_subjects, - get_all_contrasts, +from mne_bids_pipeline._config_utils import ( _bids_kwargs, - _restrict_analyze_channels, _pl, + _restrict_analyze_channels, + get_all_contrasts, + get_eeg_reference, + get_subjects_sessions, ) -from ..._logging import gen_log_kwargs, logger -from ..._parallel import parallel_func, get_parallel_backend -from ..._report import _open_report, _sanitize_cond_tag, _all_conditions -from ..._run import ( - failsafe_run, - save_logs, - _sanitize_callable, +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _all_conditions, _open_report, _sanitize_cond_tag +from mne_bids_pipeline._run import ( _prep_out_files, + _sanitize_callable, _update_for_splits, + failsafe_run, + save_logs, ) +from mne_bids_pipeline.typing import InFilesT, OutFilesT def get_input_fnames_evoked( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesT: fname_epochs = BIDSPath( subject=subject, session=session, @@ -61,9 +61,9 @@ def run_evoked( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesT, +) -> OutFilesT: out_files = dict() out_files["evoked"] = ( in_files["epochs"] @@ -76,9 +76,9 @@ def run_evoked( ) ) - msg = f'Input: {in_files["epochs"].basename}' + msg = f"Input: {in_files['epochs'].basename}" logger.info(**gen_log_kwargs(message=msg)) - msg = f'Output: {out_files["evoked"].basename}' + msg = f"Output: {out_files['evoked'].basename}" logger.info(**gen_log_kwargs(message=msg)) epochs = mne.read_epochs(in_files.pop("epochs"), preload=True) @@ -130,7 +130,7 @@ def run_evoked( for condition, evoked in all_evoked.items(): _restrict_analyze_channels(evoked, cfg) - tags = ("evoked", _sanitize_cond_tag(condition)) + tags: tuple[str, ...] = ("evoked", _sanitize_cond_tag(condition)) if condition in cfg.conditions: title = f"Condition: {condition}" else: # It's a contrast of two conditions. @@ -172,6 +172,7 @@ def get_config( contrasts=get_all_contrasts(config), noise_cov=_sanitize_callable(config.noise_cov), analyze_channels=config.analyze_channels, + eeg_reference=get_eeg_reference(config), ch_types=config.ch_types, report_evoked_n_time_points=config.report_evoked_n_time_points, **_bids_kwargs(config=config), @@ -197,7 +198,7 @@ def main(*, config: SimpleNamespace) -> None: subject=subject, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py b/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py index d1d8157a1..209a3d4ec 100644 --- a/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py +++ b/mne_bids_pipeline/steps/sensor/_02_decoding_full_epochs.py @@ -10,48 +10,51 @@ import os.path as op from types import SimpleNamespace -from typing import Optional +import mne import numpy as np import pandas as pd -from scipy.io import savemat, loadmat - -from sklearn.model_selection import cross_val_score -from sklearn.pipeline import make_pipeline -from sklearn.model_selection import StratifiedKFold - -import mne -from mne.decoding import Scaler, Vectorizer +from mne.decoding import Vectorizer from mne_bids import BIDSPath +from scipy.io import loadmat, savemat +from sklearn.model_selection import StratifiedKFold, cross_val_score +from sklearn.pipeline import make_pipeline -from ..._config_utils import ( - get_sessions, - get_subjects, - get_eeg_reference, - get_decoding_contrasts, +from mne_bids_pipeline._config_utils import ( _bids_kwargs, + _get_decoding_proc, _restrict_analyze_channels, + get_decoding_contrasts, + get_eeg_reference, + get_subjects_sessions, ) -from ..._logging import gen_log_kwargs, logger -from ..._decoding import LogReg -from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run, save_logs, _prep_out_files, _update_for_splits -from ..._report import ( - _open_report, +from mne_bids_pipeline._decoding import LogReg, _decoding_preproc_steps +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import ( _contrasts_to_names, + _open_report, _plot_full_epochs_decoding_scores, _sanitize_cond_tag, ) +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, +) +from mne_bids_pipeline.typing import InFilesT, OutFilesT def get_input_fnames_epochs_decoding( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, condition1: str, condition2: str, -) -> dict: +) -> InFilesT: + proc = _get_decoding_proc(config=cfg) fname_epochs = BIDSPath( subject=subject, session=session, @@ -60,7 +63,7 @@ def get_input_fnames_epochs_decoding( run=None, recording=cfg.rec, space=cfg.space, - processing="clean", + processing=proc, suffix="epo", extension=".fif", datatype=cfg.datatype, @@ -81,11 +84,11 @@ def run_epochs_decoding( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, condition1: str, condition2: str, - in_files: dict, -) -> dict: + in_files: InFilesT, +) -> OutFilesT: import matplotlib.pyplot as plt msg = f"Contrasting conditions: {condition1} – {condition2}" @@ -114,6 +117,16 @@ def run_epochs_decoding( # Crop to the desired analysis interval. Do it only after the concatenation to work # around https://github.com/mne-tools/mne-python/issues/12153 epochs.crop(cfg.decoding_epochs_tmin, cfg.decoding_epochs_tmax) + # omit bad channels and reference MEG sensors + pick_idx = mne.pick_types( + epochs.info, meg=True, eeg=True, ref_meg=False, exclude="bads" + ) + epochs.pick(pick_idx) + pre_steps = _decoding_preproc_steps( + subject=subject, + session=session, + epochs=epochs, + ) n_cond1 = len(epochs[epochs_conds[0]]) n_cond2 = len(epochs[epochs_conds[1]]) @@ -121,9 +134,9 @@ def run_epochs_decoding( X = epochs.get_data() y = np.r_[np.ones(n_cond1), np.zeros(n_cond2)] - classification_pipeline = make_pipeline( - Scaler(scalings="mean"), - Vectorizer(), # So we can pass the data to scikit-learn + clf = make_pipeline( + *pre_steps, + Vectorizer(), LogReg( solver="liblinear", # much faster than the default random_state=cfg.random_state, @@ -139,7 +152,13 @@ def run_epochs_decoding( n_splits=cfg.decoding_n_splits, ) scores = cross_val_score( - estimator=classification_pipeline, X=X, y=y, cv=cv, scoring="roc_auc", n_jobs=1 + estimator=clf, + X=X, + y=y, + cv=cv, + scoring="roc_auc", + n_jobs=1, + error_score="raise", ) # Save the scores @@ -189,7 +208,7 @@ def run_epochs_decoding( all_contrasts.append(contrast) del fname_decoding, processing, a_vs_b, decoding_data - fig, caption = _plot_full_epochs_decoding_scores( + fig, caption, _ = _plot_full_epochs_decoding_scores( contrast_names=_contrasts_to_names(all_contrasts), scores=all_decoding_scores, metric=cfg.decoding_metric, @@ -204,7 +223,7 @@ def run_epochs_decoding( "contrast", "decoding", *[ - f"{_sanitize_cond_tag(cond_1)}–" f"{_sanitize_cond_tag(cond_2)}" + f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}" for cond_1, cond_2 in cfg.contrasts ], ), @@ -225,6 +244,7 @@ def get_config( conditions=config.conditions, contrasts=get_decoding_contrasts(config), decode=config.decode, + decoding_which_epochs=config.decoding_which_epochs, decoding_metric=config.decoding_metric, decoding_epochs_tmin=config.decoding_epochs_tmin, decoding_epochs_tmax=config.decoding_epochs_tmax, @@ -263,8 +283,8 @@ def main(*, config: SimpleNamespace) -> None: condition2=cond_2, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions for (cond_1, cond_2) in get_decoding_contrasts(config) ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py index 02ec357dd..781ebc13c 100644 --- a/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py +++ b/mne_bids_pipeline/steps/sensor/_03_decoding_time_by_time.py @@ -13,50 +13,56 @@ import os.path as op from types import SimpleNamespace -from typing import Optional +import mne import numpy as np import pandas as pd -from scipy.io import savemat, loadmat - -import mne -from mne.decoding import GeneralizingEstimator, SlidingEstimator, cross_val_multiscore - +from mne.decoding import ( + GeneralizingEstimator, + SlidingEstimator, + Vectorizer, + cross_val_multiscore, +) from mne_bids import BIDSPath - -from sklearn.preprocessing import StandardScaler -from sklearn.pipeline import make_pipeline +from scipy.io import loadmat, savemat from sklearn.model_selection import StratifiedKFold +from sklearn.pipeline import make_pipeline -from ..._config_utils import ( - get_sessions, - get_subjects, - get_eeg_reference, - get_decoding_contrasts, +from mne_bids_pipeline._config_utils import ( _bids_kwargs, + _get_decoding_proc, _restrict_analyze_channels, + get_decoding_contrasts, + get_eeg_reference, + get_subjects_sessions, ) -from ..._decoding import LogReg -from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, save_logs, _prep_out_files, _update_for_splits -from ..._parallel import get_parallel_backend, get_parallel_backend_name -from ..._report import ( +from mne_bids_pipeline._decoding import LogReg, _decoding_preproc_steps +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, get_parallel_backend_name +from mne_bids_pipeline._report import ( _open_report, _plot_decoding_time_generalization, - _sanitize_cond_tag, _plot_time_by_time_decoding_scores, + _sanitize_cond_tag, +) +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, ) +from mne_bids_pipeline.typing import InFilesT, OutFilesT def get_input_fnames_time_decoding( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, condition1: str, condition2: str, -) -> dict: - # TODO: Shouldn't this at least use the PTP-rejected epochs if available? +) -> InFilesT: + proc = _get_decoding_proc(config=cfg) fname_epochs = BIDSPath( subject=subject, session=session, @@ -65,7 +71,7 @@ def get_input_fnames_time_decoding( run=None, recording=cfg.rec, space=cfg.space, - processing="clean", + processing=proc, suffix="epo", extension=".fif", datatype=cfg.datatype, @@ -86,11 +92,11 @@ def run_time_decoding( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, condition1: str, condition2: str, - in_files: dict, -) -> dict: + in_files: InFilesT, +) -> OutFilesT: import matplotlib.pyplot as plt if cfg.decoding_time_generalization: @@ -124,6 +130,25 @@ def run_time_decoding( epochs = mne.concatenate_epochs([epochs[epochs_conds[0]], epochs[epochs_conds[1]]]) n_cond1 = len(epochs[epochs_conds[0]]) n_cond2 = len(epochs[epochs_conds[1]]) + pick_idx = mne.pick_types( + epochs.info, meg=True, eeg=True, ref_meg=False, exclude="bads" + ) + epochs.pick(pick_idx) + # We can't use the full rank here because the number of samples can just be the + # number of epochs (which can be fewer than the number of channels) + pre_steps = _decoding_preproc_steps( + subject=subject, + session=session, + epochs=epochs, + pca=False, + ) + # At some point we might want to enable this, but it's really slow and arguably + # unnecessary so let's omit it for now: + # pre_steps.append( + # mne.decoding.UnsupervisedSpatialFilter( + # PCA(n_components=0.999, whiten=True), + # ) + # ) decim = cfg.decoding_time_generalization_decim if cfg.decoding_time_generalization and decim > 1: @@ -135,7 +160,8 @@ def run_time_decoding( verbose = get_parallel_backend_name(exec_params=exec_params) != "dask" with get_parallel_backend(exec_params): clf = make_pipeline( - StandardScaler(), + *pre_steps, + Vectorizer(), LogReg( solver="liblinear", # much faster than the default random_state=cfg.random_state, @@ -226,8 +252,7 @@ def run_time_decoding( "epochs", "contrast", "decoding", - f"{_sanitize_cond_tag(contrast[0])}–" - f"{_sanitize_cond_tag(contrast[1])}", + f"{_sanitize_cond_tag(contrast[0])}–{_sanitize_cond_tag(contrast[1])}", ) processing = f"{a_vs_b}+TimeByTime+{cfg.decoding_metric}" @@ -299,6 +324,7 @@ def get_config( conditions=config.conditions, contrasts=get_decoding_contrasts(config), decode=config.decode, + decoding_which_epochs=config.decoding_which_epochs, decoding_metric=config.decoding_metric, decoding_n_splits=config.decoding_n_splits, decoding_time_generalization=config.decoding_time_generalization, @@ -337,8 +363,8 @@ def main(*, config: SimpleNamespace) -> None: condition2=cond_2, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions for cond_1, cond_2 in get_decoding_contrasts(config) ] save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py index e1e7b440c..c2bc96c0f 100644 --- a/mne_bids_pipeline/steps/sensor/_04_time_frequency.py +++ b/mne_bids_pipeline/steps/sensor/_04_time_frequency.py @@ -5,34 +5,36 @@ """ from types import SimpleNamespace -from typing import Optional - -import numpy as np import mne - +import numpy as np from mne_bids import BIDSPath -from ..._config_utils import ( - get_sessions, - get_subjects, - get_eeg_reference, - sanitize_cond_name, +from mne_bids_pipeline._config_utils import ( _bids_kwargs, _restrict_analyze_channels, + get_eeg_reference, + get_subjects_sessions, + sanitize_cond_name, +) +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _open_report, _sanitize_cond_tag +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, ) -from ..._logging import gen_log_kwargs, logger -from ..._run import failsafe_run, save_logs, _prep_out_files, _update_for_splits -from ..._parallel import get_parallel_backend, parallel_func -from ..._report import _open_report, _sanitize_cond_tag +from mne_bids_pipeline.typing import InFilesT, OutFilesT def get_input_fnames_time_frequency( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesT: fname_epochs = BIDSPath( subject=subject, session=session, @@ -62,9 +64,9 @@ def run_time_frequency( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesT, +) -> OutFilesT: import matplotlib.pyplot as plt epochs_path = in_files.pop("epochs") @@ -198,7 +200,7 @@ def main(*, config: SimpleNamespace) -> None: subject=subject, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py index 75b24a854..50bdd0ef9 100644 --- a/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py +++ b/mne_bids_pipeline/steps/sensor/_05_decoding_csp.py @@ -1,50 +1,56 @@ -""" -Decoding based on common spatial patterns (CSP). -""" +"""Decoding based on common spatial patterns (CSP).""" import os.path as op from types import SimpleNamespace -from typing import Dict, Optional, Tuple +import matplotlib.transforms import mne import numpy as np import pandas as pd -import matplotlib.transforms -from mne.decoding import CSP, UnsupervisedSpatialFilter +from mne.decoding import CSP from mne_bids import BIDSPath -from sklearn.decomposition import PCA from sklearn.model_selection import StratifiedKFold, cross_val_score from sklearn.pipeline import make_pipeline -from ..._config_utils import ( - get_sessions, - get_subjects, - get_eeg_reference, - get_decoding_contrasts, +from mne_bids_pipeline._config_utils import ( _bids_kwargs, + _get_decoding_proc, _restrict_analyze_channels, + get_decoding_contrasts, + get_eeg_reference, + get_subjects_sessions, +) +from mne_bids_pipeline._decoding import ( + LogReg, + _decoding_preproc_steps, + _handle_csp_args, ) -from ..._decoding import LogReg, _handle_csp_args -from ..._logging import logger, gen_log_kwargs -from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run, save_logs, _prep_out_files, _update_for_splits -from ..._report import ( +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import ( + _imshow_tf, _open_report, - _sanitize_cond_tag, _plot_full_epochs_decoding_scores, - _imshow_tf, + _sanitize_cond_tag, +) +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, ) +from mne_bids_pipeline.typing import InFilesT, IntArrayT, OutFilesT -def _prepare_labels(*, epochs: mne.BaseEpochs, contrast: Tuple[str, str]) -> np.ndarray: +def _prepare_labels(*, epochs: mne.BaseEpochs, contrast: tuple[str, str]) -> IntArrayT: """Return the projection of the events_id on a boolean vector. This projection is useful in the case of hierarchical events: we project the different events contained in one condition into just one label. - Returns: - -------- + Returns + ------- A boolean numpy array containing the labels. """ epochs_cond_0 = epochs[contrast[0]] @@ -52,7 +58,7 @@ def _prepare_labels(*, epochs: mne.BaseEpochs, contrast: Tuple[str, str]) -> np. epochs_cond_1 = epochs[contrast[1]] event_codes_condition_1 = set(epochs_cond_1.events[:, 2]) - y = epochs.events[:, 2].copy() + y: IntArrayT = epochs.events[:, 2].copy() for i in range(len(y)): if y[i] in event_codes_condition_0 and y[i] in event_codes_condition_1: msg = ( @@ -78,24 +84,23 @@ def _prepare_labels(*, epochs: mne.BaseEpochs, contrast: Tuple[str, str]) -> np. def prepare_epochs_and_y( - *, epochs: mne.BaseEpochs, contrast: Tuple[str, str], cfg, fmin: float, fmax: float -) -> Tuple[mne.BaseEpochs, np.ndarray]: + *, + epochs: mne.BaseEpochs, + contrast: tuple[str, str], + cfg: SimpleNamespace, + fmin: float, + fmax: float, +) -> tuple[mne.BaseEpochs, IntArrayT]: """Band-pass between, sub-select the desired epochs, and prepare y.""" - epochs_filt = epochs.copy().pick(["meg", "eeg"]) - - # We only take mag to speed up computation - # because the information is redundant between grad and mag - if cfg.datatype == "meg" and cfg.use_maxwell_filter: - epochs_filt.pick("mag") - # filtering out the conditions we are not interested in, to ensure here we # have a valid partition between the condition of the contrast. - # + # XXX Hack for handling epochs selection via metadata + # This also makes a copy if contrast[0].startswith("event_name.isin"): - epochs_filt = epochs_filt[f"{contrast[0]} or {contrast[1]}"] + epochs_filt = epochs[f"{contrast[0]} or {contrast[1]}"] else: - epochs_filt = epochs_filt[contrast] + epochs_filt = epochs[contrast] # Filtering is costly, so do it last, after the selection of the channels # and epochs. We know that often the filter will be longer than the signal, @@ -110,9 +115,10 @@ def get_input_fnames_csp( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - contrast: Tuple[str], -) -> dict: + session: str | None, + contrast: tuple[str], +) -> InFilesT: + proc = _get_decoding_proc(config=cfg) fname_epochs = BIDSPath( subject=subject, session=session, @@ -121,7 +127,7 @@ def get_input_fnames_csp( run=None, recording=cfg.rec, space=cfg.space, - processing="clean", + processing=proc, suffix="epo", extension=".fif", datatype=cfg.datatype, @@ -141,9 +147,9 @@ def one_subject_decoding( exec_params: SimpleNamespace, subject: str, session: str, - contrast: Tuple[str, str], - in_files: Dict[str, BIDSPath], -) -> dict: + contrast: tuple[str, str], + in_files: InFilesT, +) -> OutFilesT: """Run one subject. There are two steps in this function: @@ -159,30 +165,27 @@ def one_subject_decoding( bids_path = in_files["epochs"].copy().update(processing=None, split=None) epochs = mne.read_epochs(in_files.pop("epochs")) _restrict_analyze_channels(epochs, cfg) + pick_idx = mne.pick_types( + epochs.info, meg=True, eeg=True, ref_meg=False, exclude="bads" + ) + epochs.pick(pick_idx) if cfg.time_frequency_subtract_evoked: epochs.subtract_evoked() - # Perform rank reduction via PCA. - # - # Select the channel type with the smallest rank. - # Limit it to a maximum of 100. - ranks = mne.compute_rank(inst=epochs, rank="info") - ch_type_smallest_rank = min(ranks, key=ranks.get) - rank = min(ranks[ch_type_smallest_rank], 100) - del ch_type_smallest_rank, ranks - - msg = f"Reducing data dimension via PCA; new rank: {rank}." - logger.info(**gen_log_kwargs(msg)) - pca = UnsupervisedSpatialFilter(PCA(rank), average=False) + preproc_steps = _decoding_preproc_steps( + subject=subject, + session=session, + epochs=epochs, + ) # Classifier csp = CSP( n_components=4, # XXX revisit reg=0.1, # XXX revisit - rank="info", ) clf = make_pipeline( + *preproc_steps, csp, LogReg( solver="liblinear", # much faster than the default @@ -197,7 +200,7 @@ def one_subject_decoding( ) # Loop over frequencies (all time points lumped together) - freq_name_to_bins_map = _handle_csp_args( + freq_name_to_bins_map, time_bins = _handle_csp_args( cfg.decoding_csp_times, cfg.decoding_csp_freqs, cfg.decoding_metric, @@ -229,7 +232,15 @@ def one_subject_decoding( ) del freq_decoding_table_rows - def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=None): + def _fmt_contrast( + cond1: str, + cond2: str, + fmin: float, + fmax: float, + freq_range_name: str, + tmin: float | None = None, + tmax: float | None = None, + ) -> str: msg = ( f"Contrast: {cond1} – {cond2}, " f"{fmin:4.1f}–{fmax:4.1f} Hz ({freq_range_name})" @@ -239,6 +250,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non return msg for idx, row in freq_decoding_table.iterrows(): + assert isinstance(row, pd.Series) fmin = row["f_min"] fmax = row["f_max"] cond1 = row["cond_1"] @@ -256,32 +268,23 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non # Get the data for all time points X = epochs_filt.get_data() - # We apply PCA before running CSP: - # - much faster CSP processing - # - reduced risk of numerical instabilities. - X_pca = pca.fit_transform(X) - del X - cv_scores = cross_val_score( estimator=clf, - X=X_pca, + X=X, y=y, scoring=cfg.decoding_metric, cv=cv, n_jobs=1, + error_score="raise", ) freq_decoding_table.loc[idx, "mean_crossval_score"] = cv_scores.mean() freq_decoding_table.at[idx, "scores"] = cv_scores + del fmin, fmax, cond1, cond2, freq_range_name # Loop over times x frequencies # # Note: We don't support varying time ranges for different frequency # ranges to avoid leaking of information. - time_bins = np.array(cfg.decoding_csp_times) - if time_bins.ndim == 1: - time_bins = np.array(list(zip(time_bins[:-1], time_bins[1:]))) - assert time_bins.ndim == 2 - tf_decoding_table_rows = [] for freq_range_name, freq_bins in freq_name_to_bins_map.items(): @@ -305,13 +308,19 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non } tf_decoding_table_rows.append(row) - tf_decoding_table = pd.concat( - [pd.DataFrame.from_dict(row) for row in tf_decoding_table_rows], - ignore_index=True, - ) + if len(tf_decoding_table_rows): + tf_decoding_table = pd.concat( + [pd.DataFrame.from_dict(row) for row in tf_decoding_table_rows], + ignore_index=True, + ) + else: + tf_decoding_table = pd.DataFrame() del tf_decoding_table_rows for idx, row in tf_decoding_table.iterrows(): + if len(row) == 0: + break # no data + assert isinstance(row, pd.Series) tmin = row["t_min"] tmax = row["t_max"] fmin = row["f_min"] @@ -326,18 +335,16 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non # Crop data to the time window of interest if tmax is not None: # avoid warnings about outside the interval tmax = min(tmax, epochs_filt.times[-1]) - epochs_filt.crop(tmin, tmax) - X = epochs_filt.get_data() - X_pca = pca.transform(X) - del X - + X = epochs_filt.crop(tmin, tmax).get_data() + del epochs_filt cv_scores = cross_val_score( estimator=clf, - X=X_pca, + X=X, y=y, scoring=cfg.decoding_metric, cv=cv, n_jobs=1, + error_score="raise", ) score = cv_scores.mean() tf_decoding_table.loc[idx, "mean_crossval_score"] = score @@ -345,6 +352,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non msg = _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin, tmax) msg += f": {cfg.decoding_metric}={score:0.3f}" logger.info(**gen_log_kwargs(msg)) + del tmin, tmax, fmin, fmax, cond1, cond2, freq_range_name # Write each DataFrame to a different Excel worksheet. a_vs_b = f"{condition1}+{condition2}".replace(op.sep, "") @@ -356,8 +364,10 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non ) with pd.ExcelWriter(fname_results) as w: freq_decoding_table.to_excel(w, sheet_name="CSP Frequency", index=False) - tf_decoding_table.to_excel(w, sheet_name="CSP Time-Frequency", index=False) + if not tf_decoding_table.empty: + tf_decoding_table.to_excel(w, sheet_name="CSP Time-Frequency", index=False) out_files = {"csp-excel": fname_results} + del freq_decoding_table # Report with _open_report( @@ -366,15 +376,6 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non msg = "Adding CSP decoding results to the report." logger.info(**gen_log_kwargs(message=msg)) section = "Decoding: CSP" - freq_name_to_bins_map = _handle_csp_args( - cfg.decoding_csp_times, - cfg.decoding_csp_freqs, - cfg.decoding_metric, - epochs_tmin=cfg.epochs_tmin, - epochs_tmax=cfg.epochs_tmax, - time_frequency_freq_min=cfg.time_frequency_freq_min, - time_frequency_freq_max=cfg.time_frequency_freq_max, - ) all_csp_tf_results = dict() for contrast in cfg.decoding_contrasts: cond_1, cond_2 = contrast @@ -397,14 +398,15 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non csp_freq_results["scores"] = csp_freq_results["scores"].apply( lambda x: np.array(x[1:-1].split(), float) ) - csp_tf_results = pd.read_excel( - fname_decoding, sheet_name="CSP Time-Frequency" - ) - csp_tf_results["scores"] = csp_tf_results["scores"].apply( - lambda x: np.array(x[1:-1].split(), float) - ) - all_csp_tf_results[contrast] = csp_tf_results - del csp_tf_results + if not tf_decoding_table.empty: + csp_tf_results = pd.read_excel( + fname_decoding, sheet_name="CSP Time-Frequency" + ) + csp_tf_results["scores"] = csp_tf_results["scores"].apply( + lambda x: np.array(x[1:-1].split(), float) + ) + all_csp_tf_results[contrast] = csp_tf_results + del csp_tf_results all_decoding_scores = list() contrast_names = list() @@ -419,9 +421,9 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non f_min = float(freq_bin[0]) f_max = float(freq_bin[1]) contrast_names.append( - f"{freq_range_name}\n" f"({f_min:0.1f}-{f_max:0.1f} Hz)" + f"{freq_range_name}\n({f_min:0.1f}-{f_max:0.1f} Hz)" ) - fig, caption = _plot_full_epochs_decoding_scores( + fig, caption, _ = _plot_full_epochs_decoding_scores( contrast_names=contrast_names, scores=all_decoding_scores, metric=cfg.decoding_metric, @@ -452,33 +454,39 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non f"{_sanitize_cond_tag(cond_1)}–{_sanitize_cond_tag(cond_2)}", ) results = all_csp_tf_results[contrast] - mean_crossval_scores = list() - tmin, tmax, fmin, fmax = list(), list(), list(), list() - mean_crossval_scores.extend(results["mean_crossval_score"].ravel()) - tmin.extend(results["t_min"].ravel()) - tmax.extend(results["t_max"].ravel()) - fmin.extend(results["f_min"].ravel()) - fmax.extend(results["f_max"].ravel()) - mean_crossval_scores = np.array(mean_crossval_scores, float) + mean_crossval_scores: list[float] = list() + tmin_list: list[float] = list() + tmax_list: list[float] = list() + fmin_list: list[float] = list() + fmax_list: list[float] = list() + mean_crossval_scores.extend( + results["mean_crossval_score"].to_numpy().ravel().tolist() + ) + tmin_list.extend(results["t_min"].to_numpy().ravel()) + tmax_list.extend(results["t_max"].to_numpy().ravel()) + fmin_list.extend(results["f_min"].to_numpy().ravel()) + fmax_list.extend(results["f_max"].to_numpy().ravel()) + mean_crossval_scores_array = np.array(mean_crossval_scores, float) + del mean_crossval_scores fig, ax = plt.subplots(constrained_layout=True) # XXX Add support for more metrics assert cfg.decoding_metric == "roc_auc" metric = "ROC AUC" vmax = ( max( - np.abs(mean_crossval_scores.min() - 0.5), - np.abs(mean_crossval_scores.max() - 0.5), + np.abs(mean_crossval_scores_array.min() - 0.5), + np.abs(mean_crossval_scores_array.max() - 0.5), ) + 0.5 ) vmin = 0.5 - (vmax - 0.5) img = _imshow_tf( - mean_crossval_scores, + mean_crossval_scores_array, ax, - tmin=tmin, - tmax=tmax, - fmin=fmin, - fmax=fmax, + tmin=np.array(tmin_list, float), + tmax=np.array(tmax_list, float), + fmin=np.array(fmin_list, float), + fmax=np.array(fmax_list, float), vmin=vmin, vmax=vmax, ) @@ -487,7 +495,7 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non ) for freq_range_name, bins in freq_name_to_bins_map.items(): ax.text( - tmin[0], + tmin_list[0], 0.5 * bins[0][0] + 0.5 * bins[-1][1], freq_range_name, transform=offset, @@ -495,8 +503,8 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non va="center", rotation=90, ) - ax.set_xlim([np.min(tmin), np.max(tmax)]) - ax.set_ylim([np.min(fmin), np.max(fmax)]) + ax.set_xlim([np.min(tmin_list), np.max(tmax_list)]) + ax.set_ylim([np.min(fmin_list), np.max(fmax_list)]) ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") cbar = fig.colorbar( @@ -511,13 +519,15 @@ def _fmt_contrast(cond1, cond2, fmin, fmax, freq_range_name, tmin=None, tmax=Non tags=tags, replace=True, ) + plt.close(fig) + del fig, title assert len(in_files) == 0, in_files.keys() return _prep_out_files(exec_params=exec_params, out_files=out_files) def get_config( - *, config: SimpleNamespace, subject: str, session: Optional[str] + *, config: SimpleNamespace, subject: str, session: str | None ) -> SimpleNamespace: cfg = SimpleNamespace( # Data parameters @@ -531,6 +541,7 @@ def get_config( time_frequency_freq_min=config.time_frequency_freq_min, time_frequency_freq_max=config.time_frequency_freq_max, time_frequency_subtract_evoked=config.time_frequency_subtract_evoked, + decoding_which_epochs=config.decoding_which_epochs, decoding_metric=config.decoding_metric, decoding_csp_freqs=config.decoding_csp_freqs, decoding_csp_times=config.decoding_csp_times, @@ -567,8 +578,8 @@ def main(*, config: SimpleNamespace) -> None: session=session, contrast=contrast, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions for contrast in get_decoding_contrasts(config) ) save_logs(logs=logs, config=config) diff --git a/mne_bids_pipeline/steps/sensor/_06_make_cov.py b/mne_bids_pipeline/steps/sensor/_06_make_cov.py index 2cb3b8ebf..d20a1b6ba 100644 --- a/mne_bids_pipeline/steps/sensor/_06_make_cov.py +++ b/mne_bids_pipeline/steps/sensor/_06_make_cov.py @@ -3,38 +3,38 @@ Covariance matrices are computed and saved. """ -from typing import Optional from types import SimpleNamespace import mne from mne_bids import BIDSPath -from ..._config_utils import ( - get_sessions, - get_subjects, - get_noise_cov_bids_path, +from mne_bids_pipeline._config_import import _import_config +from mne_bids_pipeline._config_utils import ( _bids_kwargs, + _restrict_analyze_channels, + get_eeg_reference, + get_noise_cov_bids_path, + get_subjects_sessions, ) -from ..._config_import import _import_config -from ..._config_utils import _restrict_analyze_channels -from ..._logging import gen_log_kwargs, logger -from ..._parallel import get_parallel_backend, parallel_func -from ..._report import _open_report, _sanitize_cond_tag, _all_conditions -from ..._run import ( - failsafe_run, - save_logs, - _sanitize_callable, +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _all_conditions, _open_report, _sanitize_cond_tag +from mne_bids_pipeline._run import ( _prep_out_files, + _sanitize_callable, _update_for_splits, + failsafe_run, + save_logs, ) +from mne_bids_pipeline.typing import InFilesT, OutFilesT def get_input_fnames_cov( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesT: cov_type = _get_cov_type(cfg) in_files = dict() fname_epochs = BIDSPath( @@ -71,15 +71,14 @@ def get_input_fnames_cov( run=None, recording=cfg.rec, space=cfg.space, - processing="filt", + processing="clean", suffix="raw", extension=".fif", datatype=cfg.datatype, root=cfg.deriv_root, check=False, ) - run_type = "resting-state" if cfg.noise_cov == "rest" else "empty-room" - if run_type == "resting-state": + if cfg.noise_cov == "rest": bids_path_raw_noise.task = "rest" else: bids_path_raw_noise.task = "noise" @@ -93,14 +92,14 @@ def get_input_fnames_cov( def compute_cov_from_epochs( *, - tmin: Optional[float], - tmax: Optional[float], + tmin: float | None, + tmax: float | None, cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, - out_files: dict, + session: str | None, + in_files: InFilesT, + out_files: InFilesT, ) -> mne.Covariance: epo_fname = in_files.pop("epochs") @@ -116,7 +115,7 @@ def compute_cov_from_epochs( epochs, tmin=tmin, tmax=tmax, - method="shrunk", + method=cfg.noise_cov_method, rank="info", verbose="error", # TODO: not baseline corrected, maybe problematic? ) @@ -128,21 +127,25 @@ def compute_cov_from_raw( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, - out_files: dict, + session: str | None, + in_files: InFilesT, + out_files: InFilesT, ) -> mne.Covariance: fname_raw = in_files.pop("raw") - run_type = "resting-state" if fname_raw.task == "rest" else "empty-room" - msg = f"Computing regularized covariance based on {run_type} recording." + run_msg = "resting-state" if fname_raw.task == "rest" else "empty-room" + msg = f"Computing regularized covariance based on {run_msg} recording." logger.info(**gen_log_kwargs(message=msg)) msg = f"Input: {fname_raw.basename}" logger.info(**gen_log_kwargs(message=msg)) - msg = f'Output: {out_files["cov"].basename}' + msg = f"Output: {out_files['cov'].basename}" logger.info(**gen_log_kwargs(message=msg)) raw_noise = mne.io.read_raw_fif(fname_raw, preload=True) - cov = mne.compute_raw_covariance(raw_noise, method="shrunk", rank="info") + cov = mne.compute_raw_covariance( + raw_noise, + method=cfg.noise_cov_method, + rank="info", + ) return cov @@ -151,9 +154,9 @@ def retrieve_custom_cov( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, - out_files: dict, + session: str | None, + in_files: InFilesT, + out_files: InFilesT, ) -> mne.Covariance: # This should be the only place we use config.noise_cov (rather than cfg.* # entries) @@ -173,7 +176,7 @@ def retrieve_custom_cov( task=cfg.task, acquisition=cfg.acq, run=None, - processing=cfg.proc, + processing="clean", recording=cfg.rec, space=cfg.space, suffix="ave", @@ -183,9 +186,9 @@ def retrieve_custom_cov( check=False, ) - msg = "Retrieving noise covariance matrix from custom user-supplied " "function" + msg = "Retrieving noise covariance matrix from custom user-supplied function" logger.info(**gen_log_kwargs(message=msg)) - msg = f'Output: {out_files["cov"].basename}' + msg = f"Output: {out_files['cov'].basename}" logger.info(**gen_log_kwargs(message=msg)) cov = config.noise_cov(evoked_bids_path) @@ -193,7 +196,7 @@ def retrieve_custom_cov( return cov -def _get_cov_type(cfg): +def _get_cov_type(cfg: SimpleNamespace) -> str: if cfg.noise_cov == "custom": return "custom" elif cfg.noise_cov == "rest": @@ -212,9 +215,9 @@ def run_covariance( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str] = None, - in_files: dict, -) -> dict: + session: str | None = None, + in_files: InFilesT, +) -> OutFilesT: import matplotlib.pyplot as plt out_files = dict() @@ -222,23 +225,38 @@ def run_covariance( cfg=cfg, subject=subject, session=session ) cov_type = _get_cov_type(cfg) - kwargs = dict( - cfg=cfg, - subject=subject, - session=session, - in_files=in_files, - out_files=out_files, - exec_params=exec_params, - ) fname_info = in_files.pop("report_info") fname_evoked = in_files.pop("evoked", None) if cov_type == "custom": - cov = retrieve_custom_cov(**kwargs) + cov = retrieve_custom_cov( + cfg=cfg, + subject=subject, + session=session, + in_files=in_files, + out_files=out_files, + exec_params=exec_params, + ) elif cov_type == "raw": - cov = compute_cov_from_raw(**kwargs) + cov = compute_cov_from_raw( + cfg=cfg, + subject=subject, + session=session, + in_files=in_files, + out_files=out_files, + exec_params=exec_params, + ) else: tmin, tmax = cfg.noise_cov - cov = compute_cov_from_epochs(tmin=tmin, tmax=tmax, **kwargs) + cov = compute_cov_from_epochs( + tmin=tmin, + tmax=tmax, + cfg=cfg, + subject=subject, + session=session, + in_files=in_files, + out_files=out_files, + exec_params=exec_params, + ) cov.save(out_files["cov"], overwrite=True) # Report @@ -262,7 +280,11 @@ def run_covariance( section = "Noise covariance" for evoked, condition in zip(all_evoked, conditions): _restrict_analyze_channels(evoked, cfg) - tags = ("evoked", "covariance", _sanitize_cond_tag(condition)) + tags: tuple[str, ...] = ( + "evoked", + "covariance", + _sanitize_cond_tag(condition), + ) title = f"Whitening: {condition}" if condition not in cfg.conditions: tags = tags + ("contrast",) @@ -292,6 +314,8 @@ def get_config( conditions=config.conditions, contrasts=config.contrasts, analyze_channels=config.analyze_channels, + eeg_reference=get_eeg_reference(config), + noise_cov_method=config.noise_cov_method, **_bids_kwargs(config=config), ) return cfg @@ -323,7 +347,7 @@ def main(*, config: SimpleNamespace) -> None: subject=subject, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/sensor/_99_group_average.py b/mne_bids_pipeline/steps/sensor/_99_group_average.py index a05a85a96..454fc852f 100644 --- a/mne_bids_pipeline/steps/sensor/_99_group_average.py +++ b/mne_bids_pipeline/steps/sensor/_99_group_average.py @@ -6,52 +6,58 @@ import os import os.path as op from functools import partial -from typing import Optional, List, Tuple from types import SimpleNamespace -from ...typing import TypedDict +import mne import numpy as np import pandas as pd -from scipy.io import loadmat, savemat - -import mne from mne_bids import BIDSPath +from scipy.io import loadmat, savemat -from ..._config_utils import ( - get_sessions, - get_subjects, - get_eeg_reference, - get_decoding_contrasts, +from mne_bids_pipeline._config_utils import ( _bids_kwargs, - _restrict_analyze_channels, _pl, + _restrict_analyze_channels, + get_decoding_contrasts, + get_eeg_reference, + get_sessions, + get_subjects, + get_subjects_given_session, ) -from ..._decoding import _handle_csp_args -from ..._logging import gen_log_kwargs, logger -from ..._parallel import get_parallel_backend, parallel_func -from ..._run import failsafe_run, save_logs, _prep_out_files, _update_for_splits -from ..._report import ( +from mne_bids_pipeline._decoding import _handle_csp_args +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import ( + _all_conditions, + _contrasts_to_names, _open_report, - _sanitize_cond_tag, - add_event_counts, - add_csp_grand_average, + _plot_decoding_time_generalization, _plot_full_epochs_decoding_scores, _plot_time_by_time_decoding_scores_gavg, + _sanitize_cond_tag, + add_csp_grand_average, + add_event_counts, plot_time_by_time_decoding_t_values, - _plot_decoding_time_generalization, - _contrasts_to_names, - _all_conditions, ) +from mne_bids_pipeline._run import ( + _prep_out_files, + _update_for_splits, + failsafe_run, + save_logs, +) +from mne_bids_pipeline.typing import FloatArrayT, InFilesT, OutFilesT, TypedDict def get_input_fnames_average_evokeds( *, cfg: SimpleNamespace, subject: str, - session: Optional[dict], -) -> dict: + session: str | None, +) -> InFilesT: in_files = dict() - for this_subject in cfg.subjects: + # for each session, only use subjects who actually have data for that session + subjects = get_subjects_given_session(cfg, session) + for this_subject in subjects: in_files[f"evoked-{this_subject}"] = BIDSPath( subject=this_subject, session=session, @@ -77,29 +83,35 @@ def average_evokeds( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesT, +) -> OutFilesT: logger.info(**gen_log_kwargs(message="Creating grand averages")) # Container for all conditions: conditions = _all_conditions(cfg=cfg) - evokeds = [list() for _ in range(len(conditions))] + evokeds_nested: list[list[mne.Evoked]] = [list() for _ in range(len(conditions))] keys = list(in_files) + subjects_in_grand_avg = list() for key in keys: - if not key.startswith("evoked-"): + if key.startswith("evoked-"): + subjects_in_grand_avg.append(key.replace("evoked-", "")) + else: continue fname_in = in_files.pop(key) these_evokeds = mne.read_evokeds(fname_in) for idx, evoked in enumerate(these_evokeds): - evokeds[idx].append(evoked) # Insert into the container - - for idx, these_evokeds in enumerate(evokeds): - evokeds[idx] = mne.grand_average( - these_evokeds, interpolate_bads=cfg.interpolate_bads_grand_average - ) # Combine subjects + evokeds_nested[idx].append(evoked) # Insert into the container + + evokeds: list[mne.Evoked] = list() + for these_evokeds in evokeds_nested: + evokeds.append( + mne.grand_average( + these_evokeds, interpolate_bads=cfg.interpolate_bads_grand_average + ) # Combine subjects + ) # Keep condition in comment - evokeds[idx].comment = "Grand average: " + these_evokeds[0].comment + evokeds[-1].comment = "Grand average: " + these_evokeds[0].comment out_files = dict() fname_out = out_files["evokeds"] = BIDSPath( @@ -108,7 +120,7 @@ def average_evokeds( task=cfg.task, acquisition=cfg.acq, run=None, - processing=cfg.proc, + processing="clean", recording=cfg.rec, space=cfg.space, suffix="ave", @@ -152,12 +164,16 @@ def average_evokeds( else: msg = "No evoked conditions or contrasts found." logger.info(**gen_log_kwargs(message=msg)) + # construct the common part of the titles + _title = f"N = {len(subjects_in_grand_avg)}" + if n_missing := (len(cfg.subjects) - len(subjects_in_grand_avg)): + _title += f"{n_missing} subjects excluded due to missing session data" for condition, evoked in zip(conditions, evokeds): - tags = ("evoked", _sanitize_cond_tag(condition)) + tags: tuple[str, ...] = ("evoked", _sanitize_cond_tag(condition)) if condition in cfg.conditions: - title = f"Average (sensor): {condition}" + title = f"Average (sensor): {condition}, {_title}" else: # It's a contrast of two conditions. - title = f"Average (sensor) contrast: {condition}" + title = f"Average (sensor) contrast: {condition}, {_title}" tags = tags + ("contrast",) report.add_evokeds( @@ -176,17 +192,17 @@ def average_evokeds( class ClusterAcrossTime(TypedDict): - times: np.ndarray + times: FloatArrayT p_value: float def _decoding_cluster_permutation_test( - scores: np.ndarray, - times: np.ndarray, - cluster_forming_t_threshold: Optional[float], + scores: FloatArrayT, + times: FloatArrayT, + cluster_forming_t_threshold: float | None, n_permutations: int, random_seed: int, -) -> Tuple[np.ndarray, List[ClusterAcrossTime], int]: +) -> tuple[FloatArrayT, list[ClusterAcrossTime], int]: """Perform a cluster permutation test on decoding scores. The clusters are formed across time points. @@ -218,11 +234,13 @@ def _get_epochs_in_files( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesT: in_files = dict() + # here we just need one subject's worth of Epochs, to get the time domain. But we + # still must be careful that the subject actually has data for the requested session in_files["epochs"] = BIDSPath( - subject=cfg.subjects[0], + subject=get_subjects_given_session(cfg, session)[0], session=session, task=cfg.task, acquisition=cfg.acq, @@ -243,14 +261,20 @@ def _decoding_out_fname( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - cond_1: str, - cond_2: str, + session: str | None, + cond_1: str | None, + cond_2: str | None, kind: str, extension: str = ".mat", -): +) -> BIDSPath: + if cond_1 is None: + assert cond_2 is None + processing = "" + else: + assert cond_2 is not None + processing = f"{cond_1}+{cond_2}+" processing = ( - f"{cond_1}+{cond_2}+{kind}+{cfg.decoding_metric}".replace(op.sep, "") + f"{processing}{kind}+{cfg.decoding_metric}".replace(op.sep, "") .replace("_", "-") .replace("-", "") ) @@ -275,13 +299,13 @@ def _get_input_fnames_decoding( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, cond_1: str, cond_2: str, kind: str, extension: str = ".mat", -) -> dict: - in_files = _get_epochs_in_files(cfg=cfg, subject=subject, session=session) +) -> InFilesT: + in_files = _get_epochs_in_files(cfg=cfg, subject="ignored", session=session) for this_subject in cfg.subjects: in_files[f"scores-{this_subject}"] = _decoding_out_fname( cfg=cfg, @@ -306,11 +330,11 @@ def average_time_by_time_decoding( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, cond_1: str, cond_2: str, - in_files: dict, -) -> dict: + in_files: InFilesT, +) -> OutFilesT: logger.info(**gen_log_kwargs(message="Averaging time-by-time decoding results")) # Get the time points from the very first subject. They are identical # across all subjects and conditions, so this should suffice. @@ -321,10 +345,9 @@ def average_time_by_time_decoding( times = epochs.times del epochs + time_points_shape: tuple[int, ...] = (len(times),) if cfg.decoding_time_generalization: - time_points_shape = (len(times), len(times)) - else: - time_points_shape = (len(times),) + time_points_shape += (len(times),) n_subjects = len(cfg.subjects) contrast_score_stats = { @@ -347,7 +370,7 @@ def average_time_by_time_decoding( } # Extract mean CV scores from all subjects. - mean_scores = np.empty((n_subjects, *time_points_shape)) + mean_scores: FloatArrayT = np.empty((n_subjects, *time_points_shape)) # Remaining in_files are all decoding data assert len(in_files) == n_subjects, list(in_files.keys()) @@ -448,7 +471,7 @@ def average_time_by_time_decoding( ) savemat(out_files["mat"], contrast_score_stats) - section = "Decoding: time-by-time" + section = f"Decoding: time-by-time, N = {len(cfg.subjects)}" with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session ) as report: @@ -469,7 +492,7 @@ def average_time_by_time_decoding( decoding_data=decoding_data, ) caption = ( - f'Based on N={decoding_data["N"].squeeze()} ' + f"Based on N={decoding_data['N'].squeeze()} " f"subjects. Standard error and confidence interval " f"of the mean were bootstrapped with {cfg.n_boot} " f"resamples. CI must not be used for statistical inference here, " @@ -480,7 +503,7 @@ def average_time_by_time_decoding( f" Time periods with decoding performance significantly above " f"chance, if any, were derived with a one-tailed " f"cluster-based permutation test " - f'({decoding_data["cluster_n_permutations"].squeeze()} ' + f"({decoding_data['cluster_n_permutations'].squeeze()} " f"permutations) and are highlighted in yellow." ) title = f"Decoding over time: {cond_1} vs. {cond_2}" @@ -522,7 +545,7 @@ def average_time_by_time_decoding( f"Time generalization (generalization across time, GAT): " f"each classifier is trained on each time point, and tested " f"on all other time points. The results were averaged across " - f'N={decoding_data["N"].item()} subjects.' + f"N={decoding_data['N'].item()} subjects." ) title = f"Time generalization: {cond_1} vs. {cond_2}" report.add_figure( @@ -549,11 +572,11 @@ def average_full_epochs_decoding( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, cond_1: str, cond_2: str, - in_files: dict, -) -> dict: + in_files: InFilesT, +) -> OutFilesT: n_subjects = len(cfg.subjects) in_files.pop("epochs") # not used but okay to include @@ -624,9 +647,9 @@ def get_input_files_average_full_epochs_report( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], - decoding_contrasts: List[List[str]], -) -> dict: + session: str | None, + decoding_contrasts: list[list[str]], +) -> InFilesT: in_files = dict() for contrast in decoding_contrasts: in_files[f"decoding-full-epochs-{contrast}"] = _decoding_out_fname( @@ -648,11 +671,22 @@ def average_full_epochs_report( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - decoding_contrasts: List[List[str]], - in_files: dict, -) -> dict: + session: str | None, + decoding_contrasts: list[list[str]], + in_files: InFilesT, +) -> OutFilesT: """Add decoding results to the grand average report.""" + out_files = dict() + out_files["cluster"] = _decoding_out_fname( + cfg=cfg, + subject=subject, + session=session, + cond_1=None, + cond_2=None, + kind="FullEpochs", + extension=".xlsx", + ) + with _open_report( cfg=cfg, exec_params=exec_params, subject=subject, session=session ) as report: @@ -671,16 +705,18 @@ def average_full_epochs_report( all_decoding_scores.append(np.atleast_1d(decoding_data["scores"].squeeze())) del decoding_data - fig, caption = _plot_full_epochs_decoding_scores( + fig, caption, data = _plot_full_epochs_decoding_scores( contrast_names=_contrasts_to_names(decoding_contrasts), scores=all_decoding_scores, metric=cfg.decoding_metric, kind="grand-average", ) + with pd.ExcelWriter(out_files["cluster"]) as w: + data.to_excel(w, sheet_name="FullEpochs", index=False) report.add_figure( fig=fig, title="Full-epochs decoding", - section="Decoding: full-epochs", + section=f"Decoding: full-epochs, N = {len(cfg.subjects)}", caption=caption, tags=( "epochs", @@ -695,7 +731,7 @@ def average_full_epochs_report( ) # close figure to save memory plt.close(fig) - return _prep_out_files(exec_params=exec_params, out_files=dict()) + return _prep_out_files(exec_params=exec_params, out_files=out_files) @failsafe_run( @@ -710,11 +746,11 @@ def average_csp_decoding( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, cond_1: str, cond_2: str, - in_files: dict, -): + in_files: InFilesT, +) -> OutFilesT: msg = f"Summarizing CSP results: {cond_1} - {cond_2}." logger.info(**gen_log_kwargs(message=msg)) in_files.pop("epochs") @@ -723,18 +759,20 @@ def average_csp_decoding( all_decoding_data_time_freq = [] for key in list(in_files): fname_xlsx = in_files.pop(key) - decoding_data_freq = pd.read_excel( - fname_xlsx, - sheet_name="CSP Frequency", - dtype={"subject": str}, # don't drop trailing zeros - ) - decoding_data_time_freq = pd.read_excel( - fname_xlsx, - sheet_name="CSP Time-Frequency", - dtype={"subject": str}, # don't drop trailing zeros - ) - all_decoding_data_freq.append(decoding_data_freq) - all_decoding_data_time_freq.append(decoding_data_time_freq) + with pd.ExcelFile(fname_xlsx) as xf: + decoding_data_freq = pd.read_excel( + xf, + sheet_name="CSP Frequency", + dtype={"subject": str}, # don't drop trailing zeros + ) + all_decoding_data_freq.append(decoding_data_freq) + if "CSP Time-Frequency" in xf.sheet_names: + decoding_data_time_freq = pd.read_excel( + xf, + sheet_name="CSP Time-Frequency", + dtype={"subject": str}, # don't drop trailing zeros + ) + all_decoding_data_time_freq.append(decoding_data_time_freq) del fname_xlsx # Now calculate descriptes and bootstrap CIs. @@ -744,12 +782,15 @@ def average_csp_decoding( session=session, data=all_decoding_data_freq, ) - grand_average_time_freq = _average_csp_time_freq( - cfg=cfg, - subject=subject, - session=session, - data=all_decoding_data_time_freq, - ) + if len(all_decoding_data_time_freq): + grand_average_time_freq = _average_csp_time_freq( + cfg=cfg, + subject=subject, + session=session, + data=all_decoding_data_time_freq, + ) + else: + grand_average_time_freq = None out_files = dict() out_files["freq"] = _decoding_out_fname( @@ -763,17 +804,15 @@ def average_csp_decoding( ) with pd.ExcelWriter(out_files["freq"]) as w: grand_average_freq.to_excel(w, sheet_name="CSP Frequency", index=False) - grand_average_time_freq.to_excel( - w, sheet_name="CSP Time-Frequency", index=False - ) + if grand_average_time_freq is not None: + grand_average_time_freq.to_excel( + w, sheet_name="CSP Time-Frequency", index=False + ) + del grand_average_time_freq # Perform a cluster-based permutation test. subjects = cfg.subjects - time_bins = np.array(cfg.decoding_csp_times) - if time_bins.ndim == 1: - time_bins = np.array(list(zip(time_bins[:-1], time_bins[1:]))) - time_bins = pd.DataFrame(time_bins, columns=["t_min", "t_max"]) - freq_name_to_bins_map = _handle_csp_args( + freq_name_to_bins_map, time_bins = _handle_csp_args( cfg.decoding_csp_times, cfg.decoding_csp_freqs, cfg.decoding_metric, @@ -782,78 +821,85 @@ def average_csp_decoding( time_frequency_freq_min=cfg.time_frequency_freq_min, time_frequency_freq_max=cfg.time_frequency_freq_max, ) - data_for_clustering = {} - for freq_range_name in freq_name_to_bins_map: - a = np.empty( - shape=( - len(subjects), - len(time_bins), - len(freq_name_to_bins_map[freq_range_name]), + if not len(time_bins): + fname_csp_cluster_results = None + else: + time_bins_df = pd.DataFrame(time_bins, columns=["t_min", "t_max"]) + del time_bins + data_for_clustering = {} + for freq_range_name in freq_name_to_bins_map: + a = np.empty( + shape=( + len(subjects), + len(time_bins_df), + len(freq_name_to_bins_map[freq_range_name]), + ) ) + a.fill(np.nan) + data_for_clustering[freq_range_name] = a + + g = pd.concat(all_decoding_data_time_freq).groupby( + ["subject", "freq_range_name", "t_min", "t_max"] ) - a.fill(np.nan) - data_for_clustering[freq_range_name] = a - g = pd.concat(all_decoding_data_time_freq).groupby( - ["subject", "freq_range_name", "t_min", "t_max"] - ) + for (subject_, freq_range_name, t_min, t_max), df in g: + scores = df["mean_crossval_score"] + sub_idx = subjects.index(subject_) + time_bin_idx = time_bins_df.loc[ + (np.isclose(time_bins_df["t_min"], t_min)) + & (np.isclose(time_bins_df["t_max"], t_max)), + :, + ].index + assert len(time_bin_idx) == 1 + time_bin_idx = time_bin_idx[0] + data_for_clustering[freq_range_name][sub_idx][time_bin_idx] = scores - for (subject_, freq_range_name, t_min, t_max), df in g: - scores = df["mean_crossval_score"] - sub_idx = subjects.index(subject_) - time_bin_idx = time_bins.loc[ - (np.isclose(time_bins["t_min"], t_min)) - & (np.isclose(time_bins["t_max"], t_max)), - :, - ].index - assert len(time_bin_idx) == 1 - time_bin_idx = time_bin_idx[0] - data_for_clustering[freq_range_name][sub_idx][time_bin_idx] = scores - - if cfg.cluster_forming_t_threshold is None: - import scipy.stats - - cluster_forming_t_threshold = scipy.stats.t.ppf( - 1 - 0.05, len(cfg.subjects) - 1 # one-sided test - ) - else: - cluster_forming_t_threshold = cfg.cluster_forming_t_threshold + if cfg.cluster_forming_t_threshold is None: + import scipy.stats - cluster_permutation_results = {} - for freq_range_name, X in data_for_clustering.items(): - if len(X) < 2: - t_vals = np.full(X.shape[1:], np.nan) - H0 = all_clusters = cluster_p_vals = np.array([]) - else: - ( - t_vals, - all_clusters, - cluster_p_vals, - H0, - ) = mne.stats.permutation_cluster_1samp_test( # noqa: E501 - X=X - 0.5, # One-sample test against zero. - threshold=cluster_forming_t_threshold, - n_permutations=cfg.cluster_n_permutations, - adjacency=None, # each time & freq bin connected to its neighbors - out_type="mask", - tail=1, # one-sided: significantly above chance level - seed=cfg.random_state, + cluster_forming_t_threshold = scipy.stats.t.ppf( + 1 - 0.05, + len(cfg.subjects) - 1, # one-sided test ) - n_permutations = H0.size - 1 - all_clusters = np.array(all_clusters) # preserve "empty" 0th dimension - cluster_permutation_results[freq_range_name] = { - "mean_crossval_scores": X.mean(axis=0), - "t_vals": t_vals, - "clusters": all_clusters, - "cluster_p_vals": cluster_p_vals, - "cluster_t_threshold": cluster_forming_t_threshold, - "n_permutations": n_permutations, - "time_bin_edges": cfg.decoding_csp_times, - "freq_bin_edges": cfg.decoding_csp_freqs[freq_range_name], - } - - out_files["cluster"] = out_files["freq"].copy().update(extension=".mat") - savemat(file_name=out_files["cluster"], mdict=cluster_permutation_results) + else: + cluster_forming_t_threshold = cfg.cluster_forming_t_threshold + + cluster_permutation_results = {} + for freq_range_name, X in data_for_clustering.items(): + if len(X) < 2: + t_vals = np.full(X.shape[1:], np.nan) + H0 = all_clusters = cluster_p_vals = np.array([]) + else: + ( + t_vals, + all_clusters, + cluster_p_vals, + H0, + ) = mne.stats.permutation_cluster_1samp_test( # noqa: E501 + X=X - 0.5, # One-sample test against zero. + threshold=cluster_forming_t_threshold, + n_permutations=cfg.cluster_n_permutations, + adjacency=None, # each time & freq bin connected to its neighbors + out_type="mask", + tail=1, # one-sided: significantly above chance level + seed=cfg.random_state, + ) + n_permutations = H0.size - 1 + all_clusters = np.array(all_clusters) # preserve "empty" 0th dimension + cluster_permutation_results[freq_range_name] = { + "mean_crossval_scores": X.mean(axis=0), + "t_vals": t_vals, + "clusters": all_clusters, + "cluster_p_vals": cluster_p_vals, + "cluster_t_threshold": cluster_forming_t_threshold, + "n_permutations": n_permutations, + "time_bin_edges": cfg.decoding_csp_times, + "freq_bin_edges": cfg.decoding_csp_freqs[freq_range_name], + } + + out_files["cluster"] = out_files["freq"].copy().update(extension=".mat") + savemat(file_name=out_files["cluster"], mdict=cluster_permutation_results) + fname_csp_cluster_results = out_files["cluster"] assert subject == "average" with _open_report( @@ -867,7 +913,7 @@ def average_csp_decoding( cond_1=cond_1, cond_2=cond_2, fname_csp_freq_results=out_files["freq"], - fname_csp_cluster_results=out_files["cluster"], + fname_csp_cluster_results=fname_csp_cluster_results, ) return _prep_out_files(out_files=out_files, exec_params=exec_params) @@ -876,7 +922,7 @@ def _average_csp_time_freq( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, data: pd.DataFrame, ) -> pd.DataFrame: # Prepare a dataframe for storing the results. @@ -936,11 +982,11 @@ def _average_csp_time_freq( def get_config( *, - config, + config: SimpleNamespace, ) -> SimpleNamespace: - dtg_decim = config.decoding_time_generalization_decim cfg = SimpleNamespace( subjects=get_subjects(config), + allow_missing_sessions=config.allow_missing_sessions, task_is_rest=config.task_is_rest, conditions=config.conditions, contrasts=config.contrasts, @@ -952,7 +998,7 @@ def get_config( decoding_metric=config.decoding_metric, decoding_n_splits=config.decoding_n_splits, decoding_time_generalization=config.decoding_time_generalization, - decoding_time_generalization_decim=dtg_decim, + decoding_time_generalization_decim=config.decoding_time_generalization_decim, decoding_csp=config.decoding_csp, decoding_csp_freqs=config.decoding_csp_freqs, decoding_csp_times=config.decoding_csp_times, diff --git a/mne_bids_pipeline/steps/sensor/__init__.py b/mne_bids_pipeline/steps/sensor/__init__.py index fc76bf551..848efadf8 100644 --- a/mne_bids_pipeline/steps/sensor/__init__.py +++ b/mne_bids_pipeline/steps/sensor/__init__.py @@ -1,12 +1,14 @@ """Sensor-space analysis.""" -from . import _01_make_evoked -from . import _02_decoding_full_epochs -from . import _03_decoding_time_by_time -from . import _04_time_frequency -from . import _05_decoding_csp -from . import _06_make_cov -from . import _99_group_average +from . import ( + _01_make_evoked, + _02_decoding_full_epochs, + _03_decoding_time_by_time, + _04_time_frequency, + _05_decoding_csp, + _06_make_cov, + _99_group_average, +) _STEPS = ( _01_make_evoked, diff --git a/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py b/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py index fc4051c9f..a3b2e687d 100644 --- a/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py +++ b/mne_bids_pipeline/steps/source/_01_make_bem_surfaces.py @@ -6,25 +6,25 @@ import glob from pathlib import Path from types import SimpleNamespace -from typing import Optional import mne -from ..._config_utils import ( - get_fs_subject, - get_subjects, - get_sessions, +from mne_bids_pipeline._config_utils import ( + _bids_kwargs, _get_bem_conductivity, + _has_session_specific_anat, + get_fs_subject, get_fs_subjects_dir, - _bids_kwargs, + get_subjects_sessions, ) -from ..._logging import logger, gen_log_kwargs -from ..._parallel import get_parallel_backend, parallel_func -from ..._run import failsafe_run, save_logs, _prep_out_files -from ..._report import _open_report, _render_bem +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _open_report, _render_bem +from mne_bids_pipeline._run import _prep_out_files, failsafe_run, save_logs +from mne_bids_pipeline.typing import InFilesPathT, OutFilesT -def _get_bem_params(cfg: SimpleNamespace): +def _get_bem_params(cfg: SimpleNamespace) -> tuple[str, Path, Path]: mri_dir = Path(cfg.fs_subjects_dir) / cfg.fs_subject / "mri" flash_dir = mri_dir / "flash" / "parameter_maps" if cfg.bem_mri_images == "FLASH" and not flash_dir.exists(): @@ -42,13 +42,13 @@ def get_input_fnames_make_bem_surfaces( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesPathT: in_files = dict() mri_images, mri_dir, flash_dir = _get_bem_params(cfg) in_files["t1"] = mri_dir / "T1.mgz" if mri_images == "FLASH": - flash_fnames = sorted(glob.glob(str(flash_dir / "mef*_*.mgz"))) + flash_fnames = sorted(Path(p) for p in glob.glob(str(flash_dir / "mef*_*.mgz"))) # We could check for existence here, but make_flash_bem does it later for fname in flash_fnames: in_files[fname.stem] = fname @@ -59,10 +59,11 @@ def get_output_fnames_make_bem_surfaces( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesPathT: out_files = dict() conductivity, _ = _get_bem_conductivity(cfg) + assert conductivity is not None n_layers = len(conductivity) bem_dir = Path(cfg.fs_subjects_dir) / cfg.fs_subject / "bem" for surf in ("inner_skull", "outer_skull", "outer_skin")[:n_layers]: @@ -79,9 +80,9 @@ def make_bem_surfaces( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesPathT, +) -> OutFilesT: mri_images, _, _ = _get_bem_params(cfg) in_files.clear() # assume we use everything we add if mri_images == "FLASH": @@ -112,16 +113,22 @@ def make_bem_surfaces( subject=subject, session=session, ) - return _prep_out_files(exec_params=exec_params, out_files=out_files) + return _prep_out_files( + exec_params=exec_params, + out_files=out_files, + check_relative=cfg.fs_subjects_dir, + bids_only=False, + ) def get_config( *, config: SimpleNamespace, subject: str, + session: str | None = None, ) -> SimpleNamespace: cfg = SimpleNamespace( - fs_subject=get_fs_subject(config=config, subject=subject), + fs_subject=get_fs_subject(config=config, subject=subject, session=session), fs_subjects_dir=get_fs_subjects_dir(config=config), bem_mri_images=config.bem_mri_images, freesurfer_verbose=config.freesurfer_verbose, @@ -147,6 +154,17 @@ def main(*, config: SimpleNamespace) -> None: mne.datasets.fetch_fsaverage(get_fs_subjects_dir(config)) return + # check for session-specific MRIs within subject, and add entries to `subj_sess` for + # each combination of subject+session that has its own MRI + subjects_dir = Path(get_fs_subjects_dir(config)) + subj_sess = set() + for _subj, sessions in get_subjects_sessions(config).items(): + for sess in sessions: + _sess = ( + sess if _has_session_specific_anat(_subj, sess, subjects_dir) else None + ) + subj_sess.add((_subj, _sess)) + with get_parallel_backend(config.exec_params): parallel, run_func = parallel_func( make_bem_surfaces, exec_params=config.exec_params @@ -156,12 +174,13 @@ def main(*, config: SimpleNamespace) -> None: cfg=get_config( config=config, subject=subject, + session=session, ), exec_params=config.exec_params, subject=subject, - session=get_sessions(config)[0], + session=session, force_run=config.recreate_bem, ) - for subject in get_subjects(config) + for subject, session in sorted(subj_sess) ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/source/_02_make_bem_solution.py b/mne_bids_pipeline/steps/source/_02_make_bem_solution.py index 67f0c2737..a7f170504 100644 --- a/mne_bids_pipeline/steps/source/_02_make_bem_solution.py +++ b/mne_bids_pipeline/steps/source/_02_make_bem_solution.py @@ -8,24 +8,27 @@ import mne -from ..._config_utils import ( +from mne_bids_pipeline._config_utils import ( _get_bem_conductivity, - get_fs_subjects_dir, get_fs_subject, - get_subjects, + get_fs_subjects_dir, + get_subjects_sessions, ) -from ..._logging import logger, gen_log_kwargs -from ..._parallel import parallel_func, get_parallel_backend -from ..._run import failsafe_run, save_logs, _prep_out_files +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._run import _prep_out_files, failsafe_run, save_logs +from mne_bids_pipeline.typing import InFilesPathT, OutFilesT def get_input_fnames_make_bem_solution( *, cfg: SimpleNamespace, subject: str, -) -> dict: + session: str | None = None, +) -> InFilesPathT: in_files = dict() conductivity, _ = _get_bem_conductivity(cfg) + assert conductivity is not None n_layers = len(conductivity) bem_dir = Path(cfg.fs_subjects_dir) / cfg.fs_subject / "bem" for surf in ("inner_skull", "outer_skull", "outer_skin")[:n_layers]: @@ -37,7 +40,8 @@ def get_output_fnames_make_bem_solution( *, cfg: SimpleNamespace, subject: str, -) -> dict: + session: str | None = None, +) -> InFilesPathT: out_files = dict() bem_dir = Path(cfg.fs_subjects_dir) / cfg.fs_subject / "bem" _, tag = _get_bem_conductivity(cfg) @@ -55,10 +59,11 @@ def make_bem_solution( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - in_files: dict, -) -> dict: + in_files: InFilesPathT, + session: str | None = None, +) -> OutFilesT: msg = "Calculating BEM solution" - logger.info(**gen_log_kwargs(message=msg, subject=subject)) + logger.info(**gen_log_kwargs(message=msg, subject=subject, session=session)) conductivity, _ = _get_bem_conductivity(cfg) bem_model = mne.make_bem_model( subject=cfg.fs_subject, @@ -69,16 +74,22 @@ def make_bem_solution( out_files = get_output_fnames_make_bem_solution(cfg=cfg, subject=subject) mne.write_bem_surfaces(out_files["model"], bem_model, overwrite=True) mne.write_bem_solution(out_files["sol"], bem_sol, overwrite=True) - return _prep_out_files(exec_params=exec_params, out_files=out_files) + return _prep_out_files( + exec_params=exec_params, + out_files=out_files, + check_relative=cfg.fs_subjects_dir, + bids_only=False, + ) def get_config( *, config: SimpleNamespace, subject: str, + session: str | None = None, ) -> SimpleNamespace: cfg = SimpleNamespace( - fs_subject=get_fs_subject(config=config, subject=subject), + fs_subject=get_fs_subject(config=config, subject=subject, session=session), fs_subjects_dir=get_fs_subjects_dir(config), ch_types=config.ch_types, use_template_mri=config.use_template_mri, @@ -86,7 +97,7 @@ def get_config( return cfg -def main(*, config) -> None: +def main(*, config: SimpleNamespace) -> None: """Run BEM solution calculation.""" if not config.run_source_estimation: msg = "Skipping, run_source_estimation is set to False …" @@ -94,7 +105,7 @@ def main(*, config) -> None: return if config.use_template_mri is not None: - msg = "Skipping, BEM solution computation not needed for " "MRI template …" + msg = "Skipping, BEM solution computation not needed for MRI template …" logger.info(**gen_log_kwargs(message=msg, emoji="skip")) if config.use_template_mri == "fsaverage": # Ensure we have the BEM @@ -107,11 +118,13 @@ def main(*, config) -> None: ) logs = parallel( run_func( - cfg=get_config(config=config, subject=subject), + cfg=get_config(config=config, subject=subject, session=session), exec_params=config.exec_params, subject=subject, + session=session, force_run=config.recreate_bem, ) - for subject in get_subjects(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/source/_03_setup_source_space.py b/mne_bids_pipeline/steps/source/_03_setup_source_space.py index 4710750f9..bfd327610 100644 --- a/mne_bids_pipeline/steps/source/_03_setup_source_space.py +++ b/mne_bids_pipeline/steps/source/_03_setup_source_space.py @@ -7,13 +7,21 @@ import mne -from ..._config_utils import get_fs_subject, get_fs_subjects_dir, get_subjects -from ..._logging import logger, gen_log_kwargs -from ..._run import failsafe_run, save_logs, _prep_out_files -from ..._parallel import parallel_func, get_parallel_backend +from mne_bids_pipeline._config_utils import ( + get_fs_subject, + get_fs_subjects_dir, + get_sessions, + get_subjects_sessions, +) +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._run import _prep_out_files, failsafe_run, save_logs +from mne_bids_pipeline.typing import InFilesPathT, OutFilesT -def get_input_fnames_setup_source_space(*, cfg, subject): +def get_input_fnames_setup_source_space( + *, cfg: SimpleNamespace, subject: str +) -> InFilesPathT: in_files = dict() surf_path = cfg.fs_subjects_dir / cfg.fs_subject / "surf" for hemi in ("lh", "rh"): @@ -22,7 +30,9 @@ def get_input_fnames_setup_source_space(*, cfg, subject): return in_files -def get_output_fnames_setup_source_space(*, cfg, subject): +def get_output_fnames_setup_source_space( + *, cfg: SimpleNamespace, subject: str +) -> InFilesPathT: out_files = dict() out_files["src"] = ( cfg.fs_subjects_dir @@ -42,8 +52,8 @@ def run_setup_source_space( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - in_files: dict, -) -> dict: + in_files: InFilesPathT, +) -> OutFilesT: msg = f"Creating source space with spacing {repr(cfg.spacing)}" logger.info(**gen_log_kwargs(message=msg, subject=subject)) src = mne.setup_source_space( @@ -55,18 +65,24 @@ def run_setup_source_space( in_files.clear() # all used by setup_source_space out_files = get_output_fnames_setup_source_space(cfg=cfg, subject=subject) mne.write_source_spaces(out_files["src"], src, overwrite=True) - return _prep_out_files(exec_params=exec_params, out_files=out_files) + return _prep_out_files( + exec_params=exec_params, + out_files=out_files, + check_relative=cfg.fs_subjects_dir, + bids_only=False, + ) def get_config( *, config: SimpleNamespace, subject: str, + session: str | None = None, ) -> SimpleNamespace: cfg = SimpleNamespace( spacing=config.spacing, use_template_mri=config.use_template_mri, - fs_subject=get_fs_subject(config=config, subject=subject), + fs_subject=get_fs_subject(config=config, subject=subject, session=session), fs_subjects_dir=get_fs_subjects_dir(config), ) return cfg @@ -80,9 +96,9 @@ def main(*, config: SimpleNamespace) -> None: return if config.use_template_mri is not None: - subjects = [config.use_template_mri] + sub_ses = {config.use_template_mri: get_sessions(config=config)} else: - subjects = get_subjects(config=config) + sub_ses = get_subjects_sessions(config=config) with get_parallel_backend(config.exec_params): parallel, run_func = parallel_func( @@ -93,10 +109,12 @@ def main(*, config: SimpleNamespace) -> None: cfg=get_config( config=config, subject=subject, + session=session, ), exec_params=config.exec_params, subject=subject, ) - for subject in subjects + for subject, sessions in sub_ses.items() + for session in sessions ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/source/_04_make_forward.py b/mne_bids_pipeline/steps/source/_04_make_forward.py index a2c1fc211..8ca2e7c7a 100644 --- a/mne_bids_pipeline/steps/source/_04_make_forward.py +++ b/mne_bids_pipeline/steps/source/_04_make_forward.py @@ -4,36 +4,38 @@ """ from types import SimpleNamespace -from typing import Optional - -import numpy as np import mne +import numpy as np from mne.coreg import Coregistration from mne_bids import BIDSPath, get_head_mri_trans -from ..._config_utils import ( - get_fs_subject, - get_subjects, +from mne_bids_pipeline._config_utils import ( + _bids_kwargs, _get_bem_conductivity, + _meg_in_ch_types, + get_fs_subject, get_fs_subjects_dir, get_runs, - _meg_in_ch_types, - get_sessions, - _bids_kwargs, + get_subjects_sessions, +) +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _open_report, _render_bem +from mne_bids_pipeline._run import ( + _prep_out_files, + _sanitize_callable, + failsafe_run, + save_logs, ) -from ..._config_import import _import_config -from ..._logging import logger, gen_log_kwargs -from ..._parallel import get_parallel_backend, parallel_func -from ..._report import _open_report, _render_bem -from ..._run import failsafe_run, save_logs, _prep_out_files +from mne_bids_pipeline.typing import InFilesT, OutFilesT def _prepare_trans_template( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, info: mne.Info, ) -> mne.transforms.Transform: assert isinstance(cfg.use_template_mri, str) @@ -69,51 +71,35 @@ def _prepare_trans_subject( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, bids_path: BIDSPath, ) -> mne.transforms.Transform: # Generate a head ↔ MRI transformation matrix from the # electrophysiological and MRI sidecar files, and save it to an MNE # "trans" file in the derivatives folder. - # TODO: This breaks our encapsulation - config = _import_config( - config_path=exec_params.config_path, - check=False, - log=False, - ) - if config.mri_t1_path_generator is None: - t1_bids_path = None - else: - t1_bids_path = BIDSPath(subject=subject, session=session, root=cfg.bids_root) - t1_bids_path = config.mri_t1_path_generator(t1_bids_path.copy()) - if t1_bids_path.suffix is None: - t1_bids_path.update(suffix="T1w") - if t1_bids_path.datatype is None: - t1_bids_path.update(datatype="anat") - - if config.mri_landmarks_kind is None: - landmarks_kind = None - else: - landmarks_kind = config.mri_landmarks_kind( - BIDSPath(subject=subject, session=session) - ) - msg = "Computing head ↔ MRI transform from matched fiducials" logger.info(**gen_log_kwargs(message=msg)) trans = get_head_mri_trans( - bids_path.copy().update(run=cfg.runs[0], root=cfg.bids_root, extension=None), - t1_bids_path=t1_bids_path, + bids_path.copy().update( + run=cfg.runs[0], + root=cfg.bids_root, + processing=cfg.proc, + extension=None, + ), + t1_bids_path=cfg.t1_bids_path, fs_subject=cfg.fs_subject, fs_subjects_dir=cfg.fs_subjects_dir, - kind=landmarks_kind, + kind=cfg.landmarks_kind, ) return trans -def get_input_fnames_forward(*, cfg, subject, session): +def get_input_fnames_forward( + *, cfg: SimpleNamespace, subject: str, session: str | None +) -> InFilesT: bids_path = BIDSPath( subject=subject, session=session, @@ -128,7 +114,18 @@ def get_input_fnames_forward(*, cfg, subject, session): check=False, ) in_files = dict() - in_files["info"] = bids_path.copy().update(**cfg.source_info_path_update) + # for consistency with 05_make_inverse, read the info from the + # data used for the noise_cov + if cfg.source_info_path_update is None: + if cfg.noise_cov in ("rest", "noise"): + source_info_path_update = dict( + processing="clean", suffix="raw", task=cfg.noise_cov + ) + else: + source_info_path_update = dict(suffix="ave") + else: + source_info_path_update = cfg.source_info_path_update + in_files["info"] = bids_path.copy().update(**source_info_path_update) bem_path = cfg.fs_subjects_dir / cfg.fs_subject / "bem" _, tag = _get_bem_conductivity(cfg) in_files["bem"] = bem_path / f"{cfg.fs_subject}-{tag}-bem-sol.fif" @@ -144,9 +141,12 @@ def run_forward( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesT, +) -> OutFilesT: + # Do not use processing=cfg.proc here because the forward could actually be + # influenced by previous steps (e.g., Maxwell filtering), so just make sure we + # use cfg.proc when figuring out the head<->MRI transform bids_path = BIDSPath( subject=subject, session=session, @@ -201,6 +201,12 @@ def run_forward( fwd = mne.make_forward_solution( info, trans=trans, src=src, bem=bem, mindist=cfg.mindist ) + if fwd["src"]._subject != cfg.fs_subject: + raise RuntimeError( + f"subject in the SourceSpace ({fwd['src']._subject}) does not match " + f"expected subject ({cfg.fs_subject}). This should not happen and probably " + f"indicates an error in the SourceSpace loaded from ({str(src)})." + ) out_files = dict() out_files["trans"] = bids_path.copy().update(suffix="trans") out_files["forward"] = bids_path.copy().update(suffix="fwd") @@ -243,7 +249,24 @@ def get_config( *, config: SimpleNamespace, subject: str, + session: str | None, ) -> SimpleNamespace: + if config.mri_t1_path_generator is None: + t1_bids_path = None + else: + t1_bids_path = BIDSPath(subject=subject, session=session, root=config.bids_root) + t1_bids_path = config.mri_t1_path_generator(t1_bids_path.copy()) + if t1_bids_path.suffix is None: + t1_bids_path.update(suffix="T1w") + if t1_bids_path.datatype is None: + t1_bids_path.update(datatype="anat") + if config.mri_landmarks_kind is None: + landmarks_kind = None + else: + landmarks_kind = config.mri_landmarks_kind( + BIDSPath(subject=subject, session=session) + ) + cfg = SimpleNamespace( runs=get_runs(config=config, subject=subject), mindist=config.mindist, @@ -251,9 +274,12 @@ def get_config( use_template_mri=config.use_template_mri, adjust_coreg=config.adjust_coreg, source_info_path_update=config.source_info_path_update, + noise_cov=_sanitize_callable(config.noise_cov), ch_types=config.ch_types, - fs_subject=get_fs_subject(config=config, subject=subject), + fs_subject=get_fs_subject(config=config, subject=subject, session=session), fs_subjects_dir=get_fs_subjects_dir(config), + t1_bids_path=t1_bids_path, + landmarks_kind=landmarks_kind, **_bids_kwargs(config=config), ) return cfg @@ -270,12 +296,12 @@ def main(*, config: SimpleNamespace) -> None: parallel, run_func = parallel_func(run_forward, exec_params=config.exec_params) logs = parallel( run_func( - cfg=get_config(config=config, subject=subject), + cfg=get_config(config=config, subject=subject, session=session), exec_params=config.exec_params, subject=subject, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/source/_05_make_inverse.py b/mne_bids_pipeline/steps/source/_05_make_inverse.py index 449675817..cd6501b92 100644 --- a/mne_bids_pipeline/steps/source/_05_make_inverse.py +++ b/mne_bids_pipeline/steps/source/_05_make_inverse.py @@ -4,37 +4,41 @@ """ from types import SimpleNamespace -from typing import Optional import mne from mne.minimum_norm import ( - make_inverse_operator, apply_inverse, + make_inverse_operator, write_inverse_operator, ) from mne_bids import BIDSPath -from ..._config_utils import ( +from mne_bids_pipeline._config_utils import ( + _bids_kwargs, + get_fs_subject, + get_fs_subjects_dir, get_noise_cov_bids_path, - get_subjects, + get_subjects_sessions, sanitize_cond_name, - get_sessions, - get_fs_subjects_dir, - get_fs_subject, - _bids_kwargs, ) -from ..._logging import logger, gen_log_kwargs -from ..._parallel import get_parallel_backend, parallel_func -from ..._report import _open_report, _sanitize_cond_tag, _all_conditions -from ..._run import failsafe_run, save_logs, _sanitize_callable, _prep_out_files +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _all_conditions, _open_report, _sanitize_cond_tag +from mne_bids_pipeline._run import ( + _prep_out_files, + _sanitize_callable, + failsafe_run, + save_logs, +) +from mne_bids_pipeline.typing import InFilesT, OutFilesT def get_input_fnames_inverse( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -): + session: str | None, +) -> InFilesT: bids_path = BIDSPath( subject=subject, session=session, @@ -49,7 +53,19 @@ def get_input_fnames_inverse( check=False, ) in_files = dict() - in_files["info"] = bids_path.copy().update(**cfg.source_info_path_update) + # make sure the info matches the data from which the noise cov + # is computed to avoid rank-mismatch + if cfg.source_info_path_update is None: + if cfg.noise_cov in ("rest", "noise"): + source_info_path_update = dict( + processing="clean", suffix="raw", task=cfg.noise_cov + ) + else: + source_info_path_update = dict(suffix="ave") + # XXX is this the right solution also for noise_cov = 'ad-hoc'? + else: + source_info_path_update = cfg.source_info_path_update + in_files["info"] = bids_path.copy().update(**source_info_path_update) in_files["forward"] = bids_path.copy().update(suffix="fwd") if cfg.noise_cov != "ad-hoc": in_files["cov"] = get_noise_cov_bids_path( @@ -68,9 +84,9 @@ def run_inverse( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesT, +) -> OutFilesT: # TODO: Eventually we should maybe loop over ch_types, e.g., to create # MEG, EEG, and MEG+EEG inverses and STCs msg = "Computing inverse solutions" @@ -129,7 +145,10 @@ def run_inverse( for condition in conditions: msg = f"Rendering inverse solution for {condition}" logger.info(**gen_log_kwargs(message=msg)) - tags = ("source-estimate", _sanitize_cond_tag(condition)) + tags: tuple[str, ...] = ( + "source-estimate", + _sanitize_cond_tag(condition), + ) if condition not in cfg.conditions: tags = tags + ("contrast",) report.add_stc( @@ -150,6 +169,7 @@ def get_config( *, config: SimpleNamespace, subject: str, + session: str | None = None, ) -> SimpleNamespace: cfg = SimpleNamespace( source_info_path_update=config.source_info_path_update, @@ -162,7 +182,7 @@ def get_config( inverse_method=config.inverse_method, noise_cov=_sanitize_callable(config.noise_cov), report_stc_n_time_points=config.report_stc_n_time_points, - fs_subject=get_fs_subject(config=config, subject=subject), + fs_subject=get_fs_subject(config=config, subject=subject, session=session), fs_subjects_dir=get_fs_subjects_dir(config), **_bids_kwargs(config=config), ) @@ -183,12 +203,13 @@ def main(*, config: SimpleNamespace) -> None: cfg=get_config( config=config, subject=subject, + session=session, ), exec_params=config.exec_params, subject=subject, session=session, ) - for subject in get_subjects(config) - for session in get_sessions(config) + for subject, sessions in get_subjects_sessions(config).items() + for session in sessions ) save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/source/_99_group_average.py b/mne_bids_pipeline/steps/source/_99_group_average.py index 9e855d6df..1c9d1ac4c 100644 --- a/mne_bids_pipeline/steps/source/_99_group_average.py +++ b/mne_bids_pipeline/steps/source/_99_group_average.py @@ -4,40 +4,42 @@ """ from types import SimpleNamespace -from typing import Optional - -import numpy as np import mne +import numpy as np from mne_bids import BIDSPath -from ..._config_utils import ( +from mne_bids_pipeline._config_utils import ( + _bids_kwargs, + get_fs_subject, get_fs_subjects_dir, + get_sessions, get_subjects, + get_subjects_given_session, + get_subjects_sessions, sanitize_cond_name, - get_fs_subject, - get_sessions, - _bids_kwargs, ) -from ..._logging import logger, gen_log_kwargs -from ..._parallel import get_parallel_backend, parallel_func -from ..._report import _all_conditions, _open_report -from ..._run import failsafe_run, save_logs, _prep_out_files +from mne_bids_pipeline._logging import gen_log_kwargs, logger +from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func +from mne_bids_pipeline._report import _all_conditions, _open_report +from mne_bids_pipeline._run import _prep_out_files, failsafe_run, save_logs +from mne_bids_pipeline.typing import InFilesT, OutFilesT def _stc_path( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], + session: str | None, condition: str, morphed: bool, ) -> BIDSPath: cond_str = sanitize_cond_name(condition) - suffix = [cond_str, cfg.inverse_method, "hemi"] + suffix_list = [cond_str, cfg.inverse_method, "hemi"] if morphed: - suffix.insert(2, "morph2fsaverage") - suffix = "+".join(suffix) + suffix_list.insert(2, "morph2fsaverage") + suffix = "+".join(suffix_list) + del suffix_list return BIDSPath( subject=subject, session=session, @@ -59,8 +61,8 @@ def get_input_fnames_morph_stc( cfg: SimpleNamespace, subject: str, fs_subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesT: in_files = dict() for condition in _all_conditions(cfg=cfg): in_files[f"original-{condition}"] = _stc_path( @@ -82,9 +84,9 @@ def morph_stc( exec_params: SimpleNamespace, subject: str, fs_subject: str, - session: Optional[str], - in_files: dict, -) -> dict: + session: str | None, + in_files: InFilesT, +) -> OutFilesT: out_files = dict() for condition in _all_conditions(cfg=cfg): fname_stc = in_files.pop(f"original-{condition}") @@ -114,12 +116,14 @@ def get_input_fnames_run_average( *, cfg: SimpleNamespace, subject: str, - session: Optional[str], -) -> dict: + session: str | None, +) -> InFilesT: in_files = dict() assert subject == "average" + # for each session, only use subjects who actually have data for that session + subjects = get_subjects_given_session(cfg, session) for condition in _all_conditions(cfg=cfg): - for this_subject in cfg.subjects: + for this_subject in subjects: in_files[f"{this_subject}-{condition}"] = _stc_path( cfg=cfg, subject=this_subject, @@ -138,17 +142,19 @@ def run_average( cfg: SimpleNamespace, exec_params: SimpleNamespace, subject: str, - session: Optional[str], - in_files: dict, -): + session: str | None, + in_files: InFilesT, +) -> OutFilesT: assert subject == "average" out_files = dict() conditions = _all_conditions(cfg=cfg) + # for each session, only use subjects who actually have data for that session + subjects = get_subjects_given_session(cfg, session) for condition in conditions: stc = np.array( [ mne.read_source_estimate(in_files.pop(f"{this_subject}-{condition}")) - for this_subject in cfg.subjects + for this_subject in subjects ] ).mean(axis=0) out_files[condition] = _stc_path( @@ -171,7 +177,7 @@ def run_average( msg = f"Rendering inverse solution for {condition}" logger.info(**gen_log_kwargs(message=msg)) cond_str = sanitize_cond_name(condition) - tags = ("source-estimate", cond_str) + tags: tuple[str, ...] = ("source-estimate", cond_str) if condition in cfg.conditions: title = f"Average (source): {condition}" else: # It's a contrast of two conditions. @@ -204,6 +210,7 @@ def get_config( subjects=get_subjects(config=config), exclude_subjects=config.exclude_subjects, sessions=get_sessions(config), + allow_missing_sessions=config.allow_missing_sessions, use_template_mri=config.use_template_mri, contrasts=config.contrasts, report_stc_n_time_points=config.report_stc_n_time_points, @@ -223,8 +230,7 @@ def main(*, config: SimpleNamespace) -> None: mne.datasets.fetch_fsaverage(subjects_dir=get_fs_subjects_dir(config)) cfg = get_config(config=config) exec_params = config.exec_params - subjects = get_subjects(config) - sessions = get_sessions(config) + all_sessions = get_sessions(config) logs = list() with get_parallel_backend(exec_params): @@ -234,10 +240,10 @@ def main(*, config: SimpleNamespace) -> None: cfg=cfg, exec_params=exec_params, subject=subject, - fs_subject=get_fs_subject(config=cfg, subject=subject), + fs_subject=get_fs_subject(config=cfg, subject=subject, session=session), session=session, ) - for subject in subjects + for subject, sessions in get_subjects_sessions(config).items() for session in sessions ) logs += [ @@ -247,6 +253,6 @@ def main(*, config: SimpleNamespace) -> None: session=session, subject="average", ) - for session in sessions + for session in all_sessions ] save_logs(config=config, logs=logs) diff --git a/mne_bids_pipeline/steps/source/__init__.py b/mne_bids_pipeline/steps/source/__init__.py index c748f7f8b..89b757670 100644 --- a/mne_bids_pipeline/steps/source/__init__.py +++ b/mne_bids_pipeline/steps/source/__init__.py @@ -1,11 +1,13 @@ """Source-space analysis.""" -from . import _01_make_bem_surfaces -from . import _02_make_bem_solution -from . import _03_setup_source_space -from . import _04_make_forward -from . import _05_make_inverse -from . import _99_group_average +from . import ( + _01_make_bem_surfaces, + _02_make_bem_solution, + _03_setup_source_space, + _04_make_forward, + _05_make_inverse, + _99_group_average, +) _STEPS = ( _01_make_bem_surfaces, diff --git a/mne_bids_pipeline/tests/configs/README.md b/mne_bids_pipeline/tests/configs/README.md new file mode 100644 index 000000000..6e0985124 --- /dev/null +++ b/mne_bids_pipeline/tests/configs/README.md @@ -0,0 +1,9 @@ +# Config files for test datasets + +!!! warning + + The documentation build scripts (`docs/source/examples/gen_examples.py`) + assume a config file name of `config_{name_of_dataset}.py` or + `config_{name-of-dataset}.py`. If you want a dataset to be shown in the docs as an + example dataset, you **must** name the config file accordingly (and also, add the + dataset to `docs/mkdocs.yml` in the list of `Examples`). diff --git a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py index 47fcb5846..e536450fa 100644 --- a/mne_bids_pipeline/tests/configs/config_ERP_CORE.py +++ b/mne_bids_pipeline/tests/configs/config_ERP_CORE.py @@ -1,7 +1,6 @@ -""" -ERP CORE +"""ERP CORE. -This example demonstrate how to process 5 participants from the +This example demonstrates how to process 5 participants from the [ERP CORE](https://erpinfo.org/erp-core) dataset. It shows how to obtain 7 ERP components from a total of 6 experimental tasks: @@ -23,11 +22,12 @@ event-related potential research. *NeuroImage* 225: 117465. [https://doi.org/10.1016/j.neuroimage.2020.117465](https://doi.org/10.1016/j.neuroimage.2020.117465) """ + import argparse -import mne import sys -study_name = "ERP-CORE" +import mne + bids_root = "~/mne_data/ERP_CORE" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ERP_CORE" @@ -94,7 +94,7 @@ on_rename_missing_events = "ignore" parallel_backend = "dask" -dask_worker_memory_limit = "2G" +dask_worker_memory_limit = "2.5G" n_jobs = 4 if task == "N400": @@ -147,7 +147,6 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) epochs_tmin = -0.6 epochs_tmax = 0.4 baseline = (-0.4, -0.2) @@ -180,7 +179,6 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) epochs_tmin = -0.8 epochs_tmax = 0.2 baseline = (None, -0.6) @@ -193,7 +191,6 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) epochs_tmin = -0.2 epochs_tmax = 0.8 baseline = (None, 0) @@ -214,7 +211,41 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) + # Analyze all EEG channels -- we only specify the channels here for the purpose of + # demonstration + analyze_channels = [ + "FP1", + "F3", + "F7", + "FC3", + "C3", + "C5", + "P3", + "P7", + "P9", + "PO7", + "PO3", + "O1", + "Oz", + "Pz", + "CPz", + "FP2", + "Fz", + "F4", + "F8", + "FC4", + "FCz", + "Cz", + "C4", + "C6", + "P4", + "P8", + "P10", + "PO8", + "PO4", + "O2", + ] + epochs_tmin = -0.2 epochs_tmax = 0.8 baseline = (None, 0) @@ -227,6 +258,41 @@ } eeg_reference = "average" + # Analyze all EEG channels -- we only specify the channels here for the purpose of + # demonstration + analyze_channels = [ + "FP1", + "F3", + "F7", + "FC3", + "C3", + "C5", + "P3", + "P7", + "P9", + "PO7", + "PO3", + "O1", + "Oz", + "Pz", + "CPz", + "FP2", + "Fz", + "F4", + "F8", + "FC4", + "FCz", + "Cz", + "C4", + "C6", + "P4", + "P8", + "P10", + "PO8", + "PO4", + "O2", + ] + ica_n_components = 30 - 1 for i in range(1, 180 + 1): orig_name = f"stimulus/{i}" @@ -281,7 +347,6 @@ } eeg_reference = ["P9", "P10"] - ica_n_components = 30 - len(eeg_reference) epochs_tmin = -0.2 epochs_tmax = 0.8 baseline = (None, 0) diff --git a/mne_bids_pipeline/tests/configs/config_MNE_funloc_data.py b/mne_bids_pipeline/tests/configs/config_MNE_funloc_data.py new file mode 100644 index 000000000..e83f9e157 --- /dev/null +++ b/mne_bids_pipeline/tests/configs/config_MNE_funloc_data.py @@ -0,0 +1,41 @@ +"""Funloc data.""" + +from pathlib import Path + +data_root = Path("~/mne_data").expanduser().resolve() +bids_root = data_root / "MNE-funloc-data" +deriv_root = data_root / "derivatives" / "mne-bids-pipeline" / "MNE-funloc-data" +subjects_dir = bids_root / "derivatives" / "freesurfer" / "subjects" +task = "funloc" +ch_types = ["meg", "eeg"] +data_type = "meg" + +# filter +l_freq = None +h_freq = 50.0 +# maxfilter +use_maxwell_filter: bool = True +crop_runs = (40, 190) +mf_st_duration = 60.0 +# SSP +spatial_filter = "ssp" +ssp_ecg_channel = {"sub-01": "MEG0111", "sub-02": "MEG0141"} +n_proj_eog = dict(n_mag=1, n_grad=1, n_eeg=2) +n_proj_ecg = dict(n_mag=1, n_grad=1, n_eeg=0) + +# Epochs +epochs_tmin = -0.2 +epochs_tmax = 0.5 +epochs_decim = 5 # 1000 -> 200 Hz +baseline = (None, 0) +conditions = [ + "auditory/standard", + # "auditory/deviant", + "visual/standard", + # "visual/deviant", +] +decode = False +decoding_time_generalization = False + +# contrasts +# contrasts = [("auditory", "visual")] diff --git a/mne_bids_pipeline/tests/configs/config_MNE_phantom_KIT_data.py b/mne_bids_pipeline/tests/configs/config_MNE_phantom_KIT_data.py new file mode 100644 index 000000000..49689bd3e --- /dev/null +++ b/mne_bids_pipeline/tests/configs/config_MNE_phantom_KIT_data.py @@ -0,0 +1,27 @@ +""" +KIT phantom data. + +https://mne.tools/dev/documentation/datasets.html#kit-phantom-dataset +""" + +bids_root = "~/mne_data/MNE-phantom-KIT-data" +deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/MNE-phantom-KIT-data" +task = "phantom" +ch_types = ["meg"] + +# Preprocessing +l_freq = None +h_freq = 40.0 +regress_artifact = dict( + picks="meg", picks_artifact=["MISC 001", "MISC 002", "MISC 003"] +) + +# Epochs +epochs_tmin = -0.08 +epochs_tmax = 0.18 +epochs_decim = 10 # 2000->200 Hz +baseline = (None, 0) +conditions = ["dip01", "dip13", "dip25", "dip37", "dip49"] + +# Decoding +decode = True # should be very good performance diff --git a/mne_bids_pipeline/tests/configs/config_ds000117.py b/mne_bids_pipeline/tests/configs/config_ds000117.py index b46db99bd..a4c5d2e85 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000117.py +++ b/mne_bids_pipeline/tests/configs/config_ds000117.py @@ -1,8 +1,5 @@ -""" -Faces dataset -""" +"""Faces dataset.""" -study_name = "ds000117" bids_root = "~/mne_data/ds000117" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000117" @@ -18,10 +15,13 @@ find_flat_channels_meg = True find_noisy_channels_meg = True use_maxwell_filter = True +process_empty_room = True mf_reference_run = "02" mf_cal_fname = bids_root + "/derivatives/meg_derivatives/sss_cal.dat" mf_ctc_fname = bids_root + "/derivatives/meg_derivatives/ct_sparse.fif" +mf_int_order = 9 +mf_ext_order = 2 reject = {"grad": 4000e-13, "mag": 4e-12} conditions = ["Famous", "Unfamiliar", "Scrambled"] diff --git a/mne_bids_pipeline/tests/configs/config_ds000246.py b/mne_bids_pipeline/tests/configs/config_ds000246.py index 6cb3a8148..1aa58f244 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000246.py +++ b/mne_bids_pipeline/tests/configs/config_ds000246.py @@ -1,11 +1,9 @@ -""" -Brainstorm - Auditory Dataset. +"""Brainstorm - Auditory Dataset. See https://openneuro.org/datasets/ds000246/versions/1.0.0 for more information. """ -study_name = "ds000246" bids_root = "~/mne_data/ds000246" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000246" @@ -19,6 +17,7 @@ ch_types = ["meg"] reject = dict(mag=4e-12, eog=250e-6) conditions = ["standard", "deviant", "button"] +epochs_metadata_tmin = ["standard", "deviant"] # for testing only contrasts = [("deviant", "standard")] decode = True decoding_time_generalization = True diff --git a/mne_bids_pipeline/tests/configs/config_ds000247.py b/mne_bids_pipeline/tests/configs/config_ds000247.py index 8d2b0451f..fc4f42464 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000247.py +++ b/mne_bids_pipeline/tests/configs/config_ds000247.py @@ -1,12 +1,9 @@ -""" -OMEGA Resting State Sample Data -""" -import numpy as np +"""OMEGA Resting State Sample Data.""" +import numpy as np -study_name = "ds000247" -bids_root = f"~/mne_data/{study_name}" -deriv_root = f"~/mne_data/derivatives/mne-bids-pipeline/{study_name}" +bids_root = "~/mne_data/ds000247" +deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000247" subjects = ["0002"] sessions = ["01"] diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_FLASH_BEM.py b/mne_bids_pipeline/tests/configs/config_ds000248_FLASH_BEM.py index f09fdc6d5..5d37fde67 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_FLASH_BEM.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_FLASH_BEM.py @@ -1,7 +1,5 @@ -""" -MNE Sample Data: BEM from FLASH images -""" -study_name = "ds000248" +"""MNE Sample Data: BEM from FLASH images.""" + bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_FLASH_BEM" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_T1_BEM.py b/mne_bids_pipeline/tests/configs/config_ds000248_T1_BEM.py index df315e035..0fdfdbf76 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_T1_BEM.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_T1_BEM.py @@ -1,8 +1,5 @@ -""" -MNE Sample Data: BEM from T1 images -""" +"""MNE Sample Data: BEM from T1 images.""" -study_name = "ds000248" bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_T1_BEM" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_base.py b/mne_bids_pipeline/tests/configs/config_ds000248_base.py index b80b6f0f0..7a5137e97 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_base.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_base.py @@ -1,9 +1,8 @@ -""" -MNE Sample Data: M/EEG combined processing -""" +"""MNE Sample Data: M/EEG combined processing.""" + import mne +import mne_bids -study_name = "ds000248" bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_base" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" @@ -23,9 +22,10 @@ use_maxwell_filter = True -def noise_cov(bp): +def noise_cov(bp: mne_bids.BIDSPath) -> mne.Covariance: + """Estimate the noise covariance.""" # Use pre-stimulus period as noise source - bp = bp.copy().update(processing="clean", suffix="epo") + bp = bp.copy().update(suffix="epo") if not bp.fpath.exists(): bp.update(split="01") epo = mne.read_epochs(bp) @@ -47,6 +47,7 @@ def noise_cov(bp): n_jobs = 2 -def mri_t1_path_generator(bids_path): +def mri_t1_path_generator(bids_path: mne_bids.BIDSPath) -> mne_bids.BIDSPath: + """Return the path to a T1 image.""" # don't really do any modifications – just for testing! return bids_path diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_coreg_surfaces.py b/mne_bids_pipeline/tests/configs/config_ds000248_coreg_surfaces.py index 9262fdcb8..dba51f97d 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_coreg_surfaces.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_coreg_surfaces.py @@ -1,8 +1,5 @@ -""" -MNE Sample Data: Head surfaces from FreeSurfer surfaces for coregistration step -""" +"""MNE Sample Data: Head surfaces from FreeSurfer surfaces for coregistration step.""" -study_name = "ds000248" bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_coreg_surfaces" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_ica.py b/mne_bids_pipeline/tests/configs/config_ds000248_ica.py index 176a2f592..6dfb49c7b 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_ica.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_ica.py @@ -1,7 +1,5 @@ -""" -MNE Sample Data: ICA -""" -study_name = 'MNE "sample" dataset' +"""MNE Sample Data: ICA.""" + bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_ica" diff --git a/mne_bids_pipeline/tests/configs/config_ds000248_no_mri.py b/mne_bids_pipeline/tests/configs/config_ds000248_no_mri.py index 9941d2842..08b98e9bb 100644 --- a/mne_bids_pipeline/tests/configs/config_ds000248_no_mri.py +++ b/mne_bids_pipeline/tests/configs/config_ds000248_no_mri.py @@ -1,8 +1,5 @@ -""" -MNE Sample Data: Using the `fsaverage` template MRI -""" +"""MNE Sample Data: Using the `fsaverage` template MRI.""" -study_name = "ds000248" bids_root = "~/mne_data/ds000248" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds000248_no_mri" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" diff --git a/mne_bids_pipeline/tests/configs/config_ds001810.py b/mne_bids_pipeline/tests/configs/config_ds001810.py index 508a99e64..f9790dec0 100644 --- a/mne_bids_pipeline/tests/configs/config_ds001810.py +++ b/mne_bids_pipeline/tests/configs/config_ds001810.py @@ -1,8 +1,8 @@ -""" -tDCS EEG -""" +"""tDCS EEG.""" + +import numpy as np +import pandas as pd -study_name = "ds001810" bids_root = "~/mne_data/ds001810" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds001810" @@ -13,7 +13,7 @@ reject = dict(eeg=100e-6) baseline = (None, 0) conditions = ["61450", "61511"] -contrasts = [("61450", "61511")] +contrasts = [("61450", "61511"), ("letter=='a'", "letter=='b'")] decode = True decoding_n_splits = 3 # only for testing, use 5 otherwise @@ -24,3 +24,42 @@ interpolate_bads_grand_average = False n_jobs = 4 + +epochs_custom_metadata = { + "ses-anodalpost": pd.DataFrame( + { + "ones": np.ones(253), + "letter": ["a" for x in range(150)] + ["b" for x in range(103)], + } + ), + "ses-anodalpre": pd.DataFrame( + { + "ones": np.ones(268), + "letter": ["a" for x in range(150)] + ["b" for x in range(118)], + } + ), + "ses-anodaltDCS": pd.DataFrame( + { + "ones": np.ones(269), + "letter": ["a" for x in range(150)] + ["b" for x in range(119)], + } + ), + "ses-cathodalpost": pd.DataFrame( + { + "ones": np.ones(290), + "letter": ["a" for x in range(150)] + ["b" for x in range(140)], + } + ), + "ses-cathodalpre": pd.DataFrame( + { + "ones": np.ones(267), + "letter": ["a" for x in range(150)] + ["b" for x in range(117)], + } + ), + "ses-cathodaltDCS": pd.DataFrame( + { + "ones": np.ones(297), + "letter": ["a" for x in range(150)] + ["b" for x in range(147)], + } + ), +} # number of rows are hand-set diff --git a/mne_bids_pipeline/tests/configs/config_ds001971.py b/mne_bids_pipeline/tests/configs/config_ds001971.py index 7a64f940d..349dbe23e 100644 --- a/mne_bids_pipeline/tests/configs/config_ds001971.py +++ b/mne_bids_pipeline/tests/configs/config_ds001971.py @@ -3,8 +3,6 @@ See ds001971 on OpenNeuro: https://github.com/OpenNeuroDatasets/ds001971 """ - -study_name = "ds001971" bids_root = "~/mne_data/ds001971" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds001971" diff --git a/mne_bids_pipeline/tests/configs/config_ds003104.py b/mne_bids_pipeline/tests/configs/config_ds003104.py index c88d07161..d47a0a64c 100644 --- a/mne_bids_pipeline/tests/configs/config_ds003104.py +++ b/mne_bids_pipeline/tests/configs/config_ds003104.py @@ -1,6 +1,5 @@ -"""Somato -""" -study_name = "MNE-somato-data-anonymized" +"""Somato.""" + bids_root = "~/mne_data/ds003104" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds003104" subjects_dir = f"{bids_root}/derivatives/freesurfer/subjects" diff --git a/mne_bids_pipeline/tests/configs/config_ds003392.py b/mne_bids_pipeline/tests/configs/config_ds003392.py index edc30228f..eb1d500c2 100644 --- a/mne_bids_pipeline/tests/configs/config_ds003392.py +++ b/mne_bids_pipeline/tests/configs/config_ds003392.py @@ -1,18 +1,21 @@ -""" -hMT+ Localizer -""" -study_name = "localizer" +"""hMT+ Localizer.""" + bids_root = "~/mne_data/ds003392" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds003392" subjects = ["01"] task = "localizer" -find_flat_channels_meg = True -find_noisy_channels_meg = True +# usually a good idea to use True, but we know no bads are detected for this dataset +find_flat_channels_meg = False +find_noisy_channels_meg = False use_maxwell_filter = True +mf_extra_kws = {"bad_condition": "warning"} ch_types = ["meg"] +mf_cal_missing = "warn" +mf_ctc_missing = "warn" + l_freq = 1.0 h_freq = 40.0 raw_resample_sfreq = 250 @@ -21,10 +24,9 @@ # Artifact correction. spatial_filter = "ica" ica_algorithm = "picard-extended_infomax" -ica_max_iterations = 500 +ica_max_iterations = 1000 ica_l_freq = 1.0 ica_n_components = 0.99 -ica_reject_components = "auto" # Epochs epochs_tmin = -0.2 @@ -39,6 +41,11 @@ decoding_time_generalization = True decoding_time_generalization_decim = 4 contrasts = [("incoherent", "coherent")] +decoding_csp = True +decoding_csp_times = [] +decoding_csp_freqs = { + "alpha": (8, 12), +} # Noise estimation noise_cov = "emptyroom" diff --git a/mne_bids_pipeline/tests/configs/config_ds003775.py b/mne_bids_pipeline/tests/configs/config_ds003775.py index 4dae88993..219b1e23a 100644 --- a/mne_bids_pipeline/tests/configs/config_ds003775.py +++ b/mne_bids_pipeline/tests/configs/config_ds003775.py @@ -1,8 +1,5 @@ -""" -SRM Resting-state EEG -""" +"""SRM Resting-state EEG.""" -study_name = "ds003775" bids_root = "~/mne_data/ds003775" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds003775" diff --git a/mne_bids_pipeline/tests/configs/config_ds004107.py b/mne_bids_pipeline/tests/configs/config_ds004107.py index 7a32d952c..0dd70a5ef 100644 --- a/mne_bids_pipeline/tests/configs/config_ds004107.py +++ b/mne_bids_pipeline/tests/configs/config_ds004107.py @@ -1,5 +1,4 @@ -""" -MIND DATA +"""MIND DATA. M.P. Weisend, F.M. Hanlon, R. Montaño, S.P. Ahlfors, A.C. Leuthold, D. Pantazis, J.C. Mosher, A.P. Georgopoulos, M.S. Hämäläinen, C.J. @@ -7,11 +6,11 @@ Paving the way for cross-site pooling of magnetoencephalography (MEG) data. International Congress Series, Volume 1300, Pages 615-618. """ + # This has auditory, median, indx, visual, rest, and emptyroom but let's just # process the auditory (it's the smallest after rest) -study_name = "ds004107" -bids_root = f"~/mne_data/{study_name}" -deriv_root = f"~/mne_data/derivatives/mne-bids-pipeline/{study_name}" +bids_root = "~/mne_data/ds004107" +deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds004107" subjects = ["mind002"] sessions = ["01"] conditions = ["left", "right"] # there are also tone and noise diff --git a/mne_bids_pipeline/tests/configs/config_ds004229.py b/mne_bids_pipeline/tests/configs/config_ds004229.py index e4ca6d449..46e5891c2 100644 --- a/mne_bids_pipeline/tests/configs/config_ds004229.py +++ b/mne_bids_pipeline/tests/configs/config_ds004229.py @@ -1,12 +1,11 @@ -""" -Single-subject infant dataset for testing maxwell_filter with movecomp. +"""Single-subject infant dataset for testing maxwell_filter with movecomp. https://openneuro.org/datasets/ds004229 """ + import mne import numpy as np -study_name = "amnoise" bids_root = "~/mne_data/ds004229" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/ds004229" @@ -20,17 +19,33 @@ z=0.055, ) @ mne.transforms.rotation(x=np.deg2rad(-15)) mf_mc = True +mf_destination = "twa" mf_st_duration = 10 mf_int_order = 6 # lower for smaller heads mf_mc_t_step_min = 0.5 # just for speed! mf_mc_t_window = 0.2 # cleaner cHPI filtering on this dataset -mf_filter_chpi = False # for speed, not needed as we low-pass anyway +mf_filter_chpi = True # even though we lowpass, set to True for test coverage mf_mc_rotation_velocity_limit = 30.0 # deg/s for annotations mf_mc_translation_velocity_limit = 20e-3 # m/s mf_esss = 8 mf_esss_reject = {"grad": 10000e-13, "mag": 40000e-15} ch_types = ["meg"] +# test extra kws +find_bad_channels_extra_kws = { + "ignore_ref": True, +} +mf_extra_kws = { + "ignore_ref": True, +} +notch_extra_kws = { + "method": "spectrum_fit", +} +bandpass_extra_kws = { + "fir_window": "blackman", +} + + l_freq = None h_freq = 40.0 @@ -46,6 +61,10 @@ epochs_tmax = 1 epochs_decim = 6 # 1200->200 Hz baseline = (None, 0) +report_add_epochs_image_kwargs = { + "grad": {"vmin": 0, "vmax": 1e13 * reject["grad"]}, # fT/cm + "mag": {"vmin": 0, "vmax": 1e15 * reject["mag"]}, # fT +} # Conditions / events to consider when epoching conditions = ["auditory"] diff --git a/mne_bids_pipeline/tests/configs/config_eeg_matchingpennies.py b/mne_bids_pipeline/tests/configs/config_eeg_matchingpennies.py index 643e51799..5d0f77614 100644 --- a/mne_bids_pipeline/tests/configs/config_eeg_matchingpennies.py +++ b/mne_bids_pipeline/tests/configs/config_eeg_matchingpennies.py @@ -1,8 +1,5 @@ -""" -Matchingpennies EEG experiment -""" +"""Matchingpennies EEG experiment.""" -study_name = "eeg_matchingpennies" bids_root = "~/mne_data/eeg_matchingpennies" deriv_root = "~/mne_data/derivatives/mne-bids-pipeline/eeg_matchingpennies" @@ -16,3 +13,8 @@ decode = True interpolate_bads_grand_average = False + +l_freq = None +h_freq = 100 +zapline_fline = 50 +zapline_iter = False diff --git a/mne_bids_pipeline/tests/conftest.py b/mne_bids_pipeline/tests/conftest.py index 295b2309a..66bd603bd 100644 --- a/mne_bids_pipeline/tests/conftest.py +++ b/mne_bids_pipeline/tests/conftest.py @@ -1,7 +1,10 @@ """Pytest config.""" +import pytest -def pytest_addoption(parser): + +def pytest_addoption(parser: pytest.Parser) -> None: + """Add pytest command line options.""" parser.addoption( "--download", action="store_true", @@ -9,7 +12,8 @@ def pytest_addoption(parser): ) -def pytest_configure(config): +def pytest_configure(config: pytest.Config) -> None: + """Add pytest configuration settings.""" # register an additional marker config.addinivalue_line("markers", "dataset_test: mark that a test runs a dataset") warning_lines = r""" @@ -39,12 +43,45 @@ def pytest_configure(config): ignore:use_inf_as_na option is deprecated.*:FutureWarning # Dask distributed with jsonschema 4.18 ignore:jsonschema\.RefResolver is deprecated.*:DeprecationWarning + ignore:.*apply_async.*is deprecated.*:DeprecationWarning # seaborn->pandas ignore:is_categorical_dtype is deprecated.*:FutureWarning ignore:use_inf_as_na option is deprecated.*:FutureWarning ignore:All-NaN axis encountered.*:RuntimeWarning # sklearn class not enough samples for cv=5 always:The least populated class in y has only.*:UserWarning + # constrained layout fails on ds003392 + # mne_bids_pipeline/steps/preprocessing/_06a_run_ica.py:551: in run_ica + # report.add_ica( + #../python_env/lib/python3.10/site-packages/mne/report/report.py:1974: in add_ica + # self._add_ica( + #../python_env/lib/python3.10/site-packages/mne/report/report.py:1872: in _add_ica + # self._add_ica_artifact_sources( + #../python_env/lib/python3.10/site-packages/mne/report/report.py:1713: + # in _add_ica_artifact_sources + # self._add_figure( + always:constrained_layout not applied.*:UserWarning + ignore:.*ing of figures upon backend switching.*: + ignore:datetime\.datetime\.utcfromtimestamp.*:DeprecationWarning + ignore:datetime\.datetime\.utcnow.*:DeprecationWarning + # pandas with no good workaround + ignore:The behavior of DataFrame concatenation with empty.*:FutureWarning + # joblib on Windows sometimes + ignore:Persisting input arguments took.*:UserWarning + # matplotlib needs to update + ignore:Conversion of an array with ndim.*:DeprecationWarning + # scipy + ignore:nperseg .* is greater.*:UserWarning + # NumPy 2.0 + ignore:__array_wrap__ must accept context.*:DeprecationWarning + ignore:__array__ implementation doesn't accept.*:DeprecationWarning + # Seaborn + ignore:.*bool was deprecated in Matplotlib.*:DeprecationWarning + ignore:.*bool will be deprecated.*:PendingDeprecationWarning + # sklearn + ignore:.*Liblinear failed to converge.*: + # json-tricks + ignore:json-tricks.*numpy scalar serialization.*:UserWarning """ for warning_line in warning_lines.split("\n"): warning_line = warning_line.strip() diff --git a/mne_bids_pipeline/tests/datasets.py b/mne_bids_pipeline/tests/datasets.py index 60ace0c48..6183e1567 100644 --- a/mne_bids_pipeline/tests/datasets.py +++ b/mne_bids_pipeline/tests/datasets.py @@ -1,36 +1,36 @@ """Definition of the testing datasets.""" -from typing import Dict, List, TypedDict +from typing import TypedDict # If not supplied below, the effective defaults are listed in comments class DATASET_OPTIONS_T(TypedDict, total=False): + """A container for sources, hash, include and excludes of a dataset.""" + git: str # "" openneuro: str # "" osf: str # "" web: str # "" - include: List[str] # [] - exclude: List[str] # [] + mne: str # "" + include: list[str] # [] + exclude: list[str] # [] hash: str # "" + processor: str # "" + fsaverage: bool # False + config_path_extra: str # "" -DATASET_OPTIONS: Dict[str, DATASET_OPTIONS_T] = { +# We can autodetect the need for fsaverage for openneuro datasets based on +# "derivatives/freesurfer/subjects" being in the include list, but for osf.io we +# need to manually mark it with fsaverage=True +DATASET_OPTIONS: dict[str, DATASET_OPTIONS_T] = { "ERP_CORE": { # original dataset: "osf": "9f5w7" - "web": "https://osf.io/3zk6n/download?version=2", + "web": "https://osf.io/download/3zk6n?version=2", "hash": "sha256:ddc94a7c9ba1922637f2770592dd51c019d341bf6bc8558e663e1979a4cb002f", # noqa: E501 + "fsaverage": False, # avoid autodetection via config import (which fails) }, "eeg_matchingpennies": { - # This dataset started out on osf.io as dataset https://osf.io/cj2dr - # then moved to g-node.org. As of 2023/02/28 when we download it via - # datalad it's too (~200 kB/sec!) and times out at the end: - # - # "git": "https://gin.g-node.org/sappelhoff/eeg_matchingpennies", - # "web": "", - # "include": ["sub-05"], - # - # So now we mirror this datalad-fetched git repo back on osf.io! - # original dataset: "osf": "cj2dr" "web": "https://osf.io/download/8rbfk?version=1", "hash": "sha256:06bfbe52c50b9343b6b8d2a5de3dd33e66ad9303f7f6bfbe6868c3c7c375fafd", # noqa: E501 }, @@ -59,21 +59,22 @@ class DATASET_OPTIONS_T(TypedDict, total=False): "openneuro": "ds000248", "include": ["sub-01", "sub-emptyroom", "derivatives/freesurfer/subjects"], "exclude": [ - "derivatives/freesurfer/subjects/fsaverage/mri/aparc.a2005s+aseg.mgz", # noqa: E501 + "derivatives/freesurfer/subjects/fsaverage/mri/aparc.a2005s+aseg.mgz", "derivatives/freesurfer/subjects/fsaverage/mri/aparc+aseg.mgz", - "derivatives/freesurfer/subjects/fsaverage/mri/aparc.a2009s+aseg.mgz", # noqa: E501 - "derivatives/freesurfer/subjects/fsaverage/xhemi/mri/aparc+aseg.mgz", # noqa: E501 + "derivatives/freesurfer/subjects/fsaverage/mri/aparc.a2009s+aseg.mgz", + "derivatives/freesurfer/subjects/fsaverage/xhemi/mri/aparc+aseg.mgz", "derivatives/freesurfer/subjects/sub-01/mri/aparc+aseg.mgz", - "derivatives/freesurfer/subjects/sub-01/mri/aparc.DKTatlas+aseg.mgz", # noqa: E501 - "derivatives/freesurfer/subjects/sub-01/mri/aparc.DKTatlas+aseg.mgz", # noqa: E501 + "derivatives/freesurfer/subjects/sub-01/mri/aparc.DKTatlas+aseg.mgz", + "derivatives/freesurfer/subjects/sub-01/mri/aparc.DKTatlas+aseg.mgz", "derivatives/freesurfer/subjects/sub-01/mri/aparc.a2009s+aseg.mgz", ], + "config_path_extra": "_base", # not just config_ds000248.py }, "ds000117": { "openneuro": "ds000117", "include": [ - "sub-01/ses-meg/meg/sub-01_ses-meg_task-facerecognition_run-01_*", # noqa: E501 - "sub-01/ses-meg/meg/sub-01_ses-meg_task-facerecognition_run-02_*", # noqa: E501 + "sub-01/ses-meg/meg/sub-01_ses-meg_task-facerecognition_run-01_*", + "sub-01/ses-meg/meg/sub-01_ses-meg_task-facerecognition_run-02_*", "sub-01/ses-meg/meg/sub-01_ses-meg_headshape.pos", "sub-01/ses-meg/*.tsv", "sub-01/ses-meg/*.json", @@ -85,6 +86,8 @@ class DATASET_OPTIONS_T(TypedDict, total=False): "ds003775": { "openneuro": "ds003775", "include": ["sub-010"], + # See https://github.com/OpenNeuroOrg/openneuro/issues/2976 + "exclude": ["sub-010/ses-t1/sub-010_ses-t1_scans.tsv"], }, "ds001810": { "openneuro": "ds001810", @@ -120,4 +123,13 @@ class DATASET_OPTIONS_T(TypedDict, total=False): "sub-emptyroom/ses-20000101", ], }, + "MNE-phantom-KIT-data": { + "mne": "phantom_kit", + }, + "MNE-funloc-data": { + "web": "https://osf.io/download/upj3h?version=1", + "hash": "sha256:67dbd38f7207db5c93c540d9c7c92ec2ac09ee1bd1b5d5e5cdd8866c08ec4858", # noqa: E501 + "processor": "untar", + "fsaverage": True, + }, } diff --git a/mne_bids_pipeline/tests/sub-010_ses-t1_scans.tsv b/mne_bids_pipeline/tests/sub-010_ses-t1_scans.tsv new file mode 100644 index 000000000..54b711284 --- /dev/null +++ b/mne_bids_pipeline/tests/sub-010_ses-t1_scans.tsv @@ -0,0 +1,2 @@ +filename acq_time +eeg/sub-010_ses-t1_task-resteyesc_eeg.edf 2017-05-09T12:11:44 diff --git a/mne_bids_pipeline/tests/test_cli.py b/mne_bids_pipeline/tests/test_cli.py index 607cbdd67..e0020bc12 100644 --- a/mne_bids_pipeline/tests/test_cli.py +++ b/mne_bids_pipeline/tests/test_cli.py @@ -2,11 +2,15 @@ import importlib import sys +from pathlib import Path + import pytest + from mne_bids_pipeline._main import main -def test_config_generation(tmp_path, monkeypatch): +def test_config_generation(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + """Test the generation of a default config file.""" cmd = ["mne_bids_pipeline", "--create-config"] monkeypatch.setattr(sys, "argv", cmd) with pytest.raises(SystemExit, match="2"): @@ -15,6 +19,6 @@ def test_config_generation(tmp_path, monkeypatch): cmd.append(str(cfg_path)) main() assert cfg_path.is_file() - spec = importlib.util.spec_from_file_location(cfg_path) + spec = importlib.util.spec_from_file_location(str(cfg_path)) varnames = [v for v in dir(spec) if not v.startswith("__")] assert varnames == [] diff --git a/mne_bids_pipeline/tests/test_documented.py b/mne_bids_pipeline/tests/test_documented.py index 097fc1032..bc3eee1e3 100644 --- a/mne_bids_pipeline/tests/test_documented.py +++ b/mne_bids_pipeline/tests/test_documented.py @@ -1,61 +1,97 @@ """Test that all config values are documented.""" + import ast -from pathlib import Path import os import re +import sys +from pathlib import Path + +import pytest import yaml +from mne_bids_pipeline._config_import import _get_default_config, _import_config +from mne_bids_pipeline._config_template import create_template_config +from mne_bids_pipeline._docs import _EXECUTION_OPTIONS, _ParseConfigSteps from mne_bids_pipeline.tests.datasets import DATASET_OPTIONS from mne_bids_pipeline.tests.test_run import TEST_SUITE -from mne_bids_pipeline._config_import import _get_default_config root_path = Path(__file__).parent.parent -def test_options_documented(): +def test_options_documented() -> None: """Test that all options are suitably documented.""" # use ast to parse _config.py for assignments - with open(root_path / "_config.py", "r") as fid: - contents = fid.read() - contents = ast.parse(contents) - in_config = [ - item.target.id for item in contents.body if isinstance(item, ast.AnnAssign) + with open(root_path / "_config.py") as fid: + contents_str = fid.read() + contents = ast.parse(contents_str) + assert isinstance(contents, ast.Module), type(contents) + in_config_list = [ + item.target.id + for item in contents.body + if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name) ] - assert len(set(in_config)) == len(in_config) - in_config = set(in_config) + assert len(set(in_config_list)) == len(in_config_list) + in_config = set(in_config_list) + del in_config_list # ensure we clean our namespace correctly config = _get_default_config() config_names = set(d for d in dir(config) if not d.startswith("_")) assert in_config == config_names settings_path = root_path.parent / "docs" / "source" / "settings" + sys.path.append(str(settings_path)) + try: + from gen_settings import main # pyright: ignore [reportMissingImports] + finally: + sys.path.pop() + main() assert settings_path.is_dir() - in_doc = set() + in_doc: dict[str, set[str]] = dict() key = " - " - allowed_duplicates = set( - [ - "source_info_path_update", - ] - ) for dirpath, _, fnames in os.walk(settings_path): for fname in fnames: if not fname.endswith(".md"): continue # This is a .md file - with open(Path(dirpath) / fname, "r") as fid: + # convert to relative path + fname = os.path.join(os.path.relpath(dirpath, settings_path), fname) + assert fname not in in_doc + in_doc[fname] = set() + with open(settings_path / fname) as fid: for line in fid: if not line.startswith(key): continue # The line starts with our magic key val = line[len(key) :].strip() - if val not in allowed_duplicates: - assert val not in in_doc, "Duplicate documentation" - in_doc.add(val) + for other in in_doc: + why = f"Duplicate docs in {fname} and {other} for {val}" + assert val not in in_doc[other], why + in_doc[fname].add(val) what = "docs/source/settings doc" - assert in_doc.difference(in_config) == set(), f"Extra values in {what}" - assert in_config.difference(in_doc) == set(), f"Values missing from {what}" + in_doc_all = set() + for vals in in_doc.values(): + in_doc_all.update(vals) + assert in_doc_all.difference(in_config) == set(), f"Extra values in {what}" + assert in_config.difference(in_doc_all) == set(), f"Values missing from {what}" + + +def test_config_options_used() -> None: + """Test that all config options are used somewhere.""" + config = _get_default_config() + config_names = set(d for d in dir(config) if not d.startswith("__")) + for key in ("_epochs_split_size", "_raw_split_size"): + config_names.add(key) + for key in _EXECUTION_OPTIONS: + config_names.remove(key) + pcs = _ParseConfigSteps(force_empty=()) + missing_from_config = sorted(set(pcs.steps) - config_names) + assert missing_from_config == [], f"Missing from config: {missing_from_config}" + missing_from_steps = sorted(config_names - set(pcs.steps)) + assert missing_from_steps == [], f"Missing from steps: {missing_from_steps}" + for key, val in pcs.steps.items(): + assert val, f"No steps for {key}" -def test_datasets_in_doc(): +def test_datasets_in_doc() -> None: """Test that all datasets in tests are in the doc.""" # There are four things to keep in sync: # @@ -67,15 +103,15 @@ def test_datasets_in_doc(): # So let's make sure they stay in sync. # 1. Read cache, test, etc. entries from CircleCI - with open(root_path.parent / ".circleci" / "config.yml", "r") as fid: + with open(root_path.parent / ".circleci" / "config.yml") as fid: circle_yaml_src = fid.read() circle_yaml = yaml.safe_load(circle_yaml_src) - caches = [job[6:] for job in circle_yaml["jobs"] if job.startswith("cache_")] - assert len(caches) == len(set(caches)) - caches = set(caches) - tests = [job[5:] for job in circle_yaml["jobs"] if job.startswith("test_")] - assert len(tests) == len(set(tests)) - tests = set(tests) + caches_list = [job[6:] for job in circle_yaml["jobs"] if job.startswith("cache_")] + caches = set(caches_list) + assert len(caches_list) == len(caches) + tests_list = [job[5:] for job in circle_yaml["jobs"] if job.startswith("test_")] + assert len(tests_list) == len(set(tests_list)) + tests = set(tests_list) # Rather than going circle_yaml['workflows']['commit']['jobs'] and # make sure everything is consistent there (too much work), let's at least # check that we get the correct number using `.count`. @@ -127,14 +163,16 @@ def test_datasets_in_doc(): # 3. Read examples from docs (being careful about tags we can't read) class SafeLoaderIgnoreUnknown(yaml.SafeLoader): - def ignore_unknown(self, node): + def ignore_unknown(self, node: yaml.Node) -> None: return None SafeLoaderIgnoreUnknown.add_constructor( - None, SafeLoaderIgnoreUnknown.ignore_unknown + # PyYAML stubs have an error -- this can be None but mypy says it can't + None, # type: ignore + SafeLoaderIgnoreUnknown.ignore_unknown, ) - with open(root_path.parent / "docs" / "mkdocs.yml", "r") as fid: + with open(root_path.parent / "docs" / "mkdocs.yml") as fid: examples = yaml.load(fid.read(), Loader=SafeLoaderIgnoreUnknown) examples = [n for n in examples["nav"] if list(n)[0] == "Examples"][0] examples = [ex for ex in examples["Examples"] if isinstance(ex, str)] @@ -143,12 +181,14 @@ def ignore_unknown(self, node): examples = set(examples) # 4. DATASET_OPTIONS - dataset_names = list(DATASET_OPTIONS) - assert len(dataset_names) == len(set(dataset_names)) + dataset_names_list = list(DATASET_OPTIONS) + dataset_names = set(dataset_names_list) + assert len(dataset_names_list) == len(dataset_names) # 5. TEST_SUITE - test_names = list(TEST_SUITE) - assert len(test_names) == len(set(test_names)) + test_names_list = list(TEST_SUITE) + test_names = set(test_names_list) + assert len(test_names_list) == len(test_names) # Some have been split into multiple test runs, so trim down to the same # set as caches @@ -171,3 +211,39 @@ def ignore_unknown(self, node): assert tests == examples, "CircleCI tests != docs/mkdocs.yml Examples" assert tests == dataset_names, "CircleCI tests != tests/datasets.py" assert tests == test_names, "CircleCI tests != tests/test_run.py" + + +def _replace_config_value_in_file(fpath: Path, config_key: str, new_value: str) -> None: + """Assign a value to a config key in a file, and uncomment the line if needed.""" + lines = fpath.read_text().split("\n") + pattern = re.compile(rf"(?:# )?({config_key}: .* = )(?:.*)") + for ix, line in enumerate(lines): + if pattern.match(line): + lines[ix] = pattern.sub( + rf"\1{new_value}", + line, # omit comment marker, change default value to `new_value` + ) + break + fpath.write_text("\n".join(lines)) + + +def test_config_template_valid(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure our config template is syntactically valid (importable).""" + monkeypatch.setenv("BIDS_ROOT", str(tmp_path)) + fpath = tmp_path / "foo.py" + create_template_config(fpath) + # `ch_types` fails pydantic validation (its default is `[]` but its annotation + # requires length > 0) + with pytest.raises( + ValueError, match="ch_types\n Value should have at least 1 item" + ): + _import_config(config_path=fpath, log=False) + # Give `ch_types` a value so pydantic will succeed... + _replace_config_value_in_file(fpath, "ch_types", '["meg"]') + # ...but now `_check_config` will raise an error that `conditions` cannot be None + # unless `task_is_rest = True` (which defaults to False) + with pytest.raises(ValueError, match="the `conditions` parameter is empty"): + _import_config(config_path=fpath, log=False) + # give a non-None value for `conditions`, now importing the config should work + _replace_config_value_in_file(fpath, "conditions", '["foo"]') + _import_config(config_path=fpath, log=False) diff --git a/mne_bids_pipeline/tests/test_functions.py b/mne_bids_pipeline/tests/test_functions.py new file mode 100644 index 000000000..6bd4c7177 --- /dev/null +++ b/mne_bids_pipeline/tests/test_functions.py @@ -0,0 +1,65 @@ +"""Test some properties of our core processing-step functions.""" + +import ast +import inspect + +import pytest + +from mne_bids_pipeline._config_utils import _get_step_modules + +# mne_bids_pipeline.init._01_init_derivatives_dir: +FLAT_MODULES = {x.__name__: x for x in sum(_get_step_modules().values(), ())} + + +@pytest.mark.parametrize("module_name", list(FLAT_MODULES)) +def test_all_functions_return(module_name: str) -> None: + """Test that all functions decorated with failsafe_run return a dict.""" + # Find the functions within the module that use the failsafe_run decorator + module = FLAT_MODULES[module_name] + funcs = list() + for name in dir(module): + obj = getattr(module, name) + if not callable(obj): + continue + if getattr(obj, "__module__", None) != module_name: + continue + if not hasattr(obj, "__wrapped__"): + continue + # All our failsafe_run decorated functions should look like this + assert "__mne_bids_pipeline_failsafe_wrapper__" in repr(obj.__code__) + funcs.append(obj) + # Some module names we know don't have any + if module_name.split(".")[-1] in ("_01_recon_all",): + assert len(funcs) == 0 + return + + assert len(funcs) != 0, f"No failsafe_runs functions found in {module_name}" + + # Adapted from numpydoc RT01 validation + def get_returns_not_on_nested_functions(node: ast.AST) -> list[ast.Return]: + returns = [node] if isinstance(node, ast.Return) else [] + for child in ast.iter_child_nodes(node): + # Ignore nested functions and its subtrees. + if not isinstance(child, ast.FunctionDef): + child_returns = get_returns_not_on_nested_functions(child) + returns.extend(child_returns) + return returns + + for func in funcs: + what = f"{module_name}.{func.__name__}" + tree = ast.parse(inspect.getsource(func.__wrapped__)).body + if func.__closure__[-1].cell_contents is False: + continue # last closure node is require_output=False + assert tree, f"Failed to parse source code for {what}" + returns = get_returns_not_on_nested_functions(tree[0]) + return_values = [r.value for r in returns] + # Replace Constant nodes valued None for None. + for i, v in enumerate(return_values): + if isinstance(v, ast.Constant) and v.value is None: + return_values[i] = None + assert len(return_values), f"Function does not return anything: {what}" + for r in return_values: + what = f"Function does _prep_out_files: {what}" + assert isinstance(r, ast.Call), what + assert isinstance(r.func, ast.Name), what + assert r.func.id == "_prep_out_files", what diff --git a/mne_bids_pipeline/tests/test_run.py b/mne_bids_pipeline/tests/test_run.py index eb07233b1..a0f65c2a4 100644 --- a/mne_bids_pipeline/tests/test_run.py +++ b/mne_bids_pipeline/tests/test_run.py @@ -1,14 +1,21 @@ """Download test data and run a test suite.""" -import sys + +import os +import re import shutil +import sys +from collections.abc import Collection +from contextlib import nullcontext from pathlib import Path -from typing import Collection, Dict, Optional, TypedDict -import os +from typing import Any, TypedDict import pytest +from h5io import read_hdf5 +from mne_bids import BIDSPath, get_bids_path_from_fname -from mne_bids_pipeline._main import main +from mne_bids_pipeline._config_import import _import_config from mne_bids_pipeline._download import main as download_main +from mne_bids_pipeline._main import main BIDS_PIPELINE_DIR = Path(__file__).absolute().parents[1] @@ -23,13 +30,13 @@ class _TestOptionsT(TypedDict, total=False): dataset: str # key.split("_")[0] config: str # f"config_{key}.py" steps: Collection[str] # ("preprocessing", "sensor") - task: Optional[str] # None - env: Dict[str, str] # {} + task: str | None # None + env: dict[str, str] # {} requires: Collection[str] # () extra_config: str # "" -TEST_SUITE: Dict[str, _TestOptionsT] = { +TEST_SUITE: dict[str, _TestOptionsT] = { "ds003392": {}, "ds004229": {}, "ds001971": {}, @@ -59,7 +66,13 @@ class _TestOptionsT(TypedDict, total=False): _n_jobs = {"preprocessing/_05_make_epochs": 1} """, }, - "ds000248_ica": {}, + "ds000248_ica": { + "extra_config": """ +_raw_split_size = "60MB" +_epochs_split_size = "30MB" +_n_jobs = {} +""" + }, "ds000248_T1_BEM": { "steps": ("source/make_bem_surfaces",), "requires": ("freesurfer",), @@ -123,11 +136,19 @@ class _TestOptionsT(TypedDict, total=False): "config": "config_ERP_CORE.py", "task": "P3", }, + "MNE-phantom-KIT-data": { + "config": "config_MNE_phantom_KIT_data.py", + }, + "MNE-funloc-data": { + "config": "config_MNE_funloc_data.py", + "steps": ["init", "preprocessing", "sensor", "source"], + }, } @pytest.fixture() -def dataset_test(request): +def dataset_test(request: pytest.FixtureRequest) -> None: + """Provide a defined context for our dataset tests.""" # There is probably a cleaner way to get this param, but this works for now capsys = request.getfixturevalue("capsys") dataset = request.getfixturevalue("dataset") @@ -144,7 +165,13 @@ def dataset_test(request): @pytest.mark.dataset_test @pytest.mark.parametrize("dataset", list(TEST_SUITE)) -def test_run(dataset, monkeypatch, dataset_test, capsys, tmp_path): +def test_run( + dataset: str, + monkeypatch: pytest.MonkeyPatch, + dataset_test: Any, + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: """Test running a dataset.""" test_options = TEST_SUITE[dataset] config = test_options.get("config", f"config_{dataset}.py") @@ -169,11 +196,20 @@ def test_run(dataset, monkeypatch, dataset_test, capsys, tmp_path): src=fix_path / "ds001971_participants.tsv", dst=DATA_DIR / "ds001971" / "participants.tsv", ) + elif dataset == "ds003775": + shutil.copy( + src=fix_path / "sub-010_ses-t1_scans.tsv", + dst=DATA_DIR + / "ds003775" + / "sub-010" + / "ses-t1" + / "sub-010_ses-t1_scans.tsv", + ) # Run the tests. steps = test_options.get("steps", ("preprocessing", "sensor")) task = test_options.get("task", None) - command = ["mne_bids_pipeline", str(config_path), f'--steps={",".join(steps)}'] + command = ["mne_bids_pipeline", str(config_path), f"--steps={','.join(steps)}"] if task: command.append(f"--task={task}") if "--pdb" in sys.argv: @@ -183,3 +219,194 @@ def test_run(dataset, monkeypatch, dataset_test, capsys, tmp_path): with capsys.disabled(): print() main() + + +@pytest.mark.parametrize("allow_missing_sessions", (False, True)) +def test_missing_sessions( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + capsys: pytest.CaptureFixture[str], + allow_missing_sessions: bool, +) -> None: + """Test the `allow_missing_sessions` config variable.""" + dataset = "fake" + bids_root = tmp_path / dataset + files = ( + "dataset_description.json", + *(f"participants.{x}" for x in ("json", "tsv")), + *(f"sub-1/sub-1_sessions.{x}" for x in ("json", "tsv")), + *( + f"sub-1/ses-a/eeg/sub-1_ses-a_task-foo_{x}.tsv" + for x in ("channels", "events") + ), + *( + f"sub-1/ses-a/eeg/sub-1_ses-a_task-foo_eeg.{x}" + for x in ("eeg", "json", "vhdr", "vmrk") + ), + ) + for _file in files: + path = bids_root / _file + path.parent.mkdir(parents=True, exist_ok=True) + path.touch() + # fake a config file (can't use static file because `bids_root` is in `tmp_path`) + config = f""" +bids_root = "{bids_root}" +deriv_root = "{tmp_path / "derivatives" / "mne-bids-pipeline" / dataset}" +interactive = False +subjects = ["1"] +sessions = ["a", "b"] +ch_types = ["eeg"] +conditions = ["zzz"] +allow_missing_sessions = {allow_missing_sessions} +""" + config_path = tmp_path / "fake_config_missing_session.py" + with open(config_path, "w") as fid: + fid.write(config) + # set up the context handler + context = ( + nullcontext() + if allow_missing_sessions + else pytest.raises(RuntimeError, match=r"Subject 1 is missing session \['b'\]") + ) + # run + command = [ + "mne_bids_pipeline", + str(config_path), + "--steps=init/_01_init_derivatives_dir", + ] + if "--pdb" in sys.argv: + command.append("--n_jobs=1") + monkeypatch.setenv("_MNE_BIDS_STUDY_TESTING", "true") + monkeypatch.setattr(sys, "argv", command) + with capsys.disabled(): + print() + with context: + main() + + +@pytest.mark.dataset_test +def test_session_specific_mri( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + """Test of (faked) session-specific MRIs.""" + dataset = "MNE-funloc-data" + test_options = TEST_SUITE[dataset] + config = test_options.get("config", f"config_{dataset}.py") + config_path = BIDS_PIPELINE_DIR / "tests" / "configs" / config + config_obj = _import_config(config_path=config_path) + # copy the dataset to a tmpdir, and in the destination location make it + # seem like there's only one subj with different MRIs for different sessions + new_bids_path = BIDSPath(root=tmp_path / dataset, subject="01", session="a") + # sub-01/* → sub-01/ses-a/* ; sub-02/* → sub-01/ses-b/* + for src_subj, dst_sess in (("01", "a"), ("02", "b")): + src_dir = config_obj.bids_root / f"sub-{src_subj}" + dst_dir = new_bids_path.root / "sub-01" / f"ses-{dst_sess}" + for walk_root, dirs, files in src_dir.walk(): + offset = walk_root.relative_to(src_dir) + for _dir in dirs: + (dst_dir / offset / _dir).mkdir(parents=True) + for _file in files: + bp = get_bids_path_from_fname(walk_root / _file) + bp.update(root=new_bids_path.root, subject="01", session=dst_sess) + # rewrite scans.tsv files to have correct filenames in it + if _file.endswith("scans.tsv"): + lines = [ + line.replace(f"sub-{src_subj}", f"sub-01_ses-{dst_sess}") + for line in (walk_root / _file).read_text().split("\n") + ] + (dst_dir / offset / bp.basename).write_text("\n".join(lines)) + # For all other files, a simple copy suffices; rewriting + # `raw.info["subject_info"]["his_id"]` is not necessary because MNE-BIDS + # overwrites it with the value in `participants.tsv` anyway. + else: + shutil.copyfile( + src=walk_root / _file, dst=dst_dir / offset / bp.basename + ) + # emptyroom + src_dir = config_obj.bids_root / "sub-emptyroom" + dst_dir = new_bids_path.root / "sub-emptyroom" + shutil.copytree(src=src_dir, dst=dst_dir) + # root-level files (dataset description, etc) + src_dir = config_obj.bids_root + dst_dir = new_bids_path.root + files = [f for f in src_dir.iterdir() if f.is_file()] + for _file in files: + # in theory we should rewrite `participants.tsv` to remove the `sub-02` line, + # but in practice it will just get ignored so we won't bother. + shutil.copyfile(src=_file, dst=dst_dir / _file.name) + # derivatives (freesurfer files) + src_dir = config_obj.bids_root / "derivatives" / "freesurfer" / "subjects" + dst_dir = new_bids_path.root / "derivatives" / "freesurfer" / "subjects" + dst_dir.mkdir(parents=True) + freesurfer_subject_mapping = {"sub-01": "sub-01_ses-a", "sub-02": "sub-01_ses-b"} + for walk_root, dirs, files in src_dir.walk(): + # change "root" so that in later steps of the walk when we're inside a subject's + # dir, the "offset" (folders between dst_dir and filename) will be correct + new_root = walk_root + if "sub-01" in walk_root.parts or "sub-02" in walk_root.parts: + new_root = Path( + *[freesurfer_subject_mapping.get(p, p) for p in new_root.parts] + ) + offset = new_root.relative_to(src_dir) + # the actual subject dirs need their names changed + for _dir in dirs: + _dir = freesurfer_subject_mapping.get(_dir, _dir) + (dst_dir / offset / _dir).mkdir() + # for filenames that contain the subject identifier (BEM files, morph maps), + # we need to change the filename too, not just parent folder name + for _file in files: + dst_file = _file + for subj in freesurfer_subject_mapping: + if subj in dst_file: + dst_file = dst_file.replace(subj, freesurfer_subject_mapping[subj]) + break + shutil.copyfile(src=walk_root / _file, dst=dst_dir / offset / dst_file) + # update config so that `subjects_dir` and `deriv_root` also point to the tempdir + extra_config = f""" +from pathlib import Path +subjects_dir = "{new_bids_path.root / "derivatives" / "freesurfer" / "subjects"}" +deriv_root = Path("{new_bids_path.root}") / "derivatives" / "mne-bids-pipeline" / "MNE-funloc-data" +""" # noqa E501 + extra_path = tmp_path / "extra_config.py" + extra_path.write_text(extra_config) + monkeypatch.setenv("_MNE_BIDS_STUDY_TESTING_EXTRA_CONFIG", str(extra_path)) + # Run the tests. + steps = test_options.get("steps", ()) + command = ["mne_bids_pipeline", str(config_path), f"--steps={','.join(steps)}"] + # hack in the new bids_root + command.append(f"--root-dir={new_bids_path.root}") + if "--pdb" in sys.argv: + command.append("--n_jobs=1") + monkeypatch.setenv("_MNE_BIDS_STUDY_TESTING", "true") + monkeypatch.setattr(sys, "argv", command) + with capsys.disabled(): + print() + main() + # check some things that are indicative of different MRIs being used in each session + results = list() + for sess in ("a", "b"): + fname = ( + new_bids_path.root + / "derivatives" + / "mne-bids-pipeline" + / "MNE-funloc-data" + / "sub-01" + / f"ses-{sess}" + / "meg" + / f"sub-01_ses-{sess}_task-funloc_report.h5" + ) + report = read_hdf5(fname, title="mnepython") + coregs = next( + filter(lambda x: x["dom_id"] == "Sensor_alignment", report["_content"]) + ) + pattern = re.compile( + r"Average distance from (?P\d+) digitized points to head: " + r"(?P\d+(?:\.\d+)?) mm" + ) + result = pattern.search(coregs["html"]) + assert result is not None + assert float(result.group("dist")) < 3 # fit between pts and outer_skin < 3 mm + results.append(result.groups()) + assert results[0] != results[1] # different npts and/or different mean distance diff --git a/mne_bids_pipeline/tests/test_validation.py b/mne_bids_pipeline/tests/test_validation.py index 25d5abdaa..611957acd 100644 --- a/mne_bids_pipeline/tests/test_validation.py +++ b/mne_bids_pipeline/tests/test_validation.py @@ -1,8 +1,14 @@ +"""Test the pipeline configuration import validator.""" + +from pathlib import Path +from shutil import rmtree + import pytest -from mne_bids_pipeline._config_import import _import_config + +from mne_bids_pipeline._config_import import ConfigError, _import_config -def test_validation(tmp_path, capsys): +def test_validation(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: """Test that misspellings are caught by our config import validator.""" config_path = tmp_path / "config.py" bad_text = "" @@ -13,7 +19,7 @@ def test_validation(tmp_path, capsys): bad_text += f"bids_root = '{tmp_path}'\n" # no ch_types config_path.write_text(bad_text) - with pytest.raises(ValueError, match="Please specify ch_types"): + with pytest.raises(ValueError, match="Value should have at least 1 item"): _import_config(config_path=config_path) bad_text += "ch_types = ['eeg']\n" # conditions @@ -43,6 +49,55 @@ def test_validation(tmp_path, capsys): _import_config(config_path=config_path) msg, err = capsys.readouterr() assert msg == err == "" # no new message + # TWA headpos without movement compensation + bad_text = working_text + "mf_destination = 'twa'\n" + config_path.write_text(bad_text) + with pytest.raises(ConfigError, match="cannot compute time-weighted average head"): + _import_config(config_path=config_path) + # maxfilter extra kwargs + bad_text = working_text + "mf_extra_kws = {'calibration': 'x', 'head_pos': False}\n" + config_path.write_text(bad_text) + with pytest.raises(ConfigError, match="contains keys calibration, head_pos that"): + _import_config(config_path=config_path) + # ecg_channel_dict key validation (all subjects have channels specified) + try: + # these must exist for dict check to work + for sub, ses in {"1": "a", "2": "b"}.items(): + _dir = tmp_path / f"sub-{sub}" / f"ses-{ses}" / "eeg" + _dir.mkdir(parents=True) + (_dir / f"sub-{sub}_ses-{ses}_eeg.fif").touch() + except Exception: + raise + else: + # test the config import when sessions = "all" (default) + bad_text = ( + working_text + "subjects = ['1', '2']\n" + "allow_missing_sessions = True\n" + "ssp_ecg_channel = {'sub-1': 'MEG0111'}\n" + ) # OK to omit session from sub-1, but entry for sub-2 is missing + config_path.write_text(bad_text) + with pytest.raises(ConfigError, match=r"Missing entries.*\n sub-2_ses-b"): + _import_config(config_path=config_path) + # test when single session specified in config + bad_text = ( + working_text + "subjects = ['1', '2']\n" + "sessions = ['a']\n" + "allow_missing_sessions = True\n" + "ssp_ecg_channel = {'sub-1_ses-b': 'MEG0111'}\n" # no entry for sub-1_ses-a + ) + config_path.write_text(bad_text) + with pytest.raises(ConfigError, match=r"Missing entries.*\n sub-1_ses-a"): + _import_config(config_path=config_path) + # clean up + finally: + for sub in ("1", "2"): + rmtree(tmp_path / f"sub-{sub}") + + # ecg_channel_dict key validation (keys in dict are well-formed) + bad_text = working_text + "ssp_ecg_channel = {'sub-0_1': 'MEG0111'}\n" # underscore + config_path.write_text(bad_text) + with pytest.raises(ConfigError, match="Malformed keys in ssp_ecg_channel dict:.*"): + _import_config(config_path=config_path) # old values bad_text = working_text bad_text += "debug = True\n" diff --git a/mne_bids_pipeline/typing.py b/mne_bids_pipeline/typing.py index 7b989309c..ebe2bcec3 100644 --- a/mne_bids_pipeline/typing.py +++ b/mne_bids_pipeline/typing.py @@ -1,43 +1,68 @@ -"""Typing.""" +"""Custom data types for MNE-BIDS-Pipeline.""" import pathlib import sys -from typing import Union, List, Dict -from typing_extensions import Annotated +from typing import Annotated, Any, Literal, TypeAlias if sys.version_info < (3, 12): from typing_extensions import TypedDict else: from typing import TypedDict +import mne import numpy as np +from mne_bids import BIDSPath from numpy.typing import ArrayLike from pydantic import PlainValidator -import mne -PathLike = Union[str, pathlib.Path] +PathLike = str | pathlib.Path + +__all__ = [ + "ArbitraryContrast", + "DigMontageType", + "FloatArrayLike", + "FloatArrayT", + "LogKwargsT", + "OutFilesT", + "PathLike", + "RunTypeT", + "RunKindT", + "TypedDict", +] + + +ShapeT: TypeAlias = tuple[int, ...] | tuple[int] +IntArrayT: TypeAlias = np.ndarray[ShapeT, np.dtype[np.integer[Any]]] +FloatArrayT: TypeAlias = np.ndarray[ShapeT, np.dtype[np.floating[Any]]] +OutFilesT: TypeAlias = dict[str, tuple[str, str | float]] +InFilesT: TypeAlias = dict[str, BIDSPath] # Only BIDSPath +InFilesPathT: TypeAlias = dict[str, BIDSPath | pathlib.Path] # allow generic Path too class ArbitraryContrast(TypedDict): + """Statistical contrast with arbitrary weights.""" + name: str - conditions: List[str] - weights: List[float] + conditions: list[str] + weights: list[float] class LogKwargsT(TypedDict): + """Container for logger keyword arguments.""" + msg: str - extra: Dict[str, str] + extra: dict[str, str] -class ReferenceRunParams(TypedDict): - montage: mne.channels.DigMontage - dev_head_t: mne.Transform +RunTypeT = Literal["experimental", "empty-room", "resting-state"] +RunKindT = Literal["orig", "sss", "filt"] -def assert_float_array_like(val): +def assert_float_array_like(val: Any) -> FloatArrayT: + """Convert the input into a NumPy float array.""" # https://docs.pydantic.dev/latest/errors/errors/#custom-errors # Should raise ValueError or AssertionError... NumPy should do this for us - return np.array(val, dtype="float") + return np.array(val, dtype=np.float64) FloatArrayLike = Annotated[ @@ -47,7 +72,8 @@ def assert_float_array_like(val): ] -def assert_dig_montage(val): +def assert_dig_montage(val: mne.channels.DigMontage) -> mne.channels.DigMontage: + """Assert that the input is a DigMontage.""" assert isinstance(val, mne.channels.DigMontage) return val diff --git a/pyproject.toml b/pyproject.toml index 5d217e36d..e72787611 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,78 +1,83 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + [project] name = "mne-bids-pipeline" # Keep in sync with README.md: description = "A full-flegded processing pipeline for your MEG and EEG data" readme = "README.md" -requires-python = ">=3.8" -license = {file = "LICENSE.txt"} +requires-python = ">=3.10" +license = { file = "LICENSE.txt" } keywords = ["science", "neuroscience", "psychology"] authors = [ - {name = "Eric Larson"}, - {name = "Alexandre Gramfort"}, - {name = "Mainak Jas"}, - {name = "Richard Höchenberger", email = "richard.hoechenberger@gmail.com"}, + { name = "Eric Larson" }, + { name = "Alexandre Gramfort" }, + { name = "Mainak Jas" }, + { name = "Richard Höchenberger", email = "richard.hoechenberger@gmail.com" }, ] classifiers = [ "Intended Audience :: Science/Research", - "Programming Language :: Python" + "Programming Language :: Python", ] dependencies = [ - "typing_extensions; python_version < '3.8'", - "importlib_metadata; python_version < '3.8'", - "psutil", # for joblib - "packaging", - "numpy", - "scipy", - "matplotlib", - "nibabel", - "joblib >= 0.14", - "threadpoolctl", - "dask[distributed]", - "bokeh < 3", # for distributed dashboard - "jupyter-server-proxy", # to have dask and jupyter working together - "scikit-learn", - "pandas", - "seaborn", - "json_tricks", - "pydantic >= 2.0.0", - "rich", - "python-picard", - "qtpy", - "pyvista", - "pyvistaqt", - "openpyxl", - "autoreject", - "mne[hdf5] >=1.2", - "mne-bids[full]", - "filelock", - "setuptools >=65", + "psutil", # for joblib + "packaging", + "numpy", + "scipy", + "matplotlib", + "nibabel", + "joblib >= 0.14", + "threadpoolctl", + "dask[distributed]", + "bokeh", # for distributed dashboard + "jupyter-server-proxy", # to have dask and jupyter working together + "scikit-learn", + "pandas", + "pyarrow", # from pandas + "seaborn", + "json_tricks", + "pydantic >= 2.0.0", + "annotated-types", + "rich", + "python-picard", + "qtpy", + "pyvista", + "pyvistaqt", + "openpyxl", + "autoreject", + "mne[hdf5] >=1.7", + "mne-bids[full]", + "filelock", + "meegkit" ] dynamic = ["version"] [project.optional-dependencies] tests = [ - "pytest", - "pytest-cov", - "pooch", - "psutil", - "datalad", - "ruff", - "mkdocs", - "mkdocs-material >= 9.0.4", - "mkdocs-material-extensions", - "mkdocs-macros-plugin", - "mkdocs-include-markdown-plugin", - "mkdocs-exclude", - "mkdocstrings-python", - "mike", - "jinja2", - "black", # function signature formatting - "livereload", - "openneuro-py >= 2022.2.0", - "httpx >= 0.20", - "tqdm", - "Pygments", - "pyyaml", + "pytest", + "pytest-cov", + "pooch", + "psutil", + "ruff", + "jinja2", + "openneuro-py >= 2022.2.0", + "httpx >= 0.20", + "tqdm", + "Pygments", + "pyyaml", +] +docs = [ + "mkdocs", + "mkdocs-material >= 9.0.4", + "mkdocs-material-extensions", + "mkdocs-macros-plugin", + "mkdocs-include-markdown-plugin", + "mkdocs-exclude", + "mkdocstrings-python", + "mike", + "livereload", + "black", # docstring reformatting ] [project.scripts] @@ -83,19 +88,21 @@ homepage = "https://mne.tools/mne-bids-pipeline" repository = "https://github.com/mne-tools/mne-bids-pipeline" changelog = "http://mne.tools/mne-bids-pipeline/changes.html" -[build-system] -requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2", "wheel"] -build-backend = "setuptools.build_meta" - -[tool.setuptools_scm] -tag_regex = "^(?Pv)?(?P[0-9.]+)(?P.*)?$" -version_scheme = "release-branch-semver" +[tool.hatch.version] +source = "vcs" +raw-options = { version_scheme = "release-branch-semver" } -[tool.setuptools.packages.find] -exclude = ["false"] # on CircleCI this folder appears during pip install -ve. for an unknown reason - -[tool.setuptools.package-data] -"mne_bids_pipeline.steps.freesurfer.contrib" = ["version"] +[tool.hatch.build] +exclude = [ + "/.*", + "/codecov.yml", + "**/tests", + "/docs", + "/docs/source/examples/gen_examples.py", # specify explicitly because its exclusion is negated in .gitignore + "/Makefile", + "/CONTRIBUTING.md", + "ignore_words.txt", +] [tool.codespell] skip = "docs/site/*,*.html,steps/freesurfer/contrib/*" @@ -108,13 +115,48 @@ count = "" [tool.pytest.ini_options] addopts = "-ra -vv --tb=short --cov=mne_bids_pipeline --cov-report= --junit-xml=junit-results.xml --durations=10" -testpaths = [ - "mne_bids_pipeline", -] +testpaths = ["mne_bids_pipeline"] junit_family = "xunit2" [tool.ruff] -exclude = ["**/freesurfer/contrib", "dist/" , "build/"] +exclude = ["**/freesurfer/contrib", "dist/", "build/", "**/.*cache"] + +[tool.ruff.lint] +select = ["A", "B006", "D", "E", "F", "I", "W", "UP", "TID252"] +ignore = [ + "D104", # Missing docstring in public package +] + +[tool.ruff.lint.per-file-ignores] +"mne_bids_pipeline/typing.py" = ["A005"] + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + +[tool.mypy] +ignore_errors = false +scripts_are_modules = true +disable_error_code = [ + # For libraries like matplotlib that we don't have types for + "import-not-found", + "import-untyped", +] +strict = true +modules = ["mne_bids_pipeline", "docs.source"] -[tool.black] -exclude = "(.*/freesurfer/contrib/.*)|(dist/)|(build/)" +[[tool.mypy.overrides]] +module = ["mne_bids_pipeline.steps.freesurfer.contrib.*"] +ignore_errors = true # not our code, don't validate + +[[tool.mypy.overrides]] +module = ["mne_bids_pipeline.tests.*"] +disable_error_code = [ + "misc", # Untyped decorator makes function "test_all_functions_return" untyped +] + +[[tool.mypy.overrides]] +module = ['mne_bids_pipeline.tests.configs.*'] +disable_error_code = [ + "assignment", # Incompatible types in assignment + "var-annotated", # Need type annotation for "plot_psd_for_runs" +]