Skip to content

Commit 6c32c77

Browse files
committed
Two side probe integration.
1 parent ee7f0e6 commit 6c32c77

File tree

4 files changed

+117
-19
lines changed

4 files changed

+117
-19
lines changed

src/probeinterface/plotting.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def plot_probe(
102102
ylims: tuple | None = None,
103103
zlims: tuple | None = None,
104104
show_channel_on_click: bool = False,
105+
side=None,
105106
):
106107
"""Plot a Probe object.
107108
Generates a 2D or 3D axis, depending on Probe.ndim
@@ -138,6 +139,8 @@ def plot_probe(
138139
Limits for z dimension
139140
show_channel_on_click : bool, default: False
140141
If True, the channel information is shown upon click
142+
side : None | "front" | "back
143+
If the probe is two side, then the side must be given otherwise this raises an error.
141144
142145
Returns
143146
-------
@@ -148,6 +151,14 @@ def plot_probe(
148151
"""
149152
import matplotlib.pyplot as plt
150153

154+
if probe.contact_sides is not None:
155+
if side is None or side not in ('front', 'back'):
156+
raise ValueError("The probe has two side, you must give which one to plot. plot_probe(probe, side='front'|'back')")
157+
mask = probe.contact_sides == side
158+
probe = probe.get_slice(mask)
159+
probe._contact_sides = None
160+
161+
151162
if ax is None:
152163
if probe.ndim == 2:
153164
fig, ax = plt.subplots()

src/probeinterface/probe.py

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(
101101
self.probe_planar_contour = None
102102

103103
# This handles the shank id per contact
104+
# If None then one shank only
104105
self._shank_ids = None
105106

106107
# This handles the wiring to device : channel index on device side.
@@ -112,6 +113,10 @@ def __init__(
112113
# This must be unique at Probe AND ProbeGroup level
113114
self._contact_ids = None
114115

116+
# Handle contact side for double face probes
117+
# If None then one face only
118+
self._contact_sides = None
119+
115120
# annotation: a dict that contains all meta information about
116121
# the probe (name, manufacturor, date of production, ...)
117122
self.annotations = dict()
@@ -153,6 +158,10 @@ def contact_ids(self):
153158
def shank_ids(self):
154159
return self._shank_ids
155160

161+
@property
162+
def contact_sides(self):
163+
return self._contact_sides
164+
156165
@property
157166
def name(self):
158167
return self.annotations.get("name", None)
@@ -237,6 +246,8 @@ def get_title(self) -> str:
237246
if self.shank_ids is not None:
238247
num_shank = self.get_shank_count()
239248
txt += f" - {num_shank}shanks"
249+
if self._contact_sides is not None:
250+
txt += f" - 2 sides"
240251
return txt
241252

242253
def __repr__(self):
@@ -291,7 +302,7 @@ def get_shank_count(self) -> int:
291302
return n
292303

293304
def set_contacts(
294-
self, positions, shapes="circle", shape_params={"radius": 10}, plane_axes=None, contact_ids=None, shank_ids=None
305+
self, positions, shapes="circle", shape_params={"radius": 10}, plane_axes=None, contact_ids=None, shank_ids=None, contact_sides=None
295306
):
296307
"""Sets contacts to a Probe.
297308
@@ -320,16 +331,29 @@ def set_contacts(
320331
shank_ids : array[str] | None, default: None
321332
Defines the shank ids for the contacts. If None, then
322333
these are assigned to a unique Shank.
334+
contact_sides : array[str] | None, default: None
335+
If probe is double sided, defines sides by a vector of ['front' | 'back']
323336
"""
324337
positions = np.array(positions)
325338
if positions.shape[1] != self.ndim:
326339
raise ValueError(f"positions.shape[1]: {positions.shape[1]} and ndim: {self.ndim} do not match!")
327340

328-
# Check for duplicate positions
329-
unique_positions = np.unique(positions, axis=0)
330-
positions_are_not_unique = unique_positions.shape[0] != positions.shape[0]
331-
if positions_are_not_unique:
332-
_raise_non_unique_positions_error(positions)
341+
342+
if contact_sides is None:
343+
# Check for duplicate positions
344+
unique_positions = np.unique(positions, axis=0)
345+
positions_are_not_unique = unique_positions.shape[0] != positions.shape[0]
346+
if positions_are_not_unique:
347+
_raise_non_unique_positions_error(positions)
348+
else:
349+
# Check for duplicate positions side by side
350+
contact_sides = np.asarray(contact_sides).astype(str)
351+
for side in ("front", "back"):
352+
mask = contact_sides == "font"
353+
unique_positions = np.unique(positions[mask], axis=0)
354+
positions_are_not_unique = unique_positions.shape[0] != positions[mask].shape[0]
355+
if positions_are_not_unique:
356+
_raise_non_unique_positions_error(positions[mask])
333357

334358
self._contact_positions = positions
335359
n = positions.shape[0]
@@ -355,6 +379,15 @@ def set_contacts(
355379
self._shank_ids = np.asarray(shank_ids).astype(str)
356380
if self.shank_ids.size != n:
357381
raise ValueError(f"shank_ids have wrong size: {self.shanks.ids.size} != {n}")
382+
383+
if contact_sides is None:
384+
self._contact_sides = contact_sides
385+
else:
386+
self._contact_sides = contact_sides
387+
if self._contact_sides.size != n:
388+
raise ValueError(f"contact_sides have wrong size: {self._contact_sides.ids.size} != {n}")
389+
if not np.all(np.isin(self._contact_sides, ["front", "back"])):
390+
raise ValueError(f"contact_sides must 'front' or 'back'")
358391

359392
# shape
360393
if isinstance(shapes, str):
@@ -592,6 +625,13 @@ def __eq__(self, other):
592625
):
593626
return False
594627

628+
if self._contact_sides is None:
629+
if other._contact_sides is not None:
630+
return False
631+
else:
632+
if not np.array_equal(self._contact_sides, other._contact_sides):
633+
return False
634+
595635
# Compare contact_annotations dictionaries
596636
if self.contact_annotations.keys() != other.contact_annotations.keys():
597637
return False
@@ -842,6 +882,7 @@ def rotate_contacts(self, thetas: float | np.array[float] | list[float]):
842882
"device_channel_indices",
843883
"_contact_ids",
844884
"_shank_ids",
885+
"_contact_sides",
845886
]
846887

847888
def to_dict(self, array_as_list: bool = False) -> dict:
@@ -895,6 +936,9 @@ def from_dict(d: dict) -> "Probe":
895936
plane_axes=d["contact_plane_axes"],
896937
shapes=d["contact_shapes"],
897938
shape_params=d["contact_shape_params"],
939+
contact_ids=d.get("contact_ids", None),
940+
shank_ids=d.get("shank_ids", None),
941+
contact_sides=d.get("contact_sides", None),
898942
)
899943

900944
v = d.get("probe_planar_contour", None)
@@ -905,14 +949,6 @@ def from_dict(d: dict) -> "Probe":
905949
if v is not None:
906950
probe.set_device_channel_indices(v)
907951

908-
v = d.get("shank_ids", None)
909-
if v is not None:
910-
probe.set_shank_ids(v)
911-
912-
v = d.get("contact_ids", None)
913-
if v is not None:
914-
probe.set_contact_ids(v)
915-
916952
if "annotations" in d:
917953
probe.annotate(**d["annotations"])
918954
if "contact_annotations" in d:
@@ -955,6 +991,7 @@ def to_numpy(self, complete: bool = False) -> np.array:
955991
...
956992
('shank_ids', 'U64'),
957993
('contact_ids', 'U64'),
994+
('contact_sides', 'U8'),
958995
959996
# The rest is added only if `complete=True`
960997
('device_channel_indices', 'int64', optional),
@@ -991,6 +1028,9 @@ def to_numpy(self, complete: bool = False) -> np.array:
9911028
dtype += [(k, "float64")]
9921029
dtype += [("shank_ids", "U64"), ("contact_ids", "U64")]
9931030

1031+
if self._contact_sides is not None:
1032+
dtype += [("contact_sides", "U8"), ]
1033+
9941034
if complete:
9951035
dtype += [("device_channel_indices", "int64")]
9961036
dtype += [("si_units", "U64")]
@@ -1014,6 +1054,11 @@ def to_numpy(self, complete: bool = False) -> np.array:
10141054

