Skip to content

Commit 9f26792

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6c32c77 commit 9f26792

File tree

4 files changed

+50
-20
lines changed

4 files changed

+50
-20
lines changed

src/probeinterface/plotting.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,14 @@ def plot_probe(
152152
import matplotlib.pyplot as plt
153153

154154
if probe.contact_sides is not None:
155-
if side is None or side not in ('front', 'back'):
156-
raise ValueError("The probe has two side, you must give which one to plot. plot_probe(probe, side='front'|'back')")
155+
if side is None or side not in ("front", "back"):
156+
raise ValueError(
157+
"The probe has two side, you must give which one to plot. plot_probe(probe, side='front'|'back')"
158+
)
157159
mask = probe.contact_sides == side
158160
probe = probe.get_slice(mask)
159161
probe._contact_sides = None
160162

161-
162163
if ax is None:
163164
if probe.ndim == 2:
164165
fig, ax = plt.subplots()

src/probeinterface/probe.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,14 @@ def get_shank_count(self) -> int:
302302
return n
303303

304304
def set_contacts(
305-
self, positions, shapes="circle", shape_params={"radius": 10}, plane_axes=None, contact_ids=None, shank_ids=None, contact_sides=None
305+
self,
306+
positions,
307+
shapes="circle",
308+
shape_params={"radius": 10},
309+
plane_axes=None,
310+
contact_ids=None,
311+
shank_ids=None,
312+
contact_sides=None,
306313
):
307314
"""Sets contacts to a Probe.
308315
@@ -338,7 +345,6 @@ def set_contacts(
338345
if positions.shape[1] != self.ndim:
339346
raise ValueError(f"positions.shape[1]: {positions.shape[1]} and ndim: {self.ndim} do not match!")
340347

341-
342348
if contact_sides is None:
343349
# Check for duplicate positions
344350
unique_positions = np.unique(positions, axis=0)
@@ -353,7 +359,7 @@ def set_contacts(
353359
unique_positions = np.unique(positions[mask], axis=0)
354360
positions_are_not_unique = unique_positions.shape[0] != positions[mask].shape[0]
355361
if positions_are_not_unique:
356-
_raise_non_unique_positions_error(positions[mask])
362+
_raise_non_unique_positions_error(positions[mask])
357363

358364
self._contact_positions = positions
359365
n = positions.shape[0]
@@ -379,7 +385,7 @@ def set_contacts(
379385
self._shank_ids = np.asarray(shank_ids).astype(str)
380386
if self.shank_ids.size != n:
381387
raise ValueError(f"shank_ids have wrong size: {self.shanks.ids.size} != {n}")
382-
388+
383389
if contact_sides is None:
384390
self._contact_sides = contact_sides
385391
else:
@@ -1029,7 +1035,9 @@ def to_numpy(self, complete: bool = False) -> np.array:
10291035
dtype += [("shank_ids", "U64"), ("contact_ids", "U64")]
10301036

10311037
if self._contact_sides is not None:
1032-
dtype += [("contact_sides", "U8"), ]
1038+
dtype += [
1039+
("contact_sides", "U8"),
1040+
]
10331041

10341042
if complete:
10351043
dtype += [("device_channel_indices", "int64")]
@@ -1057,8 +1065,6 @@ def to_numpy(self, complete: bool = False) -> np.array:
10571065
if self._contact_sides is not None:
10581066
arr["contact_sides"] = self.contact_sides
10591067

1060-
1061-
10621068
if self.contact_ids is None:
10631069
arr["contact_ids"] = [""] * self.get_contact_count()
10641070
else:
@@ -1164,19 +1170,25 @@ def from_numpy(arr: np.ndarray) -> "Probe":
11641170
else:
11651171
plane_axes = None
11661172

1167-
11681173
shank_ids = arr["shank_ids"] if "shank_ids" in fields else None
11691174
contact_sides = arr["contact_sides"] if "contact_sides" in fields else None
11701175

1171-
probe.set_contacts(positions=positions, plane_axes=plane_axes, shapes=shapes, shape_params=shape_params, shank_ids=shank_ids, contact_sides=contact_sides)
1176+
probe.set_contacts(
1177+
positions=positions,
1178+
plane_axes=plane_axes,
1179+
shapes=shapes,
1180+
shape_params=shape_params,
1181+
shank_ids=shank_ids,
1182+
contact_sides=contact_sides,
1183+
)
11721184

11731185
if "device_channel_indices" in fields:
11741186
dev_channel_indices = arr["device_channel_indices"]
11751187
if not np.all(dev_channel_indices == -1):
11761188
probe.set_device_channel_indices(dev_channel_indices)
11771189
if "contact_ids" in fields:
11781190
probe.set_contact_ids(arr["contact_ids"])
1179-
1191+
11801192
# contact annotations
11811193
for k in contact_annotation_fields:
11821194
probe.annotate_contacts(**{k: arr[k]})

tests/test_plotting.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,19 @@ def test_plot_probegroup():
5353
def test_plot_probe_two_side():
5454
probe = Probe()
5555
probe.set_contacts(
56-
positions=np.array([[0, 0], [0, 10], [0, 20],[0, 0], [0, 10], [0, 20],]),
56+
positions=np.array(
57+
[
58+
[0, 0],
59+
[0, 10],
60+
[0, 20],
61+
[0, 0],
62+
[0, 10],
63+
[0, 20],
64+
]
65+
),
5766
shapes="circle",
5867
contact_ids=["F1", "F2", "F3", "B1", "B2", "B3"],
59-
contact_sides=["front", "front", "front", "back", "back","back"]
68+
contact_sides=["front", "front", "front", "back", "back", "back"],
6069
)
6170

6271
plot_probe(probe, with_contact_id=True, side="front")

tests/test_probe.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,25 +201,33 @@ def test_double_side_probe():
201201

202202
probe = Probe()
203203
probe.set_contacts(
204-
positions=np.array([[0, 0], [0, 10], [0, 20],[0, 0], [0, 10], [0, 20],]),
204+
positions=np.array(
205+
[
206+
[0, 0],
207+
[0, 10],
208+
[0, 20],
209+
[0, 0],
210+
[0, 10],
211+
[0, 20],
212+
]
213+
),
205214
shapes="circle",
206-
contact_sides=["front", "front", "front", "back", "back","back"]
215+
contact_sides=["front", "front", "front", "back", "back", "back"],
207216
)
208217
print(probe)
209218

210219
assert "contact_sides" in probe.to_dict()
211220

212221
probe2 = Probe.from_dict(probe.to_dict())
213-
assert probe2 == probe
222+
assert probe2 == probe
214223

215224
probe3 = Probe.from_numpy(probe.to_numpy())
216-
assert probe3 == probe
225+
assert probe3 == probe
217226

218227
probe4 = Probe.from_dataframe(probe.to_dataframe())
219228
assert probe4 == probe
220229

221230

222-
223231
if __name__ == "__main__":
224232
test_probe()
225233

0 commit comments

Comments
 (0)