Skip to content

Commit 4e787cb

Browse files
committed
Refactor polycollection creation to separate function
1 parent 95c1356 commit 4e787cb

File tree

1 file changed

+132
-92
lines changed

1 file changed

+132
-92
lines changed

src/probeinterface/plotting.py

Lines changed: 132 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,77 @@
1212
from .utils import get_auto_lims
1313

1414

15+
def create_probe_collections(
16+
probe,
17+
contacts_colors: list | None = None,
18+
contacts_values: np.ndarray | None = None,
19+
cmap: str = "viridis",
20+
contacts_kargs: dict = {},
21+
probe_shape_kwargs: dict = {},
22+
):
23+
"""Create PolyCollection objects for a Probe.
24+
25+
Parameters
26+
----------
27+
probe : Probe
28+
The probe object
29+
contacts_colors : matplotlib color | None, default: None
30+
The color of the contacts
31+
contacts_values : np.ndarray | None, default: None
32+
Values to color the contacts with
33+
cmap : str, default: "viridis"
34+
A colormap color
35+
contacts_kargs : dict, default: {}
36+
Dict with kwargs for contacts (e.g. alpha, edgecolor, lw)
37+
probe_shape_kwargs : dict, default: {}
38+
Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw)
39+
40+
Returns
41+
-------
42+
poly : PolyCollection
43+
The polygon collection for contacts
44+
poly_contour : PolyCollection | None
45+
The polygon collection for the probe shape
46+
"""
47+
if probe.ndim == 2:
48+
from matplotlib.collections import PolyCollection
49+
Collection = PolyCollection
50+
elif probe.ndim == 3:
51+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
52+
Collection = Poly3DCollection
53+
else:
54+
raise ValueError(f"Unexpected probe.ndim: {probe.ndim}")
55+
56+
_probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3)
57+
_probe_shape_kwargs.update(probe_shape_kwargs)
58+
59+
_contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5)
60+
_contacts_kargs.update(contacts_kargs)
61+
62+
n = probe.get_contact_count()
63+
64+
if contacts_colors is None and contacts_values is None:
65+
contacts_colors = ["orange"] * n
66+
elif contacts_colors is not None:
67+
contacts_colors = contacts_colors
68+
elif contacts_values is not None:
69+
contacts_colors = None
70+
71+
vertices = probe.get_contact_vertices()
72+
poly = Collection(vertices, color=contacts_colors, **_contacts_kargs)
73+
74+
if contacts_values is not None:
75+
poly.set_array(contacts_values)
76+
poly.set_cmap(cmap)
77+
78+
# probe shape
79+
poly_contour = None
80+
planar_contour = probe.probe_planar_contour
81+
if planar_contour is not None:
82+
poly_contour = Collection([planar_contour], **_probe_shape_kwargs)
83+
84+
return poly, poly_contour
85+
1586
def plot_probe(
1687
probe,
1788
ax=None,
@@ -28,7 +99,6 @@ def plot_probe(
2899
ylims: tuple | None = None,
29100
zlims: tuple | None = None,
30101
show_channel_on_click: bool = False,
31-
add_to_axis: bool = True,
32102
):
33103
"""Plot a Probe object.
34104
Generates a 2D or 3D axis, depending on Probe.ndim
@@ -65,9 +135,6 @@ def plot_probe(
65135
Limits for z dimension
66136
show_channel_on_click : bool, default: False
67137
If True, the channel information is shown upon click
68-
add_to_axis : bool, default: True
69-
If True, collections are added to the axis. If False, collections are
70-
only returned without being added to the axis.
71138
72139
Returns
73140
-------
@@ -78,51 +145,37 @@ def plot_probe(
78145
"""
79146
import matplotlib.pyplot as plt
80147

81-
if probe.ndim == 2:
82-
from matplotlib.collections import PolyCollection
83-
elif probe.ndim == 3:
84-
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
85-
86-
if ax is None and add_to_axis:
148+
if ax is None:
87149
if probe.ndim == 2:
88150
fig, ax = plt.subplots()
89151
ax.set_aspect("equal")
90152
else:
91153
fig = plt.figure()
92154
ax = fig.add_subplot(1, 1, 1, projection="3d")
93-
elif ax is not None:
155+
else:
94156
fig = ax.get_figure()
95157

96-
_probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3)
97-
_probe_shape_kwargs.update(probe_shape_kwargs)
98-
99-
_contacts_kargs = dict(alpha=0.7, edgecolor=[0.3, 0.3, 0.3], lw=0.5)
100-
_contacts_kargs.update(contacts_kargs)
101-
102-
n = probe.get_contact_count()
103-
104-
if contacts_colors is None and contacts_values is None:
105-
contacts_colors = ["orange"] * n
106-
elif contacts_colors is not None:
107-
contacts_colors = contacts_colors
108-
elif contacts_values is not None:
109-
contacts_colors = None
158+
# Create collections (contacts, probe shape)
159+
poly, poly_contour = create_probe_collections(
160+
probe,
161+
contacts_colors=contacts_colors,
162+
contacts_values=contacts_values,
163+
cmap=cmap,
164+
contacts_kargs=contacts_kargs,
165+
probe_shape_kwargs=probe_shape_kwargs,
166+
)
110167

