Skip to content

Colocated topomaps for OPM #13144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
70b3a4c
detect overlap; commandeer fnirs functions
harrisonritz Mar 3, 2025
7f936b0
start to merge channels
harrisonritz Mar 4, 2025
8297974
opm tests
harrisonritz Mar 5, 2025
75e40c0
fix selection of radial channels; better selection of OPM sensors; su…
harrisonritz Mar 5, 2025
ab419e3
always use Z-axis orientation (now matches UCL's *-TAN sensors)
harrisonritz Mar 6, 2025
109e1de
update test_topomap
harrisonritz Mar 10, 2025
b86351e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2025
0f92da6
update testing data
harrisonritz Mar 10, 2025
dcf3e2e
update testing data hash
harrisonritz Mar 10, 2025
8067e4d
inst.info --> info
harrisonritz Mar 15, 2025
953e098
add `-ave` suffix
harrisonritz Mar 15, 2025
0483ad1
update dataset config
harrisonritz Mar 17, 2025
5460d71
properly load evoked
harrisonritz Mar 17, 2025
72822a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 17, 2025
b751d3e
Merge branch 'main' into colocated_topo
harrisonritz Apr 22, 2025
a72b7f8
Update mne/channels/layout.py
harrisonritz Apr 22, 2025
a65961c
Update mne/channels/layout.py
harrisonritz Apr 22, 2025
7343021
append with MERGE_REMOVE. use FIFF constants
harrisonritz Apr 22, 2025
217c579
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2025
0a02fb9
new append
harrisonritz Apr 22, 2025
bfb6152
merge-remove
harrisonritz Apr 23, 2025
32b90b9
Merge branch 'main' into colocated_topo
harrisonritz Apr 23, 2025
798d730
add changelog
harrisonritz Apr 28, 2025
7743397
update docs, include topo in opm preproc tutorial
harrisonritz Apr 28, 2025
64d3d26
Merge branch 'main' into colocated_topo
harrisonritz Apr 28, 2025
972e24d
literal instead of link
harrisonritz Apr 28, 2025
eee01b0
Merge branch 'main' into colocated_topo
harrisonritz Apr 28, 2025
6bbc071
Apply suggestions from code review
larsoner Apr 29, 2025
13387a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2025
b541bf5
Merge branch 'main' into colocated_topo
larsoner Apr 29, 2025
988e8f6
MAINT: Smoke test
larsoner Apr 29, 2025
f073d0c
FIX: Test
larsoner Apr 29, 2025
723809d
FIX: Scope
larsoner Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions mne/channels/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def _auto_topomap_coords(info, picks, ignore_overlap, to_sphere, sphere):
# Use channel locations if available
locs3d = np.array([ch["loc"][:3] for ch in chs])

# If electrode locations are not available, use digization points
# If electrode locations are not available, use digitization points
if not _check_ch_locs(info=info, picks=picks):
logging.warning(
"Did not find any electrode locations (in the info "
Expand Down Expand Up @@ -1089,7 +1089,7 @@ def _pair_grad_sensors(
return picks


def _merge_ch_data(data, ch_type, names, method="rms"):
def _merge_ch_data(data, ch_type, names, method="rms", modality="opm"):
"""Merge data from channel pairs.

Parameters
Expand All @@ -1102,6 +1102,8 @@ def _merge_ch_data(data, ch_type, names, method="rms"):
List of channel names.
method : str
Can be 'rms' or 'mean'.
modality : str
The modality of the data, either 'fnirs', 'opm', or 'other'

Returns
-------
Expand All @@ -1112,9 +1114,13 @@ def _merge_ch_data(data, ch_type, names, method="rms"):
"""
if ch_type == "grad":
data = _merge_grad_data(data, method)
else:
assert ch_type in _FNIRS_CH_TYPES_SPLIT
elif ch_type in _FNIRS_CH_TYPES_SPLIT:
data, names = _merge_nirs_data(data, names)
elif modality == "opm" and ch_type == "mag":
data, names = _merge_opm_data(data, names)
else:
raise ValueError(f"Unknown modality {modality} for channel type {ch_type}")

return data, names


Expand Down Expand Up @@ -1180,6 +1186,42 @@ def _merge_nirs_data(data, merged_names):
return data, merged_names


def _merge_opm_data(data, merged_names):
"""Merge data from multiple opm channel using the mean.

Channel names that have an x in them will be merged. The first channel in
the name is replaced with the mean of all listed channels. The other
channels are removed.

Parameters
----------
data : array, shape = (n_channels, ..., n_times)
Data for channels.
merged_names : list
List of strings containing the channel names. Channels that are to be
merged contain an x between them.

Returns
-------
data : array
Data for channels with requested channels merged. Channels used in the
merge are removed from the array.
"""
to_remove = np.empty(0, dtype=np.int32)
for idx, ch in enumerate(merged_names):
if "." in ch:
indices = np.empty(0, dtype=np.int32)
channels = ch.split(".")
for sub_ch in channels[1:]:
indices = np.append(indices, merged_names.index(sub_ch))
to_remove = np.append(to_remove, indices)
to_remove = np.unique(to_remove)
for rem in sorted(to_remove, reverse=True):
del merged_names[rem]
data = np.delete(data, rem, 0)
return data, merged_names


def generate_2d_layout(
xy,
w=0.07,
Expand Down
4 changes: 2 additions & 2 deletions mne/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
# update the checksum in the MNE_DATASETS dict below, and change version
# here: ↓↓↓↓↓↓↓↓
RELEASES = dict(
testing="0.156",
testing="0.161",
misc="0.27",
phantom_kit="0.2",
ucl_opm_auditory="0.2",
Expand Down Expand Up @@ -115,7 +115,7 @@
# Testing and misc are at the top as they're updated most often
MNE_DATASETS["testing"] = dict(
archive_name=f"{TESTING_VERSIONED}.tar.gz",
hash="md5:d94fe9f3abe949a507eaeb865fb84a3f",
hash="md5:a32cfb9e098dec39a5f3ed6c0833580d",
url=(
"https://codeload.github.com/mne-tools/mne-testing-data/"
f"tar.gz/{RELEASES['testing']}"
Expand Down
15 changes: 15 additions & 0 deletions mne/viz/tests/test_topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
subjects_dir = data_dir / "subjects"
ecg_fname = data_dir / "MEG" / "sample" / "sample_audvis_ecg-proj.fif"
triux_fname = data_dir / "SSS" / "TRIUX" / "triux_bmlhus_erm_raw.fif"
opm_fname = data_dir / "OPM" / "opm-evoked-ave.fif"


base_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
evoked_fname = base_dir / "test-ave.fif"
Expand Down Expand Up @@ -776,6 +778,19 @@ def test_plot_topomap_bads_grad():
plot_topomap(data, info, res=8)


@testing.requires_testing_data
def test_plot_topomap_opm():
"""Test plotting topomap with OPM data."""
# load data
evoked = read_evokeds(opm_fname, kind="average")[0]

# plot evoked topomap
fig_evoked = evoked.plot_topomap(
times=[-0.1, 0, 0.1, 0.2], ch_type="mag", show=False
)
assert len(fig_evoked.axes) == 5


def test_plot_topomap_nirs_overlap(fnirs_epochs):
"""Test plotting nirs topomap with overlapping channels (gh-7414)."""
fig = fnirs_epochs["A"].average(picks="hbo").plot_topomap()
Expand Down
125 changes: 99 additions & 26 deletions mne/viz/topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
)

_fnirs_types = ("hbo", "hbr", "fnirs_cw_amplitude", "fnirs_od")
_opm_coils = (8002,)


# 3.8+ uses a single Collection artist rather than .collections
Expand Down Expand Up @@ -123,6 +124,13 @@ def _prepare_topomap_plot(inst, ch_type, sphere=None):
info["bads"] = _clean_names(info["bads"])
info._check_consistency()

if any(ch["coil_type"] in _opm_coils for ch in info["chs"]):
modality = "opm"
elif ch_type in _fnirs_types:
modality = "fnirs"
else:
modality = "other"

# special case for merging grad channels
layout = find_layout(info)
if (
Expand All @@ -136,10 +144,9 @@ def _prepare_topomap_plot(inst, ch_type, sphere=None):
picks, _ = _pair_grad_sensors(info, layout)
pos = _find_topomap_coords(info, picks[::2], sphere=sphere)
merge_channels = True
elif ch_type in _fnirs_types:
# fNIRS data commonly has overlapping channels, so deal with separately
picks, pos, merge_channels, overlapping_channels = _average_fnirs_overlaps(
info, ch_type, sphere
elif modality != "other":
picks, pos, merge_channels, overlapping_channels = _find_overlaps(
info, ch_type, sphere, modality=modality
)
else:
merge_channels = False
Expand All @@ -162,7 +169,7 @@ def _prepare_topomap_plot(inst, ch_type, sphere=None):
pos = _find_topomap_coords(info, picks, sphere=sphere)

ch_names = [info["ch_names"][k] for k in picks]
if ch_type in _fnirs_types:
if modality == "fnirs":
# Remove the chroma label type for cleaner labeling.
ch_names = [k[:-4] for k in ch_names]

Expand All @@ -171,24 +178,38 @@ def _prepare_topomap_plot(inst, ch_type, sphere=None):
# change names so that vectorview combined grads appear as MEG014x
# instead of MEG0142 or MEG0143 which are the 2 planar grads.
ch_names = [ch_names[k][:-1] + "x" for k in range(0, len(ch_names), 2)]
else:
assert ch_type in _fnirs_types
# Modify the nirs channel names to indicate they are to be merged
elif modality == "fnirs":
# Modify the channel names to indicate they are to be merged
# New names will have the form S1_D1xS2_D2
# More than two channels can overlap and be merged
for set_ in overlapping_channels:
idx = ch_names.index(set_[0][:-4])
new_name = "x".join(s[:-4] for s in set_)
ch_names[idx] = new_name
elif modality == "opm":
# Modify the channel names to indicate they are to be merged
# New names will have the form S1xS2
for set_ in overlapping_channels:
idx = ch_names.index(set_[0])
new_name = ".".join(s for s in set_)
ch_names[idx] = new_name

pos = np.array(pos)[:, :2] # 2D plot, otherwise interpolation bugs
return picks, pos, merge_channels, ch_names, ch_type, sphere, clip_origin


def _average_fnirs_overlaps(info, ch_type, sphere):
def _find_overlaps(info, ch_type, sphere, modality="fnirs"):
"""Find overlapping channels."""
from ..channels.layout import _find_topomap_coords

picks = pick_types(info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads")
if modality == "fnirs":
picks = pick_types(
info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads"
)
elif modality == "opm":
picks = pick_types(info, meg=True, ref_meg=False, exclude="bads")
else:
raise ValueError(f"Invalid modality for colocated sensors: {modality}")
chs = [info["chs"][i] for i in picks]
locs3d = np.array([ch["loc"][:3] for ch in chs])
dist = pdist(locs3d)
Expand All @@ -212,33 +233,77 @@ def _average_fnirs_overlaps(info, ch_type, sphere):
overlapping_set = [
chs[i]["ch_name"] for i in np.where(overlapping_mask[chan_idx])[0]
]
overlapping_set = np.insert(
overlapping_set, 0, (chs[chan_idx]["ch_name"])
)
if modality == "fnirs":
overlapping_set = np.insert(
overlapping_set, 0, (chs[chan_idx]["ch_name"])
)
elif modality == "opm":
overlapping_set = np.insert(
overlapping_set, 0, (chs[chan_idx]["ch_name"])
)
rad_channel = _find_radial_channel(info, overlapping_set)
# Make sure the radial channel is first in the overlapping set
overlapping_set = np.array(
[ch for ch in overlapping_set if ch != rad_channel]
)
overlapping_set = np.insert(overlapping_set, 0, rad_channel)
overlapping_channels.append(overlapping_set)
channels_to_exclude.append(overlapping_set[1:])

exclude = list(itertools.chain.from_iterable(channels_to_exclude))
[exclude.append(bad) for bad in info["bads"]]
picks = pick_types(
info, meg=False, ref_meg=False, fnirs=ch_type, exclude=exclude
)
pos = _find_topomap_coords(info, picks, sphere=sphere)
picks = pick_types(info, meg=False, ref_meg=False, fnirs=ch_type)
if modality == "fnirs":
picks = pick_types(
info, meg=False, ref_meg=False, fnirs=ch_type, exclude=exclude
)
pos = _find_topomap_coords(info, picks, sphere=sphere)
picks = pick_types(info, meg=False, ref_meg=False, fnirs=ch_type)
elif modality == "opm":
picks = pick_types(info, meg=True, ref_meg=False, exclude=exclude)
pos = _find_topomap_coords(info, picks, sphere=sphere)
picks = pick_types(info, meg=True, ref_meg=False)

# Overload the merge_channels variable as this is returned to calling
# function and indicates that merging of data is required
merge_channels = overlapping_channels

else:
picks = pick_types(
info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads"
)
if modality == "fnirs":
picks = pick_types(
info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads"
)
elif modality == "opm":
picks = pick_types(info, meg=True, ref_meg=False, exclude="bads")

merge_channels = False
pos = _find_topomap_coords(info, picks, sphere=sphere)

return picks, pos, merge_channels, overlapping_channels


def _find_radial_channel(info, overlapping_set):
"""Find the most radial channel in the overlapping set."""
if len(overlapping_set) == 1:
return overlapping_set[0]
elif len(overlapping_set) < 1:
raise ValueError("No overlapping channels found.")

radial_score = np.zeros(len(overlapping_set))
for s, sens in enumerate(overlapping_set):
ch_idx = pick_channels(info["ch_names"], [sens])[0]
radial_direction = copy.copy(info["chs"][ch_idx]["loc"][0:3])
radial_direction /= np.linalg.norm(radial_direction)

orientation_vector = info["chs"][ch_idx]["loc"][9:12]
if info["dev_head_t"] is not None:
orientation_vector = apply_trans(info["dev_head_t"], orientation_vector)
radial_score[s] = np.abs(np.dot(radial_direction, orientation_vector))

radial_sensor = overlapping_set[np.argmax(radial_score)]

return radial_sensor


def _plot_update_evoked_topomap(params, bools):
"""Update topomaps."""
from ..channels.layout import _merge_ch_data
Expand Down Expand Up @@ -1636,9 +1701,8 @@ def plot_ica_components(
sphere,
clip_origin,
) = _prepare_topomap_plot(ica, ch_type, sphere=sphere)

cmap = _setup_cmap(cmap, n_axes=len(picks))
names = _prepare_sensor_names(names, show_names)
disp_names = _prepare_sensor_names(names, show_names)
outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)

data = np.dot(
Expand Down Expand Up @@ -1675,7 +1739,7 @@ def plot_ica_components(
pos,
ch_type=ch_type,
sensors=sensors,
names=names,
names=disp_names,
contours=contours,
outlines=outlines,
sphere=sphere,
Expand Down Expand Up @@ -2296,8 +2360,17 @@ def plot_evoked_topomap(
# apply scalings and merge channels
data *= scaling
if merge_channels:
data, ch_names = _merge_ch_data(data, ch_type, ch_names)
if ch_type in _fnirs_types:
# check modality
if any(ch["coil_type"] in _opm_coils for ch in evoked.info["chs"]):
modality = "opm"
elif ch_type in _fnirs_types:
modality = "fnirs"
else:
modality = "other"
# merge data
data, ch_names = _merge_ch_data(data, ch_type, ch_names, modality=modality)
# if ch_type in _fnirs_types:
if modality != "other":
merge_channels = False
# apply mask if requested
if mask is not None:
Expand Down