Skip to content

[MRG] PCA flip for volumetric source estimates #13092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
bcc07f1
example description update.
Shrecki Oct 15, 2024
99b593a
pca_flip allowed for volumetric
Shrecki Oct 16, 2024
00e0487
[FEAT] Volumetric PCA flip implementation
Shrecki Jan 30, 2025
3b3b754
Merge remote-tracking branch 'upstream/main'
Shrecki Jan 30, 2025
d88591f
Merge branch 'main' into pca_flip_volume
Shrecki Jan 30, 2025
2e56dcb
Removed source_space_custom_atlas example - should be object of separ…
Shrecki Jan 31, 2025
a677854
[MISC] Changelog update
Shrecki Jan 31, 2025
8a74ffe
[MISC] Fixed changelog
Shrecki Jan 31, 2025
dd15522
[FIX] Removed erroneous path from test case
Shrecki Jan 31, 2025
6839d73
[autofix.ci] apply automated fixes
autofix-ci[bot] Jan 31, 2025
379614e
[FEAT] Simplify label code and remove cruft code
Shrecki Mar 13, 2025
d3523b7
Merge branch 'pca_flip_volume' of github.com:Shrecki/mne-python into …
Shrecki Mar 13, 2025
33f911d
[FIX] Removed trivial branch
Shrecki Mar 13, 2025
256331b
Merge remote-tracking branch 'upstream/main'
Shrecki Mar 13, 2025
1626e5f
Merge branch 'main' into pca_flip_volume
Shrecki Mar 13, 2025
ec82986
[FIX] label_sign_flip incorrectly handled hemispheres
Shrecki Mar 14, 2025
0ca43cf
Imports moved up top
Shrecki Mar 25, 2025
ee4a174
Updating mri_name to save volumetric source
Shrecki May 15, 2025
3ce6b38
Fix of PCA flip in volume: returned constant 0 as flips meaningless i…
Shrecki May 15, 2025
7ae37e5
Fixed pca flip branch
Shrecki May 15, 2025
d9580da
Handling of flip being an int
Shrecki May 15, 2025
fd71779
Using numpy svd instead of scipy
Shrecki May 15, 2025
69937d2
PCA flip for volumetric is now using randomized SVD to manage to run …
Shrecki May 16, 2025
8888c13
Simplification of PCA flip
Shrecki May 16, 2025
775ec80
Logging
Shrecki May 16, 2025
0fd4be5
Found a trick to make everything much faster with only two svds
Shrecki May 16, 2025
1a89099
Flip handling in _compute_pca_quantitites
Shrecki May 16, 2025
b02c14c
Feat: montage now supports .pos information file
Shrecki May 23, 2025
8a96fec
Float convert in digitization
Shrecki May 23, 2025
1074c1a
Convert dig points to numpy array
Shrecki May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/13092.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add PCA-flip to pool sources in source reconstruction in :func:`mne.extract_label_time_course`, by :newcontrib:`Fabrice Guibert`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
.. _Evgeny Goldstein: https://github.com/evgenygoldstein
.. _Ezequiel Mikulan: https://github.com/ezemikulan
.. _Ezequiel Mikulan: https://github.com/ezemikulan
.. _Fabrice Guibert: https://github.com/Shrecki
.. _Fahimeh Mamashli: https://github.com/fmamashli
.. _Farzin Negahbani: https://github.com/Farzin-Negahbani
.. _Federico Raimondo: https://github.com/fraimondo
Expand Down
46 changes: 34 additions & 12 deletions mne/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,22 +1460,44 @@ def label_sign_flip(label, src):
flip : array
Sign flip vector (contains 1 or -1).
"""
if len(src) != 2:
raise ValueError("Only source spaces with 2 hemisphers are accepted")
if len(src) > 2 or len(src) == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better / more modern check would be something like:

_validate_type(src, SourceSpaces, "src")
_check_option("source space kind", src.kind, ("volume", "surface"))
if src.kind == "volume" and len(src) != 1:
    raise ValueError("Only single-segment volumes, are supported, got labelized volume source space")

And incidentally I think eventually we could add support for segmented volume source spaces, as well as mixed source spaces (once surface + volume are fully supported, mixed isn't too bad after that). But probably not as part of this PR!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably! To support this, the (clarified) code in label_sign_flip with the hemis dictionary might be recycled:

hemis = {}

# Build hemisphere info dictionary
if label.hemi == "both":
    hemis["lh"] = {"id": 0, "vertno": src[0]["vertno"]}
    hemis["rh"] = {"id": 1, "vertno": src[1]["vertno"]}
elif label.hemi in ("lh", "rh"):
    hemis[label.hemi] = {"id": 0, "vertno": src[0]["vertno"]}
else:
    raise Exception(f'Unknown hemisphere type "{label.hemi}"')

raise ValueError(
"Only source spaces with between one and two "
+ "hemispheres are accepted, was {len(src)}"
)

if len(src) == 1 and label.hemi == "both":
raise ValueError(
'Cannot use hemisphere label "both" when source'
+ "space contains a single hemisphere."
)

lh_vertno = src[0]["vertno"]
rh_vertno = src[1]["vertno"]
hemis = {}

# Build hemisphere info dictionary
if label.hemi == "both":
hemis["lh"] = {"id": 0, "vertno": src[0]["vertno"]}
hemis["rh"] = {"id": 1, "vertno": src[1]["vertno"]}
elif label.hemi in ("lh", "rh"):
# If two sources available, the hemisphere's ID must be looked up.
# If only a single source, the ID is zero.
index_ = ("lh", "rh").index(label.hemi) if len(src) == 2 else 0
hemis[label.hemi] = {"id": index_, "vertno": src[index_]["vertno"]}
else:
raise Exception(f'Unknown hemisphere type "{label.hemi}"')

# get source orientations
ori = list()
if label.hemi in ("lh", "both"):
vertices = label.vertices if label.hemi == "lh" else label.lh.vertices
vertno_sel = np.intersect1d(lh_vertno, vertices)
ori.append(src[0]["nn"][vertno_sel])
if label.hemi in ("rh", "both"):
vertices = label.vertices if label.hemi == "rh" else label.rh.vertices
vertno_sel = np.intersect1d(rh_vertno, vertices)
ori.append(src[1]["nn"][vertno_sel])
for hemi, hemi_infos in hemis.items():
# When the label is lh or rh, get vertices directly
if label.hemi == hemi:
vertices = label.vertices
# In the case where label is "both", get label.hemi.vertices
# (so either label.lh.vertices or label.rh.vertices)
else:
vertices = getattr(label, hemi).vertices
vertno_sel = np.intersect1d(hemi_infos["vertno"], vertices)
ori.append(src[hemi_infos["id"]]["nn"][vertno_sel])
if len(ori) == 0:
raise Exception(f'Unknown hemisphere type "{label.hemi}"')
ori = np.concatenate(ori, axis=0)
Expand Down
26 changes: 18 additions & 8 deletions mne/source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3376,12 +3376,19 @@ def _get_ico_tris(grade, verbose=None, return_surf=False):


def _pca_flip(flip, data):
U, s, V = _safe_svd(data, full_matrices=False)
# determine sign-flip
sign = np.sign(np.dot(U[:, 0], flip))
# use average power in label for scaling
scale = np.linalg.norm(s) / np.sqrt(len(data))
return sign * scale * V[0]
result = None
if flip is None:
result = 0
elif data.shape[0] < 2:
result = data.mean(axis=0) # Trivial accumulator
else:
U, s, V = _safe_svd(data, full_matrices=False)
# determine sign-flip
sign = np.sign(np.dot(U[:, 0], flip))
# use average power in label for scaling
scale = np.linalg.norm(s) / np.sqrt(len(data))
result = sign * scale * V[0]
return result


_label_funcs = {
Expand Down Expand Up @@ -3433,6 +3440,8 @@ def _prepare_label_extraction(stc, labels, src, mode, allow_empty, use_sparse):
# only computes vertex indices and label_flip will be list of None.
from .label import BiHemiLabel, Label, label_sign_flip

logger.debug(f"Selected mode: {mode}")

# if source estimate provided in stc, get vertices from source space and
# check that they are the same as in the stcs
_check_stc_src(stc, src)
Expand Down Expand Up @@ -3644,8 +3653,10 @@ def _get_default_label_modes():


def _get_allowed_label_modes(stc):
if isinstance(stc, _BaseVolSourceEstimate | _BaseVectorSourceEstimate):
if isinstance(stc, _BaseVectorSourceEstimate):
return ("mean", "max", "auto")
elif isinstance(stc, _BaseVolSourceEstimate):
return ("mean", "pca_flip", "max", "auto")
else:
return _get_default_label_modes()

Expand Down Expand Up @@ -3733,7 +3744,6 @@ def _gen_extract_label_time_course(
else:
this_data = stc.data[vertidx]
label_tc[i] = func(flip, this_data)

if mode is not None:
offset = nvert[:-n_mean].sum() # effectively :2 or :0
for i, nv in enumerate(nvert[2:]):
Expand Down
144 changes: 144 additions & 0 deletions mne/tests/test_source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,150 @@ def test_center_of_mass():
assert_equal(np.round(t, 2), 0.12)


@testing.requires_testing_data
@pytest.mark.parametrize(
"label_type, mri_res, test_label, cf, call",
[
(str, False, False, "head", "meth"), # head frame
(str, False, str, "mri", "func"), # fastest, default for testing
(str, True, str, "mri", "func"), # fastest, default for testing
(str, True, False, "mri", "func"), # mri_resolution
(list, True, False, "mri", "func"), # volume label as list
(dict, True, False, "mri", "func"), # volume label as dict
],
)
def test_extract_label_time_course_volume_pca_flip(
src_volume_labels, label_type, mri_res, test_label, cf, call
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always want tests to be fairly complete, but competing against that we also want them to be as quick as possible. Our test suite takes over an hour to run already! 😢

Do we need to iterate over cf here for example? If we already test that functionality elsewhere and the pca_flip functionality here is independent of that, let's not. Same thing with mri_resolution, this seems like something that is unrelated and tested elsewhere most likely.

Maybe the pca_flip option could just be added somewhere else (maybe as one parameter in an already parametrized function) to really trim it down?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this one, I'm on the fence - integrating pca_flip as an option in another test is certainly an option, much like "mean" for example, but at the same time we lose clarity in the testing if we over parameterize no? In particular, it becomes tedious to know what is actually being tested and which cases might or might not be covered...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case I think you're over-testing quite a bit. Your code change is only relevant to some minor parts of the code, not coordinate frame, mri resolution, etc. So I think it makes more sense to trim it down (a lot!) or integrate elsewhere

):
"""Test extraction of label timecourses on VolumetricSourceEstimate with PCA."""
# Setup of data
src_labels, volume_labels, lut = src_volume_labels
n_tot = 46
assert n_tot == len(src_labels)
inv = read_inverse_operator(fname_inv_vol)
if cf == "head":
src = inv["src"]
else:
src = read_source_spaces(fname_src_vol)
klass = VolVectorSourceEstimate._scalar_class
vertices = [src[0]["vertno"]]
n_verts = len(src[0]["vertno"])
n_times = 50
data = np.arange(1, n_verts + 1)
end_shape = (n_times,)
data = np.repeat(data[..., np.newaxis], n_times, -1)
stcs = [klass(data.astype(float), vertices, 0, 1)]

def eltc(*args, **kwargs):
if call == "func":
return extract_label_time_course(stcs, *args, **kwargs)
else:
return [stcs[0].extract_label_time_course(*args, **kwargs)]

# triage "labels" argument
if mri_res:
# All should be there
missing = []
else:
# Nearest misses these
missing = [
"Left-vessel",
"Right-vessel",
"5th-Ventricle",
"non-WM-hypointensities",
]
n_want = len(src_labels)
if label_type is str:
labels = fname_aseg
elif label_type is list:
labels = (fname_aseg, volume_labels)
else:
assert label_type is dict
labels = (fname_aseg, {k: lut[k] for k in volume_labels})
assert mri_res
assert len(missing) == 0
# we're going to add one that won't exist
missing = ["intentionally_bad"]
labels[1][missing[0]] = 10000
n_want += 1
n_tot += 1
n_want -= len(missing)

# _volume_labels(src, labels, mri_resolution)
# actually do the testing
from mne.source_estimate import _pca_flip, _prepare_label_extraction, _volume_labels

labels_expanded = _volume_labels(src, labels, mri_res)
_, src_flip = _prepare_label_extraction(
stcs[0], labels_expanded, src, "pca_flip", "ignore", bool(mri_res)
)

mode = "pca_flip"
with catch_logging() as log:
label_tc = eltc(
labels,
src,
mode=mode,
allow_empty="ignore",
mri_resolution=mri_res,
verbose=True,
)
log = log.getvalue()
assert re.search("^Reading atlas.*aseg\\.mgz\n", log) is not None
if len(missing):
# assert that the missing ones get logged
assert "does not contain" in log
assert repr(missing) in log
else:
assert "does not contain" not in log
assert f"\n{n_want}/{n_tot} atlas regions had at least" in log
assert len(label_tc) == 1
label_tc = label_tc[0]
assert label_tc.shape == (n_tot,) + end_shape
assert label_tc.shape == (n_tot, n_times)
# let's test some actual values by trusting the masks provided by
# setup_volume_source_space. mri_resolution=True does some
# interpolation so we should not expect equivalence, False does
# nearest so we should.
if mri_res:
rtol = 0.8 # max much more sensitive
else:
rtol = 0.0
for si, s in enumerate(src_labels):
func = _pca_flip
these = data[np.isin(src[0]["vertno"], s["vertno"])]
print(these.shape)
assert len(these) == s["nuse"]
if si == 0 and s["seg_name"] == "Unknown":
continue # unknown is crappy
if s["nuse"] == 0:
want = 0.0
if mri_res:
# this one is totally due to interpolation, so no easy
# test here
continue
else:
if src_flip[si] is None:
want = None
else:
want = func(src_flip[si], these)
if want is not None:
assert_allclose(label_tc[si], want, atol=1e-6, rtol=rtol)
# compare with in_label, only on every fourth for speed
if test_label is not False and si % 4 == 0:
label = s["seg_name"]
if test_label is int:
label = lut[label]
in_label = stcs[0].in_label(label, fname_aseg, src).data
assert in_label.shape == (s["nuse"],) + end_shape
if np.all(want == 0):
assert in_label.shape[0] == 0
else:
if src_flip[si] is not None:
in_label = func(src_flip[si], in_label)
assert_allclose(in_label, want, atol=1e-6, rtol=rtol)


@testing.requires_testing_data
@pytest.mark.parametrize("kind", ("surface", "mixed"))
@pytest.mark.parametrize("vector", (False, True))
Expand Down