|
35 | 35 | from mne_bids.read import ( |
36 | 36 | _handle_events_reading, |
37 | 37 | _handle_scans_reading, |
| 38 | + _handle_channels_reading, |
38 | 39 | _read_raw, |
39 | 40 | events_file_to_annotation_kwargs, |
40 | 41 | get_head_mri_trans, |
@@ -1463,7 +1464,13 @@ def test_channels_tsv_raw_mismatch(tmp_path): |
1463 | 1464 | raw.reorder_channels(ch_names_new) |
1464 | 1465 | raw.save(raw_path, overwrite=True) |
1465 | 1466 |
|
1466 | | - raw = read_raw_bids(bids_path) |
| 1467 | + with ( |
| 1468 | + pytest.warns( |
| 1469 | + RuntimeWarning, |
| 1470 | + match=r"Channel mismatch between .*channels\.tsv and the raw data file detected\. Using mismatch strategy: reorder\.", |
| 1471 | + ), |
| 1472 | + ): |
| 1473 | + raw = read_raw_bids(bids_path, ch_name_mismatch="reorder") |
1467 | 1474 | assert raw.ch_names == ch_names_orig |
1468 | 1475 |
|
1469 | 1476 |
|
@@ -1509,6 +1516,89 @@ def test_gsr_and_temp_reading(): |
1509 | 1516 | assert raw.get_channel_types(["Temperature"]) == ["temperature"] |
1510 | 1517 |
|
1511 | 1518 |
|
| 1519 | +def _setup_nirs_channel_mismatch(tmp_path): |
| 1520 | + ch_order_snirf = ["S1_D1 760", "S1_D2 760", "S1_D1 850", "S1_D2 850"] |
| 1521 | + ch_types = ["fnirs_cw_amplitude"] * len(ch_order_snirf) |
| 1522 | + info = mne.create_info(ch_order_snirf, sfreq=10, ch_types=ch_types) |
| 1523 | + data = np.arange(len(ch_order_snirf) * 10.0).reshape(len(ch_order_snirf), 10) |
| 1524 | + raw = mne.io.RawArray(data, info) |
| 1525 | + |
| 1526 | + for i, ch_name in enumerate(raw.ch_names): |
| 1527 | + loc = np.zeros(12) |
| 1528 | + if "S1" in ch_name: |
| 1529 | + loc[3:6] = np.array([0, 0, 0]) |
| 1530 | + if "D1" in ch_name: |
| 1531 | + loc[6:9] = np.array([1, 0, 0]) |
| 1532 | + elif "D2" in ch_name: |
| 1533 | + loc[6:9] = np.array([0, 1, 0]) |
| 1534 | + loc[9] = int(ch_name.split(" ")[1]) |
| 1535 | + loc[0:3] = (loc[3:6] + loc[6:9]) / 2 |
| 1536 | + raw.info["chs"][i]["loc"] = loc |
| 1537 | + |
| 1538 | + orig_name_to_loc = { |
| 1539 | + name: raw.info["chs"][i]["loc"].copy() for i, name in enumerate(raw.ch_names) |
| 1540 | + } |
| 1541 | + orig_name_to_data = { |
| 1542 | + name: raw.get_data(picks=i).copy() for i, name in enumerate(raw.ch_names) |
| 1543 | + } |
| 1544 | + |
| 1545 | + ch_order_bids = ["S1_D1 760", "S1_D1 850", "S1_D2 760", "S1_D2 850"] |
| 1546 | + ch_types_bids = ["NIRSCWAMPLITUDE"] * len(ch_order_bids) |
| 1547 | + channels_dict = OrderedDict([("name", ch_order_bids), ("type", ch_types_bids)]) |
| 1548 | + channels_fname = tmp_path / "channels.tsv" |
| 1549 | + _to_tsv(channels_dict, channels_fname) |
| 1550 | + |
| 1551 | + return raw, ch_order_snirf, ch_order_bids, channels_fname, orig_name_to_loc, orig_name_to_data |
| 1552 | + |
| 1553 | + |
| 1554 | +def test_channel_mismatch_raise(tmp_path): |
| 1555 | + raw, _, _, channels_fname, _, _ = _setup_nirs_channel_mismatch(tmp_path) |
| 1556 | + with pytest.raises( |
| 1557 | + RuntimeError, |
| 1558 | + match=r"Channel mismatch between .*channels\.tsv and the raw data file detected\.", |
| 1559 | + ): |
| 1560 | + _handle_channels_reading(channels_fname, raw.copy(), ch_name_mismatch="raise") |
| 1561 | + |
| 1562 | + |
| 1563 | +def test_channel_mismatch_reorder(tmp_path): |
| 1564 | + raw, _, ch_order_bids, channels_fname, orig_name_to_loc, orig_name_to_data = _setup_nirs_channel_mismatch(tmp_path) |
| 1565 | + with ( |
| 1566 | + pytest.warns( |
| 1567 | + RuntimeWarning, |
| 1568 | + match=r"Channel mismatch between .*channels\.tsv and the raw data file detected\. Using mismatch strategy: reorder\.", |
| 1569 | + ), |
| 1570 | + ): |
| 1571 | + raw_out = _handle_channels_reading(channels_fname, raw, ch_name_mismatch="reorder") |
| 1572 | + assert raw_out.ch_names == ch_order_bids |
| 1573 | + for i, new_name in enumerate(raw_out.ch_names): |
| 1574 | + np.testing.assert_allclose( |
| 1575 | + raw_out.info["chs"][i]["loc"], orig_name_to_loc[new_name] |
| 1576 | + ) |
| 1577 | + np.testing.assert_allclose( |
| 1578 | + raw_out.get_data(picks=i), orig_name_to_data[new_name] |
| 1579 | + ) |
| 1580 | + |
| 1581 | + |
| 1582 | +def test_channel_mismatch_rename(tmp_path): |
| 1583 | + raw, ch_order_snirf, ch_order_bids, channels_fname, orig_name_to_loc, orig_name_to_data = _setup_nirs_channel_mismatch(tmp_path) |
| 1584 | + with ( |
| 1585 | + pytest.warns( |
| 1586 | + RuntimeWarning, |
| 1587 | + match=r"Channel mismatch between .*channels\.tsv and the raw data file detected\. Using mismatch strategy: rename\.", |
| 1588 | + ), |
| 1589 | + ): |
| 1590 | + raw_out_rename = _handle_channels_reading(channels_fname, raw.copy(), ch_name_mismatch="rename") |
| 1591 | + assert raw_out_rename.ch_names == ch_order_bids |
| 1592 | + for i in range(len(ch_order_bids)): |
| 1593 | + orig_name_at_i = ch_order_snirf[i] |
| 1594 | + np.testing.assert_allclose( |
| 1595 | + raw_out_rename.info["chs"][i]["loc"], orig_name_to_loc[orig_name_at_i] |
| 1596 | + ) |
| 1597 | + np.testing.assert_allclose( |
| 1598 | + raw_out_rename.get_data(picks=i), orig_name_to_data[orig_name_at_i] |
| 1599 | + ) |
| 1600 | + |
| 1601 | + |
1512 | 1602 | def test_events_file_to_annotation_kwargs(tmp_path): |
1513 | 1603 | """Test that events file is read correctly.""" |
1514 | 1604 | bids_path = BIDSPath( |
|
0 commit comments