diff --git a/CITATION.cff b/CITATION.cff index d969fa753..21c25f272 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -225,6 +225,10 @@ authors: family-names: Aristimunha affiliation: 'Université Paris-Saclay, LISN, Inria TAU/INRIA Nerv, France' orcid: 'https://orcid.org/0000-0001-5258-2995' + - given-names: Kalle + family-names: Mäkelä + affiliation: 'University of Helsinki, Helsinki, Finland' + orcid: 'https://orcid.org/0009-0005-5706-0842' - given-names: Alexandre family-names: Gramfort affiliation: 'Université Paris-Saclay, Inria, CEA, Palaiseau, France' diff --git a/doc/authors.rst b/doc/authors.rst index d1f211154..5087a6aa7 100644 --- a/doc/authors.rst +++ b/doc/authors.rst @@ -32,6 +32,7 @@ .. _Julia Guiomar Niso Galán: https://github.com/guiomar .. _Julius Welzel: https://github.com/JuliusWelzel .. _Kaare Mikkelsen: https://github.com/kaare-mikkelsen +.. _Kalle Mäkelä: https://github.com/Kallemakela .. _Kambiz Tavabi: https://github.com/ktavabi .. _Laetitia Fesselier: https://github.com/laemtl .. _Mainak Jas: https://jasmainak.github.io/ diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 4fb7ce347..bbb422ed9 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -22,6 +22,7 @@ The following authors contributed for the first time. Thank you so much! 🤩 * `Julius Welzel`_ * `Alex Lopez Marquez`_ * `Bruno Aristimunha`_ +* `Kalle Mäkelä`_ The following authors had contributed before. Thank you for sticking around! 🤘 @@ -44,6 +45,7 @@ Detailed list of changes ^^^^^^^^^^^^^^^^^^^^^^^^^^^ - `tracksys` accepted as argument in :class:`mne_bids.BIDSPath()` by `Julius Welzel`_ (:gh:`1430`) +- :func:`mne_bids.read_raw_bids()` has a new parameter ``on_ch_mismatch`` that controls behaviour when there is a mismatch between channel names in ``channels.tsv`` and the raw data; accepted values are ``'raise'`` (default), ``'reorder'``, and ``'rename'``, by `Kalle Mäkelä`_. 🛠 Requirements ^^^^^^^^^^^^^^^ diff --git a/mne_bids/read.py b/mne_bids/read.py index c59ed93dd..1c690d872 100644 --- a/mne_bids/read.py +++ b/mne_bids/read.py @@ -771,7 +771,28 @@ def _get_bads_from_tsv_data(tsv_data): return bads -def _handle_channels_reading(channels_fname, raw): +def _handle_channel_mismatch(raw, on_ch_mismatch, ch_names_tsv, channels_fname): + """Handle mismatch between channels.tsv and raw channel names.""" + if on_ch_mismatch == "raise": + raise RuntimeError( + f"Channel mismatch between {channels_fname} and the raw data file detected." + f"Either align channel names in channels.tsv with the raw file, or call " + f"read_raw_bids(on_ch_mismatch='reorder'|'rename') to proceed." + ) + logger.info( + "Channel mismatch between " + f"{channels_fname} and the raw data file detected. " + f"Using mismatch strategy: {on_ch_mismatch}." + ) + if on_ch_mismatch == "reorder": + raw.reorder_channels(ch_names_tsv) + elif on_ch_mismatch == "rename": + raw.rename_channels(dict(zip(raw.ch_names, ch_names_tsv))) + else: + raise ValueError("on_ch_mismatch must be one of {'reorder','raise','rename'}") + + +def _handle_channels_reading(channels_fname, raw, on_ch_mismatch="raise"): """Read associated channels.tsv and populate raw. Updates status (bad) and types of channels. @@ -852,7 +873,9 @@ def _handle_channels_reading(channels_fname, raw): f"set channel names." ) else: - raw.rename_channels(dict(zip(raw.ch_names, ch_names_tsv))) + orig_names = list(raw.ch_names) + if orig_names != ch_names_tsv: + _handle_channel_mismatch(raw, on_ch_mismatch, ch_names_tsv, channels_fname) # Set the channel types in the raw data according to channels.tsv channel_type_bids_mne_map_available_channels = { @@ -892,7 +915,12 @@ def _handle_channels_reading(channels_fname, raw): @verbose def read_raw_bids( - bids_path, extra_params=None, *, return_event_dict=False, verbose=None + bids_path, + extra_params=None, + *, + return_event_dict=False, + on_ch_mismatch="raise", + verbose=None, ): """Read BIDS compatible data. @@ -923,6 +951,17 @@ def read_raw_bids( event IDs, in addition to the :class:`~mne.io.Raw` object. If a ``value`` column is present in the ``*_events.tsv`` file, it will be used as the source of the integer event ID values (events with ``value="n/a"`` will be omitted). + on_ch_mismatch : str + How to handle a mismatch between channel names in channels.tsv file + and channel names in the raw data file. + Must be one of ``'raise'``, ``'reorder'``, ``'rename'`` (default ``'raise'``). + + * ``'raise'`` will raise a RuntimeError if there is a channel mismatch. + * ``'reorder'`` will reorder the channels in the raw data file to match the + channel order in the channels.tsv file. + * ``'rename'`` will rename the channels in the raw data file to match the + channel names in the channels.tsv file. + %(verbose)s Returns @@ -949,6 +988,9 @@ def read_raw_bids( ValueError If the specified ``datatype`` cannot be found in the dataset. + RuntimeError + If channels.tsv and the raw file have a channel-name mismatch + and ``on_ch_mismatch`` is 'raise'. """ if not isinstance(bids_path, BIDSPath): raise RuntimeError( @@ -1087,7 +1129,9 @@ def read_raw_bids( bids_path, suffix="channels", extension=".tsv", on_error="warn" ) if channels_fname is not None: - raw = _handle_channels_reading(channels_fname, raw) + raw = _handle_channels_reading( + channels_fname, raw, on_ch_mismatch=on_ch_mismatch + ) # Try to find an associated electrodes.tsv and coordsystem.json # to get information about the status and type of present channels diff --git a/mne_bids/tests/test_read.py b/mne_bids/tests/test_read.py index 14f67bfaa..5fb556d5d 100644 --- a/mne_bids/tests/test_read.py +++ b/mne_bids/tests/test_read.py @@ -34,6 +34,7 @@ ) from mne_bids.path import _find_matching_sidecar from mne_bids.read import ( + _handle_channels_reading, _handle_events_reading, _handle_scans_reading, _read_raw, @@ -1587,7 +1588,7 @@ def test_channels_tsv_raw_mismatch(tmp_path): raw.reorder_channels(ch_names_new) raw.save(raw_path, overwrite=True) - raw = read_raw_bids(bids_path) + raw = read_raw_bids(bids_path, on_ch_mismatch="reorder") assert raw.ch_names == ch_names_orig @@ -1633,6 +1634,105 @@ def test_gsr_and_temp_reading(): assert raw.get_channel_types(["Temperature"]) == ["temperature"] +def _setup_nirs_channel_mismatch(tmp_path): + ch_order_snirf = ["S1_D1 760", "S1_D2 760", "S1_D1 850", "S1_D2 850"] + ch_types = ["fnirs_cw_amplitude"] * len(ch_order_snirf) + info = mne.create_info(ch_order_snirf, sfreq=10, ch_types=ch_types) + data = np.arange(len(ch_order_snirf) * 10.0).reshape(len(ch_order_snirf), 10) + raw = mne.io.RawArray(data, info) + + for i, ch_name in enumerate(raw.ch_names): + loc = np.zeros(12) + if "S1" in ch_name: + loc[3:6] = np.array([0, 0, 0]) + if "D1" in ch_name: + loc[6:9] = np.array([1, 0, 0]) + elif "D2" in ch_name: + loc[6:9] = np.array([0, 1, 0]) + loc[9] = int(ch_name.split(" ")[1]) + loc[0:3] = (loc[3:6] + loc[6:9]) / 2 + raw.info["chs"][i]["loc"] = loc + + orig_name_to_loc = { + name: raw.info["chs"][i]["loc"].copy() for i, name in enumerate(raw.ch_names) + } + orig_name_to_data = { + name: raw.get_data(picks=i).copy() for i, name in enumerate(raw.ch_names) + } + + ch_order_bids = ["S1_D1 760", "S1_D1 850", "S1_D2 760", "S1_D2 850"] + ch_types_bids = ["NIRSCWAMPLITUDE"] * len(ch_order_bids) + channels_dict = OrderedDict([("name", ch_order_bids), ("type", ch_types_bids)]) + channels_fname = tmp_path / "channels.tsv" + _to_tsv(channels_dict, channels_fname) + + return ( + raw, + ch_order_snirf, + ch_order_bids, + channels_fname, + orig_name_to_loc, + orig_name_to_data, + ) + + +def test_channel_mismatch_raise(tmp_path): + """Raise error when ``on_ch_mismatch='raise'`` and names differ.""" + raw, _, _, channels_fname, _, _ = _setup_nirs_channel_mismatch(tmp_path) + with pytest.raises( + RuntimeError, + match=("Channel mismatch between .*channels"), + ): + _handle_channels_reading(channels_fname, raw.copy(), on_ch_mismatch="raise") + + +def test_channel_mismatch_reorder(tmp_path): + """Reorder channels to match ``channels.tsv`` ordering.""" + raw, _, ch_order_bids, channels_fname, orig_name_to_loc, orig_name_to_data = ( + _setup_nirs_channel_mismatch(tmp_path) + ) + raw_out = _handle_channels_reading(channels_fname, raw, on_ch_mismatch="reorder") + assert raw_out.ch_names == ch_order_bids + for i, new_name in enumerate(raw_out.ch_names): + np.testing.assert_allclose( + raw_out.info["chs"][i]["loc"], orig_name_to_loc[new_name] + ) + np.testing.assert_allclose( + raw_out.get_data(picks=i), orig_name_to_data[new_name] + ) + + +def test_channel_mismatch_rename(tmp_path): + """Rename channels to match ``channels.tsv`` names.""" + ( + raw, + ch_order_snirf, + ch_order_bids, + channels_fname, + orig_name_to_loc, + orig_name_to_data, + ) = _setup_nirs_channel_mismatch(tmp_path) + raw_out_rename = _handle_channels_reading( + channels_fname, raw.copy(), on_ch_mismatch="rename" + ) + assert raw_out_rename.ch_names == ch_order_bids + for i in range(len(ch_order_bids)): + orig_name_at_i = ch_order_snirf[i] + np.testing.assert_allclose( + raw_out_rename.info["chs"][i]["loc"], orig_name_to_loc[orig_name_at_i] + ) + np.testing.assert_allclose( + raw_out_rename.get_data(picks=i), orig_name_to_data[orig_name_at_i] + ) + + +def test_channel_mismatch_invalid_option(tmp_path): + """Invalid ``on_ch_mismatch`` value should raise ``ValueError``.""" + raw, _, _, channels_fname, _, _ = _setup_nirs_channel_mismatch(tmp_path) + with pytest.raises(ValueError, match="on_ch_mismatch must be one of"): + _handle_channels_reading(channels_fname, raw.copy(), on_ch_mismatch="invalid") + + def test_events_file_to_annotation_kwargs(tmp_path): """Test that events file is read correctly.""" bids_path = BIDSPath( diff --git a/mne_bids/tests/test_write.py b/mne_bids/tests/test_write.py index 8213a54b4..d8bb60637 100644 --- a/mne_bids/tests/test_write.py +++ b/mne_bids/tests/test_write.py @@ -97,6 +97,10 @@ no_hand=r"ignore:Unable to map.*\n.*subject handedness.*:RuntimeWarning:mne", no_montage=r"ignore:Not setting position of.*channel found in " r"montage.*:RuntimeWarning:mne", + channel_mismatch=( + "ignore:Channel mismatch between .*channels\\.tsv and the raw data file " + "detected\\.:RuntimeWarning:mne" + ), ) @@ -1435,6 +1439,7 @@ def test_vhdr(_bids_validate, tmp_path): warning_str["cnt_warning2"], warning_str["no_hand"], warning_str["no_montage"], + warning_str["channel_mismatch"], ) @testing.requires_testing_data def test_eegieeg(dir_name, fname, reader, _bids_validate, tmp_path): @@ -1525,7 +1530,7 @@ def test_eegieeg(dir_name, fname, reader, _bids_validate, tmp_path): # Reading the file back should still work, even though we've renamed # some channels (there's now a mismatch between BIDS and Raw channel # names, and BIDS should take precedence) - raw_read = read_raw_bids(bids_path=bids_path) + raw_read = read_raw_bids(bids_path=bids_path, on_ch_mismatch="rename") assert raw_read.ch_names[0] == "EOGtest" assert raw_read.ch_names[1] == "EMG"