diff --git a/doc/changes/devel/13144.newfeature.rst b/doc/changes/devel/13144.newfeature.rst new file mode 100644 index 00000000000..3e05c5c6c06 --- /dev/null +++ b/doc/changes/devel/13144.newfeature.rst @@ -0,0 +1 @@ +Allow for ``topomap`` plotting of optically pumped MEG (OPM) sensors with overlapping channel locations. When channel locations overlap, plot the most radially oriented channel. By :newcontrib:`Harrison Ritz`. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 22b992a10cc..2e4027b82b7 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -107,6 +107,7 @@ .. _Hamid Maymandi: https://github.com/HamidMandi .. _Hamza Abdelhedi: https://github.com/BabaSanfour .. _Hari Bharadwaj: https://github.com/haribharadwaj +.. _Harrison Ritz: https://github.com/harrisonritz .. _Hasrat Ali Arzoo: https://github.com/hasrat17 .. _Henrich Kolkhorst: https://github.com/hekolk .. _Hongjiang Ye: https://github.com/hongjiang-ye diff --git a/mne/channels/layout.py b/mne/channels/layout.py index 31d0650037e..fb12b509997 100644 --- a/mne/channels/layout.py +++ b/mne/channels/layout.py @@ -902,7 +902,7 @@ def _auto_topomap_coords(info, picks, ignore_overlap, to_sphere, sphere): # Use channel locations if available locs3d = np.array([ch["loc"][:3] for ch in chs]) - # If electrode locations are not available, use digization points + # If electrode locations are not available, use digitization points if not _check_ch_locs(info=info, picks=picks): logging.warning( "Did not find any electrode locations (in the info " @@ -1089,7 +1089,7 @@ def _pair_grad_sensors( return picks -def _merge_ch_data(data, ch_type, names, method="rms"): +def _merge_ch_data(data, ch_type, names, method="rms", *, modality="opm"): """Merge data from channel pairs. Parameters @@ -1102,6 +1102,8 @@ def _merge_ch_data(data, ch_type, names, method="rms"): List of channel names. method : str Can be 'rms' or 'mean'. + modality : str + The modality of the data, either 'grad', 'fnirs', or 'opm' Returns ------- @@ -1112,9 +1114,13 @@ def _merge_ch_data(data, ch_type, names, method="rms"): """ if ch_type == "grad": data = _merge_grad_data(data, method) - else: - assert ch_type in _FNIRS_CH_TYPES_SPLIT + elif modality == "fnirs" or ch_type in _FNIRS_CH_TYPES_SPLIT: data, names = _merge_nirs_data(data, names) + elif modality == "opm" and ch_type == "mag": + data, names = _merge_opm_data(data, names) + else: + raise ValueError(f"Unknown modality {modality} for channel type {ch_type}") + return data, names @@ -1180,6 +1186,37 @@ def _merge_nirs_data(data, merged_names): return data, merged_names +def _merge_opm_data(data, merged_names): + """Merge data from multiple opm channel by just using the radial component. + + Channel names that end in "MERGE_REMOVE" (ie non-radial channels) will be + removed. Only the the radial channel is kept. + + Parameters + ---------- + data : array, shape = (n_channels, ..., n_times) + Data for channels. + merged_names : list + List of strings containing the channel names. Channels that are to be + removed end in "MERGE_REMOVE". + + Returns + ------- + data : array + Data for channels with requested channels merged. Channels used in the + merge are removed from the array. + """ + to_remove = np.empty(0, dtype=np.int32) + for idx, ch in enumerate(merged_names): + if ch.endswith("MERGE-REMOVE"): + to_remove = np.append(to_remove, idx) + to_remove = np.unique(to_remove) + for rem in sorted(to_remove, reverse=True): + del merged_names[rem] + data = np.delete(data, to_remove, axis=0) + return data, merged_names + + def generate_2d_layout( xy, w=0.07, diff --git a/mne/datasets/config.py b/mne/datasets/config.py index 75eff184cd1..98467aa7e36 100644 --- a/mne/datasets/config.py +++ b/mne/datasets/config.py @@ -87,7 +87,7 @@ # update the checksum in the MNE_DATASETS dict below, and change version # here: ↓↓↓↓↓↓↓↓ RELEASES = dict( - testing="0.156", + testing="0.161", misc="0.27", phantom_kit="0.2", ucl_opm_auditory="0.2", @@ -115,7 +115,7 @@ # Testing and misc are at the top as they're updated most often MNE_DATASETS["testing"] = dict( archive_name=f"{TESTING_VERSIONED}.tar.gz", - hash="md5:d94fe9f3abe949a507eaeb865fb84a3f", + hash="md5:a32cfb9e098dec39a5f3ed6c0833580d", url=( "https://codeload.github.com/mne-tools/mne-testing-data/" f"tar.gz/{RELEASES['testing']}" diff --git a/mne/preprocessing/tests/test_ica.py b/mne/preprocessing/tests/test_ica.py index d925665e48f..91dffe93078 100644 --- a/mne/preprocessing/tests/test_ica.py +++ b/mne/preprocessing/tests/test_ica.py @@ -1027,7 +1027,6 @@ def f(x, y): def test_get_explained_variance_ratio(tmp_path, short_raw_epochs): """Test ICA.get_explained_variance_ratio().""" - pytest.importorskip("sklearn") raw, epochs, _ = short_raw_epochs ica = ICA(max_iter=1) diff --git a/mne/viz/tests/test_ica.py b/mne/viz/tests/test_ica.py index a3ea7b89e58..973ccc83dbc 100644 --- a/mne/viz/tests/test_ica.py +++ b/mne/viz/tests/test_ica.py @@ -17,8 +17,10 @@ pick_types, read_cov, read_events, + read_evokeds, ) -from mne.io import read_raw_fif +from mne.datasets import testing +from mne.io import RawArray, read_raw_fif from mne.preprocessing import ICA, create_ecg_epochs, create_eog_epochs from mne.utils import _record_warnings, catch_logging from mne.viz.ica import _create_properties_layout, plot_ica_properties @@ -32,6 +34,9 @@ event_id, tmin, tmax = 1, -0.1, 0.2 raw_ctf_fname = base_dir / "test_ctf_raw.fif" +testing_path = testing.data_path(download=False) +opm_fname = testing_path / "OPM" / "opm-evoked-ave.fif" + pytest.importorskip("sklearn") @@ -526,3 +531,15 @@ def test_plot_instance_components(browser_backend): fig._fake_click((x, y), xform="data") fig._click_ch_name(ch_index=0, button=1) fig._fake_keypress("escape") + + +@pytest.mark.slowtest +@pytest.mark.filterwarnings("ignore:.*did not converge.*:") +@testing.requires_testing_data +def test_plot_components_opm(): + """Test for gh-12934.""" + evoked = read_evokeds(opm_fname, kind="average")[0] + ica = ICA(max_iter=1, random_state=0, n_components=10) + ica.fit(RawArray(evoked.data, evoked.info), picks="mag", verbose="error") + fig = ica.plot_components() + assert len(fig.axes) == 10 diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index b87d0d39f89..64a13b46006 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -62,6 +62,8 @@ subjects_dir = data_dir / "subjects" ecg_fname = data_dir / "MEG" / "sample" / "sample_audvis_ecg-proj.fif" triux_fname = data_dir / "SSS" / "TRIUX" / "triux_bmlhus_erm_raw.fif" +opm_fname = data_dir / "OPM" / "opm-evoked-ave.fif" + base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" @@ -776,6 +778,19 @@ def test_plot_topomap_bads_grad(): plot_topomap(data, info, res=8) +@testing.requires_testing_data +def test_plot_topomap_opm(): + """Test plotting topomap with OPM data.""" + # load data + evoked = read_evokeds(opm_fname, kind="average")[0] + + # plot evoked topomap + fig_evoked = evoked.plot_topomap( + times=[-0.1, 0, 0.1, 0.2], ch_type="mag", show=False + ) + assert len(fig_evoked.axes) == 5 + + def test_plot_topomap_nirs_overlap(fnirs_epochs): """Test plotting nirs topomap with overlapping channels (gh-7414).""" fig = fnirs_epochs["A"].average(picks="hbo").plot_topomap() diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index bb180a3f299..375b7256957 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -20,6 +20,7 @@ from scipy.spatial import Delaunay, Voronoi from scipy.spatial.distance import pdist, squareform +from .._fiff.constants import FIFF from .._fiff.meas_info import Info, _simplify_info from .._fiff.pick import ( _MEG_CH_TYPES_SPLIT, @@ -76,6 +77,7 @@ ) _fnirs_types = ("hbo", "hbr", "fnirs_cw_amplitude", "fnirs_od") +_opm_coils = (FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG, FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG2) # 3.8+ uses a single Collection artist rather than .collections @@ -123,6 +125,13 @@ def _prepare_topomap_plot(inst, ch_type, sphere=None): info["bads"] = _clean_names(info["bads"]) info._check_consistency() + if any(ch["coil_type"] in _opm_coils for ch in info["chs"]): + modality = "opm" + elif ch_type in _fnirs_types: + modality = "fnirs" + else: + modality = "other" + # special case for merging grad channels layout = find_layout(info) if ( @@ -136,10 +145,9 @@ def _prepare_topomap_plot(inst, ch_type, sphere=None): picks, _ = _pair_grad_sensors(info, layout) pos = _find_topomap_coords(info, picks[::2], sphere=sphere) merge_channels = True - elif ch_type in _fnirs_types: - # fNIRS data commonly has overlapping channels, so deal with separately - picks, pos, merge_channels, overlapping_channels = _average_fnirs_overlaps( - info, ch_type, sphere + elif modality != "other": + picks, pos, merge_channels, overlapping_channels = _find_overlaps( + info, ch_type, sphere, modality=modality ) else: merge_channels = False @@ -162,7 +170,7 @@ def _prepare_topomap_plot(inst, ch_type, sphere=None): pos = _find_topomap_coords(info, picks, sphere=sphere) ch_names = [info["ch_names"][k] for k in picks] - if ch_type in _fnirs_types: + if modality == "fnirs": # Remove the chroma label type for cleaner labeling. ch_names = [k[:-4] for k in ch_names] @@ -171,24 +179,38 @@ def _prepare_topomap_plot(inst, ch_type, sphere=None): # change names so that vectorview combined grads appear as MEG014x # instead of MEG0142 or MEG0143 which are the 2 planar grads. ch_names = [ch_names[k][:-1] + "x" for k in range(0, len(ch_names), 2)] - else: - assert ch_type in _fnirs_types - # Modify the nirs channel names to indicate they are to be merged + elif modality == "fnirs": + # Modify the channel names to indicate they are to be merged # New names will have the form S1_D1xS2_D2 # More than two channels can overlap and be merged for set_ in overlapping_channels: idx = ch_names.index(set_[0][:-4]) new_name = "x".join(s[:-4] for s in set_) ch_names[idx] = new_name + elif modality == "opm": + # indicate that non-radial changes are to be removed + for set_ in overlapping_channels: + for set_ch in set_[1:]: + idx = ch_names.index(set_ch) + new_name = set_ch + "_MERGE-REMOVE" + ch_names[idx] = new_name pos = np.array(pos)[:, :2] # 2D plot, otherwise interpolation bugs return picks, pos, merge_channels, ch_names, ch_type, sphere, clip_origin -def _average_fnirs_overlaps(info, ch_type, sphere): +def _find_overlaps(info, ch_type, sphere, modality="fnirs"): + """Find overlapping channels.""" from ..channels.layout import _find_topomap_coords - picks = pick_types(info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads") + if modality == "fnirs": + picks = pick_types( + info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads" + ) + elif modality == "opm": + picks = pick_types(info, meg=True, ref_meg=False, exclude="bads") + else: + raise ValueError(f"Invalid modality for colocated sensors: {modality}") chs = [info["chs"][i] for i in picks] locs3d = np.array([ch["loc"][:3] for ch in chs]) dist = pdist(locs3d) @@ -212,33 +234,79 @@ def _average_fnirs_overlaps(info, ch_type, sphere): overlapping_set = [ chs[i]["ch_name"] for i in np.where(overlapping_mask[chan_idx])[0] ] - overlapping_set = np.insert( - overlapping_set, 0, (chs[chan_idx]["ch_name"]) - ) + if modality == "fnirs": + overlapping_set = np.insert( + overlapping_set, 0, (chs[chan_idx]["ch_name"]) + ) + elif modality == "opm": + overlapping_set = np.insert( + overlapping_set, 0, (chs[chan_idx]["ch_name"]) + ) + rad_channel = _find_radial_channel(info, overlapping_set) + # Make sure the radial channel is first in the overlapping set + overlapping_set = np.array( + [ch for ch in overlapping_set if ch != rad_channel] + ) + overlapping_set = np.insert(overlapping_set, 0, rad_channel) overlapping_channels.append(overlapping_set) channels_to_exclude.append(overlapping_set[1:]) exclude = list(itertools.chain.from_iterable(channels_to_exclude)) [exclude.append(bad) for bad in info["bads"]] - picks = pick_types( - info, meg=False, ref_meg=False, fnirs=ch_type, exclude=exclude - ) - pos = _find_topomap_coords(info, picks, sphere=sphere) - picks = pick_types(info, meg=False, ref_meg=False, fnirs=ch_type) + if modality == "fnirs": + picks = pick_types( + info, meg=False, ref_meg=False, fnirs=ch_type, exclude=exclude + ) + pos = _find_topomap_coords(info, picks, sphere=sphere) + picks = pick_types(info, meg=False, ref_meg=False, fnirs=ch_type) + elif modality == "opm": + picks = pick_types(info, meg=True, ref_meg=False, exclude=exclude) + pos = _find_topomap_coords(info, picks, sphere=sphere) + picks = pick_types(info, meg=True, ref_meg=False) + # Overload the merge_channels variable as this is returned to calling # function and indicates that merging of data is required merge_channels = overlapping_channels else: - picks = pick_types( - info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads" - ) + if modality == "fnirs": + picks = pick_types( + info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads" + ) + elif modality == "opm": + picks = pick_types(info, meg=True, ref_meg=False, exclude="bads") + merge_channels = False pos = _find_topomap_coords(info, picks, sphere=sphere) return picks, pos, merge_channels, overlapping_channels +def _find_radial_channel(info, overlapping_set): + """Find the most radial channel in the overlapping set.""" + if len(overlapping_set) == 1: + return overlapping_set[0] + elif len(overlapping_set) < 1: + raise ValueError("No overlapping channels found.") + + radial_score = np.zeros(len(overlapping_set)) + for s, sens in enumerate(overlapping_set): + ch_idx = pick_channels(info["ch_names"], [sens])[0] + radial_direction = info["chs"][ch_idx]["loc"][0:3].copy() + radial_direction /= np.linalg.norm(radial_direction) + + orientation_vector = info["chs"][ch_idx]["loc"][9:12] + if info["dev_head_t"] is not None: + orientation_vector = apply_trans( + info["dev_head_t"], orientation_vector, move=False + ) + radial_score[s] = np.abs(np.dot(radial_direction, orientation_vector)) + + radial_sensor = overlapping_set[np.argmax(radial_score)] + + return radial_sensor + + def _plot_update_evoked_topomap(params, bools): """Update topomaps.""" from ..channels.layout import _merge_ch_data @@ -1636,9 +1704,8 @@ def plot_ica_components( sphere, clip_origin, ) = _prepare_topomap_plot(ica, ch_type, sphere=sphere) - cmap = _setup_cmap(cmap, n_axes=len(picks)) - names = _prepare_sensor_names(names, show_names) + disp_names = _prepare_sensor_names(names, show_names) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) data = np.dot( @@ -1675,7 +1742,7 @@ def plot_ica_components( pos, ch_type=ch_type, sensors=sensors, - names=names, + names=disp_names, contours=contours, outlines=outlines, sphere=sphere, @@ -2296,8 +2363,17 @@ def plot_evoked_topomap( # apply scalings and merge channels data *= scaling if merge_channels: - data, ch_names = _merge_ch_data(data, ch_type, ch_names) - if ch_type in _fnirs_types: + # check modality + if any(ch["coil_type"] in _opm_coils for ch in evoked.info["chs"]): + modality = "opm" + elif ch_type in _fnirs_types: + modality = "fnirs" + else: + modality = "other" + # merge data + data, ch_names = _merge_ch_data(data, ch_type, ch_names, modality=modality) + # if ch_type in _fnirs_types: + if modality != "other": merge_channels = False # apply mask if requested if mask is not None: diff --git a/tutorials/preprocessing/80_opm_processing.py b/tutorials/preprocessing/80_opm_processing.py index 25c30778a42..ba44a797a51 100644 --- a/tutorials/preprocessing/80_opm_processing.py +++ b/tutorials/preprocessing/80_opm_processing.py @@ -243,8 +243,7 @@ ) evoked = epochs.average() t_peak = evoked.times[np.argmax(np.std(evoked.copy().pick("meg").data, axis=0))] -fig = evoked.plot() -fig.axes[0].axvline(t_peak, color="red", ls="--", lw=1) +fig = evoked.plot_joint(picks="mag") # %% # Visualizing coregistration