Skip to content

Commit 56b079a

Browse files
committed
Handle channel name mismatch
1 parent fe851a4 commit 56b079a

File tree

3 files changed

+120
-6
lines changed

3 files changed

+120
-6
lines changed

mne_bids/read.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def _get_bads_from_tsv_data(tsv_data):
766766
return bads
767767

768768

769-
def _handle_channels_reading(channels_fname, raw):
769+
def _handle_channels_reading(channels_fname, raw, ch_name_mismatch="raise"):
770770
"""Read associated channels.tsv and populate raw.
771771
772772
Updates status (bad) and types of channels.
@@ -847,7 +847,21 @@ def _handle_channels_reading(channels_fname, raw):
847847
f"set channel names."
848848
)
849849
else:
850-
raw.rename_channels(dict(zip(raw.ch_names, ch_names_tsv)))
850+
orig_names = list(raw.ch_names)
851+
if orig_names != ch_names_tsv:
852+
if ch_name_mismatch == "raise":
853+
raise RuntimeError(
854+
f"Channel mismatch between {channels_fname} and the raw data file detected."
855+
)
856+
warn(f"Channel mismatch between {channels_fname} and the raw data file detected. Using mismatch strategy: {ch_name_mismatch}.")
857+
if ch_name_mismatch == "reorder":
858+
raw.reorder_channels(ch_names_tsv)
859+
elif ch_name_mismatch == "rename":
860+
raw.rename_channels(dict(zip(raw.ch_names, ch_names_tsv)))
861+
else:
862+
raise ValueError(
863+
"ch_name_mismatch must be one of {'reorder','raise','rename'}"
864+
)
851865

852866
# Set the channel types in the raw data according to channels.tsv
853867
channel_type_bids_mne_map_available_channels = {
@@ -887,7 +901,7 @@ def _handle_channels_reading(channels_fname, raw):
887901

888902
@verbose
889903
def read_raw_bids(
890-
bids_path, extra_params=None, *, return_event_dict=False, verbose=None
904+
bids_path, extra_params=None, *, return_event_dict=False, ch_name_mismatch="raise", verbose=None
891905
):
892906
"""Read BIDS compatible data.
893907
@@ -918,6 +932,12 @@ def read_raw_bids(
918932
event IDs, in addition to the :class:`~mne.io.Raw` object. If a ``value`` column
919933
is present in the ``*_events.tsv`` file, it will be used as the source of the
920934
integer event ID values (events with ``value="n/a"`` will be omitted).
935+
ch_name_mismatch : str
936+
What to do if the channel names in the channels.tsv file do not match the channel names in the raw data file.
937+
Must be one of {'raise', 'reorder', 'rename'}. Default is 'raise'.
938+
- 'raise' will raise a RuntimeError if there is a channel mismatch.
939+
- 'reorder' will reorder the channels in the raw data file to match the channel names in the channels.tsv file.
940+
- 'rename' will rename the channels in the raw data file to match the channel names in the channels.tsv file.
921941
%(verbose)s
922942
923943
Returns
@@ -944,6 +964,8 @@ def read_raw_bids(
944964
ValueError
945965
If the specified ``datatype`` cannot be found in the dataset.
946966
967+
RuntimeError
968+
If channels.tsv and raw data file have a channel name mismatch and ch_name_mismatch is 'raise'.
947969
"""
948970
if not isinstance(bids_path, BIDSPath):
949971
raise RuntimeError(
@@ -1082,7 +1104,7 @@ def read_raw_bids(
10821104
bids_path, suffix="channels", extension=".tsv", on_error="warn"
10831105
)
10841106
if channels_fname is not None:
1085-
raw = _handle_channels_reading(channels_fname, raw)
1107+
raw = _handle_channels_reading(channels_fname, raw, ch_name_mismatch=ch_name_mismatch)
10861108

10871109
# Try to find an associated electrodes.tsv and coordsystem.json
10881110
# to get information about the status and type of present channels

mne_bids/tests/test_read.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from mne_bids.read import (
3636
_handle_events_reading,
3737
_handle_scans_reading,
38+
_handle_channels_reading,
3839
_read_raw,
3940
events_file_to_annotation_kwargs,
4041
get_head_mri_trans,
@@ -1463,7 +1464,13 @@ def test_channels_tsv_raw_mismatch(tmp_path):
14631464
raw.reorder_channels(ch_names_new)
14641465
raw.save(raw_path, overwrite=True)
14651466

1466-
raw = read_raw_bids(bids_path)
1467+
with (
1468+
pytest.warns(
1469+
RuntimeWarning,
1470+
match=r"Channel mismatch between .*channels\.tsv and the raw data file detected\. Using mismatch strategy: reorder\.",
1471+
),
1472+
):
1473+
raw = read_raw_bids(bids_path, ch_name_mismatch="reorder")
14671474
assert raw.ch_names == ch_names_orig
14681475

14691476

@@ -1509,6 +1516,89 @@ def test_gsr_and_temp_reading():
15091516
assert raw.get_channel_types(["Temperature"]) == ["temperature"]
15101517

15111518

1519+
def _setup_nirs_channel_mismatch(tmp_path):
1520+
ch_order_snirf = ["S1_D1 760", "S1_D2 760", "S1_D1 850", "S1_D2 850"]
1521+
ch_types = ["fnirs_cw_amplitude"] * len(ch_order_snirf)
1522+
info = mne.create_info(ch_order_snirf, sfreq=10, ch_types=ch_types)
1523+
data = np.arange(len(ch_order_snirf) * 10.0).reshape(len(ch_order_snirf), 10)
1524+
raw = mne.io.RawArray(data, info)
1525+
1526+
for i, ch_name in enumerate(raw.ch_names):
1527+
loc = np.zeros(12)
1528+
if "S1" in ch_name:
1529+
loc[3:6] = np.array([0, 0, 0])
1530+
if "D1" in ch_name:
1531+
loc[6:9] = np.array([1, 0, 0])
1532+
elif "D2" in ch_name:
1533+
loc[6:9] = np.array([0, 1, 0])
1534+
loc[9] = int(ch_name.split(" ")[1])
1535+
loc[0:3] = (loc[3:6] + loc[6:9]) / 2
1536+
raw.info["chs"][i]["loc"] = loc
1537+
1538+
orig_name_to_loc = {
1539+
name: raw.info["chs"][i]["loc"].copy() for i, name in enumerate(raw.ch_names)
1540+
}
1541+
orig_name_to_data = {
1542+
name: raw.get_data(picks=i).copy() for i, name in enumerate(raw.ch_names)
1543+
}
1544+
1545+
ch_order_bids = ["S1_D1 760", "S1_D1 850", "S1_D2 760", "S1_D2 850"]
1546+
ch_types_bids = ["NIRSCWAMPLITUDE"] * len(ch_order_bids)
1547+
channels_dict = OrderedDict([("name", ch_order_bids), ("type", ch_types_bids)])
1548+
channels_fname = tmp_path / "channels.tsv"
1549+
_to_tsv(channels_dict, channels_fname)
1550+
1551+
return raw, ch_order_snirf, ch_order_bids, channels_fname, orig_name_to_loc, orig_name_to_data
1552+
1553+
1554+
def test_channel_mismatch_raise(tmp_path):
1555+
raw, _, _, channels_fname, _, _ = _setup_nirs_channel_mismatch(tmp_path)
1556+
with pytest.raises(
1557+
RuntimeError,
1558+
match=r"Channel mismatch between .*channels\.tsv and the raw data file detected\.",
1559+
):
1560+
_handle_channels_reading(channels_fname, raw.copy(), ch_name_mismatch="raise")
1561+
1562+
1563+
def test_channel_mismatch_reorder(tmp_path):
1564+
raw, _, ch_order_bids, channels_fname, orig_name_to_loc, orig_name_to_data = _setup_nirs_channel_mismatch(tmp_path)
1565+
with (
1566+
pytest.warns(
1567+
RuntimeWarning,
1568+
match=r"Channel mismatch between .*channels\.tsv and the raw data file detected\. Using mismatch strategy: reorder\.",
1569+
),
1570+
):
1571+
raw_out = _handle_channels_reading(channels_fname, raw, ch_name_mismatch="reorder")
1572+
assert raw_out.ch_names == ch_order_bids
1573+
for i, new_name in enumerate(raw_out.ch_names):
1574+
np.testing.assert_allclose(
1575+
raw_out.info["chs"][i]["loc"], orig_name_to_loc[new_name]
1576+
)
1577+
np.testing.assert_allclose(
1578+
raw_out.get_data(picks=i), orig_name_to_data[new_name]
1579+
)
1580+
1581+
1582+
def test_channel_mismatch_rename(tmp_path):
1583+
raw, ch_order_snirf, ch_order_bids, channels_fname, orig_name_to_loc, orig_name_to_data = _setup_nirs_channel_mismatch(tmp_path)
1584+
with (
1585+
pytest.warns(
1586+
RuntimeWarning,
1587+
match=r"Channel mismatch between .*channels\.tsv and the raw data file detected\. Using mismatch strategy: rename\.",
1588+
),
1589+
):
1590+
raw_out_rename = _handle_channels_reading(channels_fname, raw.copy(), ch_name_mismatch="rename")
1591+
assert raw_out_rename.ch_names == ch_order_bids
1592+
for i in range(len(ch_order_bids)):
1593+
orig_name_at_i = ch_order_snirf[i]
1594+
np.testing.assert_allclose(
1595+
raw_out_rename.info["chs"][i]["loc"], orig_name_to_loc[orig_name_at_i]
1596+
)
1597+
np.testing.assert_allclose(
1598+
raw_out_rename.get_data(picks=i), orig_name_to_data[orig_name_at_i]
1599+
)
1600+
1601+
15121602
def test_events_file_to_annotation_kwargs(tmp_path):
15131603
"""Test that events file is read correctly."""
15141604
bids_path = BIDSPath(

mne_bids/tests/test_write.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
no_hand=r"ignore:Unable to map.*\n.*subject handedness.*:RuntimeWarning:mne",
9797
no_montage=r"ignore:Not setting position of.*channel found in "
9898
r"montage.*:RuntimeWarning:mne",
99+
channel_mismatch="ignore:Channel mismatch between .*channels\\.tsv and the raw data file detected\\.:RuntimeWarning:mne",
99100
)
100101

101102

@@ -1398,6 +1399,7 @@ def test_vhdr(_bids_validate, tmp_path):
13981399
warning_str["cnt_warning2"],
13991400
warning_str["no_hand"],
14001401
warning_str["no_montage"],
1402+
warning_str["channel_mismatch"],
14011403
)
14021404
@testing.requires_testing_data
14031405
def test_eegieeg(dir_name, fname, reader, _bids_validate, tmp_path):
@@ -1488,7 +1490,7 @@ def test_eegieeg(dir_name, fname, reader, _bids_validate, tmp_path):
14881490
# Reading the file back should still work, even though we've renamed
14891491
# some channels (there's now a mismatch between BIDS and Raw channel
14901492
# names, and BIDS should take precedence)
1491-
raw_read = read_raw_bids(bids_path=bids_path)
1493+
raw_read = read_raw_bids(bids_path=bids_path, ch_name_mismatch="rename")
14921494
assert raw_read.ch_names[0] == "EOGtest"
14931495
assert raw_read.ch_names[1] == "EMG"
14941496

0 commit comments

Comments
 (0)