Skip to content

[MRG] PCA flip for volumetric source estimates #13092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/13092.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add PCA-flip to pool sources in source reconstruction in :func:`mne.extract_label_time_course`, by :newcontrib:`Fabrice Guibert`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
.. _Evgeny Goldstein: https://github.com/evgenygoldstein
.. _Ezequiel Mikulan: https://github.com/ezemikulan
.. _Ezequiel Mikulan: https://github.com/ezemikulan
.. _Fabrice Guibert: https://github.com/Shrecki
.. _Fahimeh Mamashli: https://github.com/fmamashli
.. _Farzin Negahbani: https://github.com/Farzin-Negahbani
.. _Federico Raimondo: https://github.com/fraimondo
Expand Down
46 changes: 34 additions & 12 deletions mne/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,22 +1460,44 @@ def label_sign_flip(label, src):
flip : array
Sign flip vector (contains 1 or -1).
"""
if len(src) != 2:
raise ValueError("Only source spaces with 2 hemisphers are accepted")
if len(src) > 2 or len(src) == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better / more modern check would be something like:

_validate_type(src, SourceSpaces, "src")
_check_option("source space kind", src.kind, ("volume", "surface"))
if src.kind == "volume" and len(src) != 1:
    raise ValueError("Only single-segment volumes, are supported, got labelized volume source space")

And incidentally I think eventually we could add support for segmented volume source spaces, as well as mixed source spaces (once surface + volume are fully supported, mixed isn't too bad after that). But probably not as part of this PR!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably! To support this, the (clarified) code in label_sign_flip with the hemis dictionary might be recycled:

hemis = {}

# Build hemisphere info dictionary
if label.hemi == "both":
    hemis["lh"] = {"id": 0, "vertno": src[0]["vertno"]}
    hemis["rh"] = {"id": 1, "vertno": src[1]["vertno"]}
elif label.hemi in ("lh", "rh"):
    hemis[label.hemi] = {"id": 0, "vertno": src[0]["vertno"]}
else:
    raise Exception(f'Unknown hemisphere type "{label.hemi}"')

raise ValueError(
"Only source spaces with between one and two "
+ "hemispheres are accepted, was {len(src)}"
)

if len(src) == 1 and label.hemi == "both":
raise ValueError(
'Cannot use hemisphere label "both" when source'
+ "space contains a single hemisphere."
)

lh_vertno = src[0]["vertno"]
rh_vertno = src[1]["vertno"]
hemis = {}

# Build hemisphere info dictionary
if label.hemi == "both":
hemis["lh"] = {"id": 0, "vertno": src[0]["vertno"]}
hemis["rh"] = {"id": 1, "vertno": src[1]["vertno"]}
elif label.hemi in ("lh", "rh"):
# If two sources available, the hemisphere's ID must be looked up.
# If only a single source, the ID is zero.
index_ = ("lh", "rh").index(label.hemi) if len(src) == 2 else 0
hemis[label.hemi] = {"id": index_, "vertno": src[index_]["vertno"]}
else:
raise Exception(f'Unknown hemisphere type "{label.hemi}"')

# get source orientations
ori = list()
if label.hemi in ("lh", "both"):
vertices = label.vertices if label.hemi == "lh" else label.lh.vertices
vertno_sel = np.intersect1d(lh_vertno, vertices)
ori.append(src[0]["nn"][vertno_sel])
if label.hemi in ("rh", "both"):
vertices = label.vertices if label.hemi == "rh" else label.rh.vertices
vertno_sel = np.intersect1d(rh_vertno, vertices)
ori.append(src[1]["nn"][vertno_sel])
for hemi, hemi_infos in hemis.items():
# When the label is lh or rh, get vertices directly
if label.hemi == hemi:
vertices = label.vertices
# In the case where label is "both", get label.hemi.vertices
# (so either label.lh.vertices or label.rh.vertices)
else:
vertices = getattr(label, hemi).vertices
vertno_sel = np.intersect1d(hemi_infos["vertno"], vertices)
ori.append(src[hemi_infos["id"]]["nn"][vertno_sel])
if len(ori) == 0:
raise Exception(f'Unknown hemisphere type "{label.hemi}"')
ori = np.concatenate(ori, axis=0)
Expand Down
26 changes: 18 additions & 8 deletions mne/source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3376,12 +3376,19 @@ def _get_ico_tris(grade, verbose=None, return_surf=False):


def _pca_flip(flip, data):
U, s, V = _safe_svd(data, full_matrices=False)
# determine sign-flip
sign = np.sign(np.dot(U[:, 0], flip))
# use average power in label for scaling
scale = np.linalg.norm(s) / np.sqrt(len(data))
return sign * scale * V[0]
result = None
if flip is None:
result = 0
elif data.shape[0] < 2:
result = data.mean(axis=0) # Trivial accumulator
else:
U, s, V = _safe_svd(data, full_matrices=False)
# determine sign-flip
sign = np.sign(np.dot(U[:, 0], flip))
# use average power in label for scaling
scale = np.linalg.norm(s) / np.sqrt(len(data))
result = sign * scale * V[0]
return result


_label_funcs = {
Expand Down Expand Up @@ -3433,6 +3440,8 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse):
# only computes vertex indices and label_flip will be list of None.
from .label import BiHemiLabel, Label, label_sign_flip

logger.debug(f"Selected mode: {mode}")

# if source estimate provided in stc, get vertices from source space and
# check that they are the same as in the stcs
_check_stc_src(stc, src)
Expand Down Expand Up @@ -3644,8 +3653,10 @@ def _get_default_label_modes():


def _get_allowed_label_modes(stc):
if isinstance(stc, _BaseVolSourceEstimate | _BaseVectorSourceEstimate):
if isinstance(stc, _BaseVectorSourceEstimate):
return ("mean", "max", "auto")
elif isinstance(stc, _BaseVolSourceEstimate):
return ("mean", "pca_flip", "max", "auto")
else:
return _get_default_label_modes()

Expand Down Expand Up @@ -3733,7 +3744,6 @@ def _gen_extract_label_time_course(
else:
this_data = stc.data[vertidx]
label_tc[i] = func(flip, this_data)

if mode is not None:
offset = nvert[:-n_mean].sum() # effectively :2 or :0
for i, nv in enumerate(nvert[2:]):
Expand Down
144 changes: 144 additions & 0 deletions mne/tests/test_source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,150 @@ def test_center_of_mass():
assert_equal(np.round(t, 2), 0.12)


@testing.requires_testing_data
@pytest.mark.parametrize(
"label_type, mri_res, test_label, cf, call",
[
(str, False, False, "head", "meth"), # head frame
(str, False, str, "mri", "func"), # fastest, default for testing
(str, True, str, "mri", "func"), # fastest, default for testing
(str, True, False, "mri", "func"), # mri_resolution
(list, True, False, "mri", "func"), # volume label as list
(dict, True, False, "mri", "func"), # volume label as dict
],
)
def test_extract_label_time_course_volume_pca_flip(
src_volume_labels, label_type, mri_res, test_label, cf, call
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always want tests to be fairly complete, but competing against that we also want them to be as quick as possible. Our test suite takes over an hour to run already! 😢

Do we need to iterate over cf here for example? If we already test that functionality elsewhere and the pca_flip functionality here is independent of that, let's not. Same thing with mri_resolution, this seems like something that is unrelated and tested elsewhere most likely.

Maybe the pca_flip option could just be added somewhere else (maybe as one parameter in an already parametrized function) to really trim it down?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this one, I'm on the fence - integrating pca_flip as an option in another test is certainly an option, much like "mean" for example, but at the same time we lose clarity in the testing if we over parameterize no? In particular, it becomes tedious to know what is actually being tested and which cases might or might not be covered...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case I think you're over-testing quite a bit. Your code change is only relevant to some minor parts of the code, not coordinate frame, mri resolution, etc. So I think it makes more sense to trim it down (a lot!) or integrate elsewhere

):
"""Test extraction of label timecourses on VolumetricSourceEstimate with PCA."""
# Setup of data
src_labels, volume_labels, lut = src_volume_labels
n_tot = 46
assert n_tot == len(src_labels)
inv = read_inverse_operator(fname_inv_vol)
if cf == "head":
src = inv["src"]
else:
src = read_source_spaces(fname_src_vol)
klass = VolVectorSourceEstimate._scalar_class
vertices = [src[0]["vertno"]]
n_verts = len(src[0]["vertno"])
n_times = 50
data = np.arange(1, n_verts + 1)
end_shape = (n_times,)
data = np.repeat(data[..., np.newaxis], n_times, -1)
stcs = [klass(data.astype(float), vertices, 0, 1)]

def eltc(*args, **kwargs):
if call == "func":
return extract_label_time_course(stcs, *args, **kwargs)
else:
return [stcs[0].extract_label_time_course(*args, **kwargs)]

# triage "labels" argument
if mri_res:
# All should be there
missing = []
else:
# Nearest misses these
missing = [
"Left-vessel",
"Right-vessel",
"5th-Ventricle",
"non-WM-hypointensities",
]
n_want = len(src_labels)
if label_type is str:
labels = fname_aseg
elif label_type is list:
labels = (fname_aseg, volume_labels)
else:
assert label_type is dict
labels = (fname_aseg, {k: lut[k] for k in volume_labels})
assert mri_res
assert len(missing) == 0
# we're going to add one that won't exist
missing = ["intentionally_bad"]
labels[1][missing[0]] = 10000
n_want += 1
n_tot += 1
n_want -= len(missing)

# _volume_labels(src, labels, mri_resolution)
# actually do the testing
from mne.source_estimate import _pca_flip, _prepare_label_extraction, _volume_labels

labels_expanded = _volume_labels(src, labels, mri_res)
_, src_flip = _prepare_label_extraction(
stcs[0], labels_expanded, src, "pca_flip", "ignore", bool(mri_res)
)

mode = "pca_flip"
with catch_logging() as log:
label_tc = eltc(
labels,
src,
mode=mode,
allow_empty="ignore",
mri_resolution=mri_res,
verbose=True,
)
log = log.getvalue()
assert re.search("^Reading atlas.*aseg\\.mgz\n", log) is not None
if len(missing):
# assert that the missing ones get logged
assert "does not contain" in log
assert repr(missing) in log
else:
assert "does not contain" not in log
assert f"\n{n_want}/{n_tot} atlas regions had at least" in log
assert len(label_tc) == 1
label_tc = label_tc[0]
assert label_tc.shape == (n_tot,) + end_shape
assert label_tc.shape == (n_tot, n_times)
# let's test some actual values by trusting the masks provided by
# setup_volume_source_space. mri_resolution=True does some
# interpolation so we should not expect equivalence, False does
# nearest so we should.
if mri_res:
rtol = 0.8 # max much more sensitive
else:
rtol = 0.0
for si, s in enumerate(src_labels):
func = _pca_flip
these = data[np.isin(src[0]["vertno"], s["vertno"])]
print(these.shape)
assert len(these) == s["nuse"]
if si == 0 and s["seg_name"] == "Unknown":
continue # unknown is crappy
if s["nuse"] == 0:
want = 0.0
if mri_res:
# this one is totally due to interpolation, so no easy
# test here
continue
else:
if src_flip[si] is None:
want = None
else:
want = func(src_flip[si], these)
if want is not None:
assert_allclose(label_tc[si], want, atol=1e-6, rtol=rtol)
# compare with in_label, only on every fourth for speed
if test_label is not False and si % 4 == 0:
label = s["seg_name"]
if test_label is int:
label = lut[label]
in_label = stcs[0].in_label(label, fname_aseg, src).data
assert in_label.shape == (s["nuse"],) + end_shape
if np.all(want == 0):
assert in_label.shape[0] == 0
else:
if src_flip[si] is not None:
in_label = func(src_flip[si], in_label)
assert_allclose(in_label, want, atol=1e-6, rtol=rtol)


@testing.requires_testing_data
@pytest.mark.parametrize("kind", ("surface", "mixed"))
@pytest.mark.parametrize("vector", (False, True))
Expand Down