1212from .utils import get_auto_lims
1313
1414
15+ def create_probe_collections (
16+ probe ,
17+ contacts_colors : list | None = None ,
18+ contacts_values : np .ndarray | None = None ,
19+ cmap : str = "viridis" ,
20+ contacts_kargs : dict = {},
21+ probe_shape_kwargs : dict = {},
22+ ):
23+ """Create PolyCollection objects for a Probe.
24+
25+ Parameters
26+ ----------
27+ probe : Probe
28+ The probe object
29+ contacts_colors : matplotlib color | None, default: None
30+ The color of the contacts
31+ contacts_values : np.ndarray | None, default: None
32+ Values to color the contacts with
33+ cmap : str, default: "viridis"
34+ A colormap color
35+ contacts_kargs : dict, default: {}
36+ Dict with kwargs for contacts (e.g. alpha, edgecolor, lw)
37+ probe_shape_kwargs : dict, default: {}
38+ Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw)
39+
40+ Returns
41+ -------
42+ poly : PolyCollection
43+ The polygon collection for contacts
44+ poly_contour : PolyCollection | None
45+ The polygon collection for the probe shape
46+ """
47+ if probe .ndim == 2 :
48+ from matplotlib .collections import PolyCollection
49+ Collection = PolyCollection
50+ elif probe .ndim == 3 :
51+ from mpl_toolkits .mplot3d .art3d import Poly3DCollection
52+ Collection = Poly3DCollection
53+ else :
54+ raise ValueError (f"Unexpected probe.ndim: { probe .ndim } " )
55+
56+ _probe_shape_kwargs = dict (facecolor = "green" , edgecolor = "k" , lw = 0.5 , alpha = 0.3 )
57+ _probe_shape_kwargs .update (probe_shape_kwargs )
58+
59+ _contacts_kargs = dict (alpha = 0.7 , edgecolor = [0.3 , 0.3 , 0.3 ], lw = 0.5 )
60+ _contacts_kargs .update (contacts_kargs )
61+
62+ n = probe .get_contact_count ()
63+
64+ if contacts_colors is None and contacts_values is None :
65+ contacts_colors = ["orange" ] * n
66+ elif contacts_colors is not None :
67+ contacts_colors = contacts_colors
68+ elif contacts_values is not None :
69+ contacts_colors = None
70+
71+ vertices = probe .get_contact_vertices ()
72+ poly = Collection (vertices , color = contacts_colors , ** _contacts_kargs )
73+
74+ if contacts_values is not None :
75+ poly .set_array (contacts_values )
76+ poly .set_cmap (cmap )
77+
78+ # probe shape
79+ poly_contour = None
80+ planar_contour = probe .probe_planar_contour
81+ if planar_contour is not None :
82+ poly_contour = Collection ([planar_contour ], ** _probe_shape_kwargs )
83+
84+ return poly , poly_contour
85+
1586def plot_probe (
1687 probe ,
1788 ax = None ,
@@ -28,7 +99,6 @@ def plot_probe(
2899 ylims : tuple | None = None ,
29100 zlims : tuple | None = None ,
30101 show_channel_on_click : bool = False ,
31- add_to_axis : bool = True ,
32102):
33103 """Plot a Probe object.
34104 Generates a 2D or 3D axis, depending on Probe.ndim
@@ -65,9 +135,6 @@ def plot_probe(
65135 Limits for z dimension
66136 show_channel_on_click : bool, default: False
67137 If True, the channel information is shown upon click
68- add_to_axis : bool, default: True
69- If True, collections are added to the axis. If False, collections are
70- only returned without being added to the axis.
71138
72139 Returns
73140 -------
@@ -78,51 +145,37 @@ def plot_probe(
78145 """
79146 import matplotlib .pyplot as plt
80147
81- if probe .ndim == 2 :
82- from matplotlib .collections import PolyCollection
83- elif probe .ndim == 3 :
84- from mpl_toolkits .mplot3d .art3d import Poly3DCollection
85-
86- if ax is None and add_to_axis :
148+ if ax is None :
87149 if probe .ndim == 2 :
88150 fig , ax = plt .subplots ()
89151 ax .set_aspect ("equal" )
90152 else :
91153 fig = plt .figure ()
92154 ax = fig .add_subplot (1 , 1 , 1 , projection = "3d" )
93- elif ax is not None :
155+ else :
94156 fig = ax .get_figure ()
95157
96- _probe_shape_kwargs = dict (facecolor = "green" , edgecolor = "k" , lw = 0.5 , alpha = 0.3 )
97- _probe_shape_kwargs .update (probe_shape_kwargs )
98-
99- _contacts_kargs = dict (alpha = 0.7 , edgecolor = [0.3 , 0.3 , 0.3 ], lw = 0.5 )
100- _contacts_kargs .update (contacts_kargs )
101-
102- n = probe .get_contact_count ()
103-
104- if contacts_colors is None and contacts_values is None :
105- contacts_colors = ["orange" ] * n
106- elif contacts_colors is not None :
107- contacts_colors = contacts_colors
108- elif contacts_values is not None :
109- contacts_colors = None
158+ # Create collections (contacts, probe shape)
159+ poly , poly_contour = create_probe_collections (
160+ probe ,
161+ contacts_colors = contacts_colors ,
162+ contacts_values = contacts_values ,
163+ cmap = cmap ,
164+ contacts_kargs = contacts_kargs ,
165+ probe_shape_kwargs = probe_shape_kwargs ,
166+ )
110167
111- vertices = probe . get_contact_vertices ()
168+ # Add collections to the axis
112169 if probe .ndim == 2 :
113- poly = PolyCollection ( vertices , color = contacts_colors , ** _contacts_kargs )
114- if add_to_axis and ax is not None :
115- ax .add_collection (poly )
170+ ax . add_collection ( poly )
171+ if poly_contour is not None :
172+ ax .add_collection (poly_contour )
116173 elif probe .ndim == 3 :
117- poly = Poly3DCollection (vertices , color = contacts_colors , ** _contacts_kargs )
118- if add_to_axis and ax is not None :
119- ax .add_collection3d (poly )
120-
121- if contacts_values is not None :
122- poly .set_array (contacts_values )
123- poly .set_cmap (cmap )
124-
125- if show_channel_on_click and add_to_axis :
174+ ax .add_collection3d (poly )
175+ if poly_contour is not None :
176+ ax .add_collection3d (poly_contour )
177+
178+ if show_channel_on_click :
126179 assert probe .ndim == 2 , "show_channel_on_click works only for ndim=2"
127180
128181 def on_press (event ):
@@ -131,64 +184,51 @@ def on_press(event):
131184 fig .canvas .mpl_connect ("button_press_event" , on_press )
132185 fig .canvas .mpl_connect ("button_release_event" , on_release )
133186
134- # probe shape
135- poly_contour = None
136- planar_contour = probe .probe_planar_contour
137- if planar_contour is not None :
138- if probe .ndim == 2 :
139- poly_contour = PolyCollection ([planar_contour ], ** _probe_shape_kwargs )
140- if add_to_axis and ax is not None :
141- ax .add_collection (poly_contour )
142- elif probe .ndim == 3 :
143- poly_contour = Poly3DCollection ([planar_contour ], ** _probe_shape_kwargs )
144- if add_to_axis and ax is not None :
145- ax .add_collection3d (poly_contour )
146-
147- if add_to_axis and ax is not None :
148- if text_on_contact is not None :
149- text_on_contact = np .asarray (text_on_contact )
150- assert text_on_contact .size == probe .get_contact_count ()
151-
152- if with_contact_id or with_device_index or text_on_contact is not None :
153- if probe .ndim == 3 :
154- raise NotImplementedError ("Channel index is 2d only" )
155- for i in range (n ):
156- txt = []
157- if with_contact_id and probe .contact_ids is not None :
158- contact_id = probe .contact_ids [i ]
159- txt .append (f"id{ contact_id } " )
160- if with_device_index and probe .device_channel_indices is not None :
161- chan_ind = probe .device_channel_indices [i ]
162- txt .append (f"dev{ chan_ind } " )
163- if text_on_contact is not None :
164- txt .append (f"{ text_on_contact [i ]} " )
165-
166- txt = "\n " .join (txt )
167- x , y = probe .contact_positions [i ]
168- ax .text (x , y , txt , ha = "center" , va = "center" , clip_on = True )
169-
170- if xlims is None or ylims is None or (zlims is None and probe .ndim == 3 ):
171- xlims , ylims , zlims = get_auto_lims (probe )
172-
173- ax .set_xlim (* xlims )
174- ax .set_ylim (* ylims )
175-
176- if probe .si_units == "um" :
177- unit_str = "($\\ mu m$)"
178- else :
179- unit_str = f"({ probe .si_units } )"
180- ax .set_xlabel (f"x { unit_str } " , fontsize = 15 )
181- ax .set_ylabel (f"y { unit_str } " , fontsize = 15 )
187+ if text_on_contact is not None :
188+ text_on_contact = np .asarray (text_on_contact )
189+ assert text_on_contact .size == probe .get_contact_count ()
182190
191+ n = probe .get_contact_count ()
192+ if with_contact_id or with_device_index or text_on_contact is not None :
183193 if probe .ndim == 3 :
184- ax .set_zlim (zlims )
185- ax .set_zlabel ("z" )
194+ raise NotImplementedError ("Channel index is 2d only" )
195+ for i in range (n ):
196+ txt = []
197+ if with_contact_id and probe .contact_ids is not None :
198+ contact_id = probe .contact_ids [i ]
199+ txt .append (f"id{ contact_id } " )
200+ if with_device_index and probe .device_channel_indices is not None :
201+ chan_ind = probe .device_channel_indices [i ]
202+ txt .append (f"dev{ chan_ind } " )
203+ if text_on_contact is not None :
204+ txt .append (f"{ text_on_contact [i ]} " )
205+
206+ txt = "\n " .join (txt )
207+ x , y = probe .contact_positions [i ]
208+ ax .text (x , y , txt , ha = "center" , va = "center" , clip_on = True )
209+
210+ if xlims is None or ylims is None or (zlims is None and probe .ndim == 3 ):
211+ xlims , ylims , zlims = get_auto_lims (probe )
212+
213+ ax .set_xlim (* xlims )
214+ ax .set_ylim (* ylims )
215+
216+ if probe .si_units == "um" :
217+ unit_str = "($\\ mu m$)"
218+ else :
219+ unit_str = f"({ probe .si_units } )"
220+ ax .set_xlabel (f"x { unit_str } " , fontsize = 15 )
221+ ax .set_ylabel (f"y { unit_str } " , fontsize = 15 )
186222
187- if probe .ndim == 2 :
188- ax .set_aspect ("equal" )
223+ if probe .ndim == 3 :
224+ ax .set_zlim (zlims )
225+ ax .set_zlabel ("z" )
226+
227+ if probe .ndim == 2 :
228+ ax .set_aspect ("equal" )
189229
190- if title :
191- ax .set_title (probe .get_title ())
230+ if title :
231+ ax .set_title (probe .get_title ())
192232
193233 return poly , poly_contour
194234
0 commit comments