Skip to content

Montage.plot(ax=ax, scale=0.1) crashes when ax is from a subplot #13438

@christian-oreilly

Description

@christian-oreilly

Description of the problem

As per the title, the scaling in Montage.suplot() has a bug and crashes when passing an Axes object from a subplot.

Steps to reproduce

import mne
import matplotlib.pyplot as plt
import numpy as np

montage = mne.channels.make_standard_montage('GSN-HydroCel-129')

def plot_montage_topo(ax, picks, montage):

    info = mne.create_info(montage.ch_names, sfreq=256, ch_types="eeg")
    raw = mne.io.RawArray(
            np.zeros((len(montage.ch_names), 1)), info, copy=None, verbose=False
        ).set_montage(montage)
    raw.pick(picks).get_montage().plot(axes=ax, show_names=False, scale=0.1)    


fig, axes = plt.subplots(2, 1)
plot_montage_topo(axes[1], montage.ch_names, montage)

Link to data

No response

Expected results

Normal behavior. No exception.

Actual results

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[2], [line 17](vscode-notebook-cell:?execution_count=2&line=17)
     13     raw.pick(picks).get_montage().plot(axes=ax, show_names=False, scale=0.1)    
     16 fig, axes = plt.subplots(2, 1)
---> [17](vscode-notebook-cell:?execution_count=2&line=17) plot_montage_topo(axes[1], montage.ch_names, montage)

Cell In[2], [line 13](vscode-notebook-cell:?execution_count=2&line=13)
      9 info = mne.create_info(montage.ch_names, sfreq=256, ch_types="eeg")
     10 raw = mne.io.RawArray(
     11         np.zeros((len(montage.ch_names), 1)), info, copy=None, verbose=False
     12     ).set_montage(montage)
---> [13](vscode-notebook-cell:?execution_count=2&line=13) raw.pick(picks).get_montage().plot(axes=ax, show_names=False, scale=0.1)

File ~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/mne/channels/montage.py:372, in DigMontage.plot(self, scale, show_names, kind, show, sphere, axes, verbose)
    360 @copy_function_doc_to_method_doc(plot_montage)
    361 def plot(
    362     self,
   (...)
    370     verbose=None,
    371 ):
--> [372](https://file+.vscode-resource.vscode-cdn.net/Users/christian/Code/eog-learn/paper_2025/~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/mne/channels/montage.py:372)     return plot_montage(
    373         self,
    374         scale=scale,
    375         show_names=show_names,
    376         kind=kind,
    377         show=show,
    378         sphere=sphere,
    379         axes=axes,
    380     )

File <decorator-gen-82>:12, in plot_montage(montage, scale, show_names, kind, show, sphere, axes, verbose)

File ~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/mne/viz/montage.py:105, in plot_montage(montage, scale, show_names, kind, show, sphere, axes, verbose)
     93 fig = plot_sensors(
     94     info,
     95     kind=kind,
   (...)
    100     axes=axes,
    101 )
    103 if scale != 1.0:
    104     # scale points
--> [105](https://file+.vscode-resource.vscode-cdn.net/Users/christian/Code/eog-learn/paper_2025/~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/mne/viz/montage.py:105)     collection = fig.axes[0].collections[0]
    106     collection.set_sizes([scale * 10])
    108     # scale labels

File ~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/matplotlib/axes/_base.py:1453, in _AxesBase.ArtistList.__getitem__(self, key)
   1452 def __getitem__(self, key):
-> [1453](https://file+.vscode-resource.vscode-cdn.net/Users/christian/Code/eog-learn/paper_2025/~/opt/anaconda3/envs/base12/lib/python3.12/site-packages/matplotlib/axes/_base.py:1453)     return [artist
   1454             for artist in self._axes._children
   1455             if self._type_check(artist)][key]

IndexError: list index out of range

Additional information

The following git-diff patch fixes this issue on my system:

@@ -7,6 +7,8 @@
 from copy import deepcopy
 
 import numpy as np
+from matplotlib.pyplot import Axes
+from collections.abc import Iterable
 from scipy.spatial.distance import cdist
 
 from .._fiff._digitization import _get_fid_coords
@@ -101,17 +103,28 @@ def plot_montage(
     )
 
     if scale != 1.0:
+
+        if axes is None:
+            plot_ax = fig.axes[0]
+        else:
+            if isinstance(fig.axes, Iterable):
+                for ax in fig.axes:
+                    if id(ax) == id(axes):
+                        plot_ax = ax
+            else:
+                plot_ax = fig.axes
+
         # scale points
-        collection = fig.axes[0].collections[0]
+        collection = plot_ax.collections[0]
         collection.set_sizes([scale * 10])
 
         # scale labels
         labels = fig.findobj(match=plt.Text)
-        x_label, y_label = fig.axes[0].xaxis.label, fig.axes[0].yaxis.label
-        z_label = fig.axes[0].zaxis.label if kind == "3d" else None
-        tick_labels = fig.axes[0].get_xticklabels() + fig.axes[0].get_yticklabels()
+        x_label, y_label = plot_ax.xaxis.label, plot_ax.yaxis.label
+        z_label = plot_ax.zaxis.label if kind == "3d" else None
+        tick_labels = plot_ax.get_xticklabels() + plot_ax.get_yticklabels()
         if kind == "3d":
-            tick_labels += fig.axes[0].get_zticklabels()
+            tick_labels += plot_ax.get_zticklabels()
         for label in labels:
             if label not in [x_label, y_label, z_label] + tick_labels:
                 label.set_fontsize(label.get_fontsize() * scale)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions