1212from .utils import get_auto_lims
1313
1414
15+ def create_probe_polygons (
16+ probe ,
17+ contacts_colors : list | None = None ,
18+ contacts_values : np .ndarray | None = None ,
19+ cmap : str = "viridis" ,
20+ contacts_kargs : dict = {},
21+ probe_shape_kwargs : dict = {},
22+ ):
23+ """Create PolyCollection objects for a Probe.
24+
25+ Parameters
26+ ----------
27+ probe : Probe
28+ The probe object
29+ contacts_colors : matplotlib color | None, default: None
30+ The color of the contacts
31+ contacts_values : np.ndarray | None, default: None
32+ Values to color the contacts with
33+ cmap : str, default: "viridis"
34+ A colormap color
35+ contacts_kargs : dict, default: {}
36+ Dict with kwargs for contacts (e.g. alpha, edgecolor, lw)
37+ probe_shape_kwargs : dict, default: {}
38+ Dict with kwargs for probe shape (e.g. alpha, edgecolor, lw)
39+
40+ Returns
41+ -------
42+ poly : PolyCollection
43+ The polygon collection for contacts
44+ poly_contour : PolyCollection | None
45+ The polygon collection for the probe shape
46+ """
47+ if probe .ndim == 2 :
48+ from matplotlib .collections import PolyCollection
49+
50+ Collection = PolyCollection
51+ elif probe .ndim == 3 :
52+ from mpl_toolkits .mplot3d .art3d import Poly3DCollection
53+
54+ Collection = Poly3DCollection
55+ else :
56+ raise ValueError (f"Unexpected probe.ndim: { probe .ndim } " )
57+
58+ _probe_shape_kwargs = dict (facecolor = "green" , edgecolor = "k" , lw = 0.5 , alpha = 0.3 )
59+ _probe_shape_kwargs .update (probe_shape_kwargs )
60+
61+ _contacts_kargs = dict (alpha = 0.7 , edgecolor = [0.3 , 0.3 , 0.3 ], lw = 0.5 )
62+ _contacts_kargs .update (contacts_kargs )
63+
64+ n = probe .get_contact_count ()
65+
66+ if contacts_colors is None and contacts_values is None :
67+ contacts_colors = ["orange" ] * n
68+ elif contacts_colors is not None :
69+ contacts_colors = contacts_colors
70+ elif contacts_values is not None :
71+ contacts_colors = None
72+
73+ vertices = probe .get_contact_vertices ()
74+ poly = Collection (vertices , color = contacts_colors , ** _contacts_kargs )
75+
76+ if contacts_values is not None :
77+ poly .set_array (contacts_values )
78+ poly .set_cmap (cmap )
79+
80+ # probe shape
81+ poly_contour = None
82+ planar_contour = probe .probe_planar_contour
83+ if planar_contour is not None :
84+ poly_contour = Collection ([planar_contour ], ** _probe_shape_kwargs )
85+
86+ return poly , poly_contour
87+
88+
1589def plot_probe (
1690 probe ,
1791 ax = None ,
@@ -74,11 +148,6 @@ def plot_probe(
74148 """
75149 import matplotlib .pyplot as plt
76150
77- if probe .ndim == 2 :
78- from matplotlib .collections import PolyCollection
79- elif probe .ndim == 3 :
80- from mpl_toolkits .mplot3d .art3d import Poly3DCollection
81-
82151 if ax is None :
83152 if probe .ndim == 2 :
84153 fig , ax = plt .subplots ()
@@ -89,32 +158,25 @@ def plot_probe(
89158 else :
90159 fig = ax .get_figure ()
91160
92- _probe_shape_kwargs = dict (facecolor = "green" , edgecolor = "k" , lw = 0.5 , alpha = 0.3 )
93- _probe_shape_kwargs .update (probe_shape_kwargs )
94-
95- _contacts_kargs = dict (alpha = 0.7 , edgecolor = [0.3 , 0.3 , 0.3 ], lw = 0.5 )
96- _contacts_kargs .update (contacts_kargs )
97-
98- n = probe .get_contact_count ()
99-
100- if contacts_colors is None and contacts_values is None :
101- contacts_colors = ["orange" ] * n
102- elif contacts_colors is not None :
103- contacts_colors = contacts_colors
104- elif contacts_values is not None :
105- contacts_colors = None
161+ # Create collections (contacts, probe shape)
162+ poly , poly_contour = create_probe_polygons (
163+ probe ,
164+ contacts_colors = contacts_colors ,
165+ contacts_values = contacts_values ,
166+ cmap = cmap ,
167+ contacts_kargs = contacts_kargs ,
168+ probe_shape_kwargs = probe_shape_kwargs ,
169+ )
106170
107- vertices = probe . get_contact_vertices ()
171+ # Add collections to the axis
108172 if probe .ndim == 2 :
109- poly = PolyCollection (vertices , color = contacts_colors , ** _contacts_kargs )
110173 ax .add_collection (poly )
174+ if poly_contour is not None :
175+ ax .add_collection (poly_contour )
111176 elif probe .ndim == 3 :
112- poly = Poly3DCollection (vertices , color = contacts_colors , ** _contacts_kargs )
113177 ax .add_collection3d (poly )
114-
115- if contacts_values is not None :
116- poly .set_array (contacts_values )
117- poly .set_cmap (cmap )
178+ if poly_contour is not None :
179+ ax .add_collection3d (poly_contour )
118180
119181 if show_channel_on_click :
120182 assert probe .ndim == 2 , "show_channel_on_click works only for ndim=2"
@@ -125,22 +187,11 @@ def on_press(event):
125187 fig .canvas .mpl_connect ("button_press_event" , on_press )
126188 fig .canvas .mpl_connect ("button_release_event" , on_release )
127189
128- # probe shape
129- planar_contour = probe .probe_planar_contour
130- if planar_contour is not None :
131- if probe .ndim == 2 :
132- poly_contour = PolyCollection ([planar_contour ], ** _probe_shape_kwargs )
133- ax .add_collection (poly_contour )
134- elif probe .ndim == 3 :
135- poly_contour = Poly3DCollection ([planar_contour ], ** _probe_shape_kwargs )
136- ax .add_collection3d (poly_contour )
137- else :
138- poly_contour = None
139-
140190 if text_on_contact is not None :
141191 text_on_contact = np .asarray (text_on_contact )
142192 assert text_on_contact .size == probe .get_contact_count ()
143193
194+ n = probe .get_contact_count ()
144195 if with_contact_id or with_device_index or text_on_contact is not None :
145196 if probe .ndim == 3 :
146197 raise NotImplementedError ("Channel index is 2d only" )
0 commit comments