@@ -101,6 +101,7 @@ def __init__(
101101 self .probe_planar_contour = None
102102
103103 # This handles the shank id per contact
104+ # If None then one shank only
104105 self ._shank_ids = None
105106
106107 # This handles the wiring to device : channel index on device side.
@@ -112,6 +113,10 @@ def __init__(
112113 # This must be unique at Probe AND ProbeGroup level
113114 self ._contact_ids = None
114115
116+ # Handle contact side for double face probes
117+ # If None then one face only
118+ self ._contact_sides = None
119+
115120 # annotation: a dict that contains all meta information about
116121 # the probe (name, manufacturor, date of production, ...)
117122 self .annotations = dict ()
@@ -153,6 +158,10 @@ def contact_ids(self):
153158 def shank_ids (self ):
154159 return self ._shank_ids
155160
161+ @property
162+ def contact_sides (self ):
163+ return self ._contact_sides
164+
156165 @property
157166 def name (self ):
158167 return self .annotations .get ("name" , None )
@@ -237,6 +246,8 @@ def get_title(self) -> str:
237246 if self .shank_ids is not None :
238247 num_shank = self .get_shank_count ()
239248 txt += f" - { num_shank } shanks"
249+ if self ._contact_sides is not None :
250+ txt += f" - 2 sides"
240251 return txt
241252
242253 def __repr__ (self ):
@@ -291,7 +302,7 @@ def get_shank_count(self) -> int:
291302 return n
292303
293304 def set_contacts (
294- self , positions , shapes = "circle" , shape_params = {"radius" : 10 }, plane_axes = None , contact_ids = None , shank_ids = None
305+ self , positions , shapes = "circle" , shape_params = {"radius" : 10 }, plane_axes = None , contact_ids = None , shank_ids = None , contact_sides = None
295306 ):
296307 """Sets contacts to a Probe.
297308
@@ -320,16 +331,29 @@ def set_contacts(
320331 shank_ids : array[str] | None, default: None
321332 Defines the shank ids for the contacts. If None, then
322333 these are assigned to a unique Shank.
334+ contact_sides : array[str] | None, default: None
335+ If probe is double sided, defines sides by a vector of ['front' | 'back']
323336 """
324337 positions = np .array (positions )
325338 if positions .shape [1 ] != self .ndim :
326339 raise ValueError (f"positions.shape[1]: { positions .shape [1 ]} and ndim: { self .ndim } do not match!" )
327340
328- # Check for duplicate positions
329- unique_positions = np .unique (positions , axis = 0 )
330- positions_are_not_unique = unique_positions .shape [0 ] != positions .shape [0 ]
331- if positions_are_not_unique :
332- _raise_non_unique_positions_error (positions )
341+
342+ if contact_sides is None :
343+ # Check for duplicate positions
344+ unique_positions = np .unique (positions , axis = 0 )
345+ positions_are_not_unique = unique_positions .shape [0 ] != positions .shape [0 ]
346+ if positions_are_not_unique :
347+ _raise_non_unique_positions_error (positions )
348+ else :
349+ # Check for duplicate positions side by side
350+ contact_sides = np .asarray (contact_sides ).astype (str )
351+ for side in ("front" , "back" ):
352+ mask = contact_sides == "font"
353+ unique_positions = np .unique (positions [mask ], axis = 0 )
354+ positions_are_not_unique = unique_positions .shape [0 ] != positions [mask ].shape [0 ]
355+ if positions_are_not_unique :
356+ _raise_non_unique_positions_error (positions [mask ])
333357
334358 self ._contact_positions = positions
335359 n = positions .shape [0 ]
@@ -355,6 +379,15 @@ def set_contacts(
355379 self ._shank_ids = np .asarray (shank_ids ).astype (str )
356380 if self .shank_ids .size != n :
357381 raise ValueError (f"shank_ids have wrong size: { self .shanks .ids .size } != { n } " )
382+
383+ if contact_sides is None :
384+ self ._contact_sides = contact_sides
385+ else :
386+ self ._contact_sides = contact_sides
387+ if self ._contact_sides .size != n :
388+ raise ValueError (f"contact_sides have wrong size: { self ._contact_sides .ids .size } != { n } " )
389+ if not np .all (np .isin (self ._contact_sides , ["front" , "back" ])):
390+ raise ValueError (f"contact_sides must 'front' or 'back'" )
358391
359392 # shape
360393 if isinstance (shapes , str ):
@@ -592,6 +625,13 @@ def __eq__(self, other):
592625 ):
593626 return False
594627
628+ if self ._contact_sides is None :
629+ if other ._contact_sides is not None :
630+ return False
631+ else :
632+ if not np .array_equal (self ._contact_sides , other ._contact_sides ):
633+ return False
634+
595635 # Compare contact_annotations dictionaries
596636 if self .contact_annotations .keys () != other .contact_annotations .keys ():
597637 return False
@@ -842,6 +882,7 @@ def rotate_contacts(self, thetas: float | np.array[float] | list[float]):
842882 "device_channel_indices" ,
843883 "_contact_ids" ,
844884 "_shank_ids" ,
885+ "_contact_sides" ,
845886 ]
846887
847888 def to_dict (self , array_as_list : bool = False ) -> dict :
@@ -895,6 +936,9 @@ def from_dict(d: dict) -> "Probe":
895936 plane_axes = d ["contact_plane_axes" ],
896937 shapes = d ["contact_shapes" ],
897938 shape_params = d ["contact_shape_params" ],
939+ contact_ids = d .get ("contact_ids" , None ),
940+ shank_ids = d .get ("shank_ids" , None ),
941+ contact_sides = d .get ("contact_sides" , None ),
898942 )
899943
900944 v = d .get ("probe_planar_contour" , None )
@@ -905,14 +949,6 @@ def from_dict(d: dict) -> "Probe":
905949 if v is not None :
906950 probe .set_device_channel_indices (v )
907951
908- v = d .get ("shank_ids" , None )
909- if v is not None :
910- probe .set_shank_ids (v )
911-
912- v = d .get ("contact_ids" , None )
913- if v is not None :
914- probe .set_contact_ids (v )
915-
916952 if "annotations" in d :
917953 probe .annotate (** d ["annotations" ])
918954 if "contact_annotations" in d :
@@ -955,6 +991,7 @@ def to_numpy(self, complete: bool = False) -> np.array:
955991 ...
956992 ('shank_ids', 'U64'),
957993 ('contact_ids', 'U64'),
994+ ('contact_sides', 'U8'),
958995
959996 # The rest is added only if `complete=True`
960997 ('device_channel_indices', 'int64', optional),
@@ -991,6 +1028,9 @@ def to_numpy(self, complete: bool = False) -> np.array:
9911028 dtype += [(k , "float64" )]
9921029 dtype += [("shank_ids" , "U64" ), ("contact_ids" , "U64" )]
9931030
1031+ if self ._contact_sides is not None :
1032+ dtype += [("contact_sides" , "U8" ), ]
1033+
9941034 if complete :
9951035 dtype += [("device_channel_indices" , "int64" )]
9961036 dtype += [("si_units" , "U64" )]
@@ -1014,6 +1054,11 @@ def to_numpy(self, complete: bool = False) -> np.array:
10141054
10151055 arr ["shank_ids" ] = self .shank_ids
10161056
1057+ if self ._contact_sides is not None :
1058+ arr ["contact_sides" ] = self .contact_sides
1059+
1060+
1061+
10171062 if self .contact_ids is None :
10181063 arr ["contact_ids" ] = ["" ] * self .get_contact_count ()
10191064 else :
@@ -1062,6 +1107,7 @@ def from_numpy(arr: np.ndarray) -> "Probe":
10621107 "contact_shapes" ,
10631108 "shank_ids" ,
10641109 "contact_ids" ,
1110+ "contact_sides" ,
10651111 "device_channel_indices" ,
10661112 "radius" ,
10671113 "width" ,
@@ -1118,17 +1164,19 @@ def from_numpy(arr: np.ndarray) -> "Probe":
11181164 else :
11191165 plane_axes = None
11201166
1121- probe .set_contacts (positions = positions , plane_axes = plane_axes , shapes = shapes , shape_params = shape_params )
1167+
1168+ shank_ids = arr ["shank_ids" ] if "shank_ids" in fields else None
1169+ contact_sides = arr ["contact_sides" ] if "contact_sides" in fields else None
1170+
1171+ probe .set_contacts (positions = positions , plane_axes = plane_axes , shapes = shapes , shape_params = shape_params , shank_ids = shank_ids , contact_sides = contact_sides )
11221172
11231173 if "device_channel_indices" in fields :
11241174 dev_channel_indices = arr ["device_channel_indices" ]
11251175 if not np .all (dev_channel_indices == - 1 ):
11261176 probe .set_device_channel_indices (dev_channel_indices )
1127- if "shank_ids" in fields :
1128- probe .set_shank_ids (arr ["shank_ids" ])
11291177 if "contact_ids" in fields :
11301178 probe .set_contact_ids (arr ["contact_ids" ])
1131-
1179+
11321180 # contact annotations
11331181 for k in contact_annotation_fields :
11341182 probe .annotate_contacts (** {k : arr [k ]})
0 commit comments