From 752854dd2ac29e58663abd5a91c72f3c0863895b Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 20 Sep 2023 17:20:09 +0300 Subject: [PATCH 01/65] Add EvokedField.plotter --- mne/viz/evoked_field.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index ba89ebb87b5..2815bae818d 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -196,6 +196,7 @@ def __init__( self._in_brain_figure = False self._renderer.set_interaction(interaction) + self.plotter = self._renderer.plotter self.interaction = interaction # Prepare the surface maps From 1178ebc9881985f4a039498e185c3e3af5428162 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 4 Oct 2023 14:24:45 +0300 Subject: [PATCH 02/65] BUG: Fix bug with sensor_colors --- doc/changes/devel.rst | 2 + mne/defaults.py | 3 + mne/gui/_coreg.py | 7 +- mne/utils/docs.py | 16 ++ mne/viz/_3d.py | 206 ++++++++++++------------- mne/viz/_brain/_brain.py | 6 + mne/viz/backends/_pyvista.py | 23 ++- mne/viz/tests/test_3d.py | 6 + tutorials/clinical/30_ecog.py | 9 +- tutorials/intro/40_sensor_locations.py | 1 - 10 files changed, 152 insertions(+), 127 deletions(-) diff --git a/doc/changes/devel.rst b/doc/changes/devel.rst index a639a64428b..b9a700dc7c4 100644 --- a/doc/changes/devel.rst +++ b/doc/changes/devel.rst @@ -31,6 +31,7 @@ Enhancements - Add the possibility to provide a float between 0 and 1 as ``n_grad``, ``n_mag`` and ``n_eeg`` in `~mne.compute_proj_raw`, `~mne.compute_proj_epochs` and `~mne.compute_proj_evoked` to select the number of vectors based on the cumulative explained variance (:gh:`11919` by `Mathieu Scheltienne`_) - Added support for Artinis fNIRS data files to :func:`mne.io.read_raw_snirf` (:gh:`11926` by `Robert Luke`_) - Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_) +- Add support for passing a :class:`python:dict` as ``sensor_color`` to specify per-channel-type colors in :func:`mne.viz.plot_alignment` (:gh:`12067` by `Eric Larson`) - Add inferring EEGLAB files' montage unit automatically based on estimated head radius using :func:`read_raw_eeglab(..., montage_units="auto") ` (:gh:`11925` by `Jack Zhang`_, :gh:`11951` by `Eric Larson`_) - Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array ` data (:gh:`11803` by `Alex Rockhill`_) - Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.Forward.save` (:gh:`12036` by `Eric Larson`_) @@ -56,6 +57,7 @@ Bugs - Fix bug with axis clip box boundaries in :func:`mne.viz.plot_evoked_topo` and related functions (:gh:`11999` by `Eric Larson`_) - Fix bug with ``subject_info`` when loading data from and exporting to EDF file (:gh:`11952` by `Paul Roujansky`_) - Fix bug with delayed checking of :class:`info["bads"] ` (:gh:`12038` by `Eric Larson`_) +- Fix bug with :func:`mne.viz.plot_alignment` where ``sensor_colors`` were not handled properly on a per-channel-type basis (:gh:`12067` by `Eric Larson`) - Fix handling of channel information in annotations when loading data from and exporting to EDF file (:gh:`11960` :gh:`12017` :gh:`12044` by `Paul Roujansky`_) - Add missing ``overwrite`` and ``verbose`` parameters to :meth:`Transform.save() ` (:gh:`12004` by `Marijn van Vliet`_) - Fix parsing of eye-link :class:`~mne.Annotations` when ``apply_offsets=False`` is provided to :func:`~mne.io.read_raw_eyelink` (:gh:`12003` by `Mathieu Scheltienne`_) diff --git a/mne/defaults.py b/mne/defaults.py index 498312caa15..3d3b4d45761 100644 --- a/mne/defaults.py +++ b/mne/defaults.py @@ -227,6 +227,7 @@ coreg=dict( mri_fid_opacity=1.0, dig_fid_opacity=1.0, + # go from unit scaling (e.g., unit-radius sphere) to meters mri_fid_scale=5e-3, dig_fid_scale=8e-3, extra_scale=4e-3, @@ -235,6 +236,8 @@ eegp_height=0.1, ecog_scale=5e-3, seeg_scale=5e-3, + meg_scale=1.0, # sensors are already in SI units + ref_meg_scale=1.0, dbs_scale=5e-3, fnirs_scale=5e-3, source_scale=5e-3, diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index c44bd71dd75..e11b61ed898 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -835,7 +835,7 @@ def _redraw(self, *, verbose=None): mri_fids=self._add_mri_fiducials, hsp=self._add_head_shape_points, hpi=self._add_hpi_coils, - eeg=self._add_eeg_channels, + eeg=self._add_eeg_fnirs_channels, head_fids=self._add_head_fiducials, helmet=self._add_helmet, ) @@ -1217,7 +1217,7 @@ def _add_head_shape_points(self): hsp_actors = None self._update_actor("head_shape_points", hsp_actors) - def _add_eeg_channels(self): + def _add_eeg_fnirs_channels(self): if self._eeg_channels: eeg = ["original"] picks = pick_types(self._info, eeg=(len(eeg) > 0), fnirs=True) @@ -1240,8 +1240,7 @@ def _add_eeg_channels(self): check_inside=self._check_inside, nearest=self._nearest, ) - sens_actors = actors["eeg"] - sens_actors.extend(actors["fnirs"]) + sens_actors = sum(actors.values(), list()) else: sens_actors = None else: diff --git a/mne/utils/docs.py b/mne/utils/docs.py index d32e1923aa4..b2fc1abf69a 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3966,6 +3966,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): automatically generated, corresponding to all non-zero events. """ +docdict[ + "sensor_colors" +] = """ +sensor_colors : array-like of color | dict | None + Colors to use for the sensor glyphs. Can be None (default) to use default colors. + A dict should provide the colors (values) for each channel type (keys), e.g.:: + + dict(eeg=eeg_colors) + + Where the value (``eeg_colors`` above) can be broadcast to an array of colors with + length that matches the number of channels of that type, i.e., is compatible with + :func:`matplotlib.colors.to_rgba_array`. A few examples of this for the case above + are the string ``"k"``, a list of ``n_eeg`` color strings, or an NumPy ndarray of + shape ``(n_eeg, 3)`` or ``(n_eeg, 4)``. +""" + docdict[ "sensors_topomap" ] = """ diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index f4aa2b1999c..9742d6c3b30 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -9,6 +9,7 @@ # # License: Simplified BSD +from collections import defaultdict import os import os.path as op import warnings @@ -604,11 +605,10 @@ def plot_alignment( .. versionadded:: 0.16 .. versionchanged:: 1.0 Defaults to ``'terrain'``. - sensor_colors : array-like | None - Colors to use for the sensor glyphs. Can be list-like of color strings - (length ``n_sensors``) or array-like of RGB(A) values (shape - ``(n_sensors, 3)`` or ``(n_sensors, 4)``). ``None`` (the default) uses - the default sensor colors for the :func:`~mne.viz.plot_alignment` GUI. + %(sensor_colors)s + + .. versionchanged:: 1.6 + Support for passing a ``dict`` was added. %(verbose)s Returns @@ -1437,29 +1437,16 @@ def _plot_sensors( sensor_colors=None, ): """Render sensors in a 3D scene.""" + from matplotlib.colors import to_rgba_array + defaults = DEFAULTS["coreg"] ch_pos, sources, detectors = _ch_pos_in_coord_frame( pick_info(info, picks), to_cf_t=to_cf_t, warn_meg=warn_meg ) - actors = dict( - meg=list(), - ref_meg=list(), - eeg=list(), - fnirs=list(), - ecog=list(), - seeg=list(), - dbs=list(), - ) - locs = dict( - eeg=list(), - fnirs=list(), - ecog=list(), - seeg=list(), - source=list(), - detector=list(), - ) - scalar = 1 if units == "m" else 1e3 + actors = defaultdict(lambda: list()) + locs = defaultdict(lambda: list()) + unit_scalar = 1 if units == "m" else 1e3 for ch_name, ch_coord in ch_pos.items(): ch_type = channel_type(info, info.ch_names.index(ch_name)) # for default picking @@ -1471,46 +1458,75 @@ def _plot_sensors( plot_sensors = (ch_type != "fnirs" or "channels" in fnirs) and ( ch_type != "eeg" or "original" in eeg ) - color = defaults[ch_type + "_color"] # plot sensors if isinstance(ch_coord, tuple): # is meg, plot coil - verts, triangles = ch_coord - actor, _ = renderer.surface( - surface=dict(rr=verts * scalar, tris=triangles), - color=color, - opacity=0.25, - backface_culling=True, - ) - actors[ch_type].append(actor) - else: - if plot_sensors: - locs[ch_type].append(ch_coord) + ch_coord = dict(rr=ch_coord[0] * unit_scalar, tris=ch_coord[1]) + if plot_sensors: + locs[ch_type].append(ch_coord) if ch_name in sources and "sources" in fnirs: locs["source"].append(sources[ch_name]) if ch_name in detectors and "detectors" in fnirs: locs["detector"].append(detectors[ch_name]) + # Plot these now if ch_name in sources and ch_name in detectors and "pairs" in fnirs: actor, _ = renderer.tube( # array of origin and dest points - origin=sources[ch_name][np.newaxis] * scalar, - destination=detectors[ch_name][np.newaxis] * scalar, - radius=0.001 * scalar, + origin=sources[ch_name][np.newaxis] * unit_scalar, + destination=detectors[ch_name][np.newaxis] * unit_scalar, + radius=0.001 * unit_scalar, ) actors[ch_type].append(actor) + del ch_type - # add sensors - for sensor_type in locs.keys(): - if len(locs[sensor_type]) > 0: - sens_loc = np.array(locs[sensor_type]) - sens_loc = sens_loc[~np.isnan(sens_loc).any(axis=1)] - scale = defaults[sensor_type + "_scale"] - if sensor_colors is None: - color = defaults[sensor_type + "_color"] + # now actually plot the sensors + extra = "" + types = (dict, None) + if len(locs) == 0: + return + elif len(locs) == 1: + # Upsample from array-like to dict when there is one channel type + extra = "(or array-like since only one sensor type is plotted)" + if sensor_colors is not None and not isinstance(sensor_colors, dict): + sensor_colors = { + list(locs)[0]: to_rgba_array(sensor_colors), + } + else: + extra = f"when more than one channel type ({list(locs)}) is plotted" + _validate_type(sensor_colors, types, "sensor_colors", extra=extra) + del extra, types + if sensor_colors is None: + sensor_colors = dict() + assert isinstance(sensor_colors, dict) + for ch_type, sens_loc in locs.items(): + assert len(sens_loc) # should be guaranteed above + colors = to_rgba_array(sensor_colors.get(ch_type, defaults[ch_type + "_color"])) + _check_option( + f"len(sensor_colors[{repr(ch_type)}])", + colors.shape[0], + (len(sens_loc), 1), + ) + scale = defaults[ch_type + "_scale"] * unit_scalar + if isinstance(sens_loc[0], dict): # meg coil + if len(colors) == 1: + colors = [colors[0]] * len(sens_loc) + for surface, color in zip(sens_loc, colors): + actor, _ = renderer.surface( + surface=surface, + color=color[:3], + opacity=0.25 * color[3], + backface_culling=False, # visible from all sides + ) + actors[ch_type].append(actor) + else: + sens_loc = np.array(sens_loc, float) + mask = ~np.isnan(sens_loc).any(axis=1) + if len(colors) == 1: + # Single color mode (one actor) actor, _ = _plot_glyphs( renderer=renderer, - loc=sens_loc * scalar, - color=color, - scale=scale * scalar, - opacity=sensor_opacity, + loc=sens_loc[mask] * unit_scalar, + color=colors[0, :3], + scale=scale, + opacity=sensor_opacity * colors[0, 3], orient_glyphs=orient_glyphs, scale_by_distance=scale_by_distance, project_points=project_points, @@ -1518,31 +1534,18 @@ def _plot_sensors( check_inside=check_inside, nearest=nearest, ) - if sensor_type in ("source", "detector"): - sensor_type = "fnirs" - actors[sensor_type].append(actor) + actors[ch_type].append(actor) else: - actor_list = [] - for idx_sen in range(sens_loc.shape[0]): - sensor_colors = np.asarray(sensor_colors) - if ( - sensor_colors.ndim not in (1, 2) - or sensor_colors.shape[0] != sens_loc.shape[0] - ): - raise ValueError( - "sensor_colors should either be None or be " - "array-like with shape (n_sensors,) or " - "(n_sensors, 3) or (n_sensors, 4). Got shape " - f"{sensor_colors.shape}." - ) - color = sensor_colors[idx_sen] - + # Multi-color mode (multiple actors) + for loc, color, usable in zip(sens_loc, colors, mask): + if not usable: + continue actor, _ = _plot_glyphs( renderer=renderer, - loc=(sens_loc * scalar)[idx_sen, :], - color=color, - scale=scale * scalar, - opacity=sensor_opacity, + loc=loc * unit_scalar, + color=color[:3], + scale=scale, + opacity=sensor_opacity * color[3], orient_glyphs=orient_glyphs, scale_by_distance=scale_by_distance, project_points=project_points, @@ -1550,40 +1553,31 @@ def _plot_sensors( check_inside=check_inside, nearest=nearest, ) - actor_list.append(actor) - if sensor_type in ("source", "detector"): - sensor_type = "fnirs" - actors[sensor_type].append(actor_list) - - # add projected eeg - eeg_indices = pick_types(info, eeg=True) - if eeg_indices.size > 0 and "projected" in eeg: - logger.info("Projecting sensors to the head surface") - eeg_loc = np.array([ch_pos[info.ch_names[idx]] for idx in eeg_indices]) - eeg_loc = eeg_loc[~np.isnan(eeg_loc).any(axis=1)] - eegp_loc, eegp_nn = _project_onto_surface( - eeg_loc, head_surf, project_rrs=True, return_nn=True - )[2:4] - del eeg_loc - eegp_loc *= scalar - scale = defaults["eegp_scale"] * scalar - actor, _ = renderer.quiver3d( - x=eegp_loc[:, 0], - y=eegp_loc[:, 1], - z=eegp_loc[:, 2], - u=eegp_nn[:, 0], - v=eegp_nn[:, 1], - w=eegp_nn[:, 2], - color=defaults["eegp_color"], - mode="cylinder", - scale=scale, - opacity=0.6, - glyph_height=defaults["eegp_height"], - glyph_center=(0.0, -defaults["eegp_height"] / 2.0, 0), - glyph_resolution=20, - backface_culling=True, - ) - actors["eeg"].append(actor) + actors[ch_type].append(actor) + if ch_type == "eeg" and "projected" in eeg: + logger.info("Projecting sensors to the head surface") + eegp_loc, eegp_nn = _project_onto_surface( + sens_loc[mask], head_surf, project_rrs=True, return_nn=True + )[2:4] + eegp_loc *= unit_scalar + actor, _ = renderer.quiver3d( + x=eegp_loc[:, 0], + y=eegp_loc[:, 1], + z=eegp_loc[:, 2], + u=eegp_nn[:, 0], + v=eegp_nn[:, 1], + w=eegp_nn[:, 2], + color=defaults["eegp_color"], + mode="cylinder", + scale=defaults["eegp_scale"] * unit_scalar, + opacity=0.6, + glyph_height=defaults["eegp_height"], + glyph_center=(0.0, -defaults["eegp_height"] / 2.0, 0), + glyph_resolution=20, + backface_culling=True, + ) + actors["eeg"].append(actor) + actors = dict(actors) # get rid of defaultdict return actors diff --git a/mne/viz/_brain/_brain.py b/mne/viz/_brain/_brain.py index 49d4bbcc45a..f4d3a90eb0a 100644 --- a/mne/viz/_brain/_brain.py +++ b/mne/viz/_brain/_brain.py @@ -2763,6 +2763,8 @@ def add_sensors( seeg=True, dbs=True, max_dist=0.004, + *, + sensor_colors=None, verbose=None, ): """Add mesh objects to represent sensor positions. @@ -2778,6 +2780,9 @@ def add_sensors( %(seeg)s %(dbs)s %(max_dist_ieeg)s + %(sensor_colors)s + + .. versionadded:: 1.6 %(verbose)s Notes @@ -2832,6 +2837,7 @@ def add_sensors( warn_meg, head_surf, self._units, + sensor_colors=sensor_colors, ) for item, actors in sensors_actors.items(): for actor in actors: diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index bad105d36e2..700ff9e6870 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -373,17 +373,18 @@ def polydata( polygon_offset=None, **kwargs, ): + from matplotlib.colors import to_rgba_array + with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) rgba = False - if color is not None and len(color) == mesh.n_points: - if color.shape[1] == 3: - scalars = np.c_[color, np.ones(mesh.n_points)] - else: - scalars = color - scalars = (scalars * 255).astype("ubyte") - color = None - rgba = True + if color is not None: + # See if we need to convert or not + check_color = to_rgba_array(color) + if len(check_color) == mesh.n_points: + scalars = (check_color * 255).astype("ubyte") + color = None + rgba = True if isinstance(colormap, np.ndarray): if colormap.dtype == np.uint8: colormap = colormap.astype(np.float64) / 255.0 @@ -395,24 +396,22 @@ def polydata( mesh.GetPointData().SetActiveNormals("Normals") else: _compute_normals(mesh) - if "rgba" in kwargs: - rgba = kwargs["rgba"] - kwargs.pop("rgba") smooth_shading = self.smooth_shading if representation == "wireframe": smooth_shading = False # never use smooth shading for wf + rgba = kwargs.pop("rgba", rgba) actor = _add_mesh( plotter=self.plotter, mesh=mesh, color=color, scalars=scalars, edge_color=color, - rgba=rgba, opacity=opacity, cmap=colormap, backface_culling=backface_culling, rng=[vmin, vmax], show_scalar_bar=False, + rgba=rgba, smooth_shading=smooth_shading, interpolate_before_map=interpolate_before_map, style=representation, diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 1a769aef2c3..f7993111543 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -278,8 +278,13 @@ def test_plot_alignment_meg(renderer, system): this_info = read_raw_kit(sqd_fname).info meg = ["helmet", "sensors"] + sensor_colors = "k" # should be upsampled to correct shape if system == "KIT": meg.append("ref") + with pytest.raises(TypeError, match="instance of dict"): + plot_alignment(this_info, meg=meg, sensor_colors=sensor_colors) + sensor_colors = dict(meg=sensor_colors) + sensor_colors["ref_meg"] = ["r"] * len(pick_types(this_info, ref_meg=True)) fig = plot_alignment( this_info, read_trans(trans_fname), @@ -287,6 +292,7 @@ def test_plot_alignment_meg(renderer, system): subjects_dir=subjects_dir, meg=meg, eeg=False, + sensor_colors=sensor_colors, ) assert isinstance(fig, Figure3D) # count the number of objects: should be n_meg_ch + 1 (helmet) + 1 (head) diff --git a/tutorials/clinical/30_ecog.py b/tutorials/clinical/30_ecog.py index e839b45365b..b97b44c1036 100644 --- a/tutorials/clinical/30_ecog.py +++ b/tutorials/clinical/30_ecog.py @@ -133,9 +133,9 @@ subjects_dir=subjects_dir, surfaces=["pial"], coord_frame="head", - sensor_colors=None, + sensor_colors=(1.0, 1.0, 1.0, 0.5), ) -mne.viz.set_3d_view(fig, azimuth=0, elevation=70) +mne.viz.set_3d_view(fig, azimuth=0, elevation=70, focalpoint="auto", distance="auto") xy, im = snapshot_brain_montage(fig, raw.info) @@ -165,7 +165,8 @@ gamma_power_at_15s -= gamma_power_at_15s.min() gamma_power_at_15s /= gamma_power_at_15s.max() rgba = colormaps.get_cmap("viridis") -sensor_colors = gamma_power_at_15s.map(rgba).tolist() +sensor_colors = np.array(gamma_power_at_15s.map(rgba).tolist(), float) +sensor_colors[:, 3] = 0.5 fig = plot_alignment( raw.info, @@ -177,7 +178,7 @@ sensor_colors=sensor_colors, ) -mne.viz.set_3d_view(fig, azimuth=0, elevation=70) +mne.viz.set_3d_view(fig, azimuth=0, elevation=70, focalpoint="auto", distance="auto") xy, im = snapshot_brain_montage(fig, raw.info) diff --git a/tutorials/intro/40_sensor_locations.py b/tutorials/intro/40_sensor_locations.py index 0ef663fa810..86fefe1bb80 100644 --- a/tutorials/intro/40_sensor_locations.py +++ b/tutorials/intro/40_sensor_locations.py @@ -9,7 +9,6 @@ MNE-Python handles physical locations of sensors. As usual we'll start by importing the modules we need: """ - # %% from pathlib import Path From 755682348faa8ee6c5e1cef6ffe616a7596b7353 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Wed, 4 Oct 2023 09:16:28 -0400 Subject: [PATCH 03/65] Update devel.rst --- doc/changes/devel.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/changes/devel.rst b/doc/changes/devel.rst index b9a700dc7c4..d4532d60721 100644 --- a/doc/changes/devel.rst +++ b/doc/changes/devel.rst @@ -31,7 +31,7 @@ Enhancements - Add the possibility to provide a float between 0 and 1 as ``n_grad``, ``n_mag`` and ``n_eeg`` in `~mne.compute_proj_raw`, `~mne.compute_proj_epochs` and `~mne.compute_proj_evoked` to select the number of vectors based on the cumulative explained variance (:gh:`11919` by `Mathieu Scheltienne`_) - Added support for Artinis fNIRS data files to :func:`mne.io.read_raw_snirf` (:gh:`11926` by `Robert Luke`_) - Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_) -- Add support for passing a :class:`python:dict` as ``sensor_color`` to specify per-channel-type colors in :func:`mne.viz.plot_alignment` (:gh:`12067` by `Eric Larson`) +- Add support for passing a :class:`python:dict` as ``sensor_color`` to specify per-channel-type colors in :func:`mne.viz.plot_alignment` (:gh:`12067` by `Eric Larson`_) - Add inferring EEGLAB files' montage unit automatically based on estimated head radius using :func:`read_raw_eeglab(..., montage_units="auto") ` (:gh:`11925` by `Jack Zhang`_, :gh:`11951` by `Eric Larson`_) - Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array ` data (:gh:`11803` by `Alex Rockhill`_) - Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.Forward.save` (:gh:`12036` by `Eric Larson`_) @@ -57,7 +57,7 @@ Bugs - Fix bug with axis clip box boundaries in :func:`mne.viz.plot_evoked_topo` and related functions (:gh:`11999` by `Eric Larson`_) - Fix bug with ``subject_info`` when loading data from and exporting to EDF file (:gh:`11952` by `Paul Roujansky`_) - Fix bug with delayed checking of :class:`info["bads"] ` (:gh:`12038` by `Eric Larson`_) -- Fix bug with :func:`mne.viz.plot_alignment` where ``sensor_colors`` were not handled properly on a per-channel-type basis (:gh:`12067` by `Eric Larson`) +- Fix bug with :func:`mne.viz.plot_alignment` where ``sensor_colors`` were not handled properly on a per-channel-type basis (:gh:`12067` by `Eric Larson`_) - Fix handling of channel information in annotations when loading data from and exporting to EDF file (:gh:`11960` :gh:`12017` :gh:`12044` by `Paul Roujansky`_) - Add missing ``overwrite`` and ``verbose`` parameters to :meth:`Transform.save() ` (:gh:`12004` by `Marijn van Vliet`_) - Fix parsing of eye-link :class:`~mne.Annotations` when ``apply_offsets=False`` is provided to :func:`~mne.io.read_raw_eyelink` (:gh:`12003` by `Mathieu Scheltienne`_) From cb5a52062de5959f451cbc0958f5f005ea741b11 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 4 Oct 2023 18:37:12 +0300 Subject: [PATCH 04/65] Add foreground/background parameters to EvokedField plot --- mne/viz/evoked_field.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index 9e314a917ed..eaf102336a5 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -48,8 +48,6 @@ class EvokedField: the average peak latency (across sensor types) is used. time_label : str | None How to print info about the time instant visualized. - %(n_jobs)s - fig : instance of Figure3D | None If None (default), a new figure will be created, otherwise it will plot into the given figure. @@ -89,6 +87,17 @@ class EvokedField: ``True`` if there is more than one time point and ``False`` otherwise. .. versionadded:: 1.6 + background : tuple(int, int, int) + The color definition of the background: (red, green, blue). + + .. versionadded:: 1.6 + foreground : matplotlib color + Color of the foreground (will be used for colorbars and text). + None (default) will use black or white depending on the value + of ``background``. + + .. versionadded:: 1.6 + %(n_jobs)s %(verbose)s Notes @@ -107,7 +116,6 @@ def __init__( *, time=None, time_label="t = %0.0f ms", - n_jobs=None, fig=None, vmax=None, n_contours=21, @@ -116,6 +124,9 @@ def __init__( interpolation="nearest", interaction="terrain", time_viewer="auto", + background="white", + foreground="black", + n_jobs=None, verbose=None, ): from .backends.renderer import _get_renderer, _get_3d_backend @@ -140,6 +151,8 @@ def __init__( self._interaction = _check_option( "interaction", interaction, ["trackball", "terrain"] ) + self._bg_color = background + self._fg_color = foreground surf_map_kinds = [surf_map["kind"] for surf_map in surf_maps] if vmax is None: @@ -190,9 +203,7 @@ def __init__( "is currently not supported inside a notebook." ) else: - self._renderer = _get_renderer( - fig, bgcolor=(0.0, 0.0, 0.0), size=(600, 600) - ) + self._renderer = _get_renderer(fig, bgcolor=background, size=(600, 600)) self._in_brain_figure = False self.plotter = self._renderer.plotter @@ -226,14 +237,17 @@ def current_time_func(): current_time_func=current_time_func, times=evoked.times, ) - if not self._in_brain_figure or "time_slider" not in fig.widgets: + if not self._in_brain_figure: # Draw the time label self._time_label = time_label if time_label is not None: if "%" in time_label: time_label = time_label % np.round(1e3 * time) self._time_label_actor = self._renderer.text2d( - x_window=0.01, y_window=0.01, text=time_label + x_window=0.01, + y_window=0.01, + text=time_label, + color=foreground, ) self._configure_dock() @@ -364,7 +378,10 @@ def _update(self): if "%" in self._time_label: time_label = self._time_label % np.round(1e3 * self._current_time) self._time_label_actor = self._renderer.text2d( - x_window=0.01, y_window=0.01, text=time_label + x_window=0.01, + y_window=0.01, + text=time_label, + color=self._fg_color, ) self._renderer.plotter.update() From d4015d4ea81af758e53f65f33e15b9264c6ba0d5 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 4 Oct 2023 19:52:52 +0300 Subject: [PATCH 05/65] Set foreground automatically --- mne/viz/evoked_field.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index eaf102336a5..5a7dc7cb939 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -124,8 +124,8 @@ def __init__( interpolation="nearest", interaction="terrain", time_viewer="auto", - background="white", - foreground="black", + background="black", + foreground=None, n_jobs=None, verbose=None, ): @@ -151,8 +151,10 @@ def __init__( self._interaction = _check_option( "interaction", interaction, ["trackball", "terrain"] ) - self._bg_color = background - self._fg_color = foreground + self._bg_color = _to_rgb(background, name="background") + if foreground is None: + foreground = "w" if sum(self._bg_color) < 2 else "k" + self._fg_color = _to_rgb(foreground, name="foreground") surf_map_kinds = [surf_map["kind"] for surf_map in surf_maps] if vmax is None: From 2c307fe777481c5dab77332e027b20a0902c1ba1 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 4 Oct 2023 20:15:08 +0300 Subject: [PATCH 06/65] First working version of lasso select in plot_evoked_topo --- mne/viz/topo.py | 32 +++++++--- mne/viz/ui_events.py | 20 +++++++ mne/viz/utils.py | 139 +++++++++++++++++++++++++------------------ 3 files changed, 125 insertions(+), 66 deletions(-) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index a01ee72a0c2..ceab61a0e0c 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -26,6 +26,7 @@ _setup_ax_spines, _check_cov, _plot_masked_image, + SelectFromCollection, ) @@ -195,8 +196,11 @@ def format_coord_multiaxis(x, y, ch_name=None): under_ax.set(xlim=[0, 1], ylim=[0, 1]) axs = list() + + shown_ch_names = [] for idx, name in iter_ch: ch_idx = ch_names.index(name) + shown_ch_names.append(name) if not unified: # old, slow way ax = plt.axes(pos[idx]) ax.patch.set_facecolor(axis_facecolor) @@ -237,15 +241,22 @@ def format_coord_multiaxis(x, y, ch_name=None): ], [1, 0, 2], ) - if not img: - under_ax.add_collection( - collections.PolyCollection( - verts, - facecolor=axis_facecolor, - edgecolor=axis_spinecolor, - linewidth=1.0, - ) - ) # Not needed for image plots. + if not img: # Not needed for image plots. + collection = collections.PolyCollection( + verts, + facecolor=axis_facecolor, + edgecolor=axis_spinecolor, + ) + under_ax.add_collection(collection) + fig.lasso = SelectFromCollection( + ax=under_ax, + collection=collection, + names=shown_ch_names, + alpha_nonselected=0, + alpha_selected=1, + linewidth_nonselected=0, + linewidth_selected=0.7, + ) for ax in axs: yield ax, ax._mne_ch_idx @@ -344,6 +355,9 @@ def _plot_topo_onpick(event, show_func): """Onpick callback that shows a single channel in a new figure.""" # make sure that the swipe gesture in OS-X doesn't open many figures orig_ax = event.inaxes + if orig_ax.figure.canvas._key in ["shift", "alt"]: + return + import matplotlib.pyplot as plt try: diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index ba5b1db9a33..c4a90e44132 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -192,6 +192,26 @@ class Contours(UIEvent): contours: List[str] +@dataclass +@fill_doc +class ChannelsSelect(UIEvent): + """Indicates that the user has selected one or more channels. + + Parameters + ---------- + ch_names : list of str + The names of the channels that were selected. + + Attributes + ---------- + %(ui_event_name_source)s + ch_names : list of str + The names of the channels that were selected. + """ + + ch_names: List[str] + + def _get_event_channel(fig): """Get the event channel associated with a figure. diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 264505b67ad..75ede905d96 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -64,6 +64,7 @@ _check_decim, ) from ..transforms import apply_trans +from .ui_events import publish, subscribe, ChannelsSelect _channel_type_prettyprint = { @@ -1044,7 +1045,7 @@ def plot_sensors( Whether to plot the sensors as 3d, topomap or as an interactive sensor selection dialog. Available options ``'topomap'``, ``'3d'``, ``'select'``. If ``'select'``, a set of channels can be selected - interactively by using lasso selector or clicking while holding control + interactively by using lasso selector or clicking while holding the shift key. The selected channels are returned along with the figure instance. Defaults to ``'topomap'``. ch_type : None | str @@ -1255,7 +1256,7 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names): if event.mouseevent.inaxes != ax: return - if event.mouseevent.key == "control" and fig.lasso is not None: + if event.mouseevent.key in ["shift", "alt"] and fig.lasso is not None: for ind in event.ind: fig.lasso.select_one(ind) @@ -1360,7 +1361,7 @@ def _plot_sensors( lw=linewidth, ) if kind == "select": - fig.lasso = SelectFromCollection(ax, pts, ch_names) + fig.lasso = SelectFromCollection(ax, pts, names=ch_names) else: fig.lasso = None @@ -1693,11 +1694,11 @@ def _draw_without_rendering(cbar): class SelectFromCollection: - """Select channels from a matplotlib collection using ``LassoSelector``. + """Select objects from a matplotlib collection using ``LassoSelector``. - Selected channels are saved in the ``selection`` attribute. This tool - highlights selected points by fading other points out (i.e., reducing their - alpha values). + The names of the selected objects are saved in the ``selection`` attribute. + This tool highlights selected objects by fading other objects out (i.e., + reducing their alpha values). Parameters ---------- @@ -1705,60 +1706,83 @@ class SelectFromCollection: Axes to interact with. collection : instance of matplotlib collection Collection you want to select from. - alpha_other : 0 <= float <= 1 - To highlight a selection, this tool sets all selected points to an - alpha value of 1 and non-selected points to ``alpha_other``. - Defaults to 0.3. - linewidth_other : float - Linewidth to use for non-selected sensors. Default is 1. + names : list of str + The names of the object. The selection is returned as a subset of these names. + alpha_selected : float + Alpha for selected objects (0=tranparant, 1=opaque). + alpha_nonselected : float + Alpha for non-selected objects (0=tranparant, 1=opaque). + linewidth_selected : float + Linewidth for the borders of selected objects. + linewidth_nonselected : float + Linewidth for the borders of non-selected objects. Notes ----- - This tool selects collection objects based on their *origins* - (i.e., ``offsets``). Calls all callbacks in self.callbacks when selection - is ready. + This tool selects collection objects which bounding boxes intersect with a lasso + path. Calls all callbacks in self.callbacks when selection is ready. """ def __init__( self, ax, collection, - ch_names, - alpha_other=0.5, - linewidth_other=0.5, + *, + names, alpha_selected=1, + alpha_nonselected=0.5, linewidth_selected=1, + linewidth_nonselected=0.5, ): from matplotlib.widgets import LassoSelector + self.fig = ax.figure self.canvas = ax.figure.canvas self.collection = collection - self.ch_names = ch_names - self.alpha_other = alpha_other - self.linewidth_other = linewidth_other + self.names = names self.alpha_selected = alpha_selected + self.alpha_nonselected = alpha_nonselected self.linewidth_selected = linewidth_selected + self.linewidth_nonselected = linewidth_nonselected + + from matplotlib.collections import PolyCollection + from matplotlib.path import Path - self.xys = collection.get_offsets() - self.Npts = len(self.xys) + if isinstance(collection, PolyCollection): + self.paths = collection.get_paths() + else: + self.paths = [Path([point]) for point in collection.get_offsets()] + self.Npts = len(self.paths) + if self.Npts != len(names): + raise ValueError( + f"Number of names ({len(names)}) does not match the number of objects " + f"in the collection ({self.Npts})." + ) - # Ensure that we have separate colors for each object + # Ensure that we have colors for each object. self.fc = collection.get_facecolors() self.ec = collection.get_edgecolors() - self.lw = collection.get_linewidths() if len(self.fc) == 0: raise ValueError("Collection must have a facecolor") elif len(self.fc) == 1: self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1) + if len(self.ec) == 0: + self.ec = np.zeros((self.Npts, 4)) # all black + elif len(self.ec) == 1: self.ec = np.tile(self.ec, self.Npts).reshape(self.Npts, -1) - self.fc[:, -1] = self.alpha_other # deselect in the beginning - self.ec[:, -1] = self.alpha_other - self.lw = np.full(self.Npts, self.linewidth_other) + self.lw = np.full(self.Npts, float(self.linewidth_nonselected)) + # Initialize the lasso selector line_kw = _prop_kw("line", dict(color="red", linewidth=0.5)) self.lasso = LassoSelector(ax, onselect=self.on_select, **line_kw) self.selection = list() - self.callbacks = list() + self.selection_inds = np.array([], dtype="int") + + # Deselect everything in the beginning. + self.style_objects([]) + + # Respond to UI-Events + subscribe(self.fig, "channels_select", self._on_channels_select) def on_select(self, verts): """Select a subset from the collection.""" @@ -1768,44 +1792,45 @@ def on_select(self, verts): return path = Path(verts) - inds = np.nonzero([path.contains_point(xy) for xy in self.xys])[0] - if self.canvas._key == "control": # Appending selection. - sels = [np.where(self.ch_names == c)[0][0] for c in self.selection] - inters = set(inds) - set(sels) - inds = list(inters.union(set(sels) - set(inds))) + inds = np.nonzero([path.intersects_path(p) for p in self.paths])[0] + if self.canvas._key == "shift": # Appending selection. + self.selection_inds = np.union1d(self.selection_inds, inds) + elif self.canvas._key == "alt": # Removing selection. + self.selection_inds = np.setdiff1d(self.selection_inds, inds) + else: + self.selection_inds = inds + ch_names = [self.names[i] for i in self.selection_inds] + publish(self.fig, ChannelsSelect(ch_names=ch_names)) - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) - self.notify() + def _on_channels_select(self, event): + ch_inds = {name: i for i, name in enumerate(self.names)} + self.selection = [name for name in event.ch_names if name in ch_inds] + self.selection_inds = [ch_inds[name] for name in self.selection] + self.style_objects(self.selection_inds) def select_one(self, ind): """Select or deselect one sensor.""" - ch_name = self.ch_names[ind] - if ch_name in self.selection: - sel_ind = self.selection.index(ch_name) - self.selection.pop(sel_ind) + if self.canvas._key == "shift": + self.selection_inds = np.union1d(self.selection_inds, [ind]) + elif self.canvas._key == "alt": + self.selection_inds = np.setdiff1d(self.selection_inds, [ind]) else: - self.selection.append(ch_name) - inds = np.isin(self.ch_names, self.selection).nonzero()[0] - self.style_sensors(inds) - self.notify() - - def notify(self): - """Notify listeners that a selection has been made.""" - for callback in self.callbacks: - callback() + return # don't notify() + ch_names = [self.names[i] for i in self.selection_inds] + publish(self.fig, ChannelsSelect(ch_names=ch_names)) def select_many(self, inds): """Select many sensors using indices (for predefined selections).""" - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selected_inds = inds + ch_names = [self.names[i] for i in self.selection_inds] + publish(self.fig, ChannelsSelect(ch_names=ch_names)) - def style_sensors(self, inds): + def style_objects(self, inds): """Style selected sensors as "active".""" # reset - self.fc[:, -1] = self.alpha_other - self.ec[:, -1] = self.alpha_other / 2 - self.lw[:] = self.linewidth_other + self.fc[:, -1] = self.alpha_nonselected + self.ec[:, -1] = self.alpha_nonselected / 2 + self.lw[:] = self.linewidth_nonselected # style sensors at `inds` self.fc[inds, -1] = self.alpha_selected self.ec[inds, -1] = self.alpha_selected From 3b37ea3bdcec432e14cb8ee358b0198f76c128b9 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 4 Oct 2023 20:18:53 +0300 Subject: [PATCH 07/65] Add slider to control contour line thickness --- mne/viz/evoked_field.py | 33 +++++++++++++++++++++++++++++++++ mne/viz/ui_events.py | 7 +++++++ 2 files changed, 40 insertions(+) diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index 5a7dc7cb939..55afddbd3ec 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -65,6 +65,10 @@ class EvokedField: The number of contours. .. versionadded:: 0.21 + contour_line_width : float + The line_width of the contour lines. + + .. versionadded:: 1.6 show_density : bool Whether to draw the field density as an overlay on top of the helmet/head surface. Defaults to ``True``. @@ -119,6 +123,7 @@ def __init__( fig=None, vmax=None, n_contours=21, + contour_line_width=1.0, show_density=True, alpha=None, interpolation="nearest", @@ -143,6 +148,7 @@ def __init__( self._vmax = _validate_type(vmax, (None, "numeric", dict), "vmax") self._n_contours = _ensure_int(n_contours, "n_contours") + self._contour_line_width = contour_line_width self._time_interpolation = _check_option( "interpolation", interpolation, @@ -370,6 +376,7 @@ def _update(self): vmin=-surf_map["map_vmax"], vmax=surf_map["map_vmax"], colormap=self._colormap_lines, + width=self._contour_line_width, ) if self._time_label is not None: if hasattr(self, "_time_label_actor"): @@ -461,6 +468,16 @@ def _callback(vmax, type, scaling): callback=self.set_contours, layout=layout, ) + + self._widgets["contours_line_width"] = r._dock_add_slider( + name="Thickness", + value=1, + rng=[0, 10], + callback=self.set_contour_line_width, + double=True, + layout=layout, + ) + r._dock_finalize() def _on_time_change(self, event): @@ -522,9 +539,13 @@ def _on_contours(self, event): break surf_map["contours"] = event.contours self._n_contours = len(event.contours) + if event.line_width is not None: + self._contour_line_width = event.line_width with disable_ui_events(self): if "contours" in self._widgets: self._widgets["contours"].set_value(len(event.contours)) + if "contour_line_width" in self._widgets and event.line_width is not None: + self._widgets["contour_line_width"].set_value(event.line_width) self._update() def set_time(self, time): @@ -559,6 +580,7 @@ def set_contours(self, n_contours): contours=np.linspace( -surf_map["map_vmax"], surf_map["map_vmax"], n_contours ).tolist(), + line_width=self._contour_line_width, ), ) @@ -593,3 +615,14 @@ def _rescale(self): current_data = surf_map["data_interp"](self._current_time) vmax = float(np.max(current_data)) self.set_vmax(vmax, type=surf_map["map_kind"]) + + def set_contour_line_width(self, line_width): + """Set the line_width of the contour lines. + + Parameters + ---------- + line_width : float + The desired line_width of the contour lines. + """ + self._contour_line_width = line_width + self.set_contours(self._n_contours) diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index ba5b1db9a33..932355076b3 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -176,6 +176,9 @@ class Contours(UIEvent): kinds. contours : list of float The new values at which contour lines need to be drawn. + line_width : float | None + The line_width with which to draw the contour lines. Can be ``None`` to + indicate to keep using the current line_width. Attributes ---------- @@ -186,10 +189,14 @@ class Contours(UIEvent): kinds. contours : list of float The new values at which contour lines need to be drawn. + line_width : float | None + The line_width with which to draw the contour lines. Can be ``None`` to + indicate to keep using the current line_width. """ kind: str contours: List[str] + line_width: Optional[float] def _get_event_channel(fig): From 826831567cbd3779d96c50777b8195fd99a39fe1 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Thu, 5 Oct 2023 11:34:21 +0300 Subject: [PATCH 08/65] fix --- mne/viz/_3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 9742d6c3b30..ce99f2e6352 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -1512,7 +1512,7 @@ def _plot_sensors( actor, _ = renderer.surface( surface=surface, color=color[:3], - opacity=0.25 * color[3], + opacity=sensor_opacity * color[3], backface_culling=False, # visible from all sides ) actors[ch_type].append(actor) From 7dd77d9ccd538b7c1bff2ef7137ed59c26294642 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Fri, 6 Oct 2023 15:21:06 +0300 Subject: [PATCH 09/65] Some initial stuff --- mne/dipole.py | 2 +- mne/gui/_xfit.py | 339 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 340 insertions(+), 1 deletion(-) create mode 100644 mne/gui/_xfit.py diff --git a/mne/dipole.py b/mne/dipole.py index a71bbc590b9..7d32e903502 100644 --- a/mne/dipole.py +++ b/mne/dipole.py @@ -1504,7 +1504,7 @@ def fit_dipole( if not bem["is_sphere"]: # Find the best-fitting sphere inner_skull = _bem_find_surface(bem, "inner_skull") - inner_skull = inner_skull.copy() + inner_skull = deepcopy(inner_skull) R, r0 = _fit_sphere(inner_skull["rr"], disp=False) # r0 back to head frame for logging r0 = apply_trans(mri_head_t["trans"], r0[np.newaxis, :])[0] diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py new file mode 100644 index 00000000000..f528c44acf3 --- /dev/null +++ b/mne/gui/_xfit.py @@ -0,0 +1,339 @@ +import mne +import pyvista +import numpy as np +from mne.transforms import _get_trans, _get_transforms_to_coord_frame +from functools import partial + +data_path = mne.datasets.sample.data_path() +evoked = mne.read_evokeds( + f"{data_path}/MEG/sample/sample_audvis-ave.fif", condition="Left Auditory" +) +trans = mne.read_trans(f"{data_path}/MEG/sample/sample_audvis_raw-trans.fif") +head_mri_t = _get_trans(trans, "head", "mri")[0] +to_cf_t = _get_transforms_to_coord_frame(evoked.info, head_mri_t, coord_frame="head") + +evoked.apply_baseline((-0.2, 0)) +field_map = mne.make_field_map(evoked, trans=None) +# cov = mne.read_cov( +# f"{data_path}/MEG/sample/sample_audvis-cov.fif" +# ) +# bem = mne.read_bem_solution( +# f"{data_path}/subjects/sample/bem/sample-5120-5120-5120-bem-sol.fif" +# ) +cov = mne.make_ad_hoc_cov(evoked.info) +bem = mne.make_sphere_model("auto", "auto", evoked.info) + +fig = mne.viz.create_3d_figure((1500, 1500), bgcolor="white", show=True) +fig = mne.viz.plot_alignment( + evoked.info, + surfaces=dict(seghead=0.2, pial=0.5), + meg=False, + eeg=False, + subject="sample", + subjects_dir=data_path / "subjects", + trans=trans, + coord_frame="head", + fig=fig, +) +fig = mne.viz.EvokedField( + evoked, + field_map, + time=0.096, + interpolation="linear", + alpha=0, + show_density=False, + foreground="black", + fig=fig, +) +renderer = fig._renderer +sensor_actors = mne.viz._3d._plot_sensors( + renderer=renderer, + info=evoked.info, + to_cf_t=to_cf_t, + picks=mne.pick_types(evoked.info, meg=True), + meg=True, + eeg=False, + fnirs=False, + warn_meg=False, + head_surf=None, + units="m", + sensor_opacity=0.1, + orient_glyphs=False, + scale_by_distance=False, + project_points=False, + surf=None, + check_inside=None, + nearest=None, + sensor_colors="black", +)["meg"] +fig.set_contour_line_width(2) +fig.separate_canvas = False + +helmet = fig._surf_maps[0]["mesh"]._polydata +helmet.compute_normals(inplace=True) + +fig_sensors = list() +dips = list() +dip_timecourses = list() +dip_lines = list() +dipole_actors = list() +green_arrows = list() +green_arrow_coords = list() +green_arrow_pos = list() +green_arrow_actors = list() +colors = mne.viz.utils._get_color_list() +time_line = list() +dipole_box = None + +vertices = np.array( + [ + [0.0, 1.0, 0.0], + [0.3, 0.7, 0.0], + [0.1, 0.7, 0.0], + [0.1, -1.0, 0.0], + [-0.1, -1.0, 0.0], + [-0.1, 0.7, 0.0], + [-0.3, 0.7, 0.0], + ] +) +faces = np.array([[7, 0, 1, 2, 3, 4, 5, 6]]) + + +def setup_mplcanvas(): + if renderer._mplcanvas is None: + renderer._mplcanvas = renderer._window_get_mplcanvas(fig, 0.5, False, False) + renderer._window_adjust_mplcanvas_layout() + if len(time_line) == 0: + time_line.append( + renderer._mplcanvas.plot_time_line( + fig._current_time, + label="time", + color=fig._fg_color, + ) + ) + return renderer._mplcanvas + + +def show_dipole(show, dip_num): + show = bool(show) + if dip_num >= len(dips): + return + dip_lines[dip_num].set_visible(show) + green_arrow_actors[dip_num].visibility = show + renderer._update() + renderer._mplcanvas.update_plot() + + +def on_fit_dipole(): + print("Fitting dipole...") + evoked_picked = evoked + cov_picked = cov + if len(fig_sensors) > 0: + picks = fig_sensors[0].lasso.selection + if len(picks) > 0: + evoked_picked = evoked.copy().pick(picks) + evoked_picked.info.normalize_proj() + cov_picked = cov.copy().pick_channels(picks, ordered=False) + cov_picked["projs"] = evoked_picked.info["projs"] + + dip = mne.fit_dipole( + evoked_picked.copy().crop(fig._current_time, fig._current_time), + cov_picked, + bem, + trans=trans, + min_dist=0, + verbose=False, + )[0] + dips.append(dip) + dip_num = len(dips) - 1 + renderer.plotter.add_arrows(dip.pos, dip.ori, color=colors[dip_num], mag=0.05) + dip_timecourse = mne.fit_dipole( + evoked_picked, + cov_picked, + bem, + pos=dip.pos[0], + ori=dip.ori[0], + trans=trans, + verbose=False, + )[0].data[0] + dip_timecourses.append(dip_timecourse) + draw_arrow(dip, dip_timecourse, color=colors[dip_num]) + + canvas = setup_mplcanvas() + dip_lines.append( + canvas.plot( + evoked.times, dip_timecourse, label=f"dip{dip_num}", color=colors[dip_num] + ) + ) + renderer._dock_add_check_box( + name=f"dip{dip_num}", + value=True, + callback=partial(show_dipole, dip_num=dip_num), + layout=dipole_box, + ) + + +def on_channels_select(event): + selected_channels = set(event.ch_names) + for act, ch_name in zip(sensor_actors, evoked.ch_names): + if ch_name in selected_channels: + act.prop.SetColor(0, 1, 0) + act.prop.SetOpacity(0.5) + else: + act.prop.SetColor(0, 0, 0) + act.prop.SetOpacity(0.1) + renderer._update() + + +def on_sensor_data(): + fig = evoked.plot_topo() + mne.viz.ui_events.subscribe(fig, "channels_select", on_channels_select) + fig_sensors[:] = [fig] + + +def on_time_change(event): + new_time = (np.clip(event.time, evoked.times[0], evoked.times[-1]),) + for i in range(len(green_arrows)): + arrow = green_arrows[i] + arrow_coords = green_arrow_coords[i] + arrow_position = green_arrow_pos[i] + dip_timecourse = dip_timecourses[i] + scaling = ( + np.interp( + new_time, + evoked.times, + dip_timecourse, + ) + * 1e6 + ) + arrow.points = (vertices * scaling) @ arrow_coords + arrow_position + if len(time_line) > 0: + time_line[0].set_xdata([new_time]) + renderer._mplcanvas.update_plot() + renderer._update() + + +def draw_arrow(dip, dip_timecourse, color): + dip_position = dip.pos[0] + + # Get the closest vertex (=point) of the helmet mesh + distances = ((helmet.points - dip_position) * helmet.point_normals).sum(axis=1) + closest_point = np.argmin(distances) + + # Compute the position of the projected dipole + norm = helmet.point_normals[closest_point] + arrow_position = dip_position + (distances[closest_point] + 0.003) * norm + + # Create a cartesian coordinate system where X and Y are tangential to the helmet + tan_coords = mne.surface._normal_orth(norm) + + # Project the orientation of the dipole tangential to the helmet + dip_ori_tan = tan_coords[:2] @ dip.ori[0] @ tan_coords[:2] + + # Rotate the coordinate system such that Y lies along the dipole orientation + arrow_coords = np.array([np.cross(dip_ori_tan, norm), dip_ori_tan, norm]) + arrow_coords /= np.linalg.norm(arrow_coords, axis=1, keepdims=True) + + # Place the arrow inside the new coordinate system + scaling = np.interp(fig._current_time, evoked.times, dip_timecourse) * 1e6 + arrow = pyvista.PolyData( + (vertices * scaling) @ arrow_coords + arrow_position, faces + ) + green_arrows.append(arrow) + green_arrow_coords.append(arrow_coords) + green_arrow_pos.append(arrow_position) + + # Render the arrow + green_arrow_actors.append(renderer.plotter.add_mesh(arrow, color=color)) + + +def set_view(view): + kwargs = dict() + if view == 1: + kwargs = dict(azimuth=-135, roll=45, elevation=60, distance="auto") + elif view == 2: + kwargs = dict(azimuth=270, roll=180, elevation=90, distance="auto") + elif view == 3: + kwargs = dict(azimuth=-45, roll=-45, elevation=60, distance="auto") + elif view == 4: + kwargs = dict(azimuth=180, roll=90, elevation=90, distance="auto") + elif view == 5: + kwargs = dict(azimuth=0, roll=0, elevation=0, distance="auto") + elif view == 6: + kwargs = dict(azimuth=0, roll=-90, elevation=90, distance="auto") + elif view == 7: + kwargs = dict(azimuth=135, roll=90, elevation=60, distance="auto") + elif view == 8: + kwargs = dict(azimuth=90, roll=0, elevation=90, distance="auto") + elif view == 9: + kwargs = dict(azimuth=45, roll=-90, elevation=60, distance="auto") + renderer.set_camera(**kwargs) + + +def add_view_buttons(r): + layout = r._dock_add_group_box("Views") + + hlayout = r._dock_add_layout(vertical=False) + r._dock_add_button( + "๐Ÿข†", callback=partial(set_view, view=7), layout=hlayout, style="pushbutton" + ) + r._dock_add_button( + "๐Ÿขƒ", callback=partial(set_view, view=8), layout=hlayout, style="pushbutton" + ) + r._dock_add_button( + "๐Ÿข‡", callback=partial(set_view, view=9), layout=hlayout, style="pushbutton" + ) + r._layout_add_widget(layout, hlayout) + + hlayout = r._dock_add_layout(vertical=False) + r._dock_add_button( + "๐Ÿข‚", callback=partial(set_view, view=4), layout=hlayout, style="pushbutton" + ) + r._dock_add_button( + "โŠ™", callback=partial(set_view, view=5), layout=hlayout, style="pushbutton" + ) + r._dock_add_button( + "๐Ÿข€", callback=partial(set_view, view=6), layout=hlayout, style="pushbutton" + ) + + r._layout_add_widget(layout, hlayout) + hlayout = r._dock_add_layout(vertical=False) + r._dock_add_button( + "๐Ÿข…", callback=partial(set_view, view=1), layout=hlayout, style="pushbutton" + ) + r._dock_add_button( + "๐Ÿข", callback=partial(set_view, view=2), layout=hlayout, style="pushbutton" + ) + r._dock_add_button( + "๐Ÿข„", callback=partial(set_view, view=3), layout=hlayout, style="pushbutton" + ) + r._layout_add_widget(layout, hlayout) + + r.plotter.add_key_event("1", partial(set_view, view=1)) + r.plotter.add_key_event("2", partial(set_view, view=2)) + r.plotter.add_key_event("3", partial(set_view, view=3)) + r.plotter.add_key_event("4", partial(set_view, view=4)) + r.plotter.add_key_event("5", partial(set_view, view=5)) + r.plotter.add_key_event("6", partial(set_view, view=6)) + r.plotter.add_key_event("7", partial(set_view, view=7)) + r.plotter.add_key_event("8", partial(set_view, view=8)) + r.plotter.add_key_event("9", partial(set_view, view=9)) + + +add_view_buttons(renderer) +renderer._dock_initialize(name="Dipole fitting", area="right") +renderer._dock_add_button("Sensor data", on_sensor_data) +renderer._dock_add_button("Fit dipole", on_fit_dipole) + +dipole_box = renderer._dock_add_group_box(name="Dipoles") +renderer._dock_add_stretch() + +renderer.set_camera(focalpoint=mne.bem.fit_sphere_to_headshape(evoked.info)[1]) +mne.viz.ui_events.subscribe(fig, "time_change", on_time_change) + + +# gfp_ax = canvas.fig.axes[0].twinx() +# gfp_ax.plot( +# evoked.times, np.mean(fig._surf_maps[0]['data'] ** 2, axis=0), color='maroon', +# ) +# canvas.update_plot() From 8ef1179c599912b01c19b5fcb70b1b79122aae12 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 18 Oct 2023 15:02:19 +0300 Subject: [PATCH 10/65] Small steps --- mne/gui/__init__.pyi | 3 +- mne/gui/_coreg.py | 1 - mne/gui/_xfit.py | 713 +++++++++++++++++++++++-------------------- mne/viz/_3d.py | 1 + 4 files changed, 391 insertions(+), 327 deletions(-) diff --git a/mne/gui/__init__.pyi b/mne/gui/__init__.pyi index 086c51a4904..8f66ad387eb 100644 --- a/mne/gui/__init__.pyi +++ b/mne/gui/__init__.pyi @@ -1,2 +1,3 @@ -__all__ = ["_GUIScraper", "coregistration"] +__all__ = ["_GUIScraper", "coregistration", "DipoleFitUI"] from ._gui import _GUIScraper, coregistration +from ._xfit import DipoleFitUI diff --git a/mne/gui/_coreg.py b/mne/gui/_coreg.py index 1e3cf69c9f2..e07848db693 100644 --- a/mne/gui/_coreg.py +++ b/mne/gui/_coreg.py @@ -864,7 +864,6 @@ def _redraw(self, *, verbose=None): mri_fids=self._add_mri_fiducials, hsp=self._add_head_shape_points, hpi=self._add_hpi_coils, - eeg=self._add_eeg_fnirs_channels, sensors=self._add_channels, head_fids=self._add_head_fiducials, helmet=self._add_helmet, diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index f528c44acf3..d6795ba6382 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -1,339 +1,402 @@ -import mne -import pyvista -import numpy as np -from mne.transforms import _get_trans, _get_transforms_to_coord_frame from functools import partial -data_path = mne.datasets.sample.data_path() -evoked = mne.read_evokeds( - f"{data_path}/MEG/sample/sample_audvis-ave.fif", condition="Left Auditory" -) -trans = mne.read_trans(f"{data_path}/MEG/sample/sample_audvis_raw-trans.fif") -head_mri_t = _get_trans(trans, "head", "mri")[0] -to_cf_t = _get_transforms_to_coord_frame(evoked.info, head_mri_t, coord_frame="head") - -evoked.apply_baseline((-0.2, 0)) -field_map = mne.make_field_map(evoked, trans=None) -# cov = mne.read_cov( -# f"{data_path}/MEG/sample/sample_audvis-cov.fif" -# ) -# bem = mne.read_bem_solution( -# f"{data_path}/subjects/sample/bem/sample-5120-5120-5120-bem-sol.fif" -# ) -cov = mne.make_ad_hoc_cov(evoked.info) -bem = mne.make_sphere_model("auto", "auto", evoked.info) - -fig = mne.viz.create_3d_figure((1500, 1500), bgcolor="white", show=True) -fig = mne.viz.plot_alignment( - evoked.info, - surfaces=dict(seghead=0.2, pial=0.5), - meg=False, - eeg=False, - subject="sample", - subjects_dir=data_path / "subjects", - trans=trans, - coord_frame="head", - fig=fig, +import numpy as np +import pyvista + +from .. import pick_types +from ..bem import ( + ConductorModel, + _ensure_bem_surfaces, + fit_sphere_to_headshape, + make_sphere_model, ) -fig = mne.viz.EvokedField( +from ..cov import make_ad_hoc_cov +from ..dipole import fit_dipole +from ..forward import make_field_map, make_forward_dipole +from ..minimum_norm import apply_inverse, make_inverse_operator +from ..transforms import _get_trans, _get_transforms_to_coord_frame +from ..utils import _check_option, fill_doc, verbose +from ..viz import EvokedField, create_3d_figure +from ..viz._3d import _plot_head_surface, _plot_sensors +from ..viz.ui_events import subscribe +from ..viz.utils import _get_color_list + + +@fill_doc +@verbose +def dipolefit( evoked, - field_map, - time=0.096, - interpolation="linear", - alpha=0, - show_density=False, - foreground="black", - fig=fig, -) -renderer = fig._renderer -sensor_actors = mne.viz._3d._plot_sensors( - renderer=renderer, - info=evoked.info, - to_cf_t=to_cf_t, - picks=mne.pick_types(evoked.info, meg=True), - meg=True, - eeg=False, - fnirs=False, - warn_meg=False, - head_surf=None, - units="m", - sensor_opacity=0.1, - orient_glyphs=False, - scale_by_distance=False, - project_points=False, - surf=None, - check_inside=None, - nearest=None, - sensor_colors="black", -)["meg"] -fig.set_contour_line_width(2) -fig.separate_canvas = False - -helmet = fig._surf_maps[0]["mesh"]._polydata -helmet.compute_normals(inplace=True) - -fig_sensors = list() -dips = list() -dip_timecourses = list() -dip_lines = list() -dipole_actors = list() -green_arrows = list() -green_arrow_coords = list() -green_arrow_pos = list() -green_arrow_actors = list() -colors = mne.viz.utils._get_color_list() -time_line = list() -dipole_box = None - -vertices = np.array( - [ - [0.0, 1.0, 0.0], - [0.3, 0.7, 0.0], - [0.1, 0.7, 0.0], - [0.1, -1.0, 0.0], - [-0.1, -1.0, 0.0], - [-0.1, 0.7, 0.0], - [-0.3, 0.7, 0.0], - ] -) -faces = np.array([[7, 0, 1, 2, 3, 4, 5, 6]]) - - -def setup_mplcanvas(): - if renderer._mplcanvas is None: - renderer._mplcanvas = renderer._window_get_mplcanvas(fig, 0.5, False, False) - renderer._window_adjust_mplcanvas_layout() - if len(time_line) == 0: - time_line.append( - renderer._mplcanvas.plot_time_line( - fig._current_time, - label="time", - color=fig._fg_color, - ) - ) - return renderer._mplcanvas - - -def show_dipole(show, dip_num): - show = bool(show) - if dip_num >= len(dips): - return - dip_lines[dip_num].set_visible(show) - green_arrow_actors[dip_num].visibility = show - renderer._update() - renderer._mplcanvas.update_plot() - - -def on_fit_dipole(): - print("Fitting dipole...") - evoked_picked = evoked - cov_picked = cov - if len(fig_sensors) > 0: - picks = fig_sensors[0].lasso.selection - if len(picks) > 0: - evoked_picked = evoked.copy().pick(picks) - evoked_picked.info.normalize_proj() - cov_picked = cov.copy().pick_channels(picks, ordered=False) - cov_picked["projs"] = evoked_picked.info["projs"] - - dip = mne.fit_dipole( - evoked_picked.copy().crop(fig._current_time, fig._current_time), - cov_picked, - bem, - trans=trans, - min_dist=0, - verbose=False, - )[0] - dips.append(dip) - dip_num = len(dips) - 1 - renderer.plotter.add_arrows(dip.pos, dip.ori, color=colors[dip_num], mag=0.05) - dip_timecourse = mne.fit_dipole( - evoked_picked, - cov_picked, - bem, - pos=dip.pos[0], - ori=dip.ori[0], + cov=None, + bem=None, + initial_time=None, + trans=None, + show_density=True, + subject=None, + subjects_dir=None, + n_jobs=None, + verbose=None, +): + """GUI for interactive dipole fitting, inspired by MEGIN's XFit program. + + Parameters + ---------- + evoked : instance of Evoked + Evoked data to show fieldmap of and fit dipoles to. + cov : instance of Covariance | None + Noise covariance matrix. If None, an ad-hoc covariance matrix is used. + bem : instance of ConductorModel | None + Boundary element model. If None, a spherical model is used. + initial_time : float | None + Initial time point to show. If None, the time point of the maximum + field strength is used. + trans : instance of Transform | None + The transformation from head coordinates to MRI coordinates. If None, + the identity matrix is used. + show_density : bool + Whether to show the density of the fieldmap. + subject : str | None + The subject name. If None, no MRI data is shown. + %(subjects_dir)s + %(n_jobs)s + %(verbose)s + """ + return DipoleFitUI( + evoked=evoked, + cov=cov, + bem=bem, + initial_time=initial_time, trans=trans, - verbose=False, - )[0].data[0] - dip_timecourses.append(dip_timecourse) - draw_arrow(dip, dip_timecourse, color=colors[dip_num]) - - canvas = setup_mplcanvas() - dip_lines.append( - canvas.plot( - evoked.times, dip_timecourse, label=f"dip{dip_num}", color=colors[dip_num] - ) - ) - renderer._dock_add_check_box( - name=f"dip{dip_num}", - value=True, - callback=partial(show_dipole, dip_num=dip_num), - layout=dipole_box, + show_density=show_density, + subject=subject, + subjects_dir=subjects_dir, + n_jobs=n_jobs, + verbose=verbose, ) -def on_channels_select(event): - selected_channels = set(event.ch_names) - for act, ch_name in zip(sensor_actors, evoked.ch_names): - if ch_name in selected_channels: - act.prop.SetColor(0, 1, 0) - act.prop.SetOpacity(0.5) - else: - act.prop.SetColor(0, 0, 0) - act.prop.SetOpacity(0.1) - renderer._update() - - -def on_sensor_data(): - fig = evoked.plot_topo() - mne.viz.ui_events.subscribe(fig, "channels_select", on_channels_select) - fig_sensors[:] = [fig] - - -def on_time_change(event): - new_time = (np.clip(event.time, evoked.times[0], evoked.times[-1]),) - for i in range(len(green_arrows)): - arrow = green_arrows[i] - arrow_coords = green_arrow_coords[i] - arrow_position = green_arrow_pos[i] - dip_timecourse = dip_timecourses[i] - scaling = ( - np.interp( - new_time, - evoked.times, - dip_timecourse, - ) - * 1e6 +@fill_doc +class DipoleFitUI: + """GUI for interactive dipole fitting, inspired by MEGIN's XFit program. + + Parameters + ---------- + evoked : instance of Evoked + Evoked data to show fieldmap of and fit dipoles to. + cov : instance of Covariance | None + Noise covariance matrix. If None, an ad-hoc covariance matrix is used. + bem : instance of ConductorModel | None + Boundary element model. If None, a spherical model is used. + initial_time : float | None + Initial time point to show. If None, the time point of the maximum + field strength is used. + trans : instance of Transform | None + The transformation from head coordinates to MRI coordinates. If None, + the identity matrix is used. + show_density : bool + Whether to show the density of the fieldmap. + subject : str | None + The subject name. If None, no MRI data is shown. + %(subjects_dir)s + %(n_jobs)s + %(verbose)s + """ + + def __init__( + self, + evoked, + cov=None, + bem=None, + initial_time=None, + trans=None, + show_density=True, + subject=None, + subjects_dir=None, + n_jobs=None, + verbose=None, + ): + field_map = make_field_map( + evoked, + ch_type="meg", + trans=trans, + subject=subject, + subjects_dir=subjects_dir, + n_jobs=n_jobs, + verbose=verbose, + ) + if cov is None: + cov = make_ad_hoc_cov(evoked.info) + if bem is None: + bem = make_sphere_model("auto", "auto", evoked.info) + bem = _ensure_bem_surfaces(bem, extra_allow=(ConductorModel, None)) + + if initial_time is None: + data = evoked.copy().pick(field_map[0]["ch_names"]).data + initial_time = evoked.times[np.argmax(np.mean(data**2, axis=0))] + + # Get transforms to convert all the various meshes to head space + head_mri_t = _get_trans(trans, "head", "mri")[0] + to_cf_t = _get_transforms_to_coord_frame( + evoked.info, head_mri_t, coord_frame="head" ) - arrow.points = (vertices * scaling) @ arrow_coords + arrow_position - if len(time_line) > 0: - time_line[0].set_xdata([new_time]) - renderer._mplcanvas.update_plot() - renderer._update() - - -def draw_arrow(dip, dip_timecourse, color): - dip_position = dip.pos[0] - - # Get the closest vertex (=point) of the helmet mesh - distances = ((helmet.points - dip_position) * helmet.point_normals).sum(axis=1) - closest_point = np.argmin(distances) - - # Compute the position of the projected dipole - norm = helmet.point_normals[closest_point] - arrow_position = dip_position + (distances[closest_point] + 0.003) * norm - - # Create a cartesian coordinate system where X and Y are tangential to the helmet - tan_coords = mne.surface._normal_orth(norm) - - # Project the orientation of the dipole tangential to the helmet - dip_ori_tan = tan_coords[:2] @ dip.ori[0] @ tan_coords[:2] - # Rotate the coordinate system such that Y lies along the dipole orientation - arrow_coords = np.array([np.cross(dip_ori_tan, norm), dip_ori_tan, norm]) - arrow_coords /= np.linalg.norm(arrow_coords, axis=1, keepdims=True) + self._actors = dict() + self._bem = bem + self._cov = cov + self._current_time = initial_time + self._dips = dict() + self._dips_active = set() + self._dips_colors = dict() + self._dips_timecourses = dict() + self._dips_lines = dict() + self._evoked = evoked + self._field_map = field_map + self._fig_sensors = None + self._n_jobs = n_jobs + self._show_density = show_density + self._subject = subject + self._subjects_dir = subjects_dir + self._to_cf_t = to_cf_t + self._trans = trans + self._verbose = verbose + + # Configure the GUI + self._renderer = self._configure_main_display() + self._configure_dock() + + def _configure_main_display(self): + """Configure main 3D display of the GUI.""" + fig = create_3d_figure((1900, 1020), bgcolor="white", show=True) + fig = EvokedField( + self._evoked, + self._field_map, + time=self._current_time, + interpolation="linear", + alpha=1, + show_density=self._show_density, + foreground="black", + fig=fig, + ) + fig.set_contour_line_width(2) + fig._renderer.set_camera( + focalpoint=fit_sphere_to_headshape(self._evoked.info)[1] + ) + self._actors["helmet"] = fig._surf_maps[0]["mesh"]._actor + + self._actors["sensors"] = _plot_sensors( + renderer=fig._renderer, + info=self._evoked.info, + to_cf_t=self._to_cf_t, + picks=pick_types(self._evoked.info, meg=True), + meg=True, + eeg=False, + fnirs=False, + warn_meg=False, + head_surf=None, + units="m", + sensor_opacity=0.1, + orient_glyphs=False, + scale_by_distance=False, + project_points=False, + surf=None, + check_inside=None, + nearest=None, + sensor_colors="black", + )["meg"] + + self._actors["head"], _, _ = _plot_head_surface( + renderer=fig._renderer, + head="head", + subject=self._subject, + subjects_dir=self._subjects_dir, + bem=self._bem, + coord_frame="head", + to_cf_t=self._to_cf_t, + alpha=1.0, + ) - # Place the arrow inside the new coordinate system - scaling = np.interp(fig._current_time, evoked.times, dip_timecourse) * 1e6 - arrow = pyvista.PolyData( - (vertices * scaling) @ arrow_coords + arrow_position, faces - ) - green_arrows.append(arrow) - green_arrow_coords.append(arrow_coords) - green_arrow_pos.append(arrow_position) - - # Render the arrow - green_arrow_actors.append(renderer.plotter.add_mesh(arrow, color=color)) - - -def set_view(view): - kwargs = dict() - if view == 1: - kwargs = dict(azimuth=-135, roll=45, elevation=60, distance="auto") - elif view == 2: - kwargs = dict(azimuth=270, roll=180, elevation=90, distance="auto") - elif view == 3: - kwargs = dict(azimuth=-45, roll=-45, elevation=60, distance="auto") - elif view == 4: - kwargs = dict(azimuth=180, roll=90, elevation=90, distance="auto") - elif view == 5: - kwargs = dict(azimuth=0, roll=0, elevation=0, distance="auto") - elif view == 6: - kwargs = dict(azimuth=0, roll=-90, elevation=90, distance="auto") - elif view == 7: - kwargs = dict(azimuth=135, roll=90, elevation=60, distance="auto") - elif view == 8: - kwargs = dict(azimuth=90, roll=0, elevation=90, distance="auto") - elif view == 9: - kwargs = dict(azimuth=45, roll=-90, elevation=60, distance="auto") - renderer.set_camera(**kwargs) - - -def add_view_buttons(r): - layout = r._dock_add_group_box("Views") - - hlayout = r._dock_add_layout(vertical=False) - r._dock_add_button( - "๐Ÿข†", callback=partial(set_view, view=7), layout=hlayout, style="pushbutton" - ) - r._dock_add_button( - "๐Ÿขƒ", callback=partial(set_view, view=8), layout=hlayout, style="pushbutton" - ) - r._dock_add_button( - "๐Ÿข‡", callback=partial(set_view, view=9), layout=hlayout, style="pushbutton" - ) - r._layout_add_widget(layout, hlayout) + self._fig = fig + return fig._renderer + + def _configure_dock(self): + """Configure the left and right dock areas of the GUI.""" + r = self._renderer + + # Toggle buttons for various meshes + layout = r._dock_add_group_box("Meshes") + for actor_name in self._actors.keys(): + r._dock_add_check_box( + name=actor_name, + value=True, + callback=partial(self.toggle_mesh, name=actor_name), + layout=layout, + ) - hlayout = r._dock_add_layout(vertical=False) - r._dock_add_button( - "๐Ÿข‚", callback=partial(set_view, view=4), layout=hlayout, style="pushbutton" - ) - r._dock_add_button( - "โŠ™", callback=partial(set_view, view=5), layout=hlayout, style="pushbutton" - ) - r._dock_add_button( - "๐Ÿข€", callback=partial(set_view, view=6), layout=hlayout, style="pushbutton" - ) + # Add view buttons + layout = r._dock_add_group_box("Views") + hlayout = None + views = zip( + [7, 8, 9, 4, 5, 6, 1, 2, 3], # numpad order + ["๐Ÿข†", "๐Ÿขƒ", "๐Ÿข‡", "๐Ÿข‚", "โŠ™", "๐Ÿข€", "๐Ÿข…", "๐Ÿข", "๐Ÿข„"], + ) + for i, (view, label) in enumerate(views): + if i % 3 == 0: # show in groups of 3 + hlayout = r._dock_add_layout(vertical=False) + r._layout_add_widget(layout, hlayout) + r._dock_add_button( + label, + callback=partial(self._set_view, view=view), + layout=hlayout, + style="pushbutton", + ) + r.plotter.add_key_event(str(view), partial(self._set_view, view=view)) + + # Right dock + r._dock_initialize(name="Dipole fitting", area="right") + r._dock_add_button("Sensor data", self._on_sensor_data) + r._dock_add_button("Fit dipole", self._on_fit_dipole) + r._dock_add_button("Fit multi-dipole", self._on_fit_multi) + r._dock_add_combo_box( + "Method", + value="MNE", + rng=["MNE", "Single-dipole", "LCMV"], + callback=self._on_select_method, + ) + self._dipole_box = r._dock_add_group_box(name="Dipoles") + r._dock_add_stretch() + + def toggle_mesh(self, name, show=None): + """Toggle a mesh on or off. + + Parameters + ---------- + name : "helmet" + Name of the mesh to toggle. + show : bool | None + Whether to show the mesh. If None, the visibility of the mesh is toggled. + """ + _check_option("name", name, self._actors.keys()) + actors = self._actors[name] + # self._actors[name] is sometimes a list and sometimes not. Make it + # always be a list to simplify the code. + if isinstance(actors, list): + actors = [actors] + if show is None: + show = not actors[0].GetVisibility() + for act in actors: + act.SetVisibility(show) + self._renderer._update() + + def _set_view(self, view): + kwargs = dict() + if view == 1: + kwargs = dict(azimuth=-135, roll=45, elevation=60, distance="auto") + elif view == 2: + kwargs = dict(azimuth=270, roll=180, elevation=90, distance="auto") + elif view == 3: + kwargs = dict(azimuth=-45, roll=-45, elevation=60, distance="auto") + elif view == 4: + kwargs = dict(azimuth=180, roll=90, elevation=90, distance="auto") + elif view == 5: + kwargs = dict(azimuth=0, roll=0, elevation=0, distance="auto") + elif view == 6: + kwargs = dict(azimuth=0, roll=-90, elevation=90, distance="auto") + elif view == 7: + kwargs = dict(azimuth=135, roll=90, elevation=60, distance="auto") + elif view == 8: + kwargs = dict(azimuth=90, roll=0, elevation=90, distance="auto") + elif view == 9: + kwargs = dict(azimuth=45, roll=-90, elevation=60, distance="auto") + self._renderer.set_camera(**kwargs) + + def _on_sensor_data(self): + """Show sensor data.""" + if self._fig_sensors is not None: + return + fig = self._evoked.plot_topo() + fig.canvas.mpl_connect("close_event", self._on_sensor_data_close) + subscribe(fig, "channels_select", self._on_channels_select) + self._fig_sensors = fig + + def _on_sensor_data_close(self, event): + self._fig_sensors = None + + def _on_channels_select(self, event): + """Show selected channels.""" + print(event) + + def _on_fit_dipole(self): + print("Fitting dipole...") + evoked_picked = self._evoked + cov_picked = self._cov + if self._fig_sensors is not None: + picks = self._fig_sensors[0].lasso.selection + if len(picks) > 0: + evoked_picked = evoked_picked.copy().pick(picks) + evoked_picked.info.normalize_proj() + cov_picked = cov_picked.copy().pick_channels(picks, ordered=False) + cov_picked["projs"] = evoked_picked.info["projs"] + evoked_picked.crop(self._current_time, self._current_time) + + dip = fit_dipole( + evoked_picked, + cov_picked, + self._bem, + trans=self._trans, + min_dist=0, + verbose=False, + )[0] + dip_name = f"dip{len(self._dips)}" + self._dips[dip_name] = dip + self._dips_active.add(dip_name) + colors = _get_color_list() + self._dips_colors[dip_name] = colors[(len(self._dips) - 1) % len(colors)] + + def _on_fit_multi(self): + print("Fitting dipoles", self._dips_active) + fwd, _ = make_forward_dipole( + [self._dips[d] for d in self._dips_active], self._bem, self._evoked.info + ) - r._layout_add_widget(layout, hlayout) - hlayout = r._dock_add_layout(vertical=False) - r._dock_add_button( - "๐Ÿข…", callback=partial(set_view, view=1), layout=hlayout, style="pushbutton" - ) - r._dock_add_button( - "๐Ÿข", callback=partial(set_view, view=2), layout=hlayout, style="pushbutton" - ) - r._dock_add_button( - "๐Ÿข„", callback=partial(set_view, view=3), layout=hlayout, style="pushbutton" + inv = make_inverse_operator( + self._evoked.info, fwd, self._cov, fixed=True, depth=0 + ) + stc = apply_inverse(self._evoked, inv, method="MNE", lambda2=0) + timecourses = stc.data + + canvas = self._setup_mplcanvas() + ymin, ymax = 0, 0 + for dip_name, timecourse in zip(self._dips_active, timecourses): + self._dip_timecourses[dip_name] = timecourse + if dip_name in self._dip_lines: + self._dip_lines[dip_name].set_ydata(timecourse) + else: + self._dip_lines[dip_name] = canvas.plot( + self._evoked.times, + timecourse, + label=dip_name, + color=self._dips_colors[dip_name], + ) + ymin = min(ymin, 1.1 * timecourse.min()) + ymax = max(ymax, 1.1 * timecourse.max()) + canvas.axes.set_ylim(ymin, ymax) + canvas.update_plot() + + def _on_select_method(self): + print("Select method") + + +def _arrow_mesh(): + """Obtain a PyVista mesh of an arrow.""" + vertices = np.array( + [ + [0.0, 1.0, 0.0], + [0.3, 0.7, 0.0], + [0.1, 0.7, 0.0], + [0.1, -1.0, 0.0], + [-0.1, -1.0, 0.0], + [-0.1, 0.7, 0.0], + [-0.3, 0.7, 0.0], + ] ) - r._layout_add_widget(layout, hlayout) - - r.plotter.add_key_event("1", partial(set_view, view=1)) - r.plotter.add_key_event("2", partial(set_view, view=2)) - r.plotter.add_key_event("3", partial(set_view, view=3)) - r.plotter.add_key_event("4", partial(set_view, view=4)) - r.plotter.add_key_event("5", partial(set_view, view=5)) - r.plotter.add_key_event("6", partial(set_view, view=6)) - r.plotter.add_key_event("7", partial(set_view, view=7)) - r.plotter.add_key_event("8", partial(set_view, view=8)) - r.plotter.add_key_event("9", partial(set_view, view=9)) - - -add_view_buttons(renderer) -renderer._dock_initialize(name="Dipole fitting", area="right") -renderer._dock_add_button("Sensor data", on_sensor_data) -renderer._dock_add_button("Fit dipole", on_fit_dipole) - -dipole_box = renderer._dock_add_group_box(name="Dipoles") -renderer._dock_add_stretch() - -renderer.set_camera(focalpoint=mne.bem.fit_sphere_to_headshape(evoked.info)[1]) -mne.viz.ui_events.subscribe(fig, "time_change", on_time_change) - - -# gfp_ax = canvas.fig.axes[0].twinx() -# gfp_ax.plot( -# evoked.times, np.mean(fig._surf_maps[0]['data'] ** 2, axis=0), color='maroon', -# ) -# canvas.update_plot() + faces = np.array([[7, 0, 1, 2, 3, 4, 5, 6]]) + return pyvista.PolyData(vertices, faces) diff --git a/mne/viz/_3d.py b/mne/viz/_3d.py index 0e186c3b01f..41ab452e5f2 100644 --- a/mne/viz/_3d.py +++ b/mne/viz/_3d.py @@ -818,6 +818,7 @@ def plot_alignment( renderer.set_interaction(interaction) # plot head + print(head, bem, subject) _, _, head_surf = _plot_head_surface( renderer, head, From 458c509e160c694e74264f6c2b994d113f0a85f9 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Thu, 19 Oct 2023 16:08:50 +0300 Subject: [PATCH 11/65] Continue work on the "fit dipole" button --- mne/gui/_xfit.py | 180 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 148 insertions(+), 32 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index d6795ba6382..e1bb96388cb 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -14,6 +14,7 @@ from ..dipole import fit_dipole from ..forward import make_field_map, make_forward_dipole from ..minimum_norm import apply_inverse, make_inverse_operator +from ..surface import _normal_orth from ..transforms import _get_trans, _get_transforms_to_coord_frame from ..utils import _check_option, fill_doc, verbose from ..viz import EvokedField, create_3d_figure @@ -140,21 +141,20 @@ def __init__( ) self._actors = dict() + self._arrows = list() self._bem = bem self._cov = cov self._current_time = initial_time - self._dips = dict() - self._dips_active = set() - self._dips_colors = dict() - self._dips_timecourses = dict() - self._dips_lines = dict() + self._dipoles = list() self._evoked = evoked self._field_map = field_map self._fig_sensors = None + self._multi_dipole_method = "MNE" self._n_jobs = n_jobs self._show_density = show_density - self._subject = subject self._subjects_dir = subjects_dir + self._subject = subject + self._time_line = None self._to_cf_t = to_cf_t self._trans = trans self._verbose = verbose @@ -165,7 +165,7 @@ def __init__( def _configure_main_display(self): """Configure main 3D display of the GUI.""" - fig = create_3d_figure((1900, 1020), bgcolor="white", show=True) + fig = create_3d_figure((1500, 1020), bgcolor="white", show=True) fig = EvokedField( self._evoked, self._field_map, @@ -174,13 +174,17 @@ def _configure_main_display(self): alpha=1, show_density=self._show_density, foreground="black", + background="white", fig=fig, ) + fig.separate_canvas = False # needed to plot the timeline later fig.set_contour_line_width(2) fig._renderer.set_camera( focalpoint=fit_sphere_to_headshape(self._evoked.info)[1] ) - self._actors["helmet"] = fig._surf_maps[0]["mesh"]._actor + helmet_mesh = fig._surf_maps[0]["mesh"] + helmet_mesh._polydata.compute_normals() # needed later + self._actors["helmet"] = helmet_mesh._actor self._actors["sensors"] = _plot_sensors( renderer=fig._renderer, @@ -200,7 +204,7 @@ def _configure_main_display(self): surf=None, check_inside=None, nearest=None, - sensor_colors="black", + sensor_colors="white", )["meg"] self._actors["head"], _, _ = _plot_head_surface( @@ -214,6 +218,7 @@ def _configure_main_display(self): alpha=1.0, ) + subscribe(fig, "time_change", self._on_time_change) self._fig = fig return fig._renderer @@ -269,7 +274,7 @@ def toggle_mesh(self, name, show=None): Parameters ---------- - name : "helmet" + name : str Name of the mesh to toggle. show : bool | None Whether to show the mesh. If None, the visibility of the mesh is toggled. @@ -278,7 +283,7 @@ def toggle_mesh(self, name, show=None): actors = self._actors[name] # self._actors[name] is sometimes a list and sometimes not. Make it # always be a list to simplify the code. - if isinstance(actors, list): + if not isinstance(actors, list): actors = [actors] if show is None: show = not actors[0].GetVisibility() @@ -308,8 +313,26 @@ def _set_view(self, view): kwargs = dict(azimuth=45, roll=-90, elevation=60, distance="auto") self._renderer.set_camera(**kwargs) + def _on_time_change(self, event): + new_time = (np.clip(event.time, self._evoked.times[0], self._evoked.times[-1]),) + # for i in range(len(green_arrows)): + # arrow = green_arrows[i] + # arrow_coords = green_arrow_coords[i] + # arrow_position = green_arrow_pos[i] + # dip_timecourse = dip_timecourses[i] + # scaling = np.interp( + # new_time, + # evoked.times, + # dip_timecourse, + # ) * (0.05 / np.max(np.abs(dip_timecourses))) + # arrow.points = (vertices * scaling) @ arrow_coords + arrow_position + if self._time_line is not None: + self._time_line.set_xdata([new_time]) + self._renderer._mplcanvas.update_plot() + self._renderer._update() + def _on_sensor_data(self): - """Show sensor data.""" + """Show sensor data and allow sensor selection.""" if self._fig_sensors is not None: return fig = self._evoked.plot_topo() @@ -318,18 +341,30 @@ def _on_sensor_data(self): self._fig_sensors = fig def _on_sensor_data_close(self, event): + """Handle closing of the sensor selection window.""" self._fig_sensors = None + if "sensors" in self._actors: + for act in self._actors["sensors"]: + act.prop.SetColor(1, 1, 1) + self._renderer._update() def _on_channels_select(self, event): - """Show selected channels.""" - print(event) + """Color selected sensor meshes.""" + selected_channels = set(event.ch_names) + if "sensors" in self._actors: + for act, ch_name in zip(self._actors["sensors"], self._evoked.ch_names): + if ch_name in selected_channels: + act.prop.SetColor(0, 1, 0) + else: + act.prop.SetColor(1, 1, 1) + self._renderer._update() def _on_fit_dipole(self): - print("Fitting dipole...") - evoked_picked = self._evoked + """Fit a single dipole.""" + evoked_picked = self._evoked.copy() cov_picked = self._cov if self._fig_sensors is not None: - picks = self._fig_sensors[0].lasso.selection + picks = self._fig_sensors.lasso.selection if len(picks) > 0: evoked_picked = evoked_picked.copy().pick(picks) evoked_picked.info.normalize_proj() @@ -345,16 +380,28 @@ def _on_fit_dipole(self): min_dist=0, verbose=False, )[0] - dip_name = f"dip{len(self._dips)}" - self._dips[dip_name] = dip - self._dips_active.add(dip_name) colors = _get_color_list() - self._dips_colors[dip_name] = colors[(len(self._dips) - 1) % len(colors)] + dip_num = len(self._dipoles) + dipole_dict = dict( + dip=dip, + num=dip_num, + name=f"dip{dip_num}", + active=True, + color=colors[dip_num % len(colors)], + ) + self._dipoles.append(dipole_dict) + + # Draw the arrow on the helmet + self._draw_arrow(dipole_dict) + + # Compute dipole timecourse + self._on_fit_multi() def _on_fit_multi(self): - print("Fitting dipoles", self._dips_active) + """Compute dipole timecourses using a multi-dipole model.""" + active_dips = [d for d in self._dipoles if d["active"]] fwd, _ = make_forward_dipole( - [self._dips[d] for d in self._dips_active], self._bem, self._evoked.info + [d["dip"] for d in active_dips], self._bem, self._evoked.info ) inv = make_inverse_operator( @@ -365,24 +412,93 @@ def _on_fit_multi(self): canvas = self._setup_mplcanvas() ymin, ymax = 0, 0 - for dip_name, timecourse in zip(self._dips_active, timecourses): - self._dip_timecourses[dip_name] = timecourse - if dip_name in self._dip_lines: - self._dip_lines[dip_name].set_ydata(timecourse) + for d, timecourse in zip(active_dips, timecourses): + d["timecourse"] = timecourse + if "line_artist" in d: + d["line_artist"].set_ydata(timecourse) else: - self._dip_lines[dip_name] = canvas.plot( + d["line_artist"] = canvas.plot( self._evoked.times, timecourse, - label=dip_name, - color=self._dips_colors[dip_name], + label=d["name"], + color=d["color"], ) ymin = min(ymin, 1.1 * timecourse.min()) ymax = max(ymax, 1.1 * timecourse.max()) canvas.axes.set_ylim(ymin, ymax) canvas.update_plot() - def _on_select_method(self): - print("Select method") + # Render the arrows in the correct size and orientation + active_arrows = [self._arrows[d["name"]] for d in active_dips] + arrow_scaling = 0.05 / np.max(np.abs(timecourses)) + for a, timecourse in zip(active_arrows, timecourses): + arrow = a["actor"].GetMapper().GetInput() + dip_moment = np.interp(self._current_time, self._evoked.times, timecourse) + arrow_size = dip_moment * arrow_scaling + arrow.points = (arrow.points * arrow_size) @ a["dip_coords"] + a["pos"] + arrow.SetVisibility(True) + self._renderer._update() + + def _on_select_method(self, method): + self._multi_dipole_method = method + + def _setup_mplcanvas(self): + """Configure the matplotlib canvas at the bottom of the window.""" + if self._renderer._mplcanvas is None: + self._renderer._mplcanvas = self._renderer._window_get_mplcanvas( + self._fig, 0.3, False, False + ) + self._renderer._window_adjust_mplcanvas_layout() + if self._time_line is None: + self._time_line = self._renderer._mplcanvas.plot_time_line( + self._current_time, + label="time", + color="black", + ) + return self._renderer._mplcanvas + + def _draw_arrow(self, dipole): + """Draw an arrow showing the dipole orientation tangential to the helmet.""" + dip_pos = dipole["dip"].pos[0] + dip_ori = dipole["dip"].ori[0] + + # Get the closest vertex (=point) of the helmet mesh + helmet = self._actors["helmet"].GetMapper().GetInput() + distances = ((helmet.points - dip_pos) * helmet.point_normals).sum(axis=1) + closest_point = np.argmin(distances) + + # Compute the position of the projected dipole + norm = helmet.point_normals[closest_point] + arrow_position = dip_pos + (distances[closest_point] + 0.003) * norm + + # Create a coordinate system where X and Y are tangential to the helmet + helmet_coords = _normal_orth(norm) + + # Project the orientation of the dipole tangential to the helmet + dip_ori_helmet = helmet_coords[:2] @ dip_ori @ helmet_coords[:2] + + # Rotate the coordinate system such that Y lies along the dipole orientation + dip_coords = np.array([np.cross(dip_ori_helmet, norm), dip_ori_helmet, norm]) + dip_coords /= np.linalg.norm(dip_coords, axis=1, keepdims=True) + + # Draw the arrow, and collect all relevant information in a dict + arrow = _arrow_mesh() + vertices, faces = arrow.points.copy(), arrow.faces.copy() + actor = self._renderer.plotter.add_mesh(arrow, color=dipole["color"]) + actor.SetVisibility(False) # hide for the moment + if "arrows" not in self._actors: + self._actors["arrows"] = [actor] + else: + self._actors["arrows"].append(actor) + return dict( + name=dipole["name"], + vertices=vertices, + faces=faces, + actor=actor, + pos=arrow_position, + helmet_coords=helmet_coords, + dip_coords=dip_coords, + ) def _arrow_mesh(): From 3c91d050bc195b5c6d2cb5a659440ace283d852b Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 31 Oct 2023 16:33:13 +0200 Subject: [PATCH 12/65] More progress --- mne/gui/_xfit.py | 365 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 256 insertions(+), 109 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index e1bb96388cb..f8ac13c9685 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -4,6 +4,7 @@ import pyvista from .. import pick_types +from ..beamformer import apply_lcmv, make_lcmv from ..bem import ( ConductorModel, _ensure_bem_surfaces, @@ -12,10 +13,14 @@ ) from ..cov import make_ad_hoc_cov from ..dipole import fit_dipole -from ..forward import make_field_map, make_forward_dipole +from ..forward import convert_forward_solution, make_field_map, make_forward_dipole from ..minimum_norm import apply_inverse, make_inverse_operator from ..surface import _normal_orth -from ..transforms import _get_trans, _get_transforms_to_coord_frame +from ..transforms import ( + _get_trans, + _get_transforms_to_coord_frame, + transform_surface_to, +) from ..utils import _check_option, fill_doc, verbose from ..viz import EvokedField, create_3d_figure from ..viz._3d import _plot_head_surface, _plot_sensors @@ -85,6 +90,8 @@ class DipoleFitUI: Evoked data to show fieldmap of and fit dipoles to. cov : instance of Covariance | None Noise covariance matrix. If None, an ad-hoc covariance matrix is used. + cov_data : instance of Covariance | None + Data covariance matrix. If None, LCMV method will be unavailable. bem : instance of ConductorModel | None Boundary element model. If None, a spherical model is used. initial_time : float | None @@ -106,29 +113,32 @@ def __init__( self, evoked, cov=None, + cov_data=None, bem=None, initial_time=None, trans=None, show_density=True, subject=None, subjects_dir=None, + ch_type=None, n_jobs=None, verbose=None, ): + if cov is None: + cov = make_ad_hoc_cov(evoked.info) + if bem is None: + bem = make_sphere_model("auto", "auto", evoked.info) + bem = _ensure_bem_surfaces(bem, extra_allow=(ConductorModel, None)) field_map = make_field_map( evoked, - ch_type="meg", + ch_type=ch_type, trans=trans, + origin=bem["r0"] if bem["is_sphere"] else "auto", subject=subject, subjects_dir=subjects_dir, n_jobs=n_jobs, verbose=verbose, ) - if cov is None: - cov = make_ad_hoc_cov(evoked.info) - if bem is None: - bem = make_sphere_model("auto", "auto", evoked.info) - bem = _ensure_bem_surfaces(bem, extra_allow=(ConductorModel, None)) if initial_time is None: data = evoked.copy().pick(field_map[0]["ch_names"]).data @@ -140,10 +150,19 @@ def __init__( evoked.info, head_mri_t, coord_frame="head" ) + # Transform the fieldmap surfaces to head space if needed + if trans is not None: + for fm in field_map: + fm["surf"] = transform_surface_to( + fm["surf"], "head", [to_cf_t["mri"], to_cf_t["head"]], copy=False + ) + self._actors = dict() self._arrows = list() self._bem = bem + self._ch_type = ch_type self._cov = cov + self._cov_data = cov_data self._current_time = initial_time self._dipoles = list() self._evoked = evoked @@ -171,7 +190,7 @@ def _configure_main_display(self): self._field_map, time=self._current_time, interpolation="linear", - alpha=1, + alpha=0, show_density=self._show_density, foreground="black", background="white", @@ -186,16 +205,41 @@ def _configure_main_display(self): helmet_mesh._polydata.compute_normals() # needed later self._actors["helmet"] = helmet_mesh._actor - self._actors["sensors"] = _plot_sensors( + show_meg = (self._ch_type is None or self._ch_type == "meg") and any( + [m["kind"] == "meg" for m in self._field_map] + ) + show_eeg = (self._ch_type is None or self._ch_type == "eeg") and any( + [m["kind"] == "eeg" for m in self._field_map] + ) + + print(f"{show_meg=} {show_eeg=}") + + for m in self._field_map: + if m["kind"] == "eeg": + head_surf = m["surf"] + break + else: + self._actors["head"], _, head_surf = _plot_head_surface( + renderer=fig._renderer, + head="head", + subject=self._subject, + subjects_dir=self._subjects_dir, + bem=self._bem, + coord_frame="head", + to_cf_t=self._to_cf_t, + alpha=0.2, + ) + + sensors = _plot_sensors( renderer=fig._renderer, info=self._evoked.info, to_cf_t=self._to_cf_t, - picks=pick_types(self._evoked.info, meg=True), - meg=True, - eeg=False, + picks=pick_types(self._evoked.info, meg=show_meg, eeg=show_eeg), + meg=show_meg, + eeg=["projected"] if show_eeg else False, fnirs=False, warn_meg=False, - head_surf=None, + head_surf=head_surf, units="m", sensor_opacity=0.1, orient_glyphs=False, @@ -204,19 +248,11 @@ def _configure_main_display(self): surf=None, check_inside=None, nearest=None, - sensor_colors="white", - )["meg"] - - self._actors["head"], _, _ = _plot_head_surface( - renderer=fig._renderer, - head="head", - subject=self._subject, - subjects_dir=self._subjects_dir, - bem=self._bem, - coord_frame="head", - to_cf_t=self._to_cf_t, - alpha=1.0, + sensor_colors=dict(meg="white", eeg="white"), ) + self._actors["sensors"] = list() + for s in sensors.values(): + self._actors["sensors"].extend(s) subscribe(fig, "time_change", self._on_time_change) self._fig = fig @@ -259,11 +295,13 @@ def _configure_dock(self): r._dock_initialize(name="Dipole fitting", area="right") r._dock_add_button("Sensor data", self._on_sensor_data) r._dock_add_button("Fit dipole", self._on_fit_dipole) - r._dock_add_button("Fit multi-dipole", self._on_fit_multi) + methods = ["MNE", "Single-dipole"] + if self._cov_data is not None: + methods.append("LCMV") r._dock_add_combo_box( - "Method", + "Dipole model", value="MNE", - rng=["MNE", "Single-dipole", "LCMV"], + rng=methods, callback=self._on_select_method, ) self._dipole_box = r._dock_add_group_box(name="Dipoles") @@ -314,22 +352,12 @@ def _set_view(self, view): self._renderer.set_camera(**kwargs) def _on_time_change(self, event): - new_time = (np.clip(event.time, self._evoked.times[0], self._evoked.times[-1]),) - # for i in range(len(green_arrows)): - # arrow = green_arrows[i] - # arrow_coords = green_arrow_coords[i] - # arrow_position = green_arrow_pos[i] - # dip_timecourse = dip_timecourses[i] - # scaling = np.interp( - # new_time, - # evoked.times, - # dip_timecourse, - # ) * (0.05 / np.max(np.abs(dip_timecourses))) - # arrow.points = (vertices * scaling) @ arrow_coords + arrow_position + new_time = np.clip(event.time, self._evoked.times[0], self._evoked.times[-1]) + self._current_time = new_time if self._time_line is not None: self._time_line.set_xdata([new_time]) self._renderer._mplcanvas.update_plot() - self._renderer._update() + self._update_arrows() def _on_sensor_data(self): """Show sensor data and allow sensor selection.""" @@ -380,39 +408,138 @@ def _on_fit_dipole(self): min_dist=0, verbose=False, )[0] + + # Coordinates needed to draw the big arrow on the helmet. + helmet_coords, helmet_pos = self._get_helmet_coords(dip) + + # Collect all relevant information on the dipole in a dict colors = _get_color_list() dip_num = len(self._dipoles) + dip_name = f"dip{dip_num}" + dip_color = colors[dip_num % len(colors)] + arrow_mesh = pyvista.PolyData(*_arrow_mesh()) dipole_dict = dict( + active=True, + arrow_actor=None, + arrow_mesh=arrow_mesh, + color=dip_color, dip=dip, + helmet_coords=helmet_coords, + helmet_pos=helmet_pos, + name=dip_name, num=dip_num, - name=f"dip{dip_num}", - active=True, - color=colors[dip_num % len(colors)], ) self._dipoles.append(dipole_dict) - # Draw the arrow on the helmet - self._draw_arrow(dipole_dict) + # Add a row to the dipole list + r = self._renderer + hlayout = r._dock_add_layout(vertical=False) + r._dock_add_check_box( + name=dip_name, + value=True, + callback=partial(self._on_dipole_toggle, dip_name=dip_name), + layout=hlayout, + ) + r._dock_add_check_box( + name="Fix pos", + value=True, + callback=partial(self._on_dipole_toggle_fix_position, dip_name=dip_name), + layout=hlayout, + ) + r._dock_add_check_box( + name="Fix ori", + value=True, + callback=partial(self._on_dipole_toggle_fix_orientation, dip_name=dip_name), + layout=hlayout, + ) + r._layout_add_widget(self._dipole_box, hlayout) + + # Compute dipole timecourse, update arrow size + self._fit_timecourses() + + # Show the dipole and arrow in the 3D view + self._renderer.plotter.add_arrows( + dip.pos[0], dip.ori[0], color=dip_color, mag=0.05 + ) + dipole_dict["arrow_actor"] = self._renderer.plotter.add_mesh( + arrow_mesh, color=dip_color + ) + + def _get_helmet_coords(self, dip): + """Compute the coordinate system used for drawing the big arrows on the helmet. - # Compute dipole timecourse - self._on_fit_multi() + In this coordinate system, Z is normal to the helmet surface, and XY + are tangential to the helmet surface. + """ + dip_pos = dip.pos[0] + + # Get the closest vertex (=point) of the helmet mesh + helmet = self._actors["helmet"].GetMapper().GetInput() + distances = ((helmet.points - dip_pos) * helmet.point_normals).sum(axis=1) + closest_point = np.argmin(distances) - def _on_fit_multi(self): + # Compute the position of the projected dipole on the helmet + norm = helmet.point_normals[closest_point] + helmet_pos = dip_pos + (distances[closest_point] + 0.003) * norm + + # Create a coordinate system where X and Y are tangential to the helmet + helmet_coords = _normal_orth(norm) + + return helmet_coords, helmet_pos + + def _fit_timecourses(self): """Compute dipole timecourses using a multi-dipole model.""" active_dips = [d for d in self._dipoles if d["active"]] - fwd, _ = make_forward_dipole( - [d["dip"] for d in active_dips], self._bem, self._evoked.info - ) + if len(active_dips) == 0: + return - inv = make_inverse_operator( - self._evoked.info, fwd, self._cov, fixed=True, depth=0 + fwd, _ = make_forward_dipole( + [d["dip"] for d in active_dips], + self._bem, + self._evoked.info, + trans=self._trans, ) - stc = apply_inverse(self._evoked, inv, method="MNE", lambda2=0) - timecourses = stc.data - + fwd = convert_forward_solution(fwd, surf_ori=False) + + if self._multi_dipole_method == "MNE": + inv = make_inverse_operator( + self._evoked.info, + fwd, + self._cov, + fixed=False, + loose=1.0, + depth=0, + ) + stc = apply_inverse( + self._evoked, inv, method="MNE", lambda2=0, pick_ori="vector" + ) + timecourses = np.linalg.norm(stc.data, axis=1) + orientations = stc.data / timecourses[:, np.newaxis, :] + elif self._multi_dipole_method == "LCMV": + lcmv = make_lcmv( + self._evoked.info, fwd, self._cov_data, reg=0.05, noise_cov=self._cov + ) + stc = apply_lcmv(self._evoked, lcmv) + timecourses = stc.data + elif self._multi_dipole_method == "Single-dipole": + timecourses = list() + for dip in active_dips: + dip_timecourse = fit_dipole( + self._evoked, + self._cov, + self._bem, + pos=dip["dip"].pos[0], + ori=dip["dip"].ori[0], + trans=self._trans, + verbose=False, + )[0].data[0] + timecourses.append(dip_timecourse) + + # Update matplotlib canvas at the bottom of the window canvas = self._setup_mplcanvas() ymin, ymax = 0, 0 - for d, timecourse in zip(active_dips, timecourses): + for d, ori, timecourse in zip(active_dips, orientations, timecourses): + d["ori"] = ori d["timecourse"] = timecourse if "line_artist" in d: d["line_artist"].set_ydata(timecourse) @@ -427,20 +554,83 @@ def _on_fit_multi(self): ymax = max(ymax, 1.1 * timecourse.max()) canvas.axes.set_ylim(ymin, ymax) canvas.update_plot() + self._update_arrows() - # Render the arrows in the correct size and orientation - active_arrows = [self._arrows[d["name"]] for d in active_dips] + def _update_arrows(self): + """Update the arrows to have the correct size and orientation.""" + active_dips = [d for d in self._dipoles if d["active"]] + if len(active_dips) == 0: + return + orientations = [d["ori"] for d in active_dips] + timecourses = [d["timecourse"] for d in active_dips] arrow_scaling = 0.05 / np.max(np.abs(timecourses)) - for a, timecourse in zip(active_arrows, timecourses): - arrow = a["actor"].GetMapper().GetInput() + for d, ori, timecourse in zip(active_dips, orientations, timecourses): + dip_ori = [ + np.interp(self._current_time, self._evoked.times, o) for o in ori + ] dip_moment = np.interp(self._current_time, self._evoked.times, timecourse) arrow_size = dip_moment * arrow_scaling - arrow.points = (arrow.points * arrow_size) @ a["dip_coords"] + a["pos"] - arrow.SetVisibility(True) + arrow_mesh = d["arrow_mesh"] + + # Project the orientation of the dipole tangential to the helmet + helmet_coords = d["helmet_coords"] + dip_ori_tan = helmet_coords[:2] @ dip_ori @ helmet_coords[:2] + + # Rotate the coordinate system such that Y lies along the dipole + # orientation, now we have our desired coordinate system for the + # arrows. + arrow_coords = np.array( + [np.cross(dip_ori_tan, helmet_coords[2]), dip_ori_tan, helmet_coords[2]] + ) + arrow_coords /= np.linalg.norm(arrow_coords, axis=1, keepdims=True) + + # Update the arrow mesh to point in the right directions + arrow_mesh.points = (_arrow_mesh()[0] * arrow_size) @ arrow_coords + arrow_mesh.points += d["helmet_pos"] self._renderer._update() def _on_select_method(self, method): + """Select the method to use for multi-dipole timecourse fitting.""" self._multi_dipole_method = method + self._fit_timecourses() + + def _on_dipole_toggle(self, active, dip_name): + """Toggle a dipole on or off.""" + for d in self._dipoles: + if d["name"] == dip_name: + dipole = d + break + else: + raise ValueError(f"Unknown dipole {dip_name}") + active = bool(active) + dipole["active"] = active + dipole["line_artist"].set_visible(active) + dipole["arrow_actor"].visibility = active + self._fit_timecourses() + self._renderer._update() + self._renderer._mplcanvas.update_plot() + + def _on_dipole_toggle_fix_position(self, fix, dip_name): + """Fix dipole position when fitting timecourse.""" + for d in self._dipoles: + if d["name"] == dip_name: + dipole = d + break + else: + raise ValueError(f"Unknown dipole {dip_name}") + dipole["fix_position"] = bool(fix) + self._fit_timecourses() + + def _on_dipole_toggle_fix_orientation(self, fix, dip_name): + """Fix dipole orientation when fitting timecourse.""" + for d in self._dipoles: + if d["name"] == dip_name: + dipole = d + break + else: + raise ValueError(f"Unknown dipole {dip_name}") + dipole["fix_position"] = bool(fix) + self._fit_timecourses() def _setup_mplcanvas(self): """Configure the matplotlib canvas at the bottom of the window.""" @@ -457,49 +647,6 @@ def _setup_mplcanvas(self): ) return self._renderer._mplcanvas - def _draw_arrow(self, dipole): - """Draw an arrow showing the dipole orientation tangential to the helmet.""" - dip_pos = dipole["dip"].pos[0] - dip_ori = dipole["dip"].ori[0] - - # Get the closest vertex (=point) of the helmet mesh - helmet = self._actors["helmet"].GetMapper().GetInput() - distances = ((helmet.points - dip_pos) * helmet.point_normals).sum(axis=1) - closest_point = np.argmin(distances) - - # Compute the position of the projected dipole - norm = helmet.point_normals[closest_point] - arrow_position = dip_pos + (distances[closest_point] + 0.003) * norm - - # Create a coordinate system where X and Y are tangential to the helmet - helmet_coords = _normal_orth(norm) - - # Project the orientation of the dipole tangential to the helmet - dip_ori_helmet = helmet_coords[:2] @ dip_ori @ helmet_coords[:2] - - # Rotate the coordinate system such that Y lies along the dipole orientation - dip_coords = np.array([np.cross(dip_ori_helmet, norm), dip_ori_helmet, norm]) - dip_coords /= np.linalg.norm(dip_coords, axis=1, keepdims=True) - - # Draw the arrow, and collect all relevant information in a dict - arrow = _arrow_mesh() - vertices, faces = arrow.points.copy(), arrow.faces.copy() - actor = self._renderer.plotter.add_mesh(arrow, color=dipole["color"]) - actor.SetVisibility(False) # hide for the moment - if "arrows" not in self._actors: - self._actors["arrows"] = [actor] - else: - self._actors["arrows"].append(actor) - return dict( - name=dipole["name"], - vertices=vertices, - faces=faces, - actor=actor, - pos=arrow_position, - helmet_coords=helmet_coords, - dip_coords=dip_coords, - ) - def _arrow_mesh(): """Obtain a PyVista mesh of an arrow.""" @@ -515,4 +662,4 @@ def _arrow_mesh(): ] ) faces = np.array([[7, 0, 1, 2, 3, 4, 5, 6]]) - return pyvista.PolyData(vertices, faces) + return vertices, faces From 41087bb7b4fc478c9543845ab7a88d6032613d22 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 14 Nov 2023 09:45:48 +0200 Subject: [PATCH 13/65] Fix divide by zero --- mne/viz/_3d_overlay.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/viz/_3d_overlay.py b/mne/viz/_3d_overlay.py index 8eb7c7313f7..367d3a5b2c3 100644 --- a/mne/viz/_3d_overlay.py +++ b/mne/viz/_3d_overlay.py @@ -101,7 +101,7 @@ def _compute_over(self, B, A): C[:, :3] *= A_w C[:, :3] += B[:, :3] * B_w C[:, 3:] += B_w - C[:, :3] /= C[:, 3:] + C[:, :3] /= np.maximum(1e-20, C[:, 3:]) # avoid divide by zero return np.clip(C, 0, 1, out=C) def _compose_overlays(self): From 568560bd9833abbb154a4c65b18615070d3c708d Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 14 Nov 2023 09:46:13 +0200 Subject: [PATCH 14/65] Fix sensor picking --- mne/gui/_xfit.py | 59 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index f8ac13c9685..a0aeaea53a4 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -23,7 +23,7 @@ ) from ..utils import _check_option, fill_doc, verbose from ..viz import EvokedField, create_3d_figure -from ..viz._3d import _plot_head_surface, _plot_sensors +from ..viz._3d import _plot_head_surface, _plot_sensors_3d from ..viz.ui_events import subscribe from ..viz.utils import _get_color_list @@ -201,9 +201,16 @@ def _configure_main_display(self): fig._renderer.set_camera( focalpoint=fit_sphere_to_headshape(self._evoked.info)[1] ) - helmet_mesh = fig._surf_maps[0]["mesh"] - helmet_mesh._polydata.compute_normals() # needed later - self._actors["helmet"] = helmet_mesh._actor + + for surf_map in fig._surf_maps: + if surf_map["map_kind"] == "meg": + helmet_mesh = surf_map["mesh"] + helmet_mesh._polydata.compute_normals() # needed later + self._actors["helmet"] = helmet_mesh._actor + elif surf_map["map_kind"] == "eeg": + head_mesh = surf_map["mesh"] + head_mesh._polydata.compute_normals() # needed later + self._actors["head"] = head_mesh._actor show_meg = (self._ch_type is None or self._ch_type == "meg") and any( [m["kind"] == "meg" for m in self._field_map] @@ -211,6 +218,10 @@ def _configure_main_display(self): show_eeg = (self._ch_type is None or self._ch_type == "eeg") and any( [m["kind"] == "eeg" for m in self._field_map] ) + meg_picks = pick_types(self._evoked.info, meg=show_meg) + eeg_picks = pick_types(self._evoked.info, meg=False, eeg=show_eeg) + picks = np.concatenate((meg_picks, eeg_picks)) + self._ch_names = [self._evoked.ch_names[i] for i in picks] print(f"{show_meg=} {show_eeg=}") @@ -230,25 +241,29 @@ def _configure_main_display(self): alpha=0.2, ) - sensors = _plot_sensors( + sensors = _plot_sensors_3d( renderer=fig._renderer, info=self._evoked.info, to_cf_t=self._to_cf_t, - picks=pick_types(self._evoked.info, meg=show_meg, eeg=show_eeg), + picks=picks, meg=show_meg, - eeg=["projected"] if show_eeg else False, + eeg=["original"] if show_eeg else False, fnirs=False, warn_meg=False, head_surf=head_surf, units="m", - sensor_opacity=0.1, + sensor_alpha=dict(meg=0.1, eeg=1.0), orient_glyphs=False, scale_by_distance=False, project_points=False, surf=None, check_inside=None, nearest=None, - sensor_colors=dict(meg="white", eeg="white"), + # sensor_colors=dict(meg="white", eeg="white"), + sensor_colors=dict( + meg=["white" for _ in meg_picks], + eeg=["white" for _ in eeg_picks], + ), ) self._actors["sensors"] = list() for s in sensors.values(): @@ -380,7 +395,7 @@ def _on_channels_select(self, event): """Color selected sensor meshes.""" selected_channels = set(event.ch_names) if "sensors" in self._actors: - for act, ch_name in zip(self._actors["sensors"], self._evoked.ch_names): + for act, ch_name in zip(self._actors["sensors"], self._ch_names): if ch_name in selected_channels: act.prop.SetColor(0, 1, 0) else: @@ -417,13 +432,18 @@ def _on_fit_dipole(self): dip_num = len(self._dipoles) dip_name = f"dip{dip_num}" dip_color = colors[dip_num % len(colors)] - arrow_mesh = pyvista.PolyData(*_arrow_mesh()) + if helmet_coords is not None: + arrow_mesh = pyvista.PolyData(*_arrow_mesh()) + else: + arrow_mesh = None dipole_dict = dict( active=True, arrow_actor=None, arrow_mesh=arrow_mesh, color=dip_color, dip=dip, + fix_ori=True, + fix_position=True, helmet_coords=helmet_coords, helmet_pos=helmet_pos, name=dip_name, @@ -461,9 +481,10 @@ def _on_fit_dipole(self): self._renderer.plotter.add_arrows( dip.pos[0], dip.ori[0], color=dip_color, mag=0.05 ) - dipole_dict["arrow_actor"] = self._renderer.plotter.add_mesh( - arrow_mesh, color=dip_color - ) + if arrow_mesh is not None: + dipole_dict["arrow_actor"] = self._renderer.plotter.add_mesh( + arrow_mesh, color=dip_color + ) def _get_helmet_coords(self, dip): """Compute the coordinate system used for drawing the big arrows on the helmet. @@ -471,9 +492,11 @@ def _get_helmet_coords(self, dip): In this coordinate system, Z is normal to the helmet surface, and XY are tangential to the helmet surface. """ - dip_pos = dip.pos[0] + if "helmet" not in self._actors: + return None, None # Get the closest vertex (=point) of the helmet mesh + dip_pos = dip.pos[0] helmet = self._actors["helmet"].GetMapper().GetInput() distances = ((helmet.points - dip_pos) * helmet.point_normals).sum(axis=1) closest_point = np.argmin(distances) @@ -565,6 +588,9 @@ def _update_arrows(self): timecourses = [d["timecourse"] for d in active_dips] arrow_scaling = 0.05 / np.max(np.abs(timecourses)) for d, ori, timecourse in zip(active_dips, orientations, timecourses): + helmet_coords = d["helmet_coords"] + if helmet_coords is None: + continue dip_ori = [ np.interp(self._current_time, self._evoked.times, o) for o in ori ] @@ -573,7 +599,6 @@ def _update_arrows(self): arrow_mesh = d["arrow_mesh"] # Project the orientation of the dipole tangential to the helmet - helmet_coords = d["helmet_coords"] dip_ori_tan = helmet_coords[:2] @ dip_ori @ helmet_coords[:2] # Rotate the coordinate system such that Y lies along the dipole @@ -629,7 +654,7 @@ def _on_dipole_toggle_fix_orientation(self, fix, dip_name): break else: raise ValueError(f"Unknown dipole {dip_name}") - dipole["fix_position"] = bool(fix) + dipole["fix_ori"] = bool(fix) self._fit_timecourses() def _setup_mplcanvas(self): From f6738e71b8f73639dcfd5a4d5d3c837c22ed988f Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 14 Nov 2023 10:32:45 +0200 Subject: [PATCH 15/65] Fix bug --- mne/viz/tests/test_raw.py | 10 +++++----- mne/viz/utils.py | 4 +++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index a4c73e76075..ce3eb510ffb 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -1063,8 +1063,8 @@ def test_plot_sensors(raw): assert fig.lasso.selection == [] _fake_click(fig, ax, (0.65, 1), xform="ax", kind="motion") _fake_click(fig, ax, (0.65, 0.7), xform="ax", kind="motion") - _fake_keypress(fig, "control") - _fake_click(fig, ax, (0, 0.7), xform="ax", kind="release", key="control") + _fake_keypress(fig, "shift") + _fake_click(fig, ax, (0, 0.7), xform="ax", kind="release", key="shift") assert fig.lasso.selection == ["MEG 0121"] # check that point appearance changes @@ -1073,11 +1073,11 @@ def test_plot_sensors(raw): assert (fc[:, -1] == [0.5, 1.0, 0.5]).all() assert (ec[:, -1] == [0.25, 1.0, 0.25]).all() - _fake_click(fig, ax, (0.7, 1), xform="ax", kind="motion", key="control") + _fake_click(fig, ax, (0.7, 1), xform="ax", kind="motion", key="shift") xy = ax.collections[0].get_offsets() - _fake_click(fig, ax, xy[2], xform="data", key="control") # single sel + _fake_click(fig, ax, xy[2], xform="data", key="shift") # single sel assert fig.lasso.selection == ["MEG 0121", "MEG 0131"] - _fake_click(fig, ax, xy[2], xform="data", key="control") # deselect + _fake_click(fig, ax, xy[2], xform="data", key="shift") # deselect assert fig.lasso.selection == ["MEG 0121"] plt.close("all") diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 4d9a6de3c71..159ad3082b3 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -1746,7 +1746,9 @@ def on_select(self, verts): def _on_channels_select(self, event): ch_inds = {name: i for i, name in enumerate(self.names)} self.selection = [name for name in event.ch_names if name in ch_inds] - self.selection_inds = [ch_inds[name] for name in self.selection] + self.selection_inds = np.array( + [ch_inds[name] for name in self.selection] + ).astype("int") self.style_objects(self.selection_inds) def select_one(self, ind): From a83b8fd4dc77807f85022ba5cb47fbe3f27586c4 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 14 Nov 2023 10:51:41 +0200 Subject: [PATCH 16/65] Fix more renames --- mne/viz/_figure.py | 4 ++-- mne/viz/_mpl_figure.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index 7f958657876..574ea4ea515 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -494,11 +494,11 @@ def _create_ch_location_fig(self, pick): show=False, ) # highlight desired channel & disable interactivity - inds = np.isin(fig.lasso.ch_names, [ch_name]) + inds = np.isin(fig.lasso.names, [ch_name]) fig.lasso.disconnect() fig.lasso.alpha_other = 0.3 fig.lasso.linewidth_selected = 3 - fig.lasso.style_sensors(inds) + fig.lasso.style_objects(inds) return fig diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index b0a059c97cf..223a76b76eb 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -1553,7 +1553,7 @@ def _update_selection(self): def _update_highlighted_sensors(self): """Update the sensor plot to show what is selected.""" inds = np.isin( - self.mne.fig_selection.lasso.ch_names, self.mne.ch_names[self.mne.picks] + self.mne.fig_selection.lasso.names, self.mne.ch_names[self.mne.picks] ).nonzero()[0] self.mne.fig_selection.lasso.select_many(inds) From 3bfe2a5f92bf461e20f82cb7057990245075095d Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 14 Nov 2023 11:24:52 +0200 Subject: [PATCH 17/65] Don't draw patches for channels that do not exist --- mne/viz/topo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index a980e4b41f7..7f98d496a92 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -235,12 +235,13 @@ def format_coord_multiaxis(x, y, ch_name=None): if unified: under_ax._mne_axs = axs # Create a PolyCollection for the axis backgrounds + sel_pos = pos[[i[0] for i in iter_ch]] verts = np.transpose( [ - pos[:, :2], - pos[:, :2] + pos[:, 2:] * [1, 0], - pos[:, :2] + pos[:, 2:], - pos[:, :2] + pos[:, 2:] * [0, 1], + sel_pos[:, :2], + sel_pos[:, :2] + sel_pos[:, 2:] * [1, 0], + sel_pos[:, :2] + sel_pos[:, 2:], + sel_pos[:, :2] + sel_pos[:, 2:] * [0, 1], ], [1, 0, 2], ) From 2e94752d9062b5ec25b0ae2c46e0d696db48a3f2 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 14 Nov 2023 12:00:55 +0200 Subject: [PATCH 18/65] Move the ChannelsSelect ui-event one abstraction layer higher --- mne/viz/topo.py | 13 +++++++++++ mne/viz/utils.py | 60 ++++++++++++++++++++++++++++++------------------ 2 files changed, 51 insertions(+), 22 deletions(-) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 7f98d496a92..353222d2431 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -16,6 +16,7 @@ from .._fiff.pick import channel_type, pick_types from ..defaults import _handle_default from ..utils import Bunch, _check_option, _clean_names, _to_rgb, fill_doc +from .ui_events import ChannelsSelect, disable_ui_events, publish, subscribe from .utils import ( DraggableColorbar, SelectFromCollection, @@ -261,6 +262,18 @@ def format_coord_multiaxis(x, y, ch_name=None): linewidth_nonselected=0, linewidth_selected=0.7, ) + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + ch_inds = {name: i for i, name in enumerate(ch_names)} + selection_inds = [ch_inds[name] for name in event.ch_names] + with disable_ui_events(fig): + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) for ax in axs: yield ax, ax._mne_ch_idx diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 159ad3082b3..52889b283ae 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -64,7 +64,13 @@ verbose, warn, ) -from .ui_events import ChannelsSelect, ColormapRange, publish, subscribe +from .ui_events import ( + ChannelsSelect, + ColormapRange, + disable_ui_events, + publish, + subscribe, +) _channel_type_prettyprint = { "eeg": "EEG channel", @@ -1304,6 +1310,18 @@ def _plot_sensors_2d( ) if kind == "select": fig.lasso = SelectFromCollection(ax, pts, names=ch_names) + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + ch_inds = {name: i for i, name in enumerate(ch_names)} + selection_inds = [ch_inds[name] for name in event.ch_names] + with disable_ui_events(fig): + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) else: fig.lasso = None @@ -1718,12 +1736,15 @@ def __init__( self.lasso = LassoSelector(ax, onselect=self.on_select, **line_kw) self.selection = list() self.selection_inds = np.array([], dtype="int") + self.callbacks = list() # Deselect everything in the beginning. - self.style_objects([]) + self.style_objects() - # Respond to UI-Events - subscribe(self.fig, "channels_select", self._on_channels_select) + def notify(self): + """Notify listeners that a selection has been made.""" + for callback in self.callbacks: + callback() def on_select(self, verts): """Select a subset from the collection.""" @@ -1740,16 +1761,9 @@ def on_select(self, verts): self.selection_inds = np.setdiff1d(self.selection_inds, inds) else: self.selection_inds = inds - ch_names = [self.names[i] for i in self.selection_inds] - publish(self.fig, ChannelsSelect(ch_names=ch_names)) - - def _on_channels_select(self, event): - ch_inds = {name: i for i, name in enumerate(self.names)} - self.selection = [name for name in event.ch_names if name in ch_inds] - self.selection_inds = np.array( - [ch_inds[name] for name in self.selection] - ).astype("int") - self.style_objects(self.selection_inds) + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() + self.notify() def select_one(self, ind): """Select or deselect one sensor.""" @@ -1759,25 +1773,27 @@ def select_one(self, ind): self.selection_inds = np.setdiff1d(self.selection_inds, [ind]) else: return # don't notify() - ch_names = [self.names[i] for i in self.selection_inds] - publish(self.fig, ChannelsSelect(ch_names=ch_names)) + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() + self.notify() def select_many(self, inds): """Select many sensors using indices (for predefined selections).""" self.selected_inds = inds - ch_names = [self.names[i] for i in self.selection_inds] - publish(self.fig, ChannelsSelect(ch_names=ch_names)) + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() + self.notify() - def style_objects(self, inds): + def style_objects(self): """Style selected sensors as "active".""" # reset self.fc[:, -1] = self.alpha_nonselected self.ec[:, -1] = self.alpha_nonselected / 2 self.lw[:] = self.linewidth_nonselected # style sensors at `inds` - self.fc[inds, -1] = self.alpha_selected - self.ec[inds, -1] = self.alpha_selected - self.lw[inds] = self.linewidth_selected + self.fc[self.selection_inds, -1] = self.alpha_selected + self.ec[self.selection_inds, -1] = self.alpha_selected + self.lw[self.selection_inds] = self.linewidth_selected self.collection.set_facecolors(self.fc) self.collection.set_edgecolors(self.ec) self.collection.set_linewidths(self.lw) From 3c6b73cf9f353648e5572154a10c4fce2a52f34b Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 14 Nov 2023 13:45:51 +0200 Subject: [PATCH 19/65] Some more fixes --- mne/viz/_figure.py | 4 ++-- mne/viz/utils.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index 574ea4ea515..7cc0d059826 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -494,11 +494,11 @@ def _create_ch_location_fig(self, pick): show=False, ) # highlight desired channel & disable interactivity - inds = np.isin(fig.lasso.names, [ch_name]) + fig.lasst.selection_inds = np.isin(fig.lasso.names, [ch_name]) fig.lasso.disconnect() fig.lasso.alpha_other = 0.3 fig.lasso.linewidth_selected = 3 - fig.lasso.style_objects(inds) + fig.lasso.style_objects() return fig diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 52889b283ae..f69dceffb90 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -1741,6 +1741,11 @@ def __init__( # Deselect everything in the beginning. self.style_objects() + # For backwards compatibility + @property + def ch_names(self): + return self.names + def notify(self): """Notify listeners that a selection has been made.""" for callback in self.callbacks: From 9b6bd6007d235e1e2b1405cfcdaafb311bb73931 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 14 Nov 2023 13:48:48 +0200 Subject: [PATCH 20/65] select_many should not notify() --- mne/viz/topo.py | 5 ++--- mne/viz/utils.py | 12 ++---------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 353222d2431..cdb5750e84c 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -16,7 +16,7 @@ from .._fiff.pick import channel_type, pick_types from ..defaults import _handle_default from ..utils import Bunch, _check_option, _clean_names, _to_rgb, fill_doc -from .ui_events import ChannelsSelect, disable_ui_events, publish, subscribe +from .ui_events import ChannelsSelect, publish, subscribe from .utils import ( DraggableColorbar, SelectFromCollection, @@ -269,8 +269,7 @@ def on_select(): def on_channels_select(event): ch_inds = {name: i for i, name in enumerate(ch_names)} selection_inds = [ch_inds[name] for name in event.ch_names] - with disable_ui_events(fig): - fig.lasso.select_many(selection_inds) + fig.lasso.select_many(selection_inds) fig.lasso.callbacks.append(on_select) subscribe(fig, "channels_select", on_channels_select) diff --git a/mne/viz/utils.py b/mne/viz/utils.py index f69dceffb90..5d5dae51c26 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -64,13 +64,7 @@ verbose, warn, ) -from .ui_events import ( - ChannelsSelect, - ColormapRange, - disable_ui_events, - publish, - subscribe, -) +from .ui_events import ChannelsSelect, ColormapRange, publish, subscribe _channel_type_prettyprint = { "eeg": "EEG channel", @@ -1317,8 +1311,7 @@ def on_select(): def on_channels_select(event): ch_inds = {name: i for i, name in enumerate(ch_names)} selection_inds = [ch_inds[name] for name in event.ch_names] - with disable_ui_events(fig): - fig.lasso.select_many(selection_inds) + fig.lasso.select_many(selection_inds) fig.lasso.callbacks.append(on_select) subscribe(fig, "channels_select", on_channels_select) @@ -1787,7 +1780,6 @@ def select_many(self, inds): self.selected_inds = inds self.selection = [self.names[i] for i in self.selection_inds] self.style_objects() - self.notify() def style_objects(self): """Style selected sensors as "active".""" From 573cb40151d4934d95bbdd5f757049f0164d67d7 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 14 Nov 2023 14:23:27 +0200 Subject: [PATCH 21/65] Add "select" parameter to enable/disable the lasso selection tool --- mne/viz/_figure.py | 2 +- mne/viz/evoked.py | 8 +++++- mne/viz/topo.py | 66 +++++++++++++++++++++++++++++++++------------- 3 files changed, 55 insertions(+), 21 deletions(-) diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index 7cc0d059826..d292732c276 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -494,7 +494,7 @@ def _create_ch_location_fig(self, pick): show=False, ) # highlight desired channel & disable interactivity - fig.lasst.selection_inds = np.isin(fig.lasso.names, [ch_name]) + fig.lasso.selection_inds = np.isin(fig.lasso.names, [ch_name]) fig.lasso.disconnect() fig.lasso.alpha_other = 0.3 fig.lasso.linewidth_selected = 3 diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 1c6712a6bec..9ebee6e69f4 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1178,6 +1178,7 @@ def plot_evoked_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """Plot 2D topography of evoked responses. @@ -1248,6 +1249,10 @@ def plot_evoked_topo( exclude : list of str | 'bads' Channels names to exclude from the plot. If 'bads', the bad channels are excluded. By default, exclude is set to 'bads'. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. show : bool Show figure if True. @@ -1304,10 +1309,11 @@ def plot_evoked_topo( font_color=font_color, merge_channels=merge_grads, legend=legend, + noise_cov=noise_cov, axes=axes, exclude=exclude, + select=select, show=show, - noise_cov=noise_cov, ) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index cdb5750e84c..fe815c3ce45 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -42,6 +42,7 @@ def iter_topography( axis_spinecolor="k", layout_scale=None, legend=False, + select=False, ): """Create iterator over channel positions. @@ -77,6 +78,10 @@ def iter_topography( If True, an additional axis is created in the bottom right corner that can be used to, e.g., construct a legend. The index of this axis will be -1. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. Returns ------- @@ -98,6 +103,7 @@ def iter_topography( axis_spinecolor, layout_scale, legend=legend, + select=select, ) @@ -133,6 +139,7 @@ def _iter_topography( img=False, axes=None, legend=False, + select=False, ): """Iterate over topography. @@ -251,28 +258,32 @@ def format_coord_multiaxis(x, y, ch_name=None): verts, facecolor=axis_facecolor, edgecolor=axis_spinecolor, + linewidth=1.0, ) under_ax.add_collection(collection) - fig.lasso = SelectFromCollection( - ax=under_ax, - collection=collection, - names=shown_ch_names, - alpha_nonselected=0, - alpha_selected=1, - linewidth_nonselected=0, - linewidth_selected=0.7, - ) - def on_select(): - publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + if select: + # Configure the lasso-selection tool + fig.lasso = SelectFromCollection( + ax=under_ax, + collection=collection, + names=shown_ch_names, + alpha_nonselected=0, + alpha_selected=1, + linewidth_nonselected=0, + linewidth_selected=0.7, + ) + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) - def on_channels_select(event): - ch_inds = {name: i for i, name in enumerate(ch_names)} - selection_inds = [ch_inds[name] for name in event.ch_names] - fig.lasso.select_many(selection_inds) + def on_channels_select(event): + ch_inds = {name: i for i, name in enumerate(ch_names)} + selection_inds = [ch_inds[name] for name in event.ch_names] + fig.lasso.select_many(selection_inds) - fig.lasso.callbacks.append(on_select) - subscribe(fig, "channels_select", on_channels_select) + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) for ax in axs: yield ax, ax._mne_ch_idx @@ -299,6 +310,7 @@ def _plot_topo( unified=False, img=False, axes=None, + select=False, ): """Plot on sensor layout.""" import matplotlib.pyplot as plt @@ -351,6 +363,7 @@ def _plot_topo( unified=unified, img=img, axes=axes, + select=select, ) for ax, ch_idx in my_topo_plot: @@ -873,9 +886,10 @@ def _plot_evoked_topo( merge_channels=False, legend=True, axes=None, + noise_cov=None, exclude="bads", + select=False, show=True, - noise_cov=None, ): """Plot 2D topography of evoked responses. @@ -947,6 +961,10 @@ def _plot_evoked_topo( exclude : list of str | 'bads' Channels names to exclude from being shown. If 'bads', the bad channels are excluded. By default, exclude is set to 'bads'. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. show : bool Show figure if True. @@ -1124,6 +1142,7 @@ def _plot_evoked_topo( y_label=y_label, unified=True, axes=axes, + select=select, ) add_background_image(fig, fig_background) @@ -1131,7 +1150,10 @@ def _plot_evoked_topo( if legend is not False: legend_loc = 0 if legend is True else legend labels = [e.comment if e.comment else "Unknown" for e in evoked] - handles = fig.axes[0].lines[: len(evoked)] + if select: + handles = fig.axes[0].lines[1 : len(evoked) + 1] + else: + handles = fig.axes[0].lines[: len(evoked)] legend = plt.legend( labels=labels, handles=handles, loc=legend_loc, prop={"size": 10} ) @@ -1190,6 +1212,7 @@ def plot_topo_image_epochs( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): """Plot Event Related Potential / Fields image on topographies. @@ -1237,6 +1260,10 @@ def plot_topo_image_epochs( :func:`matplotlib.pyplot.imshow`. Defaults to ``None``. font_color : color The color of tick labels in the colorbar. Defaults to white. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. show : bool Whether to show the figure. Defaults to ``True``. @@ -1326,6 +1353,7 @@ def plot_topo_image_epochs( y_label="Epoch", unified=True, img=True, + select=select, ) add_background_image(fig, fig_background) plt_show(show) From facd394b5f05ff5effeced8d7466be520e0a4826 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 14 Nov 2023 14:51:02 +0200 Subject: [PATCH 22/65] Add select parameter to relevant methods --- mne/epochs.py | 2 ++ mne/evoked.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/mne/epochs.py b/mne/epochs.py index b7afada3d1a..e3d625b3401 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1345,6 +1345,7 @@ def plot_topo_image( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): return plot_topo_image_epochs( @@ -1363,6 +1364,7 @@ def plot_topo_image( fig_facecolor=fig_facecolor, fig_background=fig_background, font_color=font_color, + select=select, show=show, ) diff --git a/mne/evoked.py b/mne/evoked.py index f2c75f3754b..a0150a0c2dc 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -555,6 +555,7 @@ def plot_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """ @@ -580,6 +581,7 @@ def plot_topo( background_color=background_color, noise_cov=noise_cov, exclude=exclude, + select=select, show=show, ) From a0069d82f9fb052b4b4a2d33de4a5ab380a477e5 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 15 Nov 2023 10:37:25 +0200 Subject: [PATCH 23/65] Update test --- mne/viz/tests/test_raw.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index ce3eb510ffb..f930d7cbf9d 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -1061,10 +1061,10 @@ def test_plot_sensors(raw): _fake_click(fig, ax, (0, 1), xform="ax") fig.canvas.draw() assert fig.lasso.selection == [] - _fake_click(fig, ax, (0.65, 1), xform="ax", kind="motion") - _fake_click(fig, ax, (0.65, 0.7), xform="ax", kind="motion") + _fake_click(fig, ax, (-0.11, 0.14), xform="data", kind="motion") + _fake_click(fig, ax, (-0.11, 0.065), xform="data", kind="motion") _fake_keypress(fig, "shift") - _fake_click(fig, ax, (0, 0.7), xform="ax", kind="release", key="shift") + _fake_click(fig, ax, (-0.15, 0.065), xform="data", kind="release", key="shift") assert fig.lasso.selection == ["MEG 0121"] # check that point appearance changes @@ -1073,11 +1073,11 @@ def test_plot_sensors(raw): assert (fc[:, -1] == [0.5, 1.0, 0.5]).all() assert (ec[:, -1] == [0.25, 1.0, 0.25]).all() - _fake_click(fig, ax, (0.7, 1), xform="ax", kind="motion", key="shift") + _fake_click(fig, ax, (-0.11, 0.065), xform="data", kind="motion", key="shift") xy = ax.collections[0].get_offsets() _fake_click(fig, ax, xy[2], xform="data", key="shift") # single sel assert fig.lasso.selection == ["MEG 0121", "MEG 0131"] - _fake_click(fig, ax, xy[2], xform="data", key="shift") # deselect + _fake_click(fig, ax, xy[2], xform="data", key="alt") # deselect assert fig.lasso.selection == ["MEG 0121"] plt.close("all") From b0e6cb20ef01acf716edb12d2a2edd05e9383237 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Mon, 11 Dec 2023 16:13:17 +0200 Subject: [PATCH 24/65] Enable sensor selection again --- mne/gui/_xfit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index a0aeaea53a4..92157dada60 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -378,7 +378,7 @@ def _on_sensor_data(self): """Show sensor data and allow sensor selection.""" if self._fig_sensors is not None: return - fig = self._evoked.plot_topo() + fig = self._evoked.plot_topo(select=True) fig.canvas.mpl_connect("close_event", self._on_sensor_data_close) subscribe(fig, "channels_select", self._on_channels_select) self._fig_sensors = fig From bb4fd2c2fe04e9d636715778905dd637b36d5e66 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Fri, 17 May 2024 10:50:45 +0300 Subject: [PATCH 25/65] Implement toggling fixed orientation with MNE solution --- mne/gui/_xfit.py | 18 ++++++++++++++++-- mne/viz/ui_events.py | 1 - 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index a0aeaea53a4..2714fcc91ef 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -378,7 +378,7 @@ def _on_sensor_data(self): """Show sensor data and allow sensor selection.""" if self._fig_sensors is not None: return - fig = self._evoked.plot_topo() + fig = self._evoked.plot_topo(select=True) fig.canvas.mpl_connect("close_event", self._on_sensor_data_close) subscribe(fig, "channels_select", self._on_channels_select) self._fig_sensors = fig @@ -536,16 +536,27 @@ def _fit_timecourses(self): stc = apply_inverse( self._evoked, inv, method="MNE", lambda2=0, pick_ori="vector" ) - timecourses = np.linalg.norm(stc.data, axis=1) + + timecourses = stc.magnitude().data + fixed_timecourses = stc.project( + np.array([dip["dip"].ori[0] for dip in active_dips]) + )[0].data orientations = stc.data / timecourses[:, np.newaxis, :] + + for i, dip in enumerate(active_dips): + if dip["fix_ori"]: + timecourses[i] = fixed_timecourses[i] + orientations[i] = np.array([dip["dip"].ori[0]] * len(stc.times)).T elif self._multi_dipole_method == "LCMV": lcmv = make_lcmv( self._evoked.info, fwd, self._cov_data, reg=0.05, noise_cov=self._cov ) stc = apply_lcmv(self._evoked, lcmv) timecourses = stc.data + orientations = [dip["dip"].ori[0] for dip in active_dips] elif self._multi_dipole_method == "Single-dipole": timecourses = list() + orientations = list() for dip in active_dips: dip_timecourse = fit_dipole( self._evoked, @@ -557,6 +568,9 @@ def _fit_timecourses(self): verbose=False, )[0].data[0] timecourses.append(dip_timecourse) + orientations.append( + np.array([dip["dip"].ori[0]] * len(dip_timecourse)).T + ) # Update matplotlib canvas at the bottom of the window canvas = self._setup_mplcanvas() diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index e27810b49f5..795ae442df6 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -233,7 +233,6 @@ class ChannelsSelect(UIEvent): """ ch_names: list[str] - contours: list[str] def _get_event_channel(fig): From b498c15fac54d1f40320bcdb1cd7b10599552d52 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 4 Jun 2024 11:23:26 +0300 Subject: [PATCH 26/65] small fixes --- mne/gui/_xfit.py | 2 +- mne/viz/ui_events.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index a0aeaea53a4..92157dada60 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -378,7 +378,7 @@ def _on_sensor_data(self): """Show sensor data and allow sensor selection.""" if self._fig_sensors is not None: return - fig = self._evoked.plot_topo() + fig = self._evoked.plot_topo(select=True) fig.canvas.mpl_connect("close_event", self._on_sensor_data_close) subscribe(fig, "channels_select", self._on_channels_select) self._fig_sensors = fig diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index e27810b49f5..795ae442df6 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -233,7 +233,6 @@ class ChannelsSelect(UIEvent): """ ch_names: list[str] - contours: list[str] def _get_event_channel(fig): From ca1b99e7b399d2545bd45c184188e89095082772 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Sun, 14 Jul 2024 10:45:35 +0300 Subject: [PATCH 27/65] Remove features not for v1 --- mne/gui/_xfit.py | 245 +++++++++++++++++++++++++++-------------------- 1 file changed, 139 insertions(+), 106 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 2714fcc91ef..74f5830378e 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -4,7 +4,8 @@ import pyvista from .. import pick_types -from ..beamformer import apply_lcmv, make_lcmv + +# from ..beamformer import apply_lcmv, make_lcmv from ..bem import ( ConductorModel, _ensure_bem_surfaces, @@ -113,7 +114,7 @@ def __init__( self, evoked, cov=None, - cov_data=None, + # cov_data=None, bem=None, initial_time=None, trans=None, @@ -162,13 +163,13 @@ def __init__( self._bem = bem self._ch_type = ch_type self._cov = cov - self._cov_data = cov_data + # self._cov_data = cov_data self._current_time = initial_time self._dipoles = list() self._evoked = evoked self._field_map = field_map self._fig_sensors = None - self._multi_dipole_method = "MNE" + self._multi_dipole_method = "Multi dipole (MNE)" self._n_jobs = n_jobs self._show_density = show_density self._subjects_dir = subjects_dir @@ -182,6 +183,11 @@ def __init__( self._renderer = self._configure_main_display() self._configure_dock() + @property + def dipoles(self): + """A list of the fitted dipoles.""" + return [d["dip"] for d in self._dipoles] + def _configure_main_display(self): """Configure main 3D display of the GUI.""" fig = create_3d_figure((1500, 1020), bgcolor="white", show=True) @@ -287,35 +293,35 @@ def _configure_dock(self): layout=layout, ) - # Add view buttons - layout = r._dock_add_group_box("Views") - hlayout = None - views = zip( - [7, 8, 9, 4, 5, 6, 1, 2, 3], # numpad order - ["๐Ÿข†", "๐Ÿขƒ", "๐Ÿข‡", "๐Ÿข‚", "โŠ™", "๐Ÿข€", "๐Ÿข…", "๐Ÿข", "๐Ÿข„"], - ) - for i, (view, label) in enumerate(views): - if i % 3 == 0: # show in groups of 3 - hlayout = r._dock_add_layout(vertical=False) - r._layout_add_widget(layout, hlayout) - r._dock_add_button( - label, - callback=partial(self._set_view, view=view), - layout=hlayout, - style="pushbutton", - ) - r.plotter.add_key_event(str(view), partial(self._set_view, view=view)) + # # Add view buttons + # layout = r._dock_add_group_box("Views") + # hlayout = None + # views = zip( + # [7, 8, 9, 4, 5, 6, 1, 2, 3], # numpad order + # ["๐Ÿข†", "๐Ÿขƒ", "๐Ÿข‡", "๐Ÿข‚", "โŠ™", "๐Ÿข€", "๐Ÿข…", "๐Ÿข", "๐Ÿข„"], + # ) + # for i, (view, label) in enumerate(views): + # if i % 3 == 0: # show in groups of 3 + # hlayout = r._dock_add_layout(vertical=False) + # r._layout_add_widget(layout, hlayout) + # r._dock_add_button( + # label, + # callback=partial(self._set_view, view=view), + # layout=hlayout, + # style="pushbutton", + # ) + # r.plotter.add_key_event(str(view), partial(self._set_view, view=view)) # Right dock r._dock_initialize(name="Dipole fitting", area="right") r._dock_add_button("Sensor data", self._on_sensor_data) r._dock_add_button("Fit dipole", self._on_fit_dipole) - methods = ["MNE", "Single-dipole"] - if self._cov_data is not None: - methods.append("LCMV") + methods = ["Multi dipole (MNE)", "Single dipole"] + # if self._cov_data is not None: + # methods.append("LCMV") r._dock_add_combo_box( "Dipole model", - value="MNE", + value="Multi dipole (MNE)", rng=methods, callback=self._on_select_method, ) @@ -344,27 +350,27 @@ def toggle_mesh(self, name, show=None): act.SetVisibility(show) self._renderer._update() - def _set_view(self, view): - kwargs = dict() - if view == 1: - kwargs = dict(azimuth=-135, roll=45, elevation=60, distance="auto") - elif view == 2: - kwargs = dict(azimuth=270, roll=180, elevation=90, distance="auto") - elif view == 3: - kwargs = dict(azimuth=-45, roll=-45, elevation=60, distance="auto") - elif view == 4: - kwargs = dict(azimuth=180, roll=90, elevation=90, distance="auto") - elif view == 5: - kwargs = dict(azimuth=0, roll=0, elevation=0, distance="auto") - elif view == 6: - kwargs = dict(azimuth=0, roll=-90, elevation=90, distance="auto") - elif view == 7: - kwargs = dict(azimuth=135, roll=90, elevation=60, distance="auto") - elif view == 8: - kwargs = dict(azimuth=90, roll=0, elevation=90, distance="auto") - elif view == 9: - kwargs = dict(azimuth=45, roll=-90, elevation=60, distance="auto") - self._renderer.set_camera(**kwargs) + # def _set_view(self, view): + # kwargs = dict() + # if view == 1: + # kwargs = dict(azimuth=-135, roll=45, elevation=60, distance="auto") + # elif view == 2: + # kwargs = dict(azimuth=270, roll=180, elevation=90, distance="auto") + # elif view == 3: + # kwargs = dict(azimuth=-45, roll=-45, elevation=60, distance="auto") + # elif view == 4: + # kwargs = dict(azimuth=180, roll=90, elevation=90, distance="auto") + # elif view == 5: + # kwargs = dict(azimuth=0, roll=0, elevation=0, distance="auto") + # elif view == 6: + # kwargs = dict(azimuth=0, roll=-90, elevation=90, distance="auto") + # elif view == 7: + # kwargs = dict(azimuth=135, roll=90, elevation=60, distance="auto") + # elif view == 8: + # kwargs = dict(azimuth=90, roll=0, elevation=90, distance="auto") + # elif view == 9: + # kwargs = dict(azimuth=45, roll=-90, elevation=60, distance="auto") + # self._renderer.set_camera(**kwargs) def _on_time_change(self, event): new_time = np.clip(event.time, self._evoked.times[0], self._evoked.times[-1]) @@ -430,7 +436,7 @@ def _on_fit_dipole(self): # Collect all relevant information on the dipole in a dict colors = _get_color_list() dip_num = len(self._dipoles) - dip_name = f"dip{dip_num}" + dip.name = f"dip{dip_num}" dip_color = colors[dip_num % len(colors)] if helmet_coords is not None: arrow_mesh = pyvista.PolyData(*_arrow_mesh()) @@ -446,8 +452,8 @@ def _on_fit_dipole(self): fix_position=True, helmet_coords=helmet_coords, helmet_pos=helmet_pos, - name=dip_name, num=dip_num, + fit_time=self._current_time, ) self._dipoles.append(dipole_dict) @@ -455,21 +461,28 @@ def _on_fit_dipole(self): r = self._renderer hlayout = r._dock_add_layout(vertical=False) r._dock_add_check_box( - name=dip_name, + name="", value=True, - callback=partial(self._on_dipole_toggle, dip_name=dip_name), + callback=partial(self._on_dipole_toggle, dip_num=dip_num), layout=hlayout, ) - r._dock_add_check_box( - name="Fix pos", - value=True, - callback=partial(self._on_dipole_toggle_fix_position, dip_name=dip_name), + r._dock_add_text( + name=dip.name, + value=dip.name, + placeholder="name", + callback=partial(self._on_dipole_set_name, dip_num=dip_num), layout=hlayout, ) + # r._dock_add_check_box( + # name="Fix pos", + # value=True, + # callback=partial(self._on_dipole_toggle_fix_position, dip_name=dip_name), + # layout=hlayout, + # ) r._dock_add_check_box( name="Fix ori", value=True, - callback=partial(self._on_dipole_toggle_fix_orientation, dip_name=dip_name), + callback=partial(self._on_dipole_toggle_fix_orientation, dip_num=dip_num), layout=hlayout, ) r._layout_add_widget(self._dipole_box, hlayout) @@ -511,11 +524,19 @@ def _get_helmet_coords(self, dip): return helmet_coords, helmet_pos def _fit_timecourses(self): - """Compute dipole timecourses using a multi-dipole model.""" + """Compute (or re-compute) dipole timecourses. + + Called whenever a dipole is (de)-activated or the "Fix pos" box is toggled. + """ active_dips = [d for d in self._dipoles if d["active"]] + print([d["dip"] for d in active_dips]) if len(active_dips) == 0: return + # Restrict the dipoles to only the time at which they were fitted. + for d in active_dips: + d["dip"] = d["dip"].crop(d["fit_time"], d["fit_time"]) + fwd, _ = make_forward_dipole( [d["dip"] for d in active_dips], self._bem, @@ -524,7 +545,7 @@ def _fit_timecourses(self): ) fwd = convert_forward_solution(fwd, surf_ori=False) - if self._multi_dipole_method == "MNE": + if self._multi_dipole_method == "Multi dipole (MNE)": inv = make_inverse_operator( self._evoked.info, fwd, @@ -538,23 +559,24 @@ def _fit_timecourses(self): ) timecourses = stc.magnitude().data + orientations = stc.data / timecourses[:, np.newaxis, :] + print(orientations.shape) fixed_timecourses = stc.project( np.array([dip["dip"].ori[0] for dip in active_dips]) )[0].data - orientations = stc.data / timecourses[:, np.newaxis, :] for i, dip in enumerate(active_dips): if dip["fix_ori"]: timecourses[i] = fixed_timecourses[i] - orientations[i] = np.array([dip["dip"].ori[0]] * len(stc.times)).T - elif self._multi_dipole_method == "LCMV": - lcmv = make_lcmv( - self._evoked.info, fwd, self._cov_data, reg=0.05, noise_cov=self._cov - ) - stc = apply_lcmv(self._evoked, lcmv) - timecourses = stc.data - orientations = [dip["dip"].ori[0] for dip in active_dips] - elif self._multi_dipole_method == "Single-dipole": + orientations[i] = dip["dip"].ori.repeat(len(stc.times), axis=0).T + # elif self._multi_dipole_method == "LCMV": + # lcmv = make_lcmv( + # self._evoked.info, fwd, self._cov_data, reg=0.05, noise_cov=self._cov + # ) + # stc = apply_lcmv(self._evoked, lcmv) + # timecourses = stc.data + # orientations = [dip["dip"].ori[0] for dip in active_dips] + elif self._multi_dipole_method == "Single dipole": timecourses = list() orientations = list() for dip in active_dips: @@ -562,29 +584,43 @@ def _fit_timecourses(self): self._evoked, self._cov, self._bem, - pos=dip["dip"].pos[0], - ori=dip["dip"].ori[0], + pos=dip["dip"].pos[0], # position is always fixed + ori=dip["dip"].ori[0] if dip["fix_ori"] else None, trans=self._trans, verbose=False, )[0].data[0] timecourses.append(dip_timecourse) - orientations.append( - np.array([dip["dip"].ori[0]] * len(dip_timecourse)).T - ) + + if dip["fix_ori"]: + orientations.append( + dip["dip"].ori.repeat(len(dip_timecourse), axis=0) + ) + else: + orientations.append(dip["dip"].ori) + + for o in orientations: + print(o.shape) + + # Store the timecourse and orientation in the Dipole object + for d, timecourse, orientation in zip(active_dips, timecourses, orientations): + dip = d["dip"] + dip.amplitude = timecourse + dip.ori = orientation.T + dip._set_times(self._evoked.times) + if len(dip.pos) != len(dip.times): + dip.pos = dip.pos[[0]].repeat(len(dip.times), axis=0) # Update matplotlib canvas at the bottom of the window canvas = self._setup_mplcanvas() ymin, ymax = 0, 0 - for d, ori, timecourse in zip(active_dips, orientations, timecourses): - d["ori"] = ori - d["timecourse"] = timecourse + for d in active_dips: if "line_artist" in d: d["line_artist"].set_ydata(timecourse) else: d["line_artist"] = canvas.plot( self._evoked.times, - timecourse, - label=d["name"], + d["dip"].amplitude, + label=d["dip"].name, color=d["color"], ) ymin = min(ymin, 1.1 * timecourse.min()) @@ -598,15 +634,15 @@ def _update_arrows(self): active_dips = [d for d in self._dipoles if d["active"]] if len(active_dips) == 0: return - orientations = [d["ori"] for d in active_dips] - timecourses = [d["timecourse"] for d in active_dips] + orientations = [d["dip"].ori for d in active_dips] + timecourses = [d["dip"].amplitude for d in active_dips] arrow_scaling = 0.05 / np.max(np.abs(timecourses)) for d, ori, timecourse in zip(active_dips, orientations, timecourses): helmet_coords = d["helmet_coords"] if helmet_coords is None: continue dip_ori = [ - np.interp(self._current_time, self._evoked.times, o) for o in ori + np.interp(self._current_time, self._evoked.times, o) for o in ori.T ] dip_moment = np.interp(self._current_time, self._evoked.times, timecourse) arrow_size = dip_moment * arrow_scaling @@ -633,14 +669,9 @@ def _on_select_method(self, method): self._multi_dipole_method = method self._fit_timecourses() - def _on_dipole_toggle(self, active, dip_name): + def _on_dipole_toggle(self, active, dip_num): """Toggle a dipole on or off.""" - for d in self._dipoles: - if d["name"] == dip_name: - dipole = d - break - else: - raise ValueError(f"Unknown dipole {dip_name}") + dipole = self._dipoles[dip_num] active = bool(active) dipole["active"] = active dipole["line_artist"].set_visible(active) @@ -649,26 +680,28 @@ def _on_dipole_toggle(self, active, dip_name): self._renderer._update() self._renderer._mplcanvas.update_plot() - def _on_dipole_toggle_fix_position(self, fix, dip_name): - """Fix dipole position when fitting timecourse.""" - for d in self._dipoles: - if d["name"] == dip_name: - dipole = d - break - else: - raise ValueError(f"Unknown dipole {dip_name}") - dipole["fix_position"] = bool(fix) - self._fit_timecourses() + def _on_dipole_set_name(self, name, dip_num): + """Set the name of a dipole.""" + self._dipoles[dip_num]["dip"].name = name + self._dipoles[dip_num]["line_artist"].set_label(name) + self._renderer._mplcanvas.update_plot() - def _on_dipole_toggle_fix_orientation(self, fix, dip_name): + # def _on_dipole_toggle_fix_position(self, fix, dip_name): + # """Fix dipole position when fitting timecourse.""" + # for d in self._dipoles: + # if d["name"] == dip_name: + # dipole = d + # break + # else: + # raise ValueError(f"Unknown dipole {dip_name}") + # dipole["fix_position"] = bool(fix) + # self._fit_timecourses() + + def _on_dipole_toggle_fix_orientation(self, fix, dip_num): """Fix dipole orientation when fitting timecourse.""" - for d in self._dipoles: - if d["name"] == dip_name: - dipole = d - break - else: - raise ValueError(f"Unknown dipole {dip_name}") - dipole["fix_ori"] = bool(fix) + self._dipoles[dip_num]["fix_ori"] = bool(fix) + active_dips = [d for d in self._dipoles if d["active"]] + print([d["dip"] for d in active_dips]) self._fit_timecourses() def _setup_mplcanvas(self): From 1b877dffabaf1c83669a77213aeec329138b63a7 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Thu, 18 Jul 2024 09:13:40 +0300 Subject: [PATCH 28/65] Work more on xfit --- mne/gui/_xfit.py | 73 +----------------------------------------------- 1 file changed, 1 insertion(+), 72 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 74f5830378e..88b8ee6ce6e 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -91,8 +91,6 @@ class DipoleFitUI: Evoked data to show fieldmap of and fit dipoles to. cov : instance of Covariance | None Noise covariance matrix. If None, an ad-hoc covariance matrix is used. - cov_data : instance of Covariance | None - Data covariance matrix. If None, LCMV method will be unavailable. bem : instance of ConductorModel | None Boundary element model. If None, a spherical model is used. initial_time : float | None @@ -163,7 +161,6 @@ def __init__( self._bem = bem self._ch_type = ch_type self._cov = cov - # self._cov_data = cov_data self._current_time = initial_time self._dipoles = list() self._evoked = evoked @@ -265,7 +262,6 @@ def _configure_main_display(self): surf=None, check_inside=None, nearest=None, - # sensor_colors=dict(meg="white", eeg="white"), sensor_colors=dict( meg=["white" for _ in meg_picks], eeg=["white" for _ in eeg_picks], @@ -293,32 +289,11 @@ def _configure_dock(self): layout=layout, ) - # # Add view buttons - # layout = r._dock_add_group_box("Views") - # hlayout = None - # views = zip( - # [7, 8, 9, 4, 5, 6, 1, 2, 3], # numpad order - # ["๐Ÿข†", "๐Ÿขƒ", "๐Ÿข‡", "๐Ÿข‚", "โŠ™", "๐Ÿข€", "๐Ÿข…", "๐Ÿข", "๐Ÿข„"], - # ) - # for i, (view, label) in enumerate(views): - # if i % 3 == 0: # show in groups of 3 - # hlayout = r._dock_add_layout(vertical=False) - # r._layout_add_widget(layout, hlayout) - # r._dock_add_button( - # label, - # callback=partial(self._set_view, view=view), - # layout=hlayout, - # style="pushbutton", - # ) - # r.plotter.add_key_event(str(view), partial(self._set_view, view=view)) - # Right dock r._dock_initialize(name="Dipole fitting", area="right") r._dock_add_button("Sensor data", self._on_sensor_data) r._dock_add_button("Fit dipole", self._on_fit_dipole) methods = ["Multi dipole (MNE)", "Single dipole"] - # if self._cov_data is not None: - # methods.append("LCMV") r._dock_add_combo_box( "Dipole model", value="Multi dipole (MNE)", @@ -350,28 +325,6 @@ def toggle_mesh(self, name, show=None): act.SetVisibility(show) self._renderer._update() - # def _set_view(self, view): - # kwargs = dict() - # if view == 1: - # kwargs = dict(azimuth=-135, roll=45, elevation=60, distance="auto") - # elif view == 2: - # kwargs = dict(azimuth=270, roll=180, elevation=90, distance="auto") - # elif view == 3: - # kwargs = dict(azimuth=-45, roll=-45, elevation=60, distance="auto") - # elif view == 4: - # kwargs = dict(azimuth=180, roll=90, elevation=90, distance="auto") - # elif view == 5: - # kwargs = dict(azimuth=0, roll=0, elevation=0, distance="auto") - # elif view == 6: - # kwargs = dict(azimuth=0, roll=-90, elevation=90, distance="auto") - # elif view == 7: - # kwargs = dict(azimuth=135, roll=90, elevation=60, distance="auto") - # elif view == 8: - # kwargs = dict(azimuth=90, roll=0, elevation=90, distance="auto") - # elif view == 9: - # kwargs = dict(azimuth=45, roll=-90, elevation=60, distance="auto") - # self._renderer.set_camera(**kwargs) - def _on_time_change(self, event): new_time = np.clip(event.time, self._evoked.times[0], self._evoked.times[-1]) self._current_time = new_time @@ -473,12 +426,6 @@ def _on_fit_dipole(self): callback=partial(self._on_dipole_set_name, dip_num=dip_num), layout=hlayout, ) - # r._dock_add_check_box( - # name="Fix pos", - # value=True, - # callback=partial(self._on_dipole_toggle_fix_position, dip_name=dip_name), - # layout=hlayout, - # ) r._dock_add_check_box( name="Fix ori", value=True, @@ -569,13 +516,6 @@ def _fit_timecourses(self): if dip["fix_ori"]: timecourses[i] = fixed_timecourses[i] orientations[i] = dip["dip"].ori.repeat(len(stc.times), axis=0).T - # elif self._multi_dipole_method == "LCMV": - # lcmv = make_lcmv( - # self._evoked.info, fwd, self._cov_data, reg=0.05, noise_cov=self._cov - # ) - # stc = apply_lcmv(self._evoked, lcmv) - # timecourses = stc.data - # orientations = [dip["dip"].ori[0] for dip in active_dips] elif self._multi_dipole_method == "Single dipole": timecourses = list() orientations = list() @@ -686,17 +626,6 @@ def _on_dipole_set_name(self, name, dip_num): self._dipoles[dip_num]["line_artist"].set_label(name) self._renderer._mplcanvas.update_plot() - # def _on_dipole_toggle_fix_position(self, fix, dip_name): - # """Fix dipole position when fitting timecourse.""" - # for d in self._dipoles: - # if d["name"] == dip_name: - # dipole = d - # break - # else: - # raise ValueError(f"Unknown dipole {dip_name}") - # dipole["fix_position"] = bool(fix) - # self._fit_timecourses() - def _on_dipole_toggle_fix_orientation(self, fix, dip_num): """Fix dipole orientation when fitting timecourse.""" self._dipoles[dip_num]["fix_ori"] = bool(fix) @@ -721,7 +650,7 @@ def _setup_mplcanvas(self): def _arrow_mesh(): - """Obtain a PyVista mesh of an arrow.""" + """Obtain a mesh of an arrow.""" vertices = np.array( [ [0.0, 1.0, 0.0], From f7bd8cc26ebdcbe3d153e25c288407236a4848bc Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Thu, 18 Jul 2024 16:02:54 +0300 Subject: [PATCH 29/65] Fixes --- mne/gui/_xfit.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 88b8ee6ce6e..5326c7adf32 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -226,8 +226,6 @@ def _configure_main_display(self): picks = np.concatenate((meg_picks, eeg_picks)) self._ch_names = [self._evoked.ch_names[i] for i in picks] - print(f"{show_meg=} {show_eeg=}") - for m in self._field_map: if m["kind"] == "eeg": head_surf = m["surf"] @@ -443,7 +441,9 @@ def _on_fit_dipole(self): ) if arrow_mesh is not None: dipole_dict["arrow_actor"] = self._renderer.plotter.add_mesh( - arrow_mesh, color=dip_color + arrow_mesh, + color=dip_color, + culling="front", ) def _get_helmet_coords(self, dip): @@ -476,13 +476,13 @@ def _fit_timecourses(self): Called whenever a dipole is (de)-activated or the "Fix pos" box is toggled. """ active_dips = [d for d in self._dipoles if d["active"]] - print([d["dip"] for d in active_dips]) if len(active_dips) == 0: return # Restrict the dipoles to only the time at which they were fitted. for d in active_dips: - d["dip"] = d["dip"].crop(d["fit_time"], d["fit_time"]) + if len(d["dip"].times) > 1: + d["dip"] = d["dip"].crop(d["fit_time"], d["fit_time"]) fwd, _ = make_forward_dipole( [d["dip"] for d in active_dips], @@ -507,7 +507,6 @@ def _fit_timecourses(self): timecourses = stc.magnitude().data orientations = stc.data / timecourses[:, np.newaxis, :] - print(orientations.shape) fixed_timecourses = stc.project( np.array([dip["dip"].ori[0] for dip in active_dips]) )[0].data @@ -538,24 +537,28 @@ def _fit_timecourses(self): else: orientations.append(dip["dip"].ori) - for o in orientations: - print(o.shape) - # Store the timecourse and orientation in the Dipole object for d, timecourse, orientation in zip(active_dips, timecourses, orientations): dip = d["dip"] dip.amplitude = timecourse dip.ori = orientation.T dip._set_times(self._evoked.times) - if len(dip.pos) != len(dip.times): - dip.pos = dip.pos[[0]].repeat(len(dip.times), axis=0) + + # Pad out all the other values to be defined for each timepoint. + for attr in ["pos", "gof", "khi2", "nfree"]: + setattr( + dip, attr, getattr(dip, attr)[[0]].repeat(len(dip.times), axis=0) + ) + for key in dip.conf.keys(): + dip.conf[key] = dip.conf[key][[0]].repeat(len(dip.times), axis=0) # Update matplotlib canvas at the bottom of the window canvas = self._setup_mplcanvas() ymin, ymax = 0, 0 for d in active_dips: + dip = d["dip"] if "line_artist" in d: - d["line_artist"].set_ydata(timecourse) + d["line_artist"].set_ydata(dip.amplitude) else: d["line_artist"] = canvas.plot( self._evoked.times, @@ -563,8 +566,8 @@ def _fit_timecourses(self): label=d["dip"].name, color=d["color"], ) - ymin = min(ymin, 1.1 * timecourse.min()) - ymax = max(ymax, 1.1 * timecourse.max()) + ymin = min(ymin, 1.1 * dip.amplitude.min()) + ymax = max(ymax, 1.1 * dip.amplitude.max()) canvas.axes.set_ylim(ymin, ymax) canvas.update_plot() self._update_arrows() @@ -629,8 +632,6 @@ def _on_dipole_set_name(self, name, dip_num): def _on_dipole_toggle_fix_orientation(self, fix, dip_num): """Fix dipole orientation when fitting timecourse.""" self._dipoles[dip_num]["fix_ori"] = bool(fix) - active_dips = [d for d in self._dipoles if d["active"]] - print([d["dip"] for d in active_dips]) self._fit_timecourses() def _setup_mplcanvas(self): From 5f594dde36d1f6f92c24f010bebcfdc56f0c658a Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Fri, 19 Jul 2024 17:02:16 +0300 Subject: [PATCH 30/65] more fixes --- mne/gui/_xfit.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 5326c7adf32..53a97ed2aa0 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -506,7 +506,7 @@ def _fit_timecourses(self): ) timecourses = stc.magnitude().data - orientations = stc.data / timecourses[:, np.newaxis, :] + orientations = (stc.data / timecourses[:, np.newaxis, :]).transpose(0, 2, 1) fixed_timecourses = stc.project( np.array([dip["dip"].ori[0] for dip in active_dips]) )[0].data @@ -514,12 +514,12 @@ def _fit_timecourses(self): for i, dip in enumerate(active_dips): if dip["fix_ori"]: timecourses[i] = fixed_timecourses[i] - orientations[i] = dip["dip"].ori.repeat(len(stc.times), axis=0).T + orientations[i] = dip["dip"].ori.repeat(len(stc.times), axis=0) elif self._multi_dipole_method == "Single dipole": timecourses = list() orientations = list() for dip in active_dips: - dip_timecourse = fit_dipole( + dip_with_timecourse, _ = fit_dipole( self._evoked, self._cov, self._bem, @@ -527,21 +527,21 @@ def _fit_timecourses(self): ori=dip["dip"].ori[0] if dip["fix_ori"] else None, trans=self._trans, verbose=False, - )[0].data[0] - timecourses.append(dip_timecourse) - + ) if dip["fix_ori"]: + timecourses.append(dip_with_timecourse.data[0]) orientations.append( - dip["dip"].ori.repeat(len(dip_timecourse), axis=0) + dip["dip"].ori.repeat(len(dip_with_timecourse.times), axis=0) ) else: - orientations.append(dip["dip"].ori) + timecourses.append(dip_with_timecourse.amplitude) + orientations.append(dip_with_timecourse.ori) # Store the timecourse and orientation in the Dipole object for d, timecourse, orientation in zip(active_dips, timecourses, orientations): dip = d["dip"] dip.amplitude = timecourse - dip.ori = orientation.T + dip.ori = orientation dip._set_times(self._evoked.times) # Pad out all the other values to be defined for each timepoint. From eaa081efb4fbede0ce1e3ef2233bc2b0219aaf12 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 24 Jul 2024 17:01:25 +0300 Subject: [PATCH 31/65] fix multi-dipole model --- mne/gui/_xfit.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 53a97ed2aa0..2e7f676eb28 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -484,15 +484,15 @@ def _fit_timecourses(self): if len(d["dip"].times) > 1: d["dip"] = d["dip"].crop(d["fit_time"], d["fit_time"]) - fwd, _ = make_forward_dipole( - [d["dip"] for d in active_dips], - self._bem, - self._evoked.info, - trans=self._trans, - ) - fwd = convert_forward_solution(fwd, surf_ori=False) - if self._multi_dipole_method == "Multi dipole (MNE)": + fwd, _ = make_forward_dipole( + [d["dip"] for d in active_dips], + self._bem, + self._evoked.info, + trans=self._trans, + ) + fwd = convert_forward_solution(fwd, surf_ori=False) + inv = make_inverse_operator( self._evoked.info, fwd, @@ -502,7 +502,7 @@ def _fit_timecourses(self): depth=0, ) stc = apply_inverse( - self._evoked, inv, method="MNE", lambda2=0, pick_ori="vector" + self._evoked, inv, method="MNE", lambda2=1e-6, pick_ori="vector" ) timecourses = stc.magnitude().data From ffb2fe1f97bd0f9009de66290426f555b3bb3736 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 24 Jul 2024 18:22:13 +0300 Subject: [PATCH 32/65] Add save option --- mne/gui/_xfit.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 2e7f676eb28..def31fa38f4 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -1,4 +1,5 @@ from functools import partial +from pathlib import Path import numpy as np import pyvista @@ -22,7 +23,7 @@ _get_transforms_to_coord_frame, transform_surface_to, ) -from ..utils import _check_option, fill_doc, verbose +from ..utils import _check_option, fill_doc, logger, verbose from ..viz import EvokedField, create_3d_figure from ..viz._3d import _plot_head_surface, _plot_sensors_3d from ..viz.ui_events import subscribe @@ -299,6 +300,15 @@ def _configure_dock(self): callback=self._on_select_method, ) self._dipole_box = r._dock_add_group_box(name="Dipoles") + r._dock_add_file_button( + name="save_dipoles", + desc="Save dipoles", + save=True, + func=self.save, + tooltip="Save the dipoles to disk", + filter_="Dipole files (*.bdip)", + initial_directory=".", + ) r._dock_add_stretch() def toggle_mesh(self, name, show=None): @@ -572,6 +582,14 @@ def _fit_timecourses(self): canvas.update_plot() self._update_arrows() + def save(self, fname): + logger.info("Saving dipoles as:") + fname = Path(fname) + for dip in self.dipoles: + dip_fname = fname.parent / f"{fname.stem}-{dip.name}{fname.suffix}" + logger.info(f" {dip_fname}") + dip.save(dip_fname) + def _update_arrows(self): """Update the arrows to have the correct size and orientation.""" active_dips = [d for d in self._dipoles if d["active"]] From 2f801a4757e11f182fdc080f8c54a5059dcec3b1 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 24 Jul 2024 21:25:27 +0300 Subject: [PATCH 33/65] style tweaks --- mne/gui/_xfit.py | 51 +++++++++++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index def31fa38f4..f3f3a75c94b 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -51,19 +51,20 @@ def dipolefit( evoked : instance of Evoked Evoked data to show fieldmap of and fit dipoles to. cov : instance of Covariance | None - Noise covariance matrix. If None, an ad-hoc covariance matrix is used. + Noise covariance matrix. If ``None``, an ad-hoc covariance matrix is used. bem : instance of ConductorModel | None - Boundary element model. If None, a spherical model is used. + Boundary element model to use in forward calculations. If ``None``, a spherical + model is used. initial_time : float | None - Initial time point to show. If None, the time point of the maximum - field strength is used. + Initial time point to show. If None, the time point of the maximum field + strength is used. trans : instance of Transform | None - The transformation from head coordinates to MRI coordinates. If None, + The transformation from head coordinates to MRI coordinates. If ``None``, the identity matrix is used. show_density : bool Whether to show the density of the fieldmap. subject : str | None - The subject name. If None, no MRI data is shown. + The subject name. If ``None``, no MRI data is shown. %(subjects_dir)s %(n_jobs)s %(verbose)s @@ -91,19 +92,20 @@ class DipoleFitUI: evoked : instance of Evoked Evoked data to show fieldmap of and fit dipoles to. cov : instance of Covariance | None - Noise covariance matrix. If None, an ad-hoc covariance matrix is used. + Noise covariance matrix. If ``None``, an ad-hoc covariance matrix is used. bem : instance of ConductorModel | None - Boundary element model. If None, a spherical model is used. + Boundary element model to use in forward calculations. If ``None``, a spherical + model is used. initial_time : float | None - Initial time point to show. If None, the time point of the maximum - field strength is used. + Initial time point to show. If ``None``, the time point of the maximum field + strength is used. trans : instance of Transform | None - The transformation from head coordinates to MRI coordinates. If None, + The transformation from head coordinates to MRI coordinates. If ``None``, the identity matrix is used. show_density : bool Whether to show the density of the fieldmap. subject : str | None - The subject name. If None, no MRI data is shown. + The subject name. If ``None``, no MRI data is shown. %(subjects_dir)s %(n_jobs)s %(verbose)s @@ -113,7 +115,6 @@ def __init__( self, evoked, cov=None, - # cov_data=None, bem=None, initial_time=None, trans=None, @@ -141,22 +142,24 @@ def __init__( ) if initial_time is None: + # Set initial time to moment of maximum field power. data = evoked.copy().pick(field_map[0]["ch_names"]).data initial_time = evoked.times[np.argmax(np.mean(data**2, axis=0))] - # Get transforms to convert all the various meshes to head space + # Get transforms to convert all the various meshes to head space. head_mri_t = _get_trans(trans, "head", "mri")[0] to_cf_t = _get_transforms_to_coord_frame( evoked.info, head_mri_t, coord_frame="head" ) - # Transform the fieldmap surfaces to head space if needed + # Transform the fieldmap surfaces to head space if needed. if trans is not None: for fm in field_map: fm["surf"] = transform_surface_to( fm["surf"], "head", [to_cf_t["mri"], to_cf_t["head"]], copy=False ) + # Initialize all the private attributes. self._actors = dict() self._arrows = list() self._bem = bem @@ -177,14 +180,14 @@ def __init__( self._trans = trans self._verbose = verbose - # Configure the GUI + # Configure the GUI. self._renderer = self._configure_main_display() self._configure_dock() @property def dipoles(self): - """A list of the fitted dipoles.""" - return [d["dip"] for d in self._dipoles] + """A list of all the fitted dipoles that are enabled in the GUI.""" + return [d["dip"] for d in self._dipoles if d["active"]] def _configure_main_display(self): """Configure main 3D display of the GUI.""" @@ -372,13 +375,13 @@ def _on_channels_select(self, event): def _on_fit_dipole(self): """Fit a single dipole.""" evoked_picked = self._evoked.copy() - cov_picked = self._cov + cov_picked = self._cov.copy() if self._fig_sensors is not None: picks = self._fig_sensors.lasso.selection if len(picks) > 0: - evoked_picked = evoked_picked.copy().pick(picks) + evoked_picked = evoked_picked.pick(picks) evoked_picked.info.normalize_proj() - cov_picked = cov_picked.copy().pick_channels(picks, ordered=False) + cov_picked = cov_picked.pick_channels(picks, ordered=False) cov_picked["projs"] = evoked_picked.info["projs"] evoked_picked.crop(self._current_time, self._current_time) @@ -394,7 +397,7 @@ def _on_fit_dipole(self): # Coordinates needed to draw the big arrow on the helmet. helmet_coords, helmet_pos = self._get_helmet_coords(dip) - # Collect all relevant information on the dipole in a dict + # Collect all relevant information on the dipole in a dict. colors = _get_color_list() dip_num = len(self._dipoles) dip.name = f"dip{dip_num}" @@ -442,10 +445,10 @@ def _on_fit_dipole(self): ) r._layout_add_widget(self._dipole_box, hlayout) - # Compute dipole timecourse, update arrow size + # Compute dipole timecourse, update arrow size. self._fit_timecourses() - # Show the dipole and arrow in the 3D view + # Show the dipole and arrow in the 3D view. self._renderer.plotter.add_arrows( dip.pos[0], dip.ori[0], color=dip_color, mag=0.05 ) From d913fffe705f41ed71d2e618b5f2912958163683 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 24 Jul 2024 22:03:38 +0300 Subject: [PATCH 34/65] fix bugs (thanks vulture!) --- mne/viz/_figure.py | 2 +- mne/viz/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index 6b93347b606..2b891e0ebc3 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -503,7 +503,7 @@ def _create_ch_location_fig(self, pick): # highlight desired channel & disable interactivity fig.lasso.selection_inds = np.isin(fig.lasso.names, [ch_name]) fig.lasso.disconnect() - fig.lasso.alpha_other = 0.3 + fig.lasso.alpha_nonselected = 0.3 fig.lasso.linewidth_selected = 3 fig.lasso.style_objects() diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 8af8e60e3ca..d326c5d6875 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -1744,7 +1744,7 @@ def select_one(self, ind): def select_many(self, inds): """Select many sensors using indices (for predefined selections).""" - self.selected_inds = inds + self.selection_inds = inds self.selection = [self.names[i] for i in self.selection_inds] self.style_objects() From 13877b442c4625b4fc427e1f3f0257932a237bd5 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Mon, 29 Jul 2024 10:45:05 +0300 Subject: [PATCH 35/65] Add rank parameter --- mne/gui/_xfit.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 53a97ed2aa0..6bc9c87dae6 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -37,6 +37,7 @@ def dipolefit( bem=None, initial_time=None, trans=None, + rank=None, show_density=True, subject=None, subjects_dir=None, @@ -59,6 +60,7 @@ def dipolefit( trans : instance of Transform | None The transformation from head coordinates to MRI coordinates. If None, the identity matrix is used. + %(rank)s show_density : bool Whether to show the density of the fieldmap. subject : str | None @@ -73,6 +75,7 @@ def dipolefit( bem=bem, initial_time=initial_time, trans=trans, + rank=rank, show_density=show_density, subject=subject, subjects_dir=subjects_dir, @@ -99,6 +102,7 @@ class DipoleFitUI: trans : instance of Transform | None The transformation from head coordinates to MRI coordinates. If None, the identity matrix is used. + %(rank)s show_density : bool Whether to show the density of the fieldmap. subject : str | None @@ -116,6 +120,7 @@ def __init__( bem=None, initial_time=None, trans=None, + rank=None, show_density=True, subject=None, subjects_dir=None, @@ -174,6 +179,7 @@ def __init__( self._time_line = None self._to_cf_t = to_cf_t self._trans = trans + self._rank = rank self._verbose = verbose # Configure the GUI @@ -500,6 +506,7 @@ def _fit_timecourses(self): fixed=False, loose=1.0, depth=0, + rank=self._rank, ) stc = apply_inverse( self._evoked, inv, method="MNE", lambda2=0, pick_ori="vector" From e342e43134d46c946942056e6324972f11c9fb8c Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 21 Aug 2024 08:24:53 +0300 Subject: [PATCH 36/65] start work on dipole deletion --- mne/gui/_xfit.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 3303d7a107b..bd217ed2dc7 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -171,7 +171,7 @@ def __init__( self._ch_type = ch_type self._cov = cov self._current_time = initial_time - self._dipoles = list() + self._dipoles = dict() self._evoked = evoked self._field_map = field_map self._fig_sensors = None @@ -193,7 +193,7 @@ def __init__( @property def dipoles(self): """A list of all the fitted dipoles that are enabled in the GUI.""" - return [d["dip"] for d in self._dipoles if d["active"]] + return [d["dip"] for d in self._dipoles.values() if d["active"]] def _configure_main_display(self): """Configure main 3D display of the GUI.""" @@ -414,7 +414,8 @@ def _on_fit_dipole(self): arrow_mesh = None dipole_dict = dict( active=True, - arrow_actor=None, + brain_arrow_actor=None, + helmet_arrow_actor=None, arrow_mesh=arrow_mesh, color=dip_color, dip=dip, @@ -425,7 +426,7 @@ def _on_fit_dipole(self): num=dip_num, fit_time=self._current_time, ) - self._dipoles.append(dipole_dict) + self._dipoles[dip_num] = dipole_dict # Add a row to the dipole list r = self._renderer @@ -449,17 +450,22 @@ def _on_fit_dipole(self): callback=partial(self._on_dipole_toggle_fix_orientation, dip_num=dip_num), layout=hlayout, ) + r._dock_add_button( + name="D", + callback=partial(self._on_dipole_delete, dip_num=dip_num), + layout=hlayout, + ) r._layout_add_widget(self._dipole_box, hlayout) # Compute dipole timecourse, update arrow size. self._fit_timecourses() # Show the dipole and arrow in the 3D view. - self._renderer.plotter.add_arrows( + dipole_dict["brain_arrow_actor"] = self._renderer.plotter.add_arrows( dip.pos[0], dip.ori[0], color=dip_color, mag=0.05 ) if arrow_mesh is not None: - dipole_dict["arrow_actor"] = self._renderer.plotter.add_mesh( + dipole_dict["helmet_arrow_actor"] = self._renderer.plotter.add_mesh( arrow_mesh, color=dip_color, culling="front", @@ -494,7 +500,7 @@ def _fit_timecourses(self): Called whenever a dipole is (de)-activated or the "Fix pos" box is toggled. """ - active_dips = [d for d in self._dipoles if d["active"]] + active_dips = [d for d in self._dipoles.values() if d["active"]] if len(active_dips) == 0: return @@ -602,7 +608,7 @@ def save(self, fname): def _update_arrows(self): """Update the arrows to have the correct size and orientation.""" - active_dips = [d for d in self._dipoles if d["active"]] + active_dips = [d for d in self._dipoles.values() if d["active"]] if len(active_dips) == 0: return orientations = [d["dip"].ori for d in active_dips] @@ -646,7 +652,8 @@ def _on_dipole_toggle(self, active, dip_num): active = bool(active) dipole["active"] = active dipole["line_artist"].set_visible(active) - dipole["arrow_actor"].visibility = active + dipole["brain_arrow_actor"].visibility = active + dipole["helmet_arrow_actor"].visibility = active self._fit_timecourses() self._renderer._update() self._renderer._mplcanvas.update_plot() @@ -662,6 +669,17 @@ def _on_dipole_toggle_fix_orientation(self, fix, dip_num): self._dipoles[dip_num]["fix_ori"] = bool(fix) self._fit_timecourses() + def _on_dipole_delete(self, dip_num): + """Delete previously fitted dipole.""" + dipole = self._dipoles[dip_num] + dipole["line_artist"].remove() + dipole["brain_arrow_actor"].visibility = False + dipole["helmet_arrow_actor"].visibility = False + del self._dipoles[dip_num] + self._fit_timecourses() + self._renderer._update() + self._renderer._mplcanvas.update_plot() + def _setup_mplcanvas(self): """Configure the matplotlib canvas at the bottom of the window.""" if self._renderer._mplcanvas is None: From 5bdb9e1c90fd66fde4124f1e8541661cf9d829f9 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 2 Oct 2024 14:15:04 +0300 Subject: [PATCH 37/65] Finish dipole deletion, add occlusion mesh --- mne/gui/_xfit.py | 76 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index bd217ed2dc7..d12cbd38d11 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -1,3 +1,4 @@ +from copy import deepcopy from functools import partial from pathlib import Path @@ -219,10 +220,22 @@ def _configure_main_display(self): if surf_map["map_kind"] == "meg": helmet_mesh = surf_map["mesh"] helmet_mesh._polydata.compute_normals() # needed later + helmet_mesh._actor.prop.culling = "back" self._actors["helmet"] = helmet_mesh._actor + # For MEG fieldlines, we want to occlude the ones not facing us, + # otherwise it's hard to interpret them. Since the "contours" object + # does not support backface culling, we create an opaque mesh to put in + # front of the contour lines with frontface culling. + occl_surf = deepcopy(surf_map["surf"]) + occl_surf["rr"] -= 1e-3 * occl_surf["nn"] + occl_act, _ = fig._renderer.surface(occl_surf, color="white") + occl_act.prop.culling = "front" + occl_act.prop.lighting = False + self._actors["occlusion_surf"] = occl_act elif surf_map["map_kind"] == "eeg": head_mesh = surf_map["mesh"] head_mesh._polydata.compute_normals() # needed later + head_mesh._actor.prop.culling = "back" self._actors["head"] = head_mesh._actor show_meg = (self._ch_type is None or self._ch_type == "meg") and any( @@ -251,6 +264,7 @@ def _configure_main_display(self): to_cf_t=self._to_cf_t, alpha=0.2, ) + self._actors["head"].prop.culling = "back" sensors = _plot_sensors_3d( renderer=fig._renderer, @@ -290,6 +304,8 @@ def _configure_dock(self): # Toggle buttons for various meshes layout = r._dock_add_group_box("Meshes") for actor_name in self._actors.keys(): + if actor_name == "occlusion_surf": + continue r._dock_add_check_box( name=actor_name, value=True, @@ -405,7 +421,10 @@ def _on_fit_dipole(self): # Collect all relevant information on the dipole in a dict. colors = _get_color_list() - dip_num = len(self._dipoles) + if len(self._dipoles) == 0: + dip_num = 0 + else: + dip_num = max(self._dipoles.keys()) + 1 dip.name = f"dip{dip_num}" dip_color = colors[dip_num % len(colors)] if helmet_coords is not None: @@ -431,30 +450,43 @@ def _on_fit_dipole(self): # Add a row to the dipole list r = self._renderer hlayout = r._dock_add_layout(vertical=False) - r._dock_add_check_box( - name="", - value=True, - callback=partial(self._on_dipole_toggle, dip_num=dip_num), - layout=hlayout, + widgets = [] + widgets.append( + r._dock_add_check_box( + name="", + value=True, + callback=partial(self._on_dipole_toggle, dip_num=dip_num), + layout=hlayout, + ) ) - r._dock_add_text( - name=dip.name, - value=dip.name, - placeholder="name", - callback=partial(self._on_dipole_set_name, dip_num=dip_num), - layout=hlayout, + widgets.append( + r._dock_add_text( + name=dip.name, + value=dip.name, + placeholder="name", + callback=partial(self._on_dipole_set_name, dip_num=dip_num), + layout=hlayout, + ) ) - r._dock_add_check_box( - name="Fix ori", - value=True, - callback=partial(self._on_dipole_toggle_fix_orientation, dip_num=dip_num), - layout=hlayout, + widgets.append( + r._dock_add_check_box( + name="Fix ori", + value=True, + callback=partial( + self._on_dipole_toggle_fix_orientation, dip_num=dip_num + ), + layout=hlayout, + ) ) - r._dock_add_button( - name="D", - callback=partial(self._on_dipole_delete, dip_num=dip_num), - layout=hlayout, + widgets.append( + r._dock_add_button( + name="", + icon="clear", + callback=partial(self._on_dipole_delete, dip_num=dip_num), + layout=hlayout, + ) ) + dipole_dict["widgets"] = widgets r._layout_add_widget(self._dipole_box, hlayout) # Compute dipole timecourse, update arrow size. @@ -675,6 +707,8 @@ def _on_dipole_delete(self, dip_num): dipole["line_artist"].remove() dipole["brain_arrow_actor"].visibility = False dipole["helmet_arrow_actor"].visibility = False + for widget in dipole["widgets"]: + widget.hide() del self._dipoles[dip_num] self._fit_timecourses() self._renderer._update() From 2aa51d6f8aadda3a91865bf01e319e8c1f2620cc Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 8 Oct 2024 09:41:20 +0300 Subject: [PATCH 38/65] set rank --- mne/gui/_xfit.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index d12cbd38d11..86688660854 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -39,7 +39,7 @@ def dipolefit( bem=None, initial_time=None, trans=None, - rank=None, + rank="info", show_density=True, subject=None, subjects_dir=None, @@ -123,7 +123,7 @@ def __init__( bem=None, initial_time=None, trans=None, - rank=None, + rank="info", show_density=True, subject=None, subjects_dir=None, @@ -412,7 +412,7 @@ def _on_fit_dipole(self): cov_picked, self._bem, trans=self._trans, - min_dist=0, + rank=self._rank, verbose=False, )[0] @@ -584,6 +584,7 @@ def _fit_timecourses(self): pos=dip["dip"].pos[0], # position is always fixed ori=dip["dip"].ori[0] if dip["fix_ori"] else None, trans=self._trans, + rank=self._rank, verbose=False, ) if dip["fix_ori"]: From bcef66e2cb444a03b775ae7b1c5d0c91c0a40cff Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 8 Oct 2024 12:45:56 +0300 Subject: [PATCH 39/65] attempt to fix tests --- mne/viz/tests/test_raw.py | 63 +++++++++++++++++++++++++-------------- mne/viz/utils.py | 10 +++---- 2 files changed, 46 insertions(+), 27 deletions(-) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index dffeb80e98f..163e55a4dfe 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -1090,63 +1090,82 @@ def test_plot_sensors(raw): pytest.raises(TypeError, plot_sensors, raw) # needs to be info pytest.raises(ValueError, plot_sensors, raw.info, kind="sasaasd") plt.close("all") + fig, sels = raw.plot_sensors("select", show_names=True) ax = fig.axes[0] - # Click with no sensors - _fake_click(fig, ax, (0.0, 0.0), xform="data") - _fake_click(fig, ax, (0, 0.0), xform="data", kind="release") + # Lasso with no sensors. + _fake_click(fig, ax, (-0.14, 0.14), xform="data") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.14), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.14), xform="data", kind="release") assert fig.lasso.selection == [] - # Lasso with 1 sensor (upper left) - _fake_click(fig, ax, (0, 1), xform="ax") - fig.canvas.draw() + # Lasso with 1 sensor (upper left). + _fake_click(fig, ax, (-0.13, 0.14), xform="data") assert fig.lasso.selection == [] _fake_click(fig, ax, (-0.11, 0.14), xform="data", kind="motion") _fake_click(fig, ax, (-0.11, 0.065), xform="data", kind="motion") - _fake_keypress(fig, "shift") - _fake_click(fig, ax, (-0.15, 0.065), xform="data", kind="release", key="shift") + _fake_click(fig, ax, (-0.13, 0.065), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.14), xform="ax", kind="motion") + _fake_click(fig, ax, (-0.13, 0.14), xform="ax", kind="release") assert fig.lasso.selection == ["MEG 0121"] - # check that point appearance changes + # Use SHIFT key to lasso an additional sensor. + _fake_keypress(fig, "shift") + _fake_click(fig, ax, (-0.17, 0.07), xform="data") + _fake_click(fig, ax, (-0.17, 0.05), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.05), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="release") + _fake_keypress(fig, "shift", kind="release") + assert fig.lasso.selection == ["MEG 0111", "MEG 0121"] + + # Check that the two selected sensors have a different appearance. fc = fig.lasso.collection.get_facecolors() ec = fig.lasso.collection.get_edgecolors() - assert (fc[:, -1] == [0.5, 1.0, 0.5]).all() - assert (ec[:, -1] == [0.25, 1.0, 0.25]).all() - - _fake_click(fig, ax, (-0.11, 0.065), xform="data", kind="motion", key="shift") - xy = ax.collections[0].get_offsets() - _fake_click(fig, ax, xy[2], xform="data", key="shift") # single sel - assert fig.lasso.selection == ["MEG 0121", "MEG 0131"] - _fake_click(fig, ax, xy[2], xform="data", key="alt") # deselect + assert (fc[2:, -1] == 0.5).all() + assert (ec[2:, -1] == 0.25).all() + assert (fc[:2, -1] == 1.0).all() + assert (ec[:2:, -1] == 1.0).all() + + # Use ALT key to remove a sensor from the lasso. + _fake_keypress(fig, "alt") + _fake_click(fig, ax, (-0.17, 0.07), xform="data") + _fake_click(fig, ax, (-0.17, 0.05), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.05), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="motion") + _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="release") + _fake_keypress(fig, "alt", kind="release") assert fig.lasso.selection == ["MEG 0121"] + plt.close("all") raw.info["dev_head_t"] = None # like empty room with pytest.warns(RuntimeWarning, match="identity"): raw.plot_sensors() - # Test plotting with sphere='eeglab' + # Test plotting with sphere='eeglab'. info = create_info(ch_names=["Fpz", "Oz", "T7", "T8"], sfreq=100, ch_types="eeg") data = 1e-6 * np.random.rand(4, 100) raw_eeg = RawArray(data=data, info=info) raw_eeg.set_montage("biosemi64") raw_eeg.plot_sensors(sphere="eeglab") - # Should work with "FPz" as well + # Should work with "FPz" as well. raw_eeg.rename_channels({"Fpz": "FPz"}) raw_eeg.plot_sensors(sphere="eeglab") - # Should still work without Fpz/FPz, as long as we still have Oz + # Should still work without Fpz/FPz, as long as we still have Oz. raw_eeg.drop_channels("FPz") raw_eeg.plot_sensors(sphere="eeglab") - # Should raise if Oz is missing too, as we cannot reconstruct Fpz anymore + # Should raise if Oz is missing too, as we cannot reconstruct Fpz anymore. raw_eeg.drop_channels("Oz") with pytest.raises(ValueError, match="could not find: Fpz"): raw_eeg.plot_sensors(sphere="eeglab") - # Should raise if we don't have a montage + # Should raise if we don't have a montage. chs = deepcopy(raw_eeg.info["chs"]) raw_eeg.set_montage(None) with raw_eeg.info._unlock(): diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 00d37461f62..e3f26224fd5 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -807,12 +807,12 @@ def _fake_click(fig, ax, point, xform="ax", button=1, kind="press", key=None): ) -def _fake_keypress(fig, key): +def _fake_keypress(fig, key, kind="press"): from matplotlib import backend_bases fig.canvas.callbacks.process( - "key_press_event", - backend_bases.KeyEvent(name="key_press_event", canvas=fig.canvas, key=key), + f"key_{kind}_event", + backend_bases.KeyEvent(name=f"key_{kind}_event", canvas=fig.canvas, key=key), ) @@ -1715,9 +1715,9 @@ def on_select(self, verts): path = Path(verts) inds = np.nonzero([path.intersects_path(p) for p in self.paths])[0] if self.canvas._key == "shift": # Appending selection. - self.selection_inds = np.union1d(self.selection_inds, inds) + self.selection_inds = np.union1d(self.selection_inds, inds).astype("int") elif self.canvas._key == "alt": # Removing selection. - self.selection_inds = np.setdiff1d(self.selection_inds, inds) + self.selection_inds = np.setdiff1d(self.selection_inds, inds).astype("int") else: self.selection_inds = inds self.selection = [self.names[i] for i in self.selection_inds] From 8efcb8cce1ad2af9b2431929d2a8131d34102ee0 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 22 Oct 2024 12:48:21 +0300 Subject: [PATCH 40/65] further attempts to fix tests --- mne/viz/tests/test_raw.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 163e55a4dfe..8a3282b55b7 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -1090,25 +1090,23 @@ def test_plot_sensors(raw): pytest.raises(TypeError, plot_sensors, raw) # needs to be info pytest.raises(ValueError, plot_sensors, raw.info, kind="sasaasd") plt.close("all") - fig, sels = raw.plot_sensors("select", show_names=True) ax = fig.axes[0] - # Lasso with no sensors. + # Click with no sensors _fake_click(fig, ax, (-0.14, 0.14), xform="data") _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") _fake_click(fig, ax, (-0.13, 0.14), xform="data", kind="motion") _fake_click(fig, ax, (-0.13, 0.14), xform="data", kind="release") assert fig.lasso.selection == [] - # Lasso with 1 sensor (upper left). - _fake_click(fig, ax, (-0.13, 0.14), xform="data") - assert fig.lasso.selection == [] - _fake_click(fig, ax, (-0.11, 0.14), xform="data", kind="motion") - _fake_click(fig, ax, (-0.11, 0.065), xform="data", kind="motion") - _fake_click(fig, ax, (-0.13, 0.065), xform="data", kind="motion") - _fake_click(fig, ax, (-0.13, 0.14), xform="ax", kind="motion") - _fake_click(fig, ax, (-0.13, 0.14), xform="ax", kind="release") + # Lasso with 1 sensor (upper left) + _fake_click(fig, ax, (-0.13, 0.13), xform="data") + _fake_click(fig, ax, (-0.11, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.11, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="release") assert fig.lasso.selection == ["MEG 0121"] # Use SHIFT key to lasso an additional sensor. @@ -1137,7 +1135,6 @@ def test_plot_sensors(raw): _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="motion") _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="release") _fake_keypress(fig, "alt", kind="release") - assert fig.lasso.selection == ["MEG 0121"] plt.close("all") @@ -1145,27 +1142,27 @@ def test_plot_sensors(raw): with pytest.warns(RuntimeWarning, match="identity"): raw.plot_sensors() - # Test plotting with sphere='eeglab'. + # Test plotting with sphere='eeglab' info = create_info(ch_names=["Fpz", "Oz", "T7", "T8"], sfreq=100, ch_types="eeg") data = 1e-6 * np.random.rand(4, 100) raw_eeg = RawArray(data=data, info=info) raw_eeg.set_montage("biosemi64") raw_eeg.plot_sensors(sphere="eeglab") - # Should work with "FPz" as well. + # Should work with "FPz" as well raw_eeg.rename_channels({"Fpz": "FPz"}) raw_eeg.plot_sensors(sphere="eeglab") - # Should still work without Fpz/FPz, as long as we still have Oz. + # Should still work without Fpz/FPz, as long as we still have Oz raw_eeg.drop_channels("FPz") raw_eeg.plot_sensors(sphere="eeglab") - # Should raise if Oz is missing too, as we cannot reconstruct Fpz anymore. + # Should raise if Oz is missing too, as we cannot reconstruct Fpz anymore raw_eeg.drop_channels("Oz") with pytest.raises(ValueError, match="could not find: Fpz"): raw_eeg.plot_sensors(sphere="eeglab") - # Should raise if we don't have a montage. + # Should raise if we don't have a montage chs = deepcopy(raw_eeg.info["chs"]) raw_eeg.set_montage(None) with raw_eeg.info._unlock(): From 87f72e2ae2ff8f2c18e3725bd930f98abc367d20 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 22 Oct 2024 13:07:24 +0300 Subject: [PATCH 41/65] Add what's new entry --- doc/changes/devel/12071.newfeature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 doc/changes/devel/12071.newfeature.rst diff --git a/doc/changes/devel/12071.newfeature.rst b/doc/changes/devel/12071.newfeature.rst new file mode 100644 index 00000000000..4e7995e3beb --- /dev/null +++ b/doc/changes/devel/12071.newfeature.rst @@ -0,0 +1 @@ +Add new ``select`` parameter to :func:`mne.viz.plot_evoked_topo` and :meth:`mne.Evoked.plot_topo` to toggle lasso selection of sensors, by `Marijn van Vliet`_. From 5f5666ac2b10d0b2d6c90c916cf047d1a76b4e15 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 7 Jan 2025 10:09:30 +0200 Subject: [PATCH 42/65] also show field strength input fields when not plotting density --- mne/viz/evoked_field.py | 73 ++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index ef5af1115d3..acc2a3f75e6 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -406,45 +406,44 @@ def _configure_dock(self): # Fieldline configuration layout = r._dock_add_group_box("Fieldlines") - if self._show_density: - r._dock_add_label(value="max value", align=True, layout=layout) - - @_auto_weakref - def _callback(vmax, kind, scaling): - self.set_vmax(vmax / scaling, kind=kind) - - for surf_map in self._surf_maps: - if surf_map["map_kind"] == "meg": - scaling = DEFAULTS["scalings"]["grad"] - else: - scaling = DEFAULTS["scalings"]["eeg"] - rng = [0, np.max(np.abs(surf_map["data"])) * scaling] - hlayout = r._dock_add_layout(vertical=False) - - self._widgets[f"vmax_slider_{surf_map['map_kind']}"] = ( - r._dock_add_slider( - name=surf_map["map_kind"].upper(), - value=surf_map["map_vmax"] * scaling, - rng=rng, - callback=partial( - _callback, kind=surf_map["map_kind"], scaling=scaling - ), - double=True, - layout=hlayout, - ) + r._dock_add_label(value="max value", align=True, layout=layout) + + @_auto_weakref + def _callback(vmax, kind, scaling): + self.set_vmax(vmax / scaling, kind=kind) + + for surf_map in self._surf_maps: + if surf_map["map_kind"] == "meg": + scaling = DEFAULTS["scalings"]["grad"] + else: + scaling = DEFAULTS["scalings"]["eeg"] + rng = [0, np.max(np.abs(surf_map["data"])) * scaling] + hlayout = r._dock_add_layout(vertical=False) + + self._widgets[f"vmax_slider_{surf_map['map_kind']}"] = ( + r._dock_add_slider( + name=surf_map["map_kind"].upper(), + value=surf_map["map_vmax"] * scaling, + rng=rng, + callback=partial( + _callback, kind=surf_map["map_kind"], scaling=scaling + ), + double=True, + layout=hlayout, ) - self._widgets[f"vmax_spin_{surf_map['map_kind']}"] = ( - r._dock_add_spin_box( - name="", - value=surf_map["map_vmax"] * scaling, - rng=rng, - callback=partial( - _callback, kind=surf_map["map_kind"], scaling=scaling - ), - layout=hlayout, - ) + ) + self._widgets[f"vmax_spin_{surf_map['map_kind']}"] = ( + r._dock_add_spin_box( + name="", + value=surf_map["map_vmax"] * scaling, + rng=rng, + callback=partial( + _callback, kind=surf_map["map_kind"], scaling=scaling + ), + layout=hlayout, ) - r._layout_add_widget(layout, hlayout) + ) + r._layout_add_widget(layout, hlayout) hlayout = r._dock_add_layout(vertical=False) r._dock_add_label( From 0277981d4618d08fe17e00704814fdefc10454df Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 7 Jan 2025 10:10:01 +0200 Subject: [PATCH 43/65] Update unit tests for lasso select --- mne/viz/tests/test_topo.py | 7 +++++++ mne/viz/tests/test_utils.py | 1 + 2 files changed, 8 insertions(+) diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 85b4b43dcf8..2e7e19d2f8f 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -295,6 +295,13 @@ def test_plot_topo_image_epochs(): assert len(qm_cmap) >= 1 assert qm_cmap[0] is cmap +def test_plot_topo_select(): + """Test selecting sensors in an ERP topography plot.""" + # Show topography + evoked = _get_epochs().average() + plot_evoked_topo(evoked, select=True) + + def test_plot_tfr_topo(): """Test plotting of TFR data.""" diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 59e2976e464..3672129bdb0 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -27,6 +27,7 @@ centers_to_edges, compare_fiff, concatenate_images, + SelectFromCollection, ) base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" From bea101da3f161e0015cb1541439ed6f3729c805a Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 7 Jan 2025 10:10:01 +0200 Subject: [PATCH 44/65] Update unit tests for lasso select --- mne/viz/tests/test_topo.py | 7 +++++++ mne/viz/tests/test_utils.py | 1 + 2 files changed, 8 insertions(+) diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 85b4b43dcf8..2e7e19d2f8f 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -295,6 +295,13 @@ def test_plot_topo_image_epochs(): assert len(qm_cmap) >= 1 assert qm_cmap[0] is cmap +def test_plot_topo_select(): + """Test selecting sensors in an ERP topography plot.""" + # Show topography + evoked = _get_epochs().average() + plot_evoked_topo(evoked, select=True) + + def test_plot_tfr_topo(): """Test plotting of TFR data.""" diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 59e2976e464..3672129bdb0 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -27,6 +27,7 @@ centers_to_edges, compare_fiff, concatenate_images, + SelectFromCollection, ) base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" From 2ae07bfe22306a416b274aa156a080641f140fe7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 08:11:16 +0000 Subject: [PATCH 45/65] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/viz/tests/test_topo.py | 2 +- mne/viz/tests/test_utils.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 2e7e19d2f8f..9974976ea67 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -295,6 +295,7 @@ def test_plot_topo_image_epochs(): assert len(qm_cmap) >= 1 assert qm_cmap[0] is cmap + def test_plot_topo_select(): """Test selecting sensors in an ERP topography plot.""" # Show topography @@ -302,7 +303,6 @@ def test_plot_topo_select(): plot_evoked_topo(evoked, select=True) - def test_plot_tfr_topo(): """Test plotting of TFR data.""" epochs = _get_epochs() diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 3672129bdb0..59e2976e464 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -27,7 +27,6 @@ centers_to_edges, compare_fiff, concatenate_images, - SelectFromCollection, ) base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" From e857a2e97dccb9c409d0f568acc2fdb60b992fdb Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Mon, 20 Jan 2025 11:30:00 +0200 Subject: [PATCH 46/65] Move large lasso test to test_utils.py and have smaller tests in test_raw.py and test_topo.py --- mne/viz/tests/test_raw.py | 42 +++-------------------------- mne/viz/tests/test_topo.py | 10 +++++++ mne/viz/tests/test_utils.py | 53 +++++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 38 deletions(-) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 0ca52649228..ee5dbf8e4e5 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -28,7 +28,7 @@ set_config, ) from mne.viz import plot_raw, plot_sensors -from mne.viz.utils import _fake_click, _fake_keypress +from mne.viz.utils import _fake_click base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" @@ -1088,17 +1088,11 @@ def test_plot_sensors(raw): pytest.raises(TypeError, plot_sensors, raw) # needs to be info pytest.raises(ValueError, plot_sensors, raw.info, kind="sasaasd") plt.close("all") + + # Test lasso selection. fig, sels = raw.plot_sensors("select", show_names=True) ax = fig.axes[0] - - # Click with no sensors - _fake_click(fig, ax, (-0.14, 0.14), xform="data") - _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") - _fake_click(fig, ax, (-0.13, 0.14), xform="data", kind="motion") - _fake_click(fig, ax, (-0.13, 0.14), xform="data", kind="release") - assert fig.lasso.selection == [] - - # Lasso with 1 sensor (upper left) + # Lasso a single sensor. _fake_click(fig, ax, (-0.13, 0.13), xform="data") _fake_click(fig, ax, (-0.11, 0.13), xform="data", kind="motion") _fake_click(fig, ax, (-0.11, 0.06), xform="data", kind="motion") @@ -1106,34 +1100,6 @@ def test_plot_sensors(raw): _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="release") assert fig.lasso.selection == ["MEG 0121"] - - # Use SHIFT key to lasso an additional sensor. - _fake_keypress(fig, "shift") - _fake_click(fig, ax, (-0.17, 0.07), xform="data") - _fake_click(fig, ax, (-0.17, 0.05), xform="data", kind="motion") - _fake_click(fig, ax, (-0.15, 0.05), xform="data", kind="motion") - _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="motion") - _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="release") - _fake_keypress(fig, "shift", kind="release") - assert fig.lasso.selection == ["MEG 0111", "MEG 0121"] - - # Check that the two selected sensors have a different appearance. - fc = fig.lasso.collection.get_facecolors() - ec = fig.lasso.collection.get_edgecolors() - assert (fc[2:, -1] == 0.5).all() - assert (ec[2:, -1] == 0.25).all() - assert (fc[:2, -1] == 1.0).all() - assert (ec[:2:, -1] == 1.0).all() - - # Use ALT key to remove a sensor from the lasso. - _fake_keypress(fig, "alt") - _fake_click(fig, ax, (-0.17, 0.07), xform="data") - _fake_click(fig, ax, (-0.17, 0.05), xform="data", kind="motion") - _fake_click(fig, ax, (-0.15, 0.05), xform="data", kind="motion") - _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="motion") - _fake_click(fig, ax, (-0.15, 0.07), xform="data", kind="release") - _fake_keypress(fig, "alt", kind="release") - plt.close("all") raw.info["dev_head_t"] = None # like empty room diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 9974976ea67..dbc29832c09 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -231,6 +231,16 @@ def test_plot_topo(): break plt.close("all") + # Test plot_topo with selection of channels enabled. + fig = evoked.plot_topo(select=True) + ax = fig.axes[0] + _fake_click(fig, ax, (0.05, 0.62), xform="data") + _fake_click(fig, ax, (0.2, 0.62), xform="data", kind="motion") + _fake_click(fig, ax, (0.2, 0.7), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.7), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.7), xform="data", kind="release") + assert fig.lasso.selection == ["MEG 0113", "MEG 0112", "MEG 0111"] + def test_plot_topo_nirs(fnirs_evoked): """Test plotting of ERP topography for nirs data.""" diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 59e2976e464..59bf06fe16c 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -16,6 +16,7 @@ from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap from mne.viz.ui_events import ColormapRange, link, subscribe from mne.viz.utils import ( + SelectFromCollection, _compute_scalings, _fake_click, _fake_keypress, @@ -274,3 +275,55 @@ def callback(event): cmap_new1 = fig.axes[0].CB.mappable.get_cmap().name cmap_new2 = fig2.axes[0].CB.mappable.get_cmap().name assert cmap_new1 == cmap_new2 == cmap_want != cmap_old + + +def test_select_from_collection(): + """Test the lasso selector for matplotlib figures.""" + fig, ax = plt.subplots() + collection = ax.scatter([1, 2, 2, 1], [1, 1, 0, 0], color="black", edgecolor="red") + ax.set_xlim(-1, 4) + ax.set_ylim(-1, 2) + lasso = SelectFromCollection(ax, collection, names=["A", "B", "C", "D"]) + assert lasso.selection == [] + + # Make a selection with no patches inside of it. + _fake_click(fig, ax, (0, 0), xform="data") + _fake_click(fig, ax, (0.5, 0), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1), xform="data", kind="release") + assert lasso.selection == [] + + # Make a selection with two patches in it. + _fake_click(fig, ax, (0, 0.5), xform="data") + _fake_click(fig, ax, (3, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (3, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 0.5), xform="data", kind="release") + assert lasso.selection == ["A", "B"] + + # Use SHIFT key to lasso an additional patch. + _fake_keypress(fig, "shift") + _fake_click(fig, ax, (0.5, -0.5), xform="data") + _fake_click(fig, ax, (1.5, -0.5), xform="data", kind="motion") + _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="release") + _fake_keypress(fig, "shift", kind="release") + assert lasso.selection == ["A", "B", "D"] + + # Use ALT key to remove a patch. + _fake_keypress(fig, "alt") + _fake_click(fig, ax, (0.5, 0.5), xform="data") + _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (1.5, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="release") + _fake_keypress(fig, "alt", kind="release") + assert lasso.selection == ["B", "D"] + + # Check that the two selected patches have a different appearance. + fc = lasso.collection.get_facecolors() + ec = lasso.collection.get_edgecolors() + assert (fc[:, -1] == [0.5, 1.0, 0.5, 1.0]).all() + assert (ec[:, -1] == [0.25, 1.0, 0.25, 1.0]).all() From 822f761b84e6b6f3b0e860a9118c3936bab673cf Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Mon, 20 Jan 2025 13:22:32 +0200 Subject: [PATCH 47/65] select from proper list of channels --- mne/viz/topo.py | 6 ++++-- mne/viz/utils.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 4db685a459b..5d604610342 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -273,8 +273,10 @@ def on_select(): publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) def on_channels_select(event): - ch_inds = {name: i for i, name in enumerate(ch_names)} - selection_inds = [ch_inds[name] for name in event.ch_names] + ch_inds = {name: i for i, name in enumerate(shown_ch_names)} + selection_inds = [ + ch_inds[name] for name in event.ch_names if name in ch_inds + ] fig.lasso.select_many(selection_inds) fig.lasso.callbacks.append(on_select) diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 7e411ee20a5..ff628c4c41c 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -1279,7 +1279,9 @@ def on_select(): def on_channels_select(event): ch_inds = {name: i for i, name in enumerate(ch_names)} - selection_inds = [ch_inds[name] for name in event.ch_names] + selection_inds = [ + ch_inds[name] for name in event.ch_names if name in ch_inds + ] fig.lasso.select_many(selection_inds) fig.lasso.callbacks.append(on_select) From 51efe6c11cf40f545e9df1f27e1b161a7924e550 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Mon, 20 Jan 2025 13:59:34 +0200 Subject: [PATCH 48/65] fix version --- mne/viz/evoked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index 2fc668949c3..96ee0684e6e 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1224,7 +1224,7 @@ def plot_evoked_topo( channels. The selected channels will be available in ``fig.lasso.selection``. - .. versionadded:: 1.9.0 + .. versionadded:: 1.10.0 exclude : list of str | ``'bads'`` Channels names to exclude from the plot. If ``'bads'``, the bad channels are excluded. By default, exclude is set to ``'bads'``. From e806634e212b38ae3c7f9f5e6be14bde45435b8a Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Mon, 20 Jan 2025 14:01:52 +0200 Subject: [PATCH 49/65] more versionadded annotations --- mne/viz/topo.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 5d604610342..f2073435003 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -80,6 +80,8 @@ def iter_topography( channels. The selected channels will be available in ``fig.lasso.selection``. + .. versionadded:: 1.10.0 + Returns ------- gen : generator @@ -1260,6 +1262,8 @@ def plot_topo_image_epochs( Whether to enable the lasso-selection tool to enable the user to select channels. The selected channels will be available in ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 show : bool Whether to show the figure. Defaults to ``True``. From 8b1a014182d8554ebdb698e6e5d26f7a5b1ca5a6 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 21 Jan 2025 16:55:16 +0200 Subject: [PATCH 50/65] Save and load dipoles Xfit style. Fix legend when hiding dipoles. --- mne/gui/_xfit.py | 273 +++++++++++++++++++++++++++-------------------- 1 file changed, 160 insertions(+), 113 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 86688660854..51771d4676c 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -15,7 +15,7 @@ make_sphere_model, ) from ..cov import make_ad_hoc_cov -from ..dipole import fit_dipole +from ..dipole import Dipole, fit_dipole from ..forward import convert_forward_solution, make_field_map, make_forward_dipole from ..minimum_norm import apply_inverse, make_inverse_operator from ..surface import _normal_orth @@ -271,7 +271,7 @@ def _configure_main_display(self): info=self._evoked.info, to_cf_t=self._to_cf_t, picks=picks, - meg=show_meg, + meg=["sensors"] if show_meg else False, eeg=["original"] if show_eeg else False, fnirs=False, warn_meg=False, @@ -416,92 +416,111 @@ def _on_fit_dipole(self): verbose=False, )[0] - # Coordinates needed to draw the big arrow on the helmet. - helmet_coords, helmet_pos = self._get_helmet_coords(dip) + self._add_dipole(dip) - # Collect all relevant information on the dipole in a dict. - colors = _get_color_list() - if len(self._dipoles) == 0: - dip_num = 0 - else: - dip_num = max(self._dipoles.keys()) + 1 - dip.name = f"dip{dip_num}" - dip_color = colors[dip_num % len(colors)] - if helmet_coords is not None: - arrow_mesh = pyvista.PolyData(*_arrow_mesh()) - else: - arrow_mesh = None - dipole_dict = dict( - active=True, - brain_arrow_actor=None, - helmet_arrow_actor=None, - arrow_mesh=arrow_mesh, - color=dip_color, - dip=dip, - fix_ori=True, - fix_position=True, - helmet_coords=helmet_coords, - helmet_pos=helmet_pos, - num=dip_num, - fit_time=self._current_time, - ) - self._dipoles[dip_num] = dipole_dict + def add_dipole(self, dipole): + """Add a dipole (or multiple dipoles) to the GUI. - # Add a row to the dipole list - r = self._renderer - hlayout = r._dock_add_layout(vertical=False) - widgets = [] - widgets.append( - r._dock_add_check_box( - name="", - value=True, - callback=partial(self._on_dipole_toggle, dip_num=dip_num), - layout=hlayout, + Parameters + ---------- + dipole : Dipole + The dipole to add. If the ``Dipole`` object defines multiple dipoles, they + will all be added. + """ + new_dipoles = list() + for dip_i in range(len(dipole)): + dip = dipole[dip_i] + + # Coordinates needed to draw the big arrow on the helmet. + helmet_coords, helmet_pos = self._get_helmet_coords(dip) + + # Collect all relevant information on the dipole in a dict. + colors = _get_color_list() + if len(self._dipoles) == 0: + dip_num = 0 + else: + dip_num = max(self._dipoles.keys()) + 1 + if dip.name is None: + dip.name = f"dip{dip_num}" + dip_color = colors[dip_num % len(colors)] + if helmet_coords is not None: + arrow_mesh = pyvista.PolyData(*_arrow_mesh()) + else: + arrow_mesh = None + dipole_dict = dict( + active=True, + brain_arrow_actor=None, + helmet_arrow_actor=None, + arrow_mesh=arrow_mesh, + color=dip_color, + dip=dip, + fix_ori=True, + fix_position=True, + helmet_coords=helmet_coords, + helmet_pos=helmet_pos, + num=dip_num, + fit_time=self._current_time, ) - ) - widgets.append( - r._dock_add_text( - name=dip.name, - value=dip.name, - placeholder="name", - callback=partial(self._on_dipole_set_name, dip_num=dip_num), - layout=hlayout, + self._dipoles[dip_num] = dipole_dict + + # Add a row to the dipole list + r = self._renderer + hlayout = r._dock_add_layout(vertical=False) + widgets = [] + widgets.append( + r._dock_add_check_box( + name="", + value=True, + callback=partial(self._on_dipole_toggle, dip_num=dip_num), + layout=hlayout, + ) ) - ) - widgets.append( - r._dock_add_check_box( - name="Fix ori", - value=True, - callback=partial( - self._on_dipole_toggle_fix_orientation, dip_num=dip_num - ), - layout=hlayout, + widgets.append( + r._dock_add_text( + name=dip.name, + value=dip.name, + placeholder="name", + callback=partial(self._on_dipole_set_name, dip_num=dip_num), + layout=hlayout, + ) ) - ) - widgets.append( - r._dock_add_button( - name="", - icon="clear", - callback=partial(self._on_dipole_delete, dip_num=dip_num), - layout=hlayout, + widgets.append( + r._dock_add_check_box( + name="Fix ori", + value=True, + callback=partial( + self._on_dipole_toggle_fix_orientation, dip_num=dip_num + ), + layout=hlayout, + ) ) - ) - dipole_dict["widgets"] = widgets - r._layout_add_widget(self._dipole_box, hlayout) + widgets.append( + r._dock_add_button( + name="", + icon="clear", + callback=partial(self._on_dipole_delete, dip_num=dip_num), + layout=hlayout, + ) + ) + dipole_dict["widgets"] = widgets + r._layout_add_widget(self._dipole_box, hlayout) + new_dipoles.append(dipole_dict) - # Compute dipole timecourse, update arrow size. + # Show the dipoles and arrows in the 3D view. Only do this after + # `_fit_timecourses` so that they have the correct size straight away. self._fit_timecourses() - - # Show the dipole and arrow in the 3D view. - dipole_dict["brain_arrow_actor"] = self._renderer.plotter.add_arrows( - dip.pos[0], dip.ori[0], color=dip_color, mag=0.05 - ) - if arrow_mesh is not None: - dipole_dict["helmet_arrow_actor"] = self._renderer.plotter.add_mesh( - arrow_mesh, - color=dip_color, - culling="front", + for dipole_dict in new_dipoles: + dip = dipole_dict["dip"] + dipole_dict["brain_arrow_actor"] = self._renderer.plotter.add_arrows( + dip.pos[0], dip.ori[0], color=dipole_dict["color"], mag=0.05 ) + if arrow_mesh is not None: + dipole_dict["helmet_arrow_actor"] = self._renderer.plotter.add_mesh( + arrow_mesh, + color=dipole_dict["color"], + culling="front", + ) + self._update_arrows() def _get_helmet_coords(self, dip): """Compute the coordinate system used for drawing the big arrows on the helmet. @@ -597,58 +616,84 @@ def _fit_timecourses(self): orientations.append(dip_with_timecourse.ori) # Store the timecourse and orientation in the Dipole object - for d, timecourse, orientation in zip(active_dips, timecourses, orientations): - dip = d["dip"] - dip.amplitude = timecourse - dip.ori = orientation - dip._set_times(self._evoked.times) - - # Pad out all the other values to be defined for each timepoint. - for attr in ["pos", "gof", "khi2", "nfree"]: - setattr( - dip, attr, getattr(dip, attr)[[0]].repeat(len(dip.times), axis=0) - ) - for key in dip.conf.keys(): - dip.conf[key] = dip.conf[key][[0]].repeat(len(dip.times), axis=0) + for dip, timecourse, orientation in zip(active_dips, timecourses, orientations): + dip["timecourse"] = timecourse + dip["orientation"] = orientation + dip["times"] = self._evoked.times + # dip = d["dip"] + # dip.amplitude = timecourse + # dip.ori = orientation + # dip._set_times(self._evoked.times) + + # # Pad out all the other values to be defined for each timepoint. + # for attr in ["pos", "gof", "khi2", "nfree"]: + # setattr( + # dip, attr, getattr(dip, attr)[[0]].repeat(len(dip.times), axis=0) + # ) + # for key in dip.conf.keys(): + # dip.conf[key] = dip.conf[key][[0]].repeat(len(dip.times), axis=0) # Update matplotlib canvas at the bottom of the window canvas = self._setup_mplcanvas() ymin, ymax = 0, 0 - for d in active_dips: - dip = d["dip"] - if "line_artist" in d: - d["line_artist"].set_ydata(dip.amplitude) + for dip in active_dips: + if "line_artist" in dip: + dip["line_artist"].set_ydata(dip["timecourse"]) else: - d["line_artist"] = canvas.plot( + dip["line_artist"] = canvas.plot( self._evoked.times, - d["dip"].amplitude, - label=d["dip"].name, - color=d["color"], + dip["timecourse"], + label=dip["dip"].name, + color=dip["color"], ) - ymin = min(ymin, 1.1 * dip.amplitude.min()) - ymax = max(ymax, 1.1 * dip.amplitude.max()) + ymin = min(ymin, 1.1 * dip["timecourse"].min()) + ymax = max(ymax, 1.1 * dip["timecourse"].max()) canvas.axes.set_ylim(ymin, ymax) canvas.update_plot() self._update_arrows() - def save(self, fname): + @verbose + @fill_doc + def save(self, fname, verbose=None): + """Save the fitted dipoles to a file. + + Parameters + ---------- + fname : path-like + The name of the file. Should end in ``'.dip'`` to save in plain text format, + or in ``'.bdip'`` to save in binary format. + %(verbose)s + """ logger.info("Saving dipoles as:") fname = Path(fname) - for dip in self.dipoles: - dip_fname = fname.parent / f"{fname.stem}-{dip.name}{fname.suffix}" - logger.info(f" {dip_fname}") - dip.save(dip_fname) + + # Pack the dipoles into a single mne.Dipole object. + dip = Dipole( + times=np.array([d.times[0] for d in self.dipoles]), + pos=np.array([d.pos[0] for d in self.dipoles]), + amplitude=np.array([d.amplitude[0] for d in self.dipoles]), + ori=np.array([d.ori[0] for d in self.dipoles]), + gof=np.array([d.gof[0] for d in self.dipoles]), + khi2=np.array([d.khi2[0] for d in self.dipoles]), + nfree=np.array([d.nfree[0] for d in self.dipoles]), + conf={ + key: np.array([d.conf[key][0] for d in self.dipoles]) + for key in self.dipoles[0].conf.keys() + }, + name=",".join(d.name for d in self.dipoles), + ) + dip.save(fname, overwrite=True, verbose=verbose) def _update_arrows(self): """Update the arrows to have the correct size and orientation.""" active_dips = [d for d in self._dipoles.values() if d["active"]] if len(active_dips) == 0: return - orientations = [d["dip"].ori for d in active_dips] - timecourses = [d["dip"].amplitude for d in active_dips] + orientations = [dip["orientation"] for dip in active_dips] + timecourses = [dip["timecourse"] for dip in active_dips] arrow_scaling = 0.05 / np.max(np.abs(timecourses)) - for d, ori, timecourse in zip(active_dips, orientations, timecourses): - helmet_coords = d["helmet_coords"] + for dip, ori, timecourse in zip(active_dips, orientations, timecourses): + helmet_coords = dip["helmet_coords"] if helmet_coords is None: continue dip_ori = [ @@ -656,7 +701,7 @@ def _update_arrows(self): ] dip_moment = np.interp(self._current_time, self._evoked.times, timecourse) arrow_size = dip_moment * arrow_scaling - arrow_mesh = d["arrow_mesh"] + arrow_mesh = dip["arrow_mesh"] # Project the orientation of the dipole tangential to the helmet dip_ori_tan = helmet_coords[:2] @ dip_ori @ helmet_coords[:2] @@ -671,7 +716,7 @@ def _update_arrows(self): # Update the arrow mesh to point in the right directions arrow_mesh.points = (_arrow_mesh()[0] * arrow_size) @ arrow_coords - arrow_mesh.points += d["helmet_pos"] + arrow_mesh.points += dip["helmet_pos"] self._renderer._update() def _on_select_method(self, method): @@ -685,6 +730,8 @@ def _on_dipole_toggle(self, active, dip_num): active = bool(active) dipole["active"] = active dipole["line_artist"].set_visible(active) + # Labels starting with "_" are hidden from the legend. + dipole["line_artist"].set_label(("" if active else "_") + dipole["dip"].name) dipole["brain_arrow_actor"].visibility = active dipole["helmet_arrow_actor"].visibility = active self._fit_timecourses() From 0066f0be37d5ab8cb9e3f422958bd8f638b318d6 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 21 Jan 2025 17:00:13 +0200 Subject: [PATCH 51/65] cleanup --- mne/gui/_xfit.py | 39 +++++++++++---------------------------- 1 file changed, 11 insertions(+), 28 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 51771d4676c..b11fa86f507 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -416,7 +416,7 @@ def _on_fit_dipole(self): verbose=False, )[0] - self._add_dipole(dip) + self.add_dipole(dip) def add_dipole(self, dipole): """Add a dipole (or multiple dipoles) to the GUI. @@ -590,11 +590,12 @@ def _fit_timecourses(self): for i, dip in enumerate(active_dips): if dip["fix_ori"]: - timecourses[i] = fixed_timecourses[i] - orientations[i] = dip["dip"].ori.repeat(len(stc.times), axis=0) + dip["timecourse"] = fixed_timecourses[i] + dip["orientation"] = dip["dip"].ori.repeat(len(stc.times), axis=0) + else: + dip["timecourse"] = timecourses[i] + dip["orientation"] = orientations[i] elif self._multi_dipole_method == "Single dipole": - timecourses = list() - orientations = list() for dip in active_dips: dip_with_timecourse, _ = fit_dipole( self._evoked, @@ -607,31 +608,13 @@ def _fit_timecourses(self): verbose=False, ) if dip["fix_ori"]: - timecourses.append(dip_with_timecourse.data[0]) - orientations.append( - dip["dip"].ori.repeat(len(dip_with_timecourse.times), axis=0) + dip["timecourse"] = dip_with_timecourse.data[0] + dip["orientation"] = dip["dip"].ori.repeat( + len(dip_with_timecourse.times), axis=0 ) else: - timecourses.append(dip_with_timecourse.amplitude) - orientations.append(dip_with_timecourse.ori) - - # Store the timecourse and orientation in the Dipole object - for dip, timecourse, orientation in zip(active_dips, timecourses, orientations): - dip["timecourse"] = timecourse - dip["orientation"] = orientation - dip["times"] = self._evoked.times - # dip = d["dip"] - # dip.amplitude = timecourse - # dip.ori = orientation - # dip._set_times(self._evoked.times) - - # # Pad out all the other values to be defined for each timepoint. - # for attr in ["pos", "gof", "khi2", "nfree"]: - # setattr( - # dip, attr, getattr(dip, attr)[[0]].repeat(len(dip.times), axis=0) - # ) - # for key in dip.conf.keys(): - # dip.conf[key] = dip.conf[key][[0]].repeat(len(dip.times), axis=0) + dip["timecourse"] = dip_with_timecourse.amplitude[0] + dip["orientation"] = dip_with_timecourse.ori[0] # Update matplotlib canvas at the bottom of the window canvas = self._setup_mplcanvas() From f1a036419a1a6756da6f204ca854204d1559256a Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 21 Jan 2025 17:00:40 +0200 Subject: [PATCH 52/65] more cleanup --- mne/gui/_xfit.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index b11fa86f507..4bf3ed0a4aa 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -6,8 +6,6 @@ import pyvista from .. import pick_types - -# from ..beamformer import apply_lcmv, make_lcmv from ..bem import ( ConductorModel, _ensure_bem_surfaces, From 727fa2959b22df1e38c7a7f08789ad0bae277327 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 21 Jan 2025 18:25:37 +0200 Subject: [PATCH 53/65] Fix single dipole fitting with loose orientation --- mne/gui/_xfit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 4bf3ed0a4aa..c57db5ad08d 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -611,8 +611,8 @@ def _fit_timecourses(self): len(dip_with_timecourse.times), axis=0 ) else: - dip["timecourse"] = dip_with_timecourse.amplitude[0] - dip["orientation"] = dip_with_timecourse.ori[0] + dip["timecourse"] = dip_with_timecourse.amplitude + dip["orientation"] = dip_with_timecourse.ori # Update matplotlib canvas at the bottom of the window canvas = self._setup_mplcanvas() From 710b77d2daa8701ff59187ef3683ce376a7ad37e Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Tue, 21 Jan 2025 17:42:21 +0000 Subject: [PATCH 54/65] [autofix.ci] apply automated fixes --- mne/gui/_xfit.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index c57db5ad08d..917444ffe6e 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -1,3 +1,7 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + from copy import deepcopy from functools import partial from pathlib import Path From 4b71df066958b7c10fa47848620afc972b577ee3 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 21 Jan 2025 22:27:48 +0200 Subject: [PATCH 55/65] fixes proposed by vulture --- mne/gui/_xfit.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 917444ffe6e..3d4887e5d59 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -169,7 +169,6 @@ def __init__( # Initialize all the private attributes. self._actors = dict() - self._arrows = list() self._bem = bem self._ch_type = ch_type self._cov = cov @@ -179,7 +178,6 @@ def __init__( self._field_map = field_map self._fig_sensors = None self._multi_dipole_method = "Multi dipole (MNE)" - self._n_jobs = n_jobs self._show_density = show_density self._subjects_dir = subjects_dir self._subject = subject From dbabf053456acef55c55de6c0d957f3b9aa9d635 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 22 Jan 2025 10:48:34 +0200 Subject: [PATCH 56/65] Properly implement and test single channel picking --- mne/viz/tests/test_raw.py | 11 ++++++++++- mne/viz/tests/test_topo.py | 21 +++++++++++++++++++-- mne/viz/tests/test_utils.py | 28 ++++++++++++++++++++++------ mne/viz/topo.py | 17 +++++++++++------ mne/viz/utils.py | 27 +++++++++++++++------------ 5 files changed, 77 insertions(+), 27 deletions(-) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index ee5dbf8e4e5..339b4f49c56 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -28,7 +28,7 @@ set_config, ) from mne.viz import plot_raw, plot_sensors -from mne.viz.utils import _fake_click +from mne.viz.utils import _fake_click, _fake_keypress base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" @@ -1089,6 +1089,8 @@ def test_plot_sensors(raw): pytest.raises(ValueError, plot_sensors, raw.info, kind="sasaasd") plt.close("all") + print(raw.ch_names) + # Test lasso selection. fig, sels = raw.plot_sensors("select", show_names=True) ax = fig.axes[0] @@ -1100,6 +1102,13 @@ def test_plot_sensors(raw): _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="release") assert fig.lasso.selection == ["MEG 0121"] + + # Add another sensor with a single click. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (-0.1278, 0.0318), xform="data") + _fake_click(fig, ax, (-0.1278, 0.0318), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") + assert fig.lasso.selection == ["MEG 0121", "MEG 0131"] plt.close("all") raw.info["dev_head_t"] = None # like empty room diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index dbc29832c09..48d031739b9 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -23,7 +23,7 @@ ) from mne.viz.evoked import _line_plot_onselect from mne.viz.topo import _imshow_tfr, _plot_update_evoked_topo_proj, iter_topography -from mne.viz.utils import _fake_click +from mne.viz.utils import _fake_click, _fake_keypress base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" @@ -310,7 +310,24 @@ def test_plot_topo_select(): """Test selecting sensors in an ERP topography plot.""" # Show topography evoked = _get_epochs().average() - plot_evoked_topo(evoked, select=True) + fig = plot_evoked_topo(evoked, select=True) + ax = fig.axes[0] + + # Lasso select 3 out of the 6 sensors. + _fake_click(fig, ax, (0.05, 0.5), xform="data") + _fake_click(fig, ax, (0.2, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.2, 0.6), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.6), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.5), xform="data", kind="release") + assert fig.lasso.selection == ["MEG 0132", "MEG 0133", "MEG 0131"] + + # Add another sensor with a single click. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (0.11, 0.65), xform="data") + _fake_click(fig, ax, (0.21, 0.65), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") + assert fig.lasso.selection == ["MEG 0111", "MEG 0132", "MEG 0133", "MEG 0131"] def test_plot_tfr_topo(): diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 59bf06fe16c..55dc0f1e65c 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -293,6 +293,10 @@ def test_select_from_collection(): _fake_click(fig, ax, (0.5, 1), xform="data", kind="release") assert lasso.selection == [] + # Doing a single click on a patch should not select it. + _fake_click(fig, ax, (1, 1), xform="data") + assert lasso.selection == [] + # Make a selection with two patches in it. _fake_click(fig, ax, (0, 0.5), xform="data") _fake_click(fig, ax, (3, 0.5), xform="data", kind="motion") @@ -302,24 +306,24 @@ def test_select_from_collection(): _fake_click(fig, ax, (0, 0.5), xform="data", kind="release") assert lasso.selection == ["A", "B"] - # Use SHIFT key to lasso an additional patch. - _fake_keypress(fig, "shift") + # Use Control key to lasso an additional patch. + _fake_keypress(fig, "control") _fake_click(fig, ax, (0.5, -0.5), xform="data") _fake_click(fig, ax, (1.5, -0.5), xform="data", kind="motion") _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="motion") _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="release") - _fake_keypress(fig, "shift", kind="release") + _fake_keypress(fig, "control", kind="release") assert lasso.selection == ["A", "B", "D"] - # Use ALT key to remove a patch. - _fake_keypress(fig, "alt") + # Use CTRL+SHIFT to remove a patch. + _fake_keypress(fig, "ctrl+shift") _fake_click(fig, ax, (0.5, 0.5), xform="data") _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") _fake_click(fig, ax, (1.5, 1.5), xform="data", kind="motion") _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="motion") _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="release") - _fake_keypress(fig, "alt", kind="release") + _fake_keypress(fig, "ctrl+shift", kind="release") assert lasso.selection == ["B", "D"] # Check that the two selected patches have a different appearance. @@ -327,3 +331,15 @@ def test_select_from_collection(): ec = lasso.collection.get_edgecolors() assert (fc[:, -1] == [0.5, 1.0, 0.5, 1.0]).all() assert (ec[:, -1] == [0.25, 1.0, 0.25, 1.0]).all() + + # Test adding and removing single channels. + lasso.select_one(2) # should not do anything without modifier keys + assert lasso.selection == ["B", "D"] + _fake_keypress(fig, "control") + lasso.select_one(2) # add to selection + _fake_keypress(fig, "control", kind="release") + assert lasso.selection == ["B", "C", "D"] + _fake_keypress(fig, "ctrl+shift") + lasso.select_one(1) # remove from selection + assert lasso.selection == ["C", "D"] + _fake_keypress(fig, "ctrl+shift", kind="release") diff --git a/mne/viz/topo.py b/mne/viz/topo.py index f2073435003..6a4e5ff1079 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -275,10 +275,9 @@ def on_select(): publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) def on_channels_select(event): - ch_inds = {name: i for i, name in enumerate(shown_ch_names)} - selection_inds = [ - ch_inds[name] for name in event.ch_names if name in ch_inds - ] + selection_inds = np.flatnonzero( + np.isin(shown_ch_names, event.ch_names) + ) fig.lasso.select_many(selection_inds) fig.lasso.callbacks.append(on_select) @@ -381,9 +380,15 @@ def _plot_topo( def _plot_topo_onpick(event, show_func): """Onpick callback that shows a single channel in a new figure.""" - # make sure that the swipe gesture in OS-X doesn't open many figures orig_ax = event.inaxes - if orig_ax.figure.canvas._key in ["shift", "alt"]: + fig = orig_ax.figure + + # If we are doing lasso select, allow it to handle the click instead. + if fig.lasso is not None and event.key in ["control", "ctrl+shift"]: + return + + # make sure that the swipe gesture in OS-X doesn't open many figures + if fig.canvas._key in ["shift", "alt"]: return import matplotlib.pyplot as plt diff --git a/mne/viz/utils.py b/mne/viz/utils.py index ff628c4c41c..b01462acfda 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -952,7 +952,7 @@ def plot_sensors( Whether to plot the sensors as 3d, topomap or as an interactive sensor selection dialog. Available options ``'topomap'``, ``'3d'``, ``'select'``. If ``'select'``, a set of channels can be selected - interactively by using lasso selector or clicking while holding the shift + interactively by using lasso selector or clicking while holding the control key. The selected channels are returned along with the figure instance. Defaults to ``'topomap'``. ch_type : None | str @@ -1163,10 +1163,10 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names): if event.mouseevent.inaxes != ax: return - if event.mouseevent.key in ["shift", "alt"] and fig.lasso is not None: + if fig.lasso is not None and event.mouseevent.key in ["control", "ctrl+shift"]: + # Add the sensor to the selection instead of showing its name. for ind in event.ind: fig.lasso.select_one(ind) - return if show_names: return # channel names already visible @@ -1278,10 +1278,7 @@ def on_select(): publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) def on_channels_select(event): - ch_inds = {name: i for i, name in enumerate(ch_names)} - selection_inds = [ - ch_inds[name] for name in event.ch_names if name in ch_inds - ] + selection_inds = np.flatnonzero(np.isin(ch_names, event.ch_names)) fig.lasso.select_many(selection_inds) fig.lasso.callbacks.append(on_select) @@ -1614,6 +1611,9 @@ class SelectFromCollection: This tool highlights selected objects by fading other objects out (i.e., reducing their alpha values). + Holding down the Control key will add to the current selection, and holding down + Control+Shift will remove from the current selection. + Parameters ---------- ax : instance of Axes @@ -1711,14 +1711,17 @@ def on_select(self, verts): """Select a subset from the collection.""" from matplotlib.path import Path - if len(verts) <= 3: # Seems to be a good way to exclude single clicks. + # Don't respond to single clicks without extra keys being hold down. + # Figures like plot_evoked_topo want to do something else with them. + print(verts, self.canvas._key) + if len(verts) <= 3 and self.canvas._key not in ["control", "ctrl+shift"]: return path = Path(verts) inds = np.nonzero([path.intersects_path(p) for p in self.paths])[0] - if self.canvas._key == "shift": # Appending selection. + if self.canvas._key == "control": # Appending selection. self.selection_inds = np.union1d(self.selection_inds, inds).astype("int") - elif self.canvas._key == "alt": # Removing selection. + elif self.canvas._key == "ctrl+shift": self.selection_inds = np.setdiff1d(self.selection_inds, inds).astype("int") else: self.selection_inds = inds @@ -1728,9 +1731,9 @@ def on_select(self, verts): def select_one(self, ind): """Select or deselect one sensor.""" - if self.canvas._key == "shift": + if self.canvas._key == "control": self.selection_inds = np.union1d(self.selection_inds, [ind]) - elif self.canvas._key == "alt": + elif self.canvas._key == "ctrl+shift": self.selection_inds = np.setdiff1d(self.selection_inds, [ind]) else: return # don't notify() From e90887dda0ab135965e7fe8285161e3c5d179d74 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 22 Jan 2025 10:55:55 +0200 Subject: [PATCH 57/65] Add logging message --- mne/viz/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/viz/utils.py b/mne/viz/utils.py index b01462acfda..f9d64c49ec8 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -1647,6 +1647,7 @@ def __init__( alpha_nonselected=0.5, linewidth_selected=1, linewidth_nonselected=0.5, + verbose=None, ): from matplotlib.widgets import LassoSelector @@ -1704,6 +1705,7 @@ def ch_names(self): def notify(self): """Notify listeners that a selection has been made.""" + logger.info(f"Selected channels: {self.selection}") for callback in self.callbacks: callback() @@ -1713,7 +1715,6 @@ def on_select(self, verts): # Don't respond to single clicks without extra keys being hold down. # Figures like plot_evoked_topo want to do something else with them. - print(verts, self.canvas._key) if len(verts) <= 3 and self.canvas._key not in ["control", "ctrl+shift"]: return From 622ff54f8175499c32c450bcf3c3e8156330dee4 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 22 Jan 2025 11:37:28 +0200 Subject: [PATCH 58/65] small fix --- mne/viz/topo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 6a4e5ff1079..5c43d4de48e 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -384,7 +384,7 @@ def _plot_topo_onpick(event, show_func): fig = orig_ax.figure # If we are doing lasso select, allow it to handle the click instead. - if fig.lasso is not None and event.key in ["control", "ctrl+shift"]: + if hasattr(fig, "lasso") and event.key in ["control", "ctrl+shift"]: return # make sure that the swipe gesture in OS-X doesn't open many figures From d878f60a7f270b9d9eabe98b5621129a01f32406 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Fri, 31 Jan 2025 15:12:52 +0200 Subject: [PATCH 59/65] fix to dipole --- mne/dipole.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/dipole.py b/mne/dipole.py index 9b86f7cd875..33a99ca87c5 100644 --- a/mne/dipole.py +++ b/mne/dipole.py @@ -842,7 +842,7 @@ def _write_dipole_bdip(fname, dip): fid.write(np.array(has_errors, ">i4").tobytes()) # has_errors fid.write(np.zeros(1, ">f4").tobytes()) # noise level for key in _BDIP_ERROR_KEYS: - val = dip.conf[key][ti] if key in dip.conf else 0.0 + val = dip.conf[key][ti] if key in dip.conf else np.array(0.0) assert val.shape == () fid.write(np.array(val, ">f4").tobytes()) fid.write(np.zeros(25, ">f4").tobytes()) From b1ee63a364e3488de319e62a09367f248d3747bb Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 4 Feb 2025 10:01:26 +0200 Subject: [PATCH 60/65] Add possibility to show an stc with the fieldmap --- mne/gui/_xfit.py | 73 +++++++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 3d4887e5d59..9b1c55be843 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -21,15 +21,11 @@ from ..forward import convert_forward_solution, make_field_map, make_forward_dipole from ..minimum_norm import apply_inverse, make_inverse_operator from ..surface import _normal_orth -from ..transforms import ( - _get_trans, - _get_transforms_to_coord_frame, - transform_surface_to, -) +from ..transforms import _get_trans, _get_transforms_to_coord_frame, apply_trans from ..utils import _check_option, fill_doc, logger, verbose from ..viz import EvokedField, create_3d_figure from ..viz._3d import _plot_head_surface, _plot_sensors_3d -from ..viz.ui_events import subscribe +from ..viz.ui_events import link, subscribe from ..viz.utils import _get_color_list @@ -41,7 +37,7 @@ def dipolefit( bem=None, initial_time=None, trans=None, - rank="info", + rank=None, show_density=True, subject=None, subjects_dir=None, @@ -60,11 +56,13 @@ def dipolefit( Boundary element model to use in forward calculations. If ``None``, a spherical model is used. initial_time : float | None - Initial time point to show. If None, the time point of the maximum field + Initial time point to show. If ``None``, the time point of the maximum field strength is used. trans : instance of Transform | None - The transformation from head coordinates to MRI coordinates. If ``None``, - the identity matrix is used. + The transformation from head coordinates to MRI coordinates. If ``None``, the + identity matrix is used. + stc : instance of SourceEstimate | None + An optional distributed source estimate to show alongside the fieldmap. %(rank)s show_density : bool Whether to show the density of the fieldmap. @@ -80,6 +78,7 @@ def dipolefit( bem=bem, initial_time=initial_time, trans=trans, + stc=None, rank=rank, show_density=show_density, subject=subject, @@ -108,6 +107,8 @@ class DipoleFitUI: trans : instance of Transform | None The transformation from head coordinates to MRI coordinates. If ``None``, the identity matrix is used. + stc : instance of SourceEstimate | None + An optional distributed source estimate to show alongside the fieldmap. %(rank)s show_density : bool Whether to show the density of the fieldmap. @@ -125,6 +126,7 @@ def __init__( bem=None, initial_time=None, trans=None, + stc=None, rank="info", show_density=True, subject=None, @@ -154,19 +156,12 @@ def __init__( data = evoked.copy().pick(field_map[0]["ch_names"]).data initial_time = evoked.times[np.argmax(np.mean(data**2, axis=0))] - # Get transforms to convert all the various meshes to head space. + # Get transforms to convert all the various meshes to MRI space. head_mri_t = _get_trans(trans, "head", "mri")[0] to_cf_t = _get_transforms_to_coord_frame( - evoked.info, head_mri_t, coord_frame="head" + evoked.info, head_mri_t, coord_frame="mri" ) - # Transform the fieldmap surfaces to head space if needed. - if trans is not None: - for fm in field_map: - fm["surf"] = transform_surface_to( - fm["surf"], "head", [to_cf_t["mri"], to_cf_t["head"]], copy=False - ) - # Initialize all the private attributes. self._actors = dict() self._bem = bem @@ -179,11 +174,12 @@ def __init__( self._fig_sensors = None self._multi_dipole_method = "Multi dipole (MNE)" self._show_density = show_density + self._stc = stc self._subjects_dir = subjects_dir self._subject = subject self._time_line = None + self._head_mri_t = head_mri_t self._to_cf_t = to_cf_t - self._trans = trans self._rank = rank self._verbose = verbose @@ -199,6 +195,22 @@ def dipoles(self): def _configure_main_display(self): """Configure main 3D display of the GUI.""" fig = create_3d_figure((1500, 1020), bgcolor="white", show=True) + + self._fig_stc = None + if self._stc is not None: + self._fig_stc = self._stc.plot( + subject=self._subject, + subjects_dir=self._subjects_dir, + surface="white", + hemi="both", + time_viewer=False, + initial_time=self._current_time, + brain_kwargs=dict(units="m"), + figure=fig, + ) + fig = self._fig_stc + self._actors["brain"] = fig._actors["data"] + fig = EvokedField( self._evoked, self._field_map, @@ -216,6 +228,9 @@ def _configure_main_display(self): focalpoint=fit_sphere_to_headshape(self._evoked.info)[1] ) + if self._stc is not None: + link(self._fig_stc, fig) + for surf_map in fig._surf_maps: if surf_map["map_kind"] == "meg": helmet_mesh = surf_map["mesh"] @@ -260,7 +275,7 @@ def _configure_main_display(self): subject=self._subject, subjects_dir=self._subjects_dir, bem=self._bem, - coord_frame="head", + coord_frame="mri", to_cf_t=self._to_cf_t, alpha=0.2, ) @@ -405,13 +420,16 @@ def _on_fit_dipole(self): evoked_picked.info.normalize_proj() cov_picked = cov_picked.pick_channels(picks, ordered=False) cov_picked["projs"] = evoked_picked.info["projs"] + # Do we need to set the rank? + # for k, v in self._rank.items(): + # self._rank[k] = min(v, len(cov_picked.ch_names)) evoked_picked.crop(self._current_time, self._current_time) dip = fit_dipole( evoked_picked, cov_picked, self._bem, - trans=self._trans, + trans=self._head_mri_t, rank=self._rank, verbose=False, )[0] @@ -512,7 +530,10 @@ def add_dipole(self, dipole): for dipole_dict in new_dipoles: dip = dipole_dict["dip"] dipole_dict["brain_arrow_actor"] = self._renderer.plotter.add_arrows( - dip.pos[0], dip.ori[0], color=dipole_dict["color"], mag=0.05 + apply_trans(self._head_mri_t, dip.pos[0]), + dip.ori[0], + color=dipole_dict["color"], + mag=0.05, ) if arrow_mesh is not None: dipole_dict["helmet_arrow_actor"] = self._renderer.plotter.add_mesh( @@ -532,7 +553,7 @@ def _get_helmet_coords(self, dip): return None, None # Get the closest vertex (=point) of the helmet mesh - dip_pos = dip.pos[0] + dip_pos = apply_trans(self._head_mri_t, dip.pos[0]) helmet = self._actors["helmet"].GetMapper().GetInput() distances = ((helmet.points - dip_pos) * helmet.point_normals).sum(axis=1) closest_point = np.argmin(distances) @@ -565,7 +586,7 @@ def _fit_timecourses(self): [d["dip"] for d in active_dips], self._bem, self._evoked.info, - trans=self._trans, + trans=self._head_mri_t, ) fwd = convert_forward_solution(fwd, surf_ori=False) @@ -603,7 +624,7 @@ def _fit_timecourses(self): self._bem, pos=dip["dip"].pos[0], # position is always fixed ori=dip["dip"].ori[0] if dip["fix_ori"] else None, - trans=self._trans, + trans=self._head_mri_t, rank=self._rank, verbose=False, ) From 15519ebafaf0344a15af5f5f9f56b89bf36d803e Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 4 Feb 2025 10:02:52 +0200 Subject: [PATCH 61/65] Take units (m or mm) into account when showing fieldmaps on top of brain figures --- mne/viz/evoked_field.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index cf5a9996216..b1df34c907e 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -7,6 +7,7 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from copy import deepcopy from functools import partial import numpy as np @@ -185,6 +186,7 @@ def __init__( if isinstance(fig, Brain): self._renderer = fig._renderer self._in_brain_figure = True + self._units = fig._units if _get_3d_backend() == "notebook": raise NotImplementedError( "Plotting on top of an existing Brain figure " @@ -195,6 +197,7 @@ def __init__( fig, bgcolor=(0.0, 0.0, 0.0), size=(600, 600) ) self._in_brain_figure = False + self._units = "m" self.plotter = self._renderer.plotter self.interaction = interaction @@ -276,8 +279,8 @@ def _prepare_surf_map(self, surf_map, color, alpha): current_data = data_interp(self._current_time) # Make a solid surface - surf = surf_map["surf"] - if self._in_brain_figure: + surf = deepcopy(surf_map["surf"]) + if self._units == "mm": surf["rr"] *= 1000 map_vmax = self._vmax.get(surf_map["kind"]) if map_vmax is None: From 5108607b2be60cebbe7aa8807da03d25f4d6782f Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 4 Feb 2025 10:27:32 +0200 Subject: [PATCH 62/65] Add unit test and towncrier --- doc/changes/devel/13101.bugfix.rst | 1 + mne/viz/tests/test_3d.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 doc/changes/devel/13101.bugfix.rst diff --git a/doc/changes/devel/13101.bugfix.rst b/doc/changes/devel/13101.bugfix.rst new file mode 100644 index 00000000000..d24e55b5056 --- /dev/null +++ b/doc/changes/devel/13101.bugfix.rst @@ -0,0 +1 @@ +Take units (m or mm) into account when drawing :func:`~mne.viz.plot_evoked_field` on top of :class:`~mne.viz.Brain`, by `Marijn van Vliet`_. diff --git a/mne/viz/tests/test_3d.py b/mne/viz/tests/test_3d.py index 34022d59768..e3e4a2143d2 100644 --- a/mne/viz/tests/test_3d.py +++ b/mne/viz/tests/test_3d.py @@ -192,9 +192,18 @@ def test_plot_evoked_field(renderer): ) evoked.plot_field(maps, time=0.1, n_contours=n_contours) - # Test plotting inside an existing Brain figure. - brain = Brain("fsaverage", "lh", "inflated", subjects_dir=subjects_dir) - fig = evoked.plot_field(maps, time=0.1, fig=brain) + # Test plotting inside an existing Brain figure. Check that units are taken into + # account. + for units in ["mm", "m"]: + brain = Brain( + "fsaverage", "lh", "inflated", units=units, subjects_dir=subjects_dir + ) + fig = evoked.plot_field(maps, time=0.1, fig=brain) + assert brain._units == fig._units + scale = 1000 if units == "mm" else 1 + assert ( + fig._surf_maps[0]["surf"]["rr"][0, 0] == scale * maps[0]["surf"]["rr"][0, 0] + ) # Test some methods fig = evoked.plot_field(maps, time_viewer=True) From e243310678b98c30d1044e2e571f439f5a632e43 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 4 Feb 2025 10:29:32 +0200 Subject: [PATCH 63/65] Don't make unnecessary copy --- mne/viz/evoked_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index b1df34c907e..2138f800136 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -279,8 +279,8 @@ def _prepare_surf_map(self, surf_map, color, alpha): current_data = data_interp(self._current_time) # Make a solid surface - surf = deepcopy(surf_map["surf"]) if self._units == "mm": + surf = deepcopy(surf_map["surf"]) surf["rr"] *= 1000 map_vmax = self._vmax.get(surf_map["kind"]) if map_vmax is None: From a1f5328cfa3cf368d6e7af82b18782a1ddb0ae39 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Tue, 4 Feb 2025 10:30:33 +0200 Subject: [PATCH 64/65] Fix --- mne/viz/evoked_field.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/viz/evoked_field.py b/mne/viz/evoked_field.py index 2138f800136..839259ee117 100644 --- a/mne/viz/evoked_field.py +++ b/mne/viz/evoked_field.py @@ -279,8 +279,9 @@ def _prepare_surf_map(self, surf_map, color, alpha): current_data = data_interp(self._current_time) # Make a solid surface + surf = surf_map["surf"] if self._units == "mm": - surf = deepcopy(surf_map["surf"]) + surf = deepcopy(surf) surf["rr"] *= 1000 map_vmax = self._vmax.get(surf_map["kind"]) if map_vmax is None: From c703e3260ac8b6ff592f89e332a35d2cedf0d4e9 Mon Sep 17 00:00:00 2001 From: Marijn van Vliet Date: Wed, 5 Feb 2025 13:42:23 +0200 Subject: [PATCH 65/65] some checks on inputs and doc --- mne/gui/_xfit.py | 52 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/mne/gui/_xfit.py b/mne/gui/_xfit.py index 9b1c55be843..dabb4a020dc 100644 --- a/mne/gui/_xfit.py +++ b/mne/gui/_xfit.py @@ -96,8 +96,10 @@ class DipoleFitUI: ---------- evoked : instance of Evoked Evoked data to show fieldmap of and fit dipoles to. - cov : instance of Covariance | None - Noise covariance matrix. If ``None``, an ad-hoc covariance matrix is used. + cov : instance of Covariance | "baseline" | None + Noise covariance matrix. If ``None``, an ad-hoc covariance matrix is used with + default values for the diagonal elements (see Notes). If ``"baseline"``, the + diagonal elements is estimated from the baseline period of the evoked data. bem : instance of ConductorModel | None Boundary element model to use in forward calculations. If ``None``, a spherical model is used. @@ -106,17 +108,26 @@ class DipoleFitUI: strength is used. trans : instance of Transform | None The transformation from head coordinates to MRI coordinates. If ``None``, - the identity matrix is used. + the identity matrix is used and everything will be done in head coordinates. stc : instance of SourceEstimate | None - An optional distributed source estimate to show alongside the fieldmap. - %(rank)s - show_density : bool - Whether to show the density of the fieldmap. + An optional distributed source estimate to show alongside the fieldmap. The time + samples need to match those of the evoked data. subject : str | None The subject name. If ``None``, no MRI data is shown. %(subjects_dir)s + %(rank)s + show_density : bool + Whether to show the density of the fieldmap. + ch_type : "meg" | "eeg" | None + Type of channels to use for the dipole fitting. By default (``None``) both MEG + and EEG channels will be used. %(n_jobs)s %(verbose)s + + Notes + ----- + When using ``cov=None`` the default noise values are 5 fT/cm, 20 fT, and 0.2 ยตV for + gradiometers, magnetometers, and EEG channels respectively. """ def __init__( @@ -127,16 +138,22 @@ def __init__( initial_time=None, trans=None, stc=None, - rank="info", - show_density=True, subject=None, subjects_dir=None, + rank="info", + show_density=True, ch_type=None, n_jobs=None, verbose=None, ): if cov is None: cov = make_ad_hoc_cov(evoked.info) + elif cov == "baseline": + std = dict() + for typ in set(evoked.get_channel_types(only_data_chs=True)): + baseline = evoked.copy().pick(typ).crop(*evoked.baseline) + std[typ] = baseline.data.std(axis=1).mean() + cov = make_ad_hoc_cov(evoked.info, std) if bem is None: bem = make_sphere_model("auto", "auto", evoked.info) bem = _ensure_bem_surfaces(bem, extra_allow=(ConductorModel, None)) @@ -156,6 +173,18 @@ def __init__( data = evoked.copy().pick(field_map[0]["ch_names"]).data initial_time = evoked.times[np.argmax(np.mean(data**2, axis=0))] + if stc is not None: + if not np.allclose(stc.times, evoked.times): + raise ValueError( + "The time samples of the source estimate do not match those of the " + "evoked data." + ) + if trans is None: + raise ValueError( + "`trans` cannot be `None` when showing the fieldlines in " + "combination with a source estimate." + ) + # Get transforms to convert all the various meshes to MRI space. head_mri_t = _get_trans(trans, "head", "mri")[0] to_cf_t = _get_transforms_to_coord_frame( @@ -412,7 +441,7 @@ def _on_channels_select(self, event): def _on_fit_dipole(self): """Fit a single dipole.""" evoked_picked = self._evoked.copy() - cov_picked = self._cov.copy() + cov_picked = self._cov.copy().as_diag() # FIXME: as_diag necessary? if self._fig_sensors is not None: picks = self._fig_sensors.lasso.selection if len(picks) > 0: @@ -420,9 +449,6 @@ def _on_fit_dipole(self): evoked_picked.info.normalize_proj() cov_picked = cov_picked.pick_channels(picks, ordered=False) cov_picked["projs"] = evoked_picked.info["projs"] - # Do we need to set the rank? - # for k, v in self._rank.items(): - # self._rank[k] = min(v, len(cov_picked.ch_names)) evoked_picked.crop(self._current_time, self._current_time) dip = fit_dipole(