|
34 | 34 | ) |
35 | 35 | from mne_bids.path import _find_matching_sidecar |
36 | 36 | from mne_bids.read import ( |
| 37 | + _handle_channels_reading, |
37 | 38 | _handle_events_reading, |
38 | 39 | _handle_scans_reading, |
39 | 40 | _read_raw, |
@@ -1587,7 +1588,7 @@ def test_channels_tsv_raw_mismatch(tmp_path): |
1587 | 1588 | raw.reorder_channels(ch_names_new) |
1588 | 1589 | raw.save(raw_path, overwrite=True) |
1589 | 1590 |
|
1590 | | - raw = read_raw_bids(bids_path) |
| 1591 | + raw = read_raw_bids(bids_path, on_ch_mismatch="reorder") |
1591 | 1592 | assert raw.ch_names == ch_names_orig |
1592 | 1593 |
|
1593 | 1594 |
|
@@ -1633,6 +1634,105 @@ def test_gsr_and_temp_reading(): |
1633 | 1634 | assert raw.get_channel_types(["Temperature"]) == ["temperature"] |
1634 | 1635 |
|
1635 | 1636 |
|
| 1637 | +def _setup_nirs_channel_mismatch(tmp_path): |
| 1638 | + ch_order_snirf = ["S1_D1 760", "S1_D2 760", "S1_D1 850", "S1_D2 850"] |
| 1639 | + ch_types = ["fnirs_cw_amplitude"] * len(ch_order_snirf) |
| 1640 | + info = mne.create_info(ch_order_snirf, sfreq=10, ch_types=ch_types) |
| 1641 | + data = np.arange(len(ch_order_snirf) * 10.0).reshape(len(ch_order_snirf), 10) |
| 1642 | + raw = mne.io.RawArray(data, info) |
| 1643 | + |
| 1644 | + for i, ch_name in enumerate(raw.ch_names): |
| 1645 | + loc = np.zeros(12) |
| 1646 | + if "S1" in ch_name: |
| 1647 | + loc[3:6] = np.array([0, 0, 0]) |
| 1648 | + if "D1" in ch_name: |
| 1649 | + loc[6:9] = np.array([1, 0, 0]) |
| 1650 | + elif "D2" in ch_name: |
| 1651 | + loc[6:9] = np.array([0, 1, 0]) |
| 1652 | + loc[9] = int(ch_name.split(" ")[1]) |
| 1653 | + loc[0:3] = (loc[3:6] + loc[6:9]) / 2 |
| 1654 | + raw.info["chs"][i]["loc"] = loc |
| 1655 | + |
| 1656 | + orig_name_to_loc = { |
| 1657 | + name: raw.info["chs"][i]["loc"].copy() for i, name in enumerate(raw.ch_names) |
| 1658 | + } |
| 1659 | + orig_name_to_data = { |
| 1660 | + name: raw.get_data(picks=i).copy() for i, name in enumerate(raw.ch_names) |
| 1661 | + } |
| 1662 | + |
| 1663 | + ch_order_bids = ["S1_D1 760", "S1_D1 850", "S1_D2 760", "S1_D2 850"] |
| 1664 | + ch_types_bids = ["NIRSCWAMPLITUDE"] * len(ch_order_bids) |
| 1665 | + channels_dict = OrderedDict([("name", ch_order_bids), ("type", ch_types_bids)]) |
| 1666 | + channels_fname = tmp_path / "channels.tsv" |
| 1667 | + _to_tsv(channels_dict, channels_fname) |
| 1668 | + |
| 1669 | + return ( |
| 1670 | + raw, |
| 1671 | + ch_order_snirf, |
| 1672 | + ch_order_bids, |
| 1673 | + channels_fname, |
| 1674 | + orig_name_to_loc, |
| 1675 | + orig_name_to_data, |
| 1676 | + ) |
| 1677 | + |
| 1678 | + |
| 1679 | +def test_channel_mismatch_raise(tmp_path): |
| 1680 | + """Raise error when ``on_ch_mismatch='raise'`` and names differ.""" |
| 1681 | + raw, _, _, channels_fname, _, _ = _setup_nirs_channel_mismatch(tmp_path) |
| 1682 | + with pytest.raises( |
| 1683 | + RuntimeError, |
| 1684 | + match=("Channel mismatch between .*channels"), |
| 1685 | + ): |
| 1686 | + _handle_channels_reading(channels_fname, raw.copy(), on_ch_mismatch="raise") |
| 1687 | + |
| 1688 | + |
| 1689 | +def test_channel_mismatch_reorder(tmp_path): |
| 1690 | + """Reorder channels to match ``channels.tsv`` ordering.""" |
| 1691 | + raw, _, ch_order_bids, channels_fname, orig_name_to_loc, orig_name_to_data = ( |
| 1692 | + _setup_nirs_channel_mismatch(tmp_path) |
| 1693 | + ) |
| 1694 | + raw_out = _handle_channels_reading(channels_fname, raw, on_ch_mismatch="reorder") |
| 1695 | + assert raw_out.ch_names == ch_order_bids |
| 1696 | + for i, new_name in enumerate(raw_out.ch_names): |
| 1697 | + np.testing.assert_allclose( |
| 1698 | + raw_out.info["chs"][i]["loc"], orig_name_to_loc[new_name] |
| 1699 | + ) |
| 1700 | + np.testing.assert_allclose( |
| 1701 | + raw_out.get_data(picks=i), orig_name_to_data[new_name] |
| 1702 | + ) |
| 1703 | + |
| 1704 | + |
| 1705 | +def test_channel_mismatch_rename(tmp_path): |
| 1706 | + """Rename channels to match ``channels.tsv`` names.""" |
| 1707 | + ( |
| 1708 | + raw, |
| 1709 | + ch_order_snirf, |
| 1710 | + ch_order_bids, |
| 1711 | + channels_fname, |
| 1712 | + orig_name_to_loc, |
| 1713 | + orig_name_to_data, |
| 1714 | + ) = _setup_nirs_channel_mismatch(tmp_path) |
| 1715 | + raw_out_rename = _handle_channels_reading( |
| 1716 | + channels_fname, raw.copy(), on_ch_mismatch="rename" |
| 1717 | + ) |
| 1718 | + assert raw_out_rename.ch_names == ch_order_bids |
| 1719 | + for i in range(len(ch_order_bids)): |
| 1720 | + orig_name_at_i = ch_order_snirf[i] |
| 1721 | + np.testing.assert_allclose( |
| 1722 | + raw_out_rename.info["chs"][i]["loc"], orig_name_to_loc[orig_name_at_i] |
| 1723 | + ) |
| 1724 | + np.testing.assert_allclose( |
| 1725 | + raw_out_rename.get_data(picks=i), orig_name_to_data[orig_name_at_i] |
| 1726 | + ) |
| 1727 | + |
| 1728 | + |
| 1729 | +def test_channel_mismatch_invalid_option(tmp_path): |
| 1730 | + """Invalid ``on_ch_mismatch`` value should raise ``ValueError``.""" |
| 1731 | + raw, _, _, channels_fname, _, _ = _setup_nirs_channel_mismatch(tmp_path) |
| 1732 | + with pytest.raises(ValueError, match="on_ch_mismatch must be one of"): |
| 1733 | + _handle_channels_reading(channels_fname, raw.copy(), on_ch_mismatch="invalid") |
| 1734 | + |
| 1735 | + |
1636 | 1736 | def test_events_file_to_annotation_kwargs(tmp_path): |
1637 | 1737 | """Test that events file is read correctly.""" |
1638 | 1738 | bids_path = BIDSPath( |
|
0 commit comments