-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Open
Description
Description of the problem
As per the title, the scaling in Montage.suplot() has a bug and crashes when passing an Axes object from a subplot.
Steps to reproduce
import mne
import matplotlib.pyplot as plt
import numpy as np
montage = mne.channels.make_standard_montage('GSN-HydroCel-129')
def plot_montage_topo(ax, picks, montage):
info = mne.create_info(montage.ch_names, sfreq=256, ch_types="eeg")
raw = mne.io.RawArray(
np.zeros((len(montage.ch_names), 1)), info, copy=None, verbose=False
).set_montage(montage)
raw.pick(picks).get_montage().plot(axes=ax, show_names=False, scale=0.1)
fig, axes = plt.subplots(2, 1)
plot_montage_topo(axes[1], montage.ch_names, montage)
Link to data
No response
Expected results
Normal behavior. No exception.
Actual results
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[2], [line 17](vscode-notebook-cell:?execution_count=2&line=17)
13 raw.pick(picks).get_montage().plot(axes=ax, show_names=False, scale=0.1)
16 fig, axes = plt.subplots(2, 1)
---> [17](vscode-notebook-cell:?execution_count=2&line=17) plot_montage_topo(axes[1], montage.ch_names, montage)
Cell In[2], [line 13](vscode-notebook-cell:?execution_count=2&line=13)
9 info = mne.create_info(montage.ch_names, sfreq=256, ch_types="eeg")
10 raw = mne.io.RawArray(
11 np.zeros((len(montage.ch_names), 1)), info, copy=None, verbose=False
12 ).set_montage(montage)
---> [13](vscode-notebook-cell:?execution_count=2&line=13) raw.pick(picks).get_montage().plot(axes=ax, show_names=False, scale=0.1)
File ~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/mne/channels/montage.py:372, in DigMontage.plot(self, scale, show_names, kind, show, sphere, axes, verbose)
360 @copy_function_doc_to_method_doc(plot_montage)
361 def plot(
362 self,
(...)
370 verbose=None,
371 ):
--> [372](https://file+.vscode-resource.vscode-cdn.net/Users/christian/Code/eog-learn/paper_2025/~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/mne/channels/montage.py:372) return plot_montage(
373 self,
374 scale=scale,
375 show_names=show_names,
376 kind=kind,
377 show=show,
378 sphere=sphere,
379 axes=axes,
380 )
File <decorator-gen-82>:12, in plot_montage(montage, scale, show_names, kind, show, sphere, axes, verbose)
File ~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/mne/viz/montage.py:105, in plot_montage(montage, scale, show_names, kind, show, sphere, axes, verbose)
93 fig = plot_sensors(
94 info,
95 kind=kind,
(...)
100 axes=axes,
101 )
103 if scale != 1.0:
104 # scale points
--> [105](https://file+.vscode-resource.vscode-cdn.net/Users/christian/Code/eog-learn/paper_2025/~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/mne/viz/montage.py:105) collection = fig.axes[0].collections[0]
106 collection.set_sizes([scale * 10])
108 # scale labels
File ~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/matplotlib/axes/_base.py:1453, in _AxesBase.ArtistList.__getitem__(self, key)
1452 def __getitem__(self, key):
-> [1453](https://file+.vscode-resource.vscode-cdn.net/Users/christian/Code/eog-learn/paper_2025/~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/matplotlib/axes/_base.py:1453) return [artist
1454 for artist in self._axes._children
1455 if self._type_check(artist)][key]
IndexError: list index out of range
Additional information
The following git-diff patch fixes this issue on my system:
@@ -7,6 +7,8 @@
from copy import deepcopy
import numpy as np
+from matplotlib.pyplot import Axes
+from collections.abc import Iterable
from scipy.spatial.distance import cdist
from .._fiff._digitization import _get_fid_coords
@@ -101,17 +103,28 @@ def plot_montage(
)
if scale != 1.0:
+
+ if axes is None:
+ plot_ax = fig.axes[0]
+ else:
+ if isinstance(fig.axes, Iterable):
+ for ax in fig.axes:
+ if id(ax) == id(axes):
+ plot_ax = ax
+ else:
+ plot_ax = fig.axes
+
# scale points
- collection = fig.axes[0].collections[0]
+ collection = plot_ax.collections[0]
collection.set_sizes([scale * 10])
# scale labels
labels = fig.findobj(match=plt.Text)
- x_label, y_label = fig.axes[0].xaxis.label, fig.axes[0].yaxis.label
- z_label = fig.axes[0].zaxis.label if kind == "3d" else None
- tick_labels = fig.axes[0].get_xticklabels() + fig.axes[0].get_yticklabels()
+ x_label, y_label = plot_ax.xaxis.label, plot_ax.yaxis.label
+ z_label = plot_ax.zaxis.label if kind == "3d" else None
+ tick_labels = plot_ax.get_xticklabels() + plot_ax.get_yticklabels()
if kind == "3d":
- tick_labels += fig.axes[0].get_zticklabels()
+ tick_labels += plot_ax.get_zticklabels()
for label in labels:
if label not in [x_label, y_label, z_label] + tick_labels:
label.set_fontsize(label.get_fontsize() * scale)