10151055
arr["shank_ids"] = self.shank_ids
10161056

1057+
if self._contact_sides is not None:
1058+
arr["contact_sides"] = self.contact_sides
1059+
1060+
1061+
10171062
if self.contact_ids is None:
10181063
arr["contact_ids"] = [""] * self.get_contact_count()
10191064
else:
@@ -1062,6 +1107,7 @@ def from_numpy(arr: np.ndarray) -> "Probe":
10621107
"contact_shapes",
10631108
"shank_ids",
10641109
"contact_ids",
1110+
"contact_sides",
10651111
"device_channel_indices",
10661112
"radius",
10671113
"width",
@@ -1118,17 +1164,19 @@ def from_numpy(arr: np.ndarray) -> "Probe":
11181164
else:
11191165
plane_axes = None
11201166

1121-
probe.set_contacts(positions=positions, plane_axes=plane_axes, shapes=shapes, shape_params=shape_params)
1167+
1168+
shank_ids = arr["shank_ids"] if "shank_ids" in fields else None
1169+
contact_sides = arr["contact_sides"] if "contact_sides" in fields else None
1170+
1171+
probe.set_contacts(positions=positions, plane_axes=plane_axes, shapes=shapes, shape_params=shape_params, shank_ids=shank_ids, contact_sides=contact_sides)
11221172

11231173
if "device_channel_indices" in fields:
11241174
dev_channel_indices = arr["device_channel_indices"]
11251175
if not np.all(dev_channel_indices == -1):
11261176
probe.set_device_channel_indices(dev_channel_indices)
1127-
if "shank_ids" in fields:
1128-
probe.set_shank_ids(arr["shank_ids"])
11291177
if "contact_ids" in fields:
11301178
probe.set_contact_ids(arr["contact_ids"])
1131-
1179+
11321180
# contact annotations
11331181
for k in contact_annotation_fields:
11341182
probe.annotate_contacts(**{k: arr[k]})

tests/test_plotting.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,21 @@ def test_plot_probegroup():
5050
plot_probegroup(probegroup_3d, same_axes=True)
5151

5252

53+
def test_plot_probe_two_side():
54+
probe = Probe()
55+
probe.set_contacts(
56+
positions=np.array([[0, 0], [0, 10], [0, 20],[0, 0], [0, 10], [0, 20],]),
57+
shapes="circle",
58+
contact_ids=["F1", "F2", "F3", "B1", "B2", "B3"],
59+
contact_sides=["front", "front", "front", "back", "back","back"]
60+
)
61+
62+
plot_probe(probe, with_contact_id=True, side="front")
63+
plot_probe(probe, with_contact_id=True, side="back")
64+
65+
5366
if __name__ == "__main__":
54-
test_plot_probe()
67+
# test_plot_probe()
5568
# test_plot_probe_group()
69+
test_plot_probe_two_side()
5670
plt.show()

tests/test_probe.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,34 @@ def test_position_uniqueness():
197197
probe.set_contacts(positions=positions_with_dups, shapes="circle", shape_params={"radius": 5})
198198

199199

200+
def test_double_side_probe():
201+
202+
probe = Probe()
203+
probe.set_contacts(
204+
positions=np.array([[0, 0], [0, 10], [0, 20],[0, 0], [0, 10], [0, 20],]),
205+
shapes="circle",
206+
contact_sides=["front", "front", "front", "back", "back","back"]
207+
)
208+
print(probe)
209+
210+
assert "contact_sides" in probe.to_dict()
211+
212+
probe2 = Probe.from_dict(probe.to_dict())
213+
assert probe2 == probe
214+
215+
probe3 = Probe.from_numpy(probe.to_numpy())
216+
assert probe3 == probe
217+
218+
probe4 = Probe.from_dataframe(probe.to_dataframe())
219+
assert probe4 == probe
220+
221+
222+
200223
if __name__ == "__main__":
201224
test_probe()
202225

203226
tmp_path = Path("tmp")
204227
tmp_path.mkdir(exist_ok=True)
205228
test_save_to_zarr(tmp_path)
229+
230+
test_double_side_probe()

0 commit comments

Comments
 (0)