111-
vertices = probe.get_contact_vertices()
168+
# Add collections to the axis
112169
if probe.ndim == 2:
113-
poly = PolyCollection(vertices, color=contacts_colors, **_contacts_kargs)
114-
if add_to_axis and ax is not None:
115-
ax.add_collection(poly)
170+
ax.add_collection(poly)
171+
if poly_contour is not None:
172+
ax.add_collection(poly_contour)
116173
elif probe.ndim == 3:
117-
poly = Poly3DCollection(vertices, color=contacts_colors, **_contacts_kargs)
118-
if add_to_axis and ax is not None:
119-
ax.add_collection3d(poly)
120-
121-
if contacts_values is not None:
122-
poly.set_array(contacts_values)
123-
poly.set_cmap(cmap)
124-
125-
if show_channel_on_click and add_to_axis:
174+
ax.add_collection3d(poly)
175+
if poly_contour is not None:
176+
ax.add_collection3d(poly_contour)
177+
178+
if show_channel_on_click:
126179
assert probe.ndim == 2, "show_channel_on_click works only for ndim=2"
127180

128181
def on_press(event):
@@ -131,64 +184,51 @@ def on_press(event):
131184
fig.canvas.mpl_connect("button_press_event", on_press)
132185
fig.canvas.mpl_connect("button_release_event", on_release)
133186

134-
# probe shape
135-
poly_contour = None
136-
planar_contour = probe.probe_planar_contour
137-
if planar_contour is not None:
138-
if probe.ndim == 2:
139-
poly_contour = PolyCollection([planar_contour], **_probe_shape_kwargs)
140-
if add_to_axis and ax is not None:
141-
ax.add_collection(poly_contour)
142-
elif probe.ndim == 3:
143-
poly_contour = Poly3DCollection([planar_contour], **_probe_shape_kwargs)
144-
if add_to_axis and ax is not None:
145-
ax.add_collection3d(poly_contour)
146-
147-
if add_to_axis and ax is not None:
148-
if text_on_contact is not None:
149-
text_on_contact = np.asarray(text_on_contact)
150-
assert text_on_contact.size == probe.get_contact_count()
151-
152-
if with_contact_id or with_device_index or text_on_contact is not None:
153-
if probe.ndim == 3:
154-
raise NotImplementedError("Channel index is 2d only")
155-
for i in range(n):
156-
txt = []
157-
if with_contact_id and probe.contact_ids is not None:
158-
contact_id = probe.contact_ids[i]
159-
txt.append(f"id{contact_id}")
160-
if with_device_index and probe.device_channel_indices is not None:
161-
chan_ind = probe.device_channel_indices[i]
162-
txt.append(f"dev{chan_ind}")
163-
if text_on_contact is not None:
164-
txt.append(f"{text_on_contact[i]}")
165-
166-
txt = "\n".join(txt)
167-
x, y = probe.contact_positions[i]
168-
ax.text(x, y, txt, ha="center", va="center", clip_on=True)
169-
170-
if xlims is None or ylims is None or (zlims is None and probe.ndim == 3):
171-
xlims, ylims, zlims = get_auto_lims(probe)
172-
173-
ax.set_xlim(*xlims)
174-
ax.set_ylim(*ylims)
175-
176-
if probe.si_units == "um":
177-
unit_str = "($\\mu m$)"
178-
else:
179-
unit_str = f"({probe.si_units})"
180-
ax.set_xlabel(f"x {unit_str}", fontsize=15)
181-
ax.set_ylabel(f"y {unit_str}", fontsize=15)
187+
if text_on_contact is not None:
188+
text_on_contact = np.asarray(text_on_contact)
189+
assert text_on_contact.size == probe.get_contact_count()
182190

191+
n = probe.get_contact_count()
192+
if with_contact_id or with_device_index or text_on_contact is not None:
183193
if probe.ndim == 3:
184-
ax.set_zlim(zlims)
185-
ax.set_zlabel("z")
194+
raise NotImplementedError("Channel index is 2d only")
195+
for i in range(n):
196+
txt = []
197+
if with_contact_id and probe.contact_ids is not None:
198+
contact_id = probe.contact_ids[i]
199+
txt.append(f"id{contact_id}")
200+
if with_device_index and probe.device_channel_indices is not None:
201+
chan_ind = probe.device_channel_indices[i]
202+
txt.append(f"dev{chan_ind}")
203+
if text_on_contact is not None:
204+
txt.append(f"{text_on_contact[i]}")
205+
206+
txt = "\n".join(txt)
207+
x, y = probe.contact_positions[i]
208+
ax.text(x, y, txt, ha="center", va="center", clip_on=True)
209+
210+
if xlims is None or ylims is None or (zlims is None and probe.ndim == 3):
211+
xlims, ylims, zlims = get_auto_lims(probe)
212+
213+
ax.set_xlim(*xlims)
214+
ax.set_ylim(*ylims)
215+
216+
if probe.si_units == "um":
217+
unit_str = "($\\mu m$)"
218+
else:
219+
unit_str = f"({probe.si_units})"
220+
ax.set_xlabel(f"x {unit_str}", fontsize=15)
221+
ax.set_ylabel(f"y {unit_str}", fontsize=15)
186222

187-
if probe.ndim == 2:
188-
ax.set_aspect("equal")
223+
if probe.ndim == 3:
224+
ax.set_zlim(zlims)
225+
ax.set_zlabel("z")
226+
227+
if probe.ndim == 2:
228+
ax.set_aspect("equal")
189229

190-
if title:
191-
ax.set_title(probe.get_title())
230+
if title:
231+
ax.set_title(probe.get_title())
192232

193233
return poly, poly_contour
194234

0 commit comments

Comments
 (